Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions giskard/ml_worker/ml_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from giskard.ml_worker.stomp.constants import HeaderType
from giskard.ml_worker.stomp.parsing import Frame, StompFrame
from giskard.ml_worker.websocket.action import ActionPayload, ConfigPayload, MLWorkerAction
from giskard.ml_worker.websocket.listener import WEBSOCKET_ACTORS, MLWorkerInfo
from giskard.ml_worker.websocket.listener import WEBSOCKET_ACTORS
from giskard.ml_worker.websocket.utils import fragment_message
from giskard.registry.registry import load_plugins
from giskard.settings import settings
Expand Down Expand Up @@ -63,8 +63,6 @@ def __init__(
self._ws_max_reply_payload_size = MAX_STOMP_ML_WORKER_REPLY_SIZE
self._api_key = api_key
self._hf_token = hf_token
# TODO(Bazire): Cleanup this
self._worker_info = MLWorkerInfo(id=self._worker_name)

async def config_handler(self, frame: Frame) -> List[Frame]:
payload = ConfigPayload.parse_raw(frame.body)
Expand All @@ -88,7 +86,7 @@ async def action_handler(self, frame: Frame) -> List[Frame]:
self.stop()

payload: Optional[Union[str, Frame]] = await WEBSOCKET_ACTORS[data.action.name](
data, client_params, self._worker_info
data, client_params, self._worker_name
)
# If no rep_id
if payload is None:
Expand All @@ -103,7 +101,7 @@ async def action_handler(self, frame: Frame) -> List[Frame]:
"mlworker:websocket:action:reply",
{
"name": data.action.name,
"worker": self._worker_info.id,
"worker": self._worker_name,
"language": "PYTHON",
"frag_len": self._ws_max_reply_payload_size,
"frag_count": frag_count,
Expand Down
30 changes: 12 additions & 18 deletions giskard/ml_worker/websocket/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from collections import defaultdict
from concurrent.futures import CancelledError, Future
from copy import copy
from dataclasses import dataclass
from pathlib import Path
from uuid import UUID

Expand Down Expand Up @@ -61,20 +60,15 @@
MAX_STOMP_ML_WORKER_REPLY_SIZE = 1500


@dataclass
class MLWorkerInfo:
id: str


def websocket_log_actor(ml_worker: MLWorkerInfo, req: ActionPayload, *args, **kwargs):
logger.info("ML Worker %s performing %s params: %s", {ml_worker.id}, {req.action}, {req.param})
def websocket_log_actor(worker_id: str, req: ActionPayload, *args, **kwargs):
logger.info("ML Worker %s performing %s params: %s", {worker_id}, {req.action}, {req.param})


WEBSOCKET_ACTORS = dict((action.name, websocket_log_actor) for action in MLWorkerAction)


def wrapped_handle_result(
action: MLWorkerAction, start: float, job_id: Optional[UUID], worker_info: MLWorkerInfo, ignore_timeout: bool
action: MLWorkerAction, start: float, job_id: Optional[UUID], worker_id: str, ignore_timeout: bool
):
async def handle_result(future: Union[Future, Callable[..., websocket.WorkerReply]]):
info = None # Needs to be defined in case of cancellation
Expand Down Expand Up @@ -110,7 +104,7 @@ async def handle_result(future: Union[Future, Callable[..., websocket.WorkerRepl
"mlworker:websocket:action",
{
"name": action.name,
"worker": worker_info.id,
"worker": worker_id,
"language": "PYTHON",
"type": "ERROR" if isinstance(info, websocket.ErrorReply) else "SUCCESS",
"action_time": time.process_time() - start,
Expand All @@ -132,12 +126,12 @@ def parse_and_execute(
callback: Callable,
action: MLWorkerAction,
params,
ml_worker: MLWorkerInfo,
worker_id: str,
client_params: Dict[str, str],
) -> websocket.WorkerReply:
action_params = parse_action_param(action, params)
return callback(
ml_worker=ml_worker,
worker_id=worker_id,
client=GiskardClient(**client_params),
action=action.name,
params=action_params,
Expand All @@ -149,7 +143,7 @@ async def dispatch_action(
action: MLWorkerAction,
req: ActionPayload,
client_params: Dict[str, Any],
worker_info: MLWorkerInfo,
worker_id: str,
execute_in_pool: bool,
timeout: Optional[float] = None,
ignore_timeout=False,
Expand All @@ -164,19 +158,19 @@ async def dispatch_action(
"mlworker:websocket:action:type",
{
"name": action.name,
"worker": worker_info.id,
"worker": worker_id,
"language": "PYTHON",
},
)
start = time.process_time()

result_handler = wrapped_handle_result(action, start, job_id, worker_info, ignore_timeout=ignore_timeout)
result_handler = wrapped_handle_result(action, start, job_id, worker_id, ignore_timeout=ignore_timeout)
# If execution should be done in a pool
kwargs = {
"callback": callback,
"action": action,
"params": params,
"ml_worker": worker_info,
"worker_id": worker_id,
"client_params": client_params,
}
if execute_in_pool and settings.use_pool:
Expand Down Expand Up @@ -226,7 +220,7 @@ def on_abort(params: websocket.AbortParams, *args, **kwargs):


@websocket_actor(MLWorkerAction.getInfo, execute_in_pool=False)
def on_ml_worker_get_info(ml_worker: MLWorkerInfo, params: GetInfoParam, *args, **kwargs) -> websocket.GetInfo:
def on_ml_worker_get_info(worker_id: str, params: GetInfoParam, *args, **kwargs) -> websocket.GetInfo:
logger.info("Collecting ML Worker info from WebSocket")
# TODO(Bazire): seems to be deprecated https://setuptools.pypa.io/en/latest/pkg_resources.html#workingset-objects
installed_packages = {p.project_name: p.version for p in pkg_resources.working_set} if params.list_packages else {}
Expand All @@ -246,7 +240,7 @@ def on_ml_worker_get_info(ml_worker: MLWorkerInfo, params: GetInfoParam, *args,
interpreter=sys.executable,
interpreterVersion=platform.python_version(),
installedPackages=installed_packages,
kernelName=ml_worker.id,
kernelName=worker_id,
)


Expand Down
4 changes: 2 additions & 2 deletions tests/communications/test_websocket_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_websocket_actor_get_info():

# External worker, without packages
remote_ml_worker_info = listener.on_ml_worker_get_info(
ml_worker=listener.MLWorkerInfo(ml_worker._worker_name),
worker_id=ml_worker._worker_name,
params=without_package_params,
)
assert isinstance(remote_ml_worker_info, websocket.GetInfo)
Expand All @@ -51,7 +51,7 @@ def test_websocket_actor_get_info():

# External worker, with packages
remote_ml_worker_info = listener.on_ml_worker_get_info(
ml_worker=listener.MLWorkerInfo(ml_worker._worker_name),
worker_id=ml_worker._worker_name,
params=with_package_params,
)
assert isinstance(remote_ml_worker_info, websocket.GetInfo)
Expand Down