-
Notifications
You must be signed in to change notification settings - Fork 1
git repository #36
base: main
Are you sure you want to change the base?
git repository #36
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,9 +1,12 @@ | ||
| import asyncio | ||
| import logging | ||
|
|
||
| import requests | ||
| import grpc | ||
|
|
||
| from urllib.parse import urlparse | ||
|
|
||
| from giskard.settings import settings | ||
| from giskard.ml_worker.testing.git_testing_repository import clone_git_testing_repository | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
@@ -30,16 +33,24 @@ async def _start_grpc_server(is_server=False): | |
| return server, port | ||
|
|
||
|
|
||
| async def start_ml_worker(is_server=False, remote_host=None, remote_port=None): | ||
| async def start_ml_worker(is_server=False, is_silent=False,remote_host=None, remote_port=None, token=None): | ||
| from giskard.ml_worker.bridge.ml_worker_bridge import MLWorkerBridge | ||
|
|
||
| url = f'{remote_host}/api/v2/settings/general' | ||
| host_name = urlparse(remote_host).hostname | ||
| res = requests.get(url, headers={"Authorization": f"Bearer {token}"}) | ||
| if res.status_code == 401: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After this check we have to verify that the status code is actually 200 and raise an exception if it's not like "Failed to connect to Giskard instance" |
||
| raise Exception("Wrong Token") # Not shure of what exception | ||
|
||
| res_json = res.json() | ||
| if remote_port is None: | ||
| remote_port = res_json['externalMlWorkerEntrypointPort'] | ||
| clone_git_testing_repository(res_json['generalSettings']['instanceId'], is_silent) | ||
| tasks = [] | ||
| server, grpc_server_port = await _start_grpc_server(is_server) | ||
| if not is_server: | ||
| logger.info( | ||
| "Remote server host and port are specified, connecting as an external ML Worker" | ||
| ) | ||
| tunnel = MLWorkerBridge(grpc_server_port, remote_host, remote_port) | ||
| tunnel = MLWorkerBridge(grpc_server_port, host_name, remote_port) | ||
| tasks.append(asyncio.create_task(tunnel.start())) | ||
|
|
||
| tasks.append(asyncio.create_task(server.wait_for_termination())) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| import os | ||
| import logging | ||
| import click | ||
| from git import Repo | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't forget to add |
||
|
|
||
| from giskard.settings import settings | ||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def clone_git_testing_repository(instance_id: int, is_silent: bool): | ||
| instance_path = os.path.expanduser(f'{settings.home}/{str(instance_id)}') | ||
|
||
| os.makedirs(instance_path, exist_ok=True) | ||
| repo_path = f'{instance_path}/project' | ||
| if not os.path.isdir(repo_path): | ||
| Repo.clone_from('http://localhost:3000/repository.git', to_path=f'{instance_path}/project') | ||
| logger.info(f'Git testing repo cloned in {repo_path} You can now add some tests') | ||
| else: | ||
| repo = Repo(repo_path) | ||
| repo.remotes.origin.fetch() | ||
| actual_branch = repo.active_branch | ||
| commits_behind = repo.iter_commits(f'{actual_branch}...origin/main') | ||
| nb_commit_behind = sum(1 for _ in commits_behind) | ||
| if is_silent: | ||
| logger.info(f'You are currently on branch {actual_branch} and you are {nb_commit_behind} commits behind the origin/main.') | ||
| else: | ||
| if nb_commit_behind > 0: | ||
| validation = click.confirm(f'You are currently on branch {actual_branch} and you are {nb_commit_behind} commits behind the origin/main. Do you want to pull from origin/main?') | ||
| if validation: | ||
| repo.remotes.origin.pull('main') | ||
| else: | ||
| logger.info(f'You are currently on branch {actual_branch} up to date') | ||
|
|
||
|
|
||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of prompting this value directly it's better to declare it as another cli option with a
promptpropertylike so
https://click.palletsprojects.com/en/8.1.x/options/#prompting