diff --git a/giskard/commands/cli_worker.py b/giskard/commands/cli_worker.py index 235b1ec64b..f36105a06c 100644 --- a/giskard/commands/cli_worker.py +++ b/giskard/commands/cli_worker.py @@ -109,7 +109,7 @@ def start_command(url: AnyHttpUrl, is_server, api_key, is_daemon, hf_token, nb_w ) api_key = initialize_api_key(api_key, is_server) hf_token = initialize_hf_token(hf_token, is_server) - _start_command(is_server, url, api_key, is_daemon, hf_token, nb_workers) + _start_command(is_server, url, api_key, is_daemon, hf_token, int(nb_workers) if nb_workers is not None else None) def initialize_api_key(api_key, is_server): diff --git a/giskard/ml_worker/websocket/listener.py b/giskard/ml_worker/websocket/listener.py index fae0b9c4cd..b373231d51 100644 --- a/giskard/ml_worker/websocket/listener.py +++ b/giskard/ml_worker/websocket/listener.py @@ -9,7 +9,7 @@ import tempfile import time import traceback -from concurrent.futures import Future +from concurrent.futures import CancelledError, Future from dataclasses import dataclass from pathlib import Path @@ -69,7 +69,7 @@ def __init__(self, worker: MLWorker): self.is_remote = worker.is_remote_worker() -def websocket_log_actor(ml_worker: MLWorkerInfo, req: dict, *args, **kwargs): +def websocket_log_actor(ml_worker: MLWorkerInfo, req: Dict, *args, **kwargs): param = req["param"] if "param" in req.keys() else {} action = req["action"] if "action" in req.keys() else "" logger.info(f"ML Worker {ml_worker.id} performing {action} params: {param}") @@ -86,6 +86,9 @@ def handle_result(future: Union[Future, Callable[..., websocket.WorkerReply]]): try: info: websocket.WorkerReply = future.result() if isinstance(future, Future) else future() + except CancelledError: + info: websocket.WorkerReply = websocket.Empty() + logger.warning("Task for %s has timed out and been cancelled", action.name) except Exception as e: info: websocket.WorkerReply = websocket.ErrorReply( error_str=str(e), error_type=type(e).__name__, detail=traceback.format_exc() @@ -168,7 +171,7 @@ def parse_and_execute( ) -def dispatch_action(callback, ml_worker, action, req, execute_in_pool): +def dispatch_action(callback, ml_worker, action, req, execute_in_pool, timeout=None): # Parse the response ID rep_id = req["id"] if "id" in req.keys() else None # Parse the param @@ -199,7 +202,7 @@ def dispatch_action(callback, ml_worker, action, req, execute_in_pool): result_handler = wrapped_handle_result(action, ml_worker, start, rep_id) # If execution should be done in a pool if execute_in_pool: - logger.debug("Submitting for action %s '%s' into the pool with %s", action.name, callback.__name__, params) + logger.debug("Submitting for action %s '%s' into the pool", action.name, callback.__name__) future = call_in_pool( parse_and_execute, callback=callback, @@ -207,6 +210,7 @@ def dispatch_action(callback, ml_worker, action, req, execute_in_pool): params=params, ml_worker=MLWorkerInfo(ml_worker), client_params=client_params, + timeout=timeout, ) future.add_done_callback(result_handler) log_pool_stats() @@ -223,7 +227,7 @@ def dispatch_action(callback, ml_worker, action, req, execute_in_pool): ) -def websocket_actor(action: MLWorkerAction, execute_in_pool: bool = True): +def websocket_actor(action: MLWorkerAction, execute_in_pool: bool = True, timeout: Optional[float] = None): """ Register a function as an actor to an action from WebSocket connection """ @@ -234,7 +238,7 @@ def websocket_actor_callback(callback: callable): logger.debug(f'Registered "{callback.__name__}" for ML Worker "{action.name}"') def wrapped_callback(ml_worker: MLWorker, req: dict, *args, **kwargs): - dispatch_action(callback, ml_worker, action, req, execute_in_pool) + dispatch_action(callback, ml_worker, action, req, execute_in_pool, timeout) WEBSOCKET_ACTORS[action.name] = wrapped_callback @@ -657,7 +661,7 @@ def echo(params: websocket.EchoMsg, *args, **kwargs) -> websocket.EchoMsg: return params -@websocket_actor(MLWorkerAction.getPush) +@websocket_actor(MLWorkerAction.getPush, timeout=30) def get_push( client: Optional[GiskardClient], params: websocket.GetPushParam, *args, **kwargs ) -> websocket.GetPushResponse: diff --git a/giskard/settings.py b/giskard/settings.py index 13e61dcf51..f3cdc1154a 100644 --- a/giskard/settings.py +++ b/giskard/settings.py @@ -34,6 +34,7 @@ class Settings(BaseModel): loglevel: str = "INFO" cache_dir: str = "cache" disable_analytics: bool = False + min_workers: int = 2 class Config: env_prefix = "GSK_" diff --git a/giskard/utils/__init__.py b/giskard/utils/__init__.py index 36f5c8d9c2..310deac1e7 100644 --- a/giskard/utils/__init__.py +++ b/giskard/utils/__init__.py @@ -1,8 +1,12 @@ import logging -from asyncio import Future -from concurrent.futures import ProcessPoolExecutor +import os +import signal +from concurrent.futures import CancelledError, Future, InvalidStateError, ProcessPoolExecutor from functools import wraps -from threading import Thread +from threading import Lock, Thread +from time import sleep, time + +from giskard.settings import settings LOGGER = logging.getLogger(__name__) @@ -21,19 +25,23 @@ class WorkerPool: def __init__(self): self.pool = None + self.nb_cancellable = 0 + self.max_workers = 0 - def start(self, *args, **kwargs): + def start(self, max_workers: int = None): if self.pool is not None: return - LOGGER.info("Starting worker pool...") - self.pool = ProcessPoolExecutor(*args, **kwargs) + self.max_workers = max(max_workers, settings.min_workers) if max_workers is not None else os.cpu_count() + LOGGER.info("Starting worker pool with %s workers...", self.max_workers) + self.pool = ProcessPoolExecutor(max_workers=self.max_workers) LOGGER.info("Pool is started") def shutdown(self, wait=True, cancel_futures=False): if self.pool is None: return self.pool.shutdown(wait=wait, cancel_futures=cancel_futures) - self.pool = None + with NB_CANCELLABLE_WORKER_LOCK: + self.nb_cancellable = 0 def submit(self, *args, **kwargs) -> Future: if self.pool is None: @@ -50,9 +58,10 @@ def log_stats(self): LOGGER.debug("Pool is not yet started") return LOGGER.debug( - "Pool is currently having :\n - %s pending items\n - %s workers", + "Pool is currently having :\n - %s pending items\n - %s workers\n - %s cancellable tasks", len(self.pool._pending_work_items), len(self.pool._processes), + self.nb_cancellable, ) @@ -83,18 +92,6 @@ def shutdown_pool(): POOL.shutdown(wait=True, cancel_futures=True) -def call_in_pool(fn, *args, **kwargs): - """Submit the function call with args and kwargs inside the process pool - - Args: - fn (function): the function to call - - Returns: - Future: the promise of the results - """ - return POOL.submit(fn, *args, **kwargs) - - def pooled(fn): """Decorator to make a function be called inside the pool. @@ -109,6 +106,64 @@ def wrapper(*args, **kwargs): return wrapper +NB_CANCELLABLE_WORKER_LOCK = Lock() + + +@threaded +def start_killer(timeout: float, future: Future, pid: int, executor: ProcessPoolExecutor): + start = time() + # Try to get the result in proper time + while (time() - start) < timeout: + # future.result(timeout=timeout) => Not working with WSL and python 3.10, switchin to something safer + LOGGER.debug("Sleeping for pid %s", pid) + sleep(1) + if future.done(): + executor.shutdown(wait=True, cancel_futures=False) + with NB_CANCELLABLE_WORKER_LOCK: + POOL.nb_cancellable -= 1 + return + LOGGER.warning("Thread gonna kill pid %s", pid) + # Manually setting exception, to allow customisation + # TODO(Bazire): See if we need a custom error to handle that properly + try: + future.set_exception(CancelledError("Background task was taking too much time and was cancelled")) + except InvalidStateError: + pass + # Shutting down an executor is actually not stopping the running processes + executor.shutdown(wait=False, cancel_futures=False) + # Kill the process running by targeting its pid + os.kill(pid, signal.SIGINT) + # Let's clean up the executor also + # Also, does not matter to call shutdown several times + executor.shutdown(wait=True, cancel_futures=False) + with NB_CANCELLABLE_WORKER_LOCK: + POOL.nb_cancellable -= 1 + LOGGER.debug("Executor has been successfully shutdown") + log_pool_stats() + + +def call_in_pool(fn, *args, timeout=None, **kwargs): + """Submit the function call with args and kwargs inside the process pool + + Args: + fn (function): the function to call + + Returns: + Future: the promise of the results + """ + if timeout is None: + return POOL.submit(fn, *args, **kwargs) + # Create independant process pool + # If we kill a running process, it breaking the Process pool, making it unusable + one_shot_executor = ProcessPoolExecutor(max_workers=1) + pid = one_shot_executor.submit(os.getpid).result(timeout=5) + future = one_shot_executor.submit(fn, *args, **kwargs) + start_killer(timeout, future, pid, one_shot_executor) + with NB_CANCELLABLE_WORKER_LOCK: + POOL.nb_cancellable += 1 + return future + + def fullname(o): klass = o.__class__ module = klass.__module__