Skip to content
This repository was archived by the owner on May 7, 2023. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 17 additions & 12 deletions giskard/cli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import logging
import os

from urllib.parse import urlparse
import click
import lockfile
import psutil
Expand Down Expand Up @@ -38,7 +38,7 @@ def worker() -> None:

def start_stop_options(fn):
fn = click.option(
"--host", "-h", type=STRING, default='localhost', help="Remote Giskard host address to connect to"
"--host", "-h", type=STRING, default='http://localhost:19000', help="Remote Giskard host address to connect to"
)(fn)

fn = click.option(
Expand All @@ -49,9 +49,12 @@ def start_stop_options(fn):
"--port",
"-p",
type=INT,
default=40051,
help="Remote Giskard port accepting external ML Worker connections",
)(fn)
fn = click.option(
"--silent", "is_silent", is_flag=True, default=False,
help="If true, this option will not ask you to update your git test repository"
)(fn)
fn = click.option(
"--verbose",
"-v",
Expand All @@ -75,7 +78,7 @@ def start_stop_options(fn):
default=False,
help="Should ML Worker be started as a Daemon in a background",
)
def start_command(host, port, is_server, is_daemon):
def start_command(host, port, is_server, is_daemon, is_silent):
"""\b
Start ML Worker.

Expand All @@ -86,28 +89,29 @@ def start_command(host, port, is_server, is_daemon):
- client: ML Worker acts as a client and should connect to a running Giskard instance
by specifying this instance's host and port.
"""

_start_command(is_server, host, port, is_daemon)
_start_command(is_server, is_silent, host, port, is_daemon)


def _start_command(is_server, host, port, is_daemon):
def _start_command(is_server, is_silent, host, port, is_daemon):
token = click.prompt("Please enter an API Access Token")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of prompting this value directly it's better to declare it as another cli option with a prompt property
like so
https://click.palletsprojects.com/en/8.1.x/options/#prompting

host_name = urlparse(host).hostname
start_msg = "Starting ML Worker"
start_msg += " server" if is_server else " client"
if is_daemon:
start_msg += " daemon"
logger.info(start_msg)
pid_file_path = create_pid_file_path(is_server, host, port)
pid_file_path = create_pid_file_path(is_server, host_name, port)
pid_file = PIDLockFile(pid_file_path)
remove_stale_pid_file(pid_file)
try:
pid_file.acquire()
if is_daemon:
# Releasing the lock because it will be re-acquired by a daemon process
pid_file.release()
run_daemon(is_server, host, port)
run_daemon(is_server, host_name, port)
else:
loop = asyncio.new_event_loop()
loop.create_task(start_ml_worker(is_server, host, port))
loop.create_task(start_ml_worker(is_server, is_silent, host, port, token))
loop.run_forever()
except KeyboardInterrupt:
logger.info("Exiting")
Expand All @@ -133,6 +137,7 @@ def _ml_worker_description(is_server, host, port):
"--all", "-a", "stop_all", is_flag=True, default=False, help="Stop all running ML Workers"
)
def stop_command(is_server, host, port, stop_all):
host = urlparse(host).hostname
import re

if stop_all:
Expand All @@ -146,9 +151,9 @@ def stop_command(is_server, host, port, stop_all):

@worker.command("restart", help="Restart ML Worker")
@start_stop_options
def restart_command(is_server, host, port):
def restart_command(is_server, is_silent, host, port):
_find_and_stop(is_server, host, port)
_start_command(is_server, host, port, is_daemon=True)
_start_command(is_server, is_silent, host, port, is_daemon=True)


def _stop_pid_fname(pid_fname):
Expand Down
19 changes: 15 additions & 4 deletions giskard/ml_worker/ml_worker.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import asyncio
import logging

import requests
import grpc

from urllib.parse import urlparse

from giskard.settings import settings
from giskard.ml_worker.testing.git_testing_repository import clone_git_testing_repository

logger = logging.getLogger(__name__)

Expand All @@ -30,16 +33,24 @@ async def _start_grpc_server(is_server=False):
return server, port


async def start_ml_worker(is_server=False, remote_host=None, remote_port=None):
async def start_ml_worker(is_server=False, is_silent=False,remote_host=None, remote_port=None, token=None):
from giskard.ml_worker.bridge.ml_worker_bridge import MLWorkerBridge

url = f'{remote_host}/api/v2/settings/general'
host_name = urlparse(remote_host).hostname
res = requests.get(url, headers={"Authorization": f"Bearer {token}"})
if res.status_code == 401:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After this check we have to verify that the status code is actually 200 and raise an exception if it's not like "Failed to connect to Giskard instance"

raise Exception("Wrong Token") # Not shure of what exception
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Invalid API Token" seems like a good option

res_json = res.json()
if remote_port is None:
remote_port = res_json['externalMlWorkerEntrypointPort']
clone_git_testing_repository(res_json['generalSettings']['instanceId'], is_silent)
tasks = []
server, grpc_server_port = await _start_grpc_server(is_server)
if not is_server:
logger.info(
"Remote server host and port are specified, connecting as an external ML Worker"
)
tunnel = MLWorkerBridge(grpc_server_port, remote_host, remote_port)
tunnel = MLWorkerBridge(grpc_server_port, host_name, remote_port)
tasks.append(asyncio.create_task(tunnel.start()))

tasks.append(asyncio.create_task(server.wait_for_termination()))
Expand Down
36 changes: 36 additions & 0 deletions giskard/ml_worker/testing/git_testing_repository.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import os
import logging
import click
from git import Repo
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't forget to add git to pyproject.toml


from giskard.settings import settings
logger = logging.getLogger(__name__)


def clone_git_testing_repository(instance_id: int, is_silent: bool):
instance_path = os.path.expanduser(f'{settings.home}/{str(instance_id)}')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead do:

Path(expand_env_var(settings.home)) / str(instance_id)

in this case it it won't rely on unix forward slash

os.makedirs(instance_path, exist_ok=True)
repo_path = f'{instance_path}/project'
if not os.path.isdir(repo_path):
Repo.clone_from('http://localhost:3000/repository.git', to_path=f'{instance_path}/project')
logger.info(f'Git testing repo cloned in {repo_path} You can now add some tests')
else:
repo = Repo(repo_path)
repo.remotes.origin.fetch()
actual_branch = repo.active_branch
commits_behind = repo.iter_commits(f'{actual_branch}...origin/main')
nb_commit_behind = sum(1 for _ in commits_behind)
if is_silent:
logger.info(f'You are currently on branch {actual_branch} and you are {nb_commit_behind} commits behind the origin/main.')
else:
if nb_commit_behind > 0:
validation = click.confirm(f'You are currently on branch {actual_branch} and you are {nb_commit_behind} commits behind the origin/main. Do you want to pull from origin/main?')
if validation:
repo.remotes.origin.pull('main')
else:
logger.info(f'You are currently on branch {actual_branch} up to date')