diff --git a/giskard/cli.py b/giskard/cli.py index e91735a..c46f44c 100644 --- a/giskard/cli.py +++ b/giskard/cli.py @@ -1,7 +1,7 @@ import asyncio import logging import os - +from urllib.parse import urlparse import click import lockfile import psutil @@ -38,20 +38,23 @@ def worker() -> None: def start_stop_options(fn): fn = click.option( - "--host", "-h", type=STRING, default='localhost', help="Remote Giskard host address to connect to" + "--host", "-h", type=STRING, default='http://localhost:19000', help="Remote Giskard host address to connect to" )(fn) fn = click.option( - "--server", "-s", "is_server", is_flag=True, default=False, + "--server", "-s", "server_instance", type=STRING, help="Server mode. Used by Giskard embedded ML Worker" )(fn) fn = click.option( "--port", "-p", type=INT, - default=40051, help="Remote Giskard port accepting external ML Worker connections", )(fn) + fn = click.option( + "--silent", "is_silent", is_flag=True, default=False, + help="If true, this option will not ask you to update your git test repository" + )(fn) fn = click.option( "--verbose", "-v", @@ -75,7 +78,10 @@ def start_stop_options(fn): default=False, help="Should ML Worker be started as a Daemon in a background", ) -def start_command(host, port, is_server, is_daemon): +@click.option( + '--token', prompt="Please enter an API Access Token" +) +def start_command(token, host, port, server_instance, is_daemon, is_silent): """\b Start ML Worker. @@ -86,17 +92,17 @@ def start_command(host, port, is_server, is_daemon): - client: ML Worker acts as a client and should connect to a running Giskard instance by specifying this instance's host and port. """ - - _start_command(is_server, host, port, is_daemon) + _start_command(token, server_instance, is_silent, host, port, is_daemon) -def _start_command(is_server, host, port, is_daemon): +def _start_command(token, server_instance, is_silent, host, port, is_daemon): + host_name = urlparse(host).hostname start_msg = "Starting ML Worker" - start_msg += " server" if is_server else " client" + start_msg += " server" if server_instance is not None else " client" if is_daemon: start_msg += " daemon" logger.info(start_msg) - pid_file_path = create_pid_file_path(is_server, host, port) + pid_file_path = create_pid_file_path(server_instance, host_name, port) pid_file = PIDLockFile(pid_file_path) remove_stale_pid_file(pid_file) try: @@ -104,17 +110,17 @@ def _start_command(is_server, host, port, is_daemon): if is_daemon: # Releasing the lock because it will be re-acquired by a daemon process pid_file.release() - run_daemon(is_server, host, port) + run_daemon(server_instance, host_name, port) else: loop = asyncio.new_event_loop() - loop.create_task(start_ml_worker(is_server, host, port)) + loop.create_task(start_ml_worker(server_instance, is_silent, host, port, token)) loop.run_forever() except KeyboardInterrupt: logger.info("Exiting") except lockfile.AlreadyLocked: existing_pid = read_pid_from_pidfile(pid_file_path) logger.warning( - f"Another ML Worker {_ml_worker_description(is_server, host, port)} " + f"Another ML Worker {_ml_worker_description(server_instance, host, port)} " f"is already running with PID: {existing_pid}. " f"Not starting a new one." ) @@ -123,8 +129,8 @@ def _start_command(is_server, host, port, is_daemon): pid_file.release() -def _ml_worker_description(is_server, host, port): - return "server" if is_server else f"client for {host or 'localhost'}:{port or ''}" +def _ml_worker_description(server_instance, host, port): + return "server" if server_instance is not None else f"client for {host or 'localhost'}:{port or ''}" @worker.command("stop", help="Stop running ML Workers") @@ -132,7 +138,8 @@ def _ml_worker_description(is_server, host, port): @click.option( "--all", "-a", "stop_all", is_flag=True, default=False, help="Stop all running ML Workers" ) -def stop_command(is_server, host, port, stop_all): +def stop_command(server_instance, host, port, stop_all): + host = urlparse(host).hostname import re if stop_all: @@ -141,14 +148,14 @@ def stop_command(is_server, host, port, stop_all): continue _stop_pid_fname(pid_fname) else: - _find_and_stop(is_server, host, port) + _find_and_stop(server_instance, host, port) @worker.command("restart", help="Restart ML Worker") @start_stop_options -def restart_command(is_server, host, port): - _find_and_stop(is_server, host, port) - _start_command(is_server, host, port, is_daemon=True) +def restart_command(server_instance, is_silent, host, port): + _find_and_stop(server_instance, host, port) + _start_command(server_instance, is_silent, host, port, is_daemon=True) def _stop_pid_fname(pid_fname): @@ -162,17 +169,17 @@ def _stop_pid_fname(pid_fname): remove_existing_pidfile(pid_file_path) -def _find_and_stop(is_server, host, port): - pid_file_path = str(create_pid_file_path(is_server, host, port)) +def _find_and_stop(server_instance, host, port): + pid_file_path = str(create_pid_file_path(server_instance, host, port)) remove_stale_pid_file(PIDLockFile(pid_file_path)) pid = read_pid_from_pidfile(pid_file_path) logger.info("Stopping ML Worker Daemon") if pid: worker_process = psutil.Process(pid) worker_process.terminate() - logger.info(f"Stopped ML Worker {_ml_worker_description(is_server, host, port)}") + logger.info(f"Stopped ML Worker {_ml_worker_description(server_instance, host, port)}") else: - logger.info(f"ML Worker {_ml_worker_description(is_server, host, port)} is not running") + logger.info(f"ML Worker {_ml_worker_description(server_instance, host, port)} is not running") remove_existing_pidfile(pid_file_path) diff --git a/giskard/ml_worker/ml_worker.py b/giskard/ml_worker/ml_worker.py index 0b4c572..632e946 100644 --- a/giskard/ml_worker/ml_worker.py +++ b/giskard/ml_worker/ml_worker.py @@ -1,9 +1,15 @@ import asyncio import logging +import os +from pathlib import Path +import requests import grpc -from giskard.settings import settings +from urllib.parse import urlparse + +from giskard.settings import settings, expand_env_var +from giskard.ml_worker.testing.git_testing_repository import clone_git_testing_repository logger = logging.getLogger(__name__) @@ -30,16 +36,31 @@ async def _start_grpc_server(is_server=False): return server, port -async def start_ml_worker(is_server=False, remote_host=None, remote_port=None): +async def start_ml_worker(server_instance=None, is_silent=False,remote_host=None, remote_port=None, token=None): from giskard.ml_worker.bridge.ml_worker_bridge import MLWorkerBridge - + url = f'{remote_host}/api/v2/settings/general' + host_name = urlparse(remote_host).hostname + res = requests.get(url, headers={"Authorization": f"Bearer {token}"}) + if res.status_code == 401: + raise Exception("Invalid API Token") + if res.status_code != 200: + raise Exception("Failed to connect to Giskard Instance") + res_json = res.json() + if remote_port is None: + remote_port = res_json['externalMlWorkerEntrypointPort'] + if server_instance is None: + clone_git_testing_repository(res_json['generalSettings']['instanceId'], is_silent, remote_host) + else: + instance_path = Path(expand_env_var(settings.home)) / server_instance + os.makedirs(instance_path, exist_ok=True) + return tasks = [] - server, grpc_server_port = await _start_grpc_server(is_server) - if not is_server: + server, grpc_server_port = await _start_grpc_server(server_instance is not None) + if server_instance is None: logger.info( "Remote server host and port are specified, connecting as an external ML Worker" ) - tunnel = MLWorkerBridge(grpc_server_port, remote_host, remote_port) + tunnel = MLWorkerBridge(grpc_server_port, host_name, remote_port) tasks.append(asyncio.create_task(tunnel.start())) tasks.append(asyncio.create_task(server.wait_for_termination())) diff --git a/giskard/ml_worker/testing/git_testing_repository.py b/giskard/ml_worker/testing/git_testing_repository.py new file mode 100644 index 0000000..ec0801d --- /dev/null +++ b/giskard/ml_worker/testing/git_testing_repository.py @@ -0,0 +1,43 @@ +import os +import logging +from pathlib import Path + +import click +from git import Repo +from git import GitCommandError + +from giskard.settings import settings, expand_env_var + +logger = logging.getLogger(__name__) + + +def clone_git_testing_repository(instance_id: int, is_silent: bool, remote_host: str): + instance_path = Path(expand_env_var(settings.home)) / str(instance_id) + os.makedirs(instance_path, exist_ok=True) + repo_path = f'{instance_path}/project' + if not os.path.isdir(repo_path): + git_repo_path = f'{remote_host}/repository.git' + Repo.clone_from(git_repo_path, to_path=f'{instance_path}/project') + logger.info(f'Git testing repo cloned in {repo_path} You can now add some tests') + else: + repo = Repo(repo_path) + repo.remotes.origin.fetch() + actual_branch = repo.active_branch + num_behind = None + try: + commits_diff = repo.git.rev_list('--left-right', '--count', f'{actual_branch}...{actual_branch}@{{u}}') + _, num_behind = commits_diff.split('\t') + except GitCommandError: + repo.remotes.origin.pull() + if is_silent: + logger.info( + f'You are currently on branch {actual_branch} and you are {num_behind} commits behind.') + else: + if int(num_behind) > 0: + validation = click.confirm( + f'You are currently on branch {actual_branch} and you are {num_behind} commits behind.' + f'Do you want to pull?') + if validation: + repo.remotes.origin.pull() + else: + logger.info(f'You are currently on branch {actual_branch} up to date') diff --git a/pyproject.toml b/pyproject.toml index 37b92e3..4934d1d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ scikit-learn = ">=1.0.0,<1.1.0" mixpanel = "^4.9.0" beautifulsoup4 = "^4.11.1" eli5 = "^0.13.0" +git = "3.1.29" grpcio = "^1.46.3" grpcio-status = "^1.46.3" protobuf = "^3.9.2"