Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion giskard/commands/cli_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def start_command(url: AnyHttpUrl, is_server, api_key, is_daemon, hf_token, nb_w
)
api_key = initialize_api_key(api_key, is_server)
hf_token = initialize_hf_token(hf_token, is_server)
_start_command(is_server, url, api_key, is_daemon, hf_token, nb_workers)
_start_command(is_server, url, api_key, is_daemon, hf_token, int(nb_workers) if nb_workers is not None else None)


def initialize_api_key(api_key, is_server):
Expand Down
18 changes: 11 additions & 7 deletions giskard/ml_worker/websocket/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import tempfile
import time
import traceback
from concurrent.futures import Future
from concurrent.futures import CancelledError, Future
from dataclasses import dataclass
from pathlib import Path

Expand Down Expand Up @@ -69,7 +69,7 @@ def __init__(self, worker: MLWorker):
self.is_remote = worker.is_remote_worker()


def websocket_log_actor(ml_worker: MLWorkerInfo, req: dict, *args, **kwargs):
def websocket_log_actor(ml_worker: MLWorkerInfo, req: Dict, *args, **kwargs):
param = req["param"] if "param" in req.keys() else {}
action = req["action"] if "action" in req.keys() else ""
logger.info(f"ML Worker {ml_worker.id} performing {action} params: {param}")
Expand All @@ -86,6 +86,9 @@ def handle_result(future: Union[Future, Callable[..., websocket.WorkerReply]]):

try:
info: websocket.WorkerReply = future.result() if isinstance(future, Future) else future()
except CancelledError:
info: websocket.WorkerReply = websocket.Empty()
logger.warning("Task for %s has timed out and been cancelled", action.name)
except Exception as e:
info: websocket.WorkerReply = websocket.ErrorReply(
error_str=str(e), error_type=type(e).__name__, detail=traceback.format_exc()
Expand Down Expand Up @@ -168,7 +171,7 @@ def parse_and_execute(
)


def dispatch_action(callback, ml_worker, action, req, execute_in_pool):
def dispatch_action(callback, ml_worker, action, req, execute_in_pool, timeout=None):
# Parse the response ID
rep_id = req["id"] if "id" in req.keys() else None
# Parse the param
Expand Down Expand Up @@ -199,14 +202,15 @@ def dispatch_action(callback, ml_worker, action, req, execute_in_pool):
result_handler = wrapped_handle_result(action, ml_worker, start, rep_id)
# If execution should be done in a pool
if execute_in_pool:
logger.debug("Submitting for action %s '%s' into the pool with %s", action.name, callback.__name__, params)
logger.debug("Submitting for action %s '%s' into the pool", action.name, callback.__name__)
future = call_in_pool(
parse_and_execute,
callback=callback,
action=action,
params=params,
ml_worker=MLWorkerInfo(ml_worker),
client_params=client_params,
timeout=timeout,
)
future.add_done_callback(result_handler)
log_pool_stats()
Expand All @@ -223,7 +227,7 @@ def dispatch_action(callback, ml_worker, action, req, execute_in_pool):
)


def websocket_actor(action: MLWorkerAction, execute_in_pool: bool = True):
def websocket_actor(action: MLWorkerAction, execute_in_pool: bool = True, timeout: Optional[float] = None):
"""
Register a function as an actor to an action from WebSocket connection
"""
Expand All @@ -234,7 +238,7 @@ def websocket_actor_callback(callback: callable):
logger.debug(f'Registered "{callback.__name__}" for ML Worker "{action.name}"')

def wrapped_callback(ml_worker: MLWorker, req: dict, *args, **kwargs):
dispatch_action(callback, ml_worker, action, req, execute_in_pool)
dispatch_action(callback, ml_worker, action, req, execute_in_pool, timeout)

WEBSOCKET_ACTORS[action.name] = wrapped_callback

Expand Down Expand Up @@ -657,7 +661,7 @@ def echo(params: websocket.EchoMsg, *args, **kwargs) -> websocket.EchoMsg:
return params


@websocket_actor(MLWorkerAction.getPush)
@websocket_actor(MLWorkerAction.getPush, timeout=30)
def get_push(
client: Optional[GiskardClient], params: websocket.GetPushParam, *args, **kwargs
) -> websocket.GetPushResponse:
Expand Down
1 change: 1 addition & 0 deletions giskard/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class Settings(BaseModel):
loglevel: str = "INFO"
cache_dir: str = "cache"
disable_analytics: bool = False
min_workers: int = 2

class Config:
env_prefix = "GSK_"
Expand Down
95 changes: 75 additions & 20 deletions giskard/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import logging
from asyncio import Future
from concurrent.futures import ProcessPoolExecutor
import os
import signal
from concurrent.futures import CancelledError, Future, InvalidStateError, ProcessPoolExecutor
from functools import wraps
from threading import Thread
from threading import Lock, Thread
from time import sleep, time

from giskard.settings import settings

LOGGER = logging.getLogger(__name__)

Expand All @@ -21,19 +25,23 @@ class WorkerPool:

def __init__(self):
self.pool = None
self.nb_cancellable = 0
self.max_workers = 0

def start(self, *args, **kwargs):
def start(self, max_workers: int = None):
if self.pool is not None:
return
LOGGER.info("Starting worker pool...")
self.pool = ProcessPoolExecutor(*args, **kwargs)
self.max_workers = max(max_workers, settings.min_workers) if max_workers is not None else os.cpu_count()
LOGGER.info("Starting worker pool with %s workers...", self.max_workers)
self.pool = ProcessPoolExecutor(max_workers=self.max_workers)
LOGGER.info("Pool is started")

def shutdown(self, wait=True, cancel_futures=False):
if self.pool is None:
return
self.pool.shutdown(wait=wait, cancel_futures=cancel_futures)
self.pool = None
with NB_CANCELLABLE_WORKER_LOCK:
self.nb_cancellable = 0

def submit(self, *args, **kwargs) -> Future:
if self.pool is None:
Expand All @@ -50,9 +58,10 @@ def log_stats(self):
LOGGER.debug("Pool is not yet started")
return
LOGGER.debug(
"Pool is currently having :\n - %s pending items\n - %s workers",
"Pool is currently having :\n - %s pending items\n - %s workers\n - %s cancellable tasks",
len(self.pool._pending_work_items),
len(self.pool._processes),
self.nb_cancellable,
)


Expand Down Expand Up @@ -83,18 +92,6 @@ def shutdown_pool():
POOL.shutdown(wait=True, cancel_futures=True)


def call_in_pool(fn, *args, **kwargs):
"""Submit the function call with args and kwargs inside the process pool

Args:
fn (function): the function to call

Returns:
Future: the promise of the results
"""
return POOL.submit(fn, *args, **kwargs)


def pooled(fn):
"""Decorator to make a function be called inside the pool.

Expand All @@ -109,6 +106,64 @@ def wrapper(*args, **kwargs):
return wrapper


NB_CANCELLABLE_WORKER_LOCK = Lock()


@threaded
def start_killer(timeout: float, future: Future, pid: int, executor: ProcessPoolExecutor):
start = time()
# Try to get the result in proper time
while (time() - start) < timeout:
# future.result(timeout=timeout) => Not working with WSL and python 3.10, switchin to something safer
LOGGER.debug("Sleeping for pid %s", pid)
sleep(1)
if future.done():
executor.shutdown(wait=True, cancel_futures=False)
with NB_CANCELLABLE_WORKER_LOCK:
POOL.nb_cancellable -= 1
return
LOGGER.warning("Thread gonna kill pid %s", pid)
# Manually setting exception, to allow customisation
# TODO(Bazire): See if we need a custom error to handle that properly
try:
future.set_exception(CancelledError("Background task was taking too much time and was cancelled"))
except InvalidStateError:
pass
# Shutting down an executor is actually not stopping the running processes
executor.shutdown(wait=False, cancel_futures=False)
# Kill the process running by targeting its pid
os.kill(pid, signal.SIGINT)
# Let's clean up the executor also
# Also, does not matter to call shutdown several times
executor.shutdown(wait=True, cancel_futures=False)
with NB_CANCELLABLE_WORKER_LOCK:
POOL.nb_cancellable -= 1
LOGGER.debug("Executor has been successfully shutdown")
log_pool_stats()


def call_in_pool(fn, *args, timeout=None, **kwargs):
"""Submit the function call with args and kwargs inside the process pool

Args:
fn (function): the function to call

Returns:
Future: the promise of the results
"""
if timeout is None:
return POOL.submit(fn, *args, **kwargs)
# Create independant process pool
# If we kill a running process, it breaking the Process pool, making it unusable
one_shot_executor = ProcessPoolExecutor(max_workers=1)
pid = one_shot_executor.submit(os.getpid).result(timeout=5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get it, the comment above said

.result(timeout=timeout) => Not working with WSL and python 3.10

here it's expected to work?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should remove the timeout here anyway, no reason for basic os.getpid to fail or take a long time.

Basically, when I tested it, in the killer thread, the timeout was not respected, and it never stopped, so I had to do the loop.

Btw, the timeout was working out on MacOS with python 3.11.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the overall killing mechanism could be simplified to avoid having a killer thread and a cancellable counter.

We could use an inter-process data structure to communicate pids between a spawned worker process and the main process: multiprocessing.Queue for example.

In this case the first thing the pool process will do is identify its pid and add it to queue, then do the actual work.
In the main process we could use concurrent.futures._base.as_completed to wrap the future and add a timeout to it and then call .add_done_callback(result_handler) as we do now.

When the timeout is expired we'd find a PID related to a given task and kill it from the main process.

WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Hartorn , actually, it looks like https://pypi.org/project/Pebble/ does exactly what we need:

from pebble import ProcessPool

pool = ProcessPool(max_workers=5)

pool.schedule(long_fn, args=(a,b,c), timeout=1).add_done_callback(done_callback)

Copy link
Member Author

@Hartorn Hartorn Oct 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the overall killing mechanism could be simplified to avoid having a killer thread and a cancellable counter.

We could use an inter-process data structure to communicate pids between a spawned worker process and the main process: multiprocessing.Queue for example.

In this case the first thing the pool process will do is identify its pid and add it to queue, then do the actual work. In the main process we could use concurrent.futures._base.as_completed to wrap the future and add a timeout to it and then call .add_done_callback(result_handler) as we do now.

When the timeout is expired we'd find a PID related to a given task and kill it from the main process.

WDYT?

Cancellable counter is only for logs, we can remove it if we want.
I just wanted to check if that was working, or eventually to avoid having too many process launched.
For the callback of a future, you need everything to be pickable, so you cannot have lock, processPool, and so on.
So I don't know how to shutdown properly the one shot executor we ran.

Also, it cannot be the main executor, since killing a process is breaking the pool, which only raises BrokenPoolException after that.

I'm pretty sure we could avoid the thread is the whole code we were running was async, since a coroutine could do this job, but here I would not be confident.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Hartorn , actually, it looks like https://pypi.org/project/Pebble/ does exactly what we need:

from pebble import ProcessPool

pool = ProcessPool(max_workers=5)

pool.schedule(long_fn, args=(a,b,c), timeout=1).add_done_callback(done_callback)

Looking at it !

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I saw, they are kinda doing the same stuff, when using a separate process, they have a handler thread for watching it and handling timeout.

Should we get this merged and change to use it ? Or want me to switch it ?

Although I'm a bit concerned it's not that much used, the code looks clean

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, while having a clean API and being easy to use Pebble is a LGPL library so we won't be able to use it (it might be the reason why it's not that widely adopted).

I also read their code and found similarities with your implementation. I suggest we stick to your current code (and merge it since we need these changes ASAP).

As an improvement, I think in a separate PR we could inspire from Pebble's API and also encapsulate call_in_pool into our custom ProcessPool implementation and have a schedule method, WDYT?
In the end Pebble is just a ~5k LOC library

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shame, the code was working well with it also.
Sure, seems a good idea to have our own worker implementation. I'm sure it will also help us to improve performance of the lib

future = one_shot_executor.submit(fn, *args, **kwargs)
start_killer(timeout, future, pid, one_shot_executor)
with NB_CANCELLABLE_WORKER_LOCK:
POOL.nb_cancellable += 1
return future


def fullname(o):
klass = o.__class__
module = klass.__module__
Expand Down