diff --git a/.gitignore b/.gitignore index ceb8008f2..d54b7612c 100644 --- a/.gitignore +++ b/.gitignore @@ -80,3 +80,5 @@ target/ # Notebooks tests/**.ipynb +dask-worker-space/ +tests/functional/commands/*.json \ No newline at end of file diff --git a/docs/src/code/client.rst b/docs/src/code/client.rst index 82dc287ab..d3d960472 100644 --- a/docs/src/code/client.rst +++ b/docs/src/code/client.rst @@ -9,6 +9,7 @@ Client helper functions client/cli client/experiment client/manual + client/extensions .. automodule:: orion.client :members: diff --git a/docs/src/code/client/extensions.rst b/docs/src/code/client/extensions.rst new file mode 100644 index 000000000..140887af7 --- /dev/null +++ b/docs/src/code/client/extensions.rst @@ -0,0 +1,7 @@ +Extensions +========== + +.. automodule:: orion.ext.extensions + :members: + + diff --git a/setup.py b/setup.py index 07f67bcaf..819d7f99b 100644 --- a/setup.py +++ b/setup.py @@ -21,6 +21,7 @@ "orion.client", "orion.core", "orion.executor", + "orion.ext", "orion.plotting", "orion.serving", "orion.storage", diff --git a/src/orion/client/experiment.py b/src/orion/client/experiment.py index 8e2180d11..d242751ae 100644 --- a/src/orion/client/experiment.py +++ b/src/orion/client/experiment.py @@ -26,6 +26,7 @@ from orion.core.worker.trial import Trial, TrialCM from orion.core.worker.trial_pacemaker import TrialPacemaker from orion.executor.base import Executor +from orion.ext.extensions import OrionExtensionManager from orion.plotting.base import PlotAccessor from orion.storage.base import FailedUpdate @@ -72,6 +73,11 @@ class ExperimentClient: producer: `orion.core.worker.producer.Producer` Producer object used to produce new trials. + Notes + ----- + + Users can write generic extensions to ExperimentClient through + `orion.client.experiment.OrionExtension`. """ def __init__(self, experiment, producer, executor=None, heartbeat=None): @@ -87,6 +93,7 @@ def __init__(self, experiment, producer, executor=None, heartbeat=None): **orion.core.config.worker.executor_configuration, ) self.plot = PlotAccessor(self) + self.extensions = OrionExtensionManager() ### # Attributes @@ -320,6 +327,16 @@ def fetch_noncompleted_trials(self, with_evc_tree=False): ### # Actions ### + def register_extension(self, ext): + """Register a third party extension + + Parameters + ---------- + ext: OrionExtension + object that implements the OrionExtension interface + + """ + return self.extensions.register(ext) # pylint: disable=unused-argument def insert(self, params, results=None, reserve=False): @@ -753,19 +770,20 @@ def workon( self._experiment.max_trials = max_trials self._experiment.algorithms.algorithm.max_trials = max_trials - trials = self.executor.wait( - self.executor.submit( - self._optimize, - fct, - pool_size, - max_trials_per_worker, - max_broken, - trial_arg, - on_error, - **kwargs, + with self.extensions.experiment(self._experiment): + trials = self.executor.wait( + self.executor.submit( + self._optimize, + fct, + pool_size, + max_trials_per_worker, + max_broken, + trial_arg, + on_error, + **kwargs, + ) + for _ in range(n_workers) ) - for _ in range(n_workers) - ) return sum(trials) @@ -776,6 +794,7 @@ def _optimize( trials = 0 kwargs = flatten(kwargs) max_trials = min(max_trials, self.max_trials) + while not self.is_done and trials - worker_broken_trials < max_trials: try: with self.suggest(pool_size=pool_size) as trial: @@ -786,10 +805,11 @@ def _optimize( kwargs[trial_arg] = trial try: - results = self.executor.wait( - [self.executor.submit(fct, **unflatten(kwargs))] - )[0] - self.observe(trial, results=results) + with self.extensions.trial(trial): + results = self.executor.wait( + [self.executor.submit(fct, **unflatten(kwargs))] + )[0] + self.observe(trial, results=results) except (KeyboardInterrupt, InvalidResult): raise except BaseException as e: @@ -808,6 +828,7 @@ def _optimize( ) else: self.release(trial, status="broken") + except CompletedExperiment as e: log.warning(e) break diff --git a/src/orion/ext/extensions.py b/src/orion/ext/extensions.py new file mode 100644 index 000000000..21bdb7dbf --- /dev/null +++ b/src/orion/ext/extensions.py @@ -0,0 +1,222 @@ +"""Defines extension mechanism for third party to hook into Orion""" + + +class EventDelegate: + """Allow extensions to listen to incoming events from Orion. + Orion broadcasts events which trigger extensions callbacks. + + Parameters + ---------- + name: str + name of the event we are creating, this is useful for error reporting + + deferred: bool + if false events are triggered as soon as broadcast is called + If true, the events will need to be triggered manually. + """ + + def __init__(self, name, deferred=False) -> None: + self.handlers = [] + self.deferred_calls = [] + self.name = name + self.deferred = deferred + self.bad_handlers = [] + self.manager = None + + def remove(self, function) -> bool: + """Remove an event handler from the handler list""" + try: + self.handlers.remove(function) + return True + except ValueError: + return False + + def add(self, function): + """Add an event handler to our handler list""" + self.handlers.append(function) + + def broadcast(self, *args, **kwargs): + """Broadcast and event to all our handlers""" + if not self.deferred: + self._execute(args, kwargs) + return + + self.deferred_calls.append((args, kwargs)) + + def _execute(self, args, kwargs): + for fun in self.handlers: + try: + fun(*args, **kwargs) + except Exception as err: + if self.manager: + self.manager.on_extension_error.broadcast( + self.name, fun, err, args=(args, kwargs) + ) + + def execute(self): + """Execute all our deferred handlers if any""" + for args, kwargs in self.deferred_calls: + self._execute(args, kwargs) + + +class _DelegateStartEnd: + def __init__(self, start, error, end, *args, **kwargs): + self.args = args + self.kwargs = kwargs + self.start = start + self.end = end + self.error = error + + def __enter__(self): + self.start.broadcast(*self.args, **self.kwargs) + return self + + def __exit__(self, exception_type, exception_value, exception_traceback): + self.end.broadcast(*self.args, **self.kwargs) + + if exception_value is not None: + self.error.broadcast( + *self.args, + exception_type, + exception_value, + exception_traceback, + **self.kwargs + ) + + +class OrionExtensionManager: + """Manages third party extensions for Orion""" + + def __init__(self): + self._events = {} + self._get_event("on_extension_error") + + # -- Trials + self._get_event("new_trial") + self._get_event("on_trial_error") + self._get_event("end_trial") + + # -- Experiments + self._get_event("start_experiment") + self._get_event("on_experiment_error") + self._get_event("end_experiment") + + @property + def on_extension_error(self): + """Called when an extension is throwing an exception""" + return self._get_event("on_extension_error") + + def experiment(self, *args, **kwargs): + """Initialize a context manager that will call start/error/end events automatically""" + return _DelegateStartEnd( + self._get_event("start_experiment"), + self._get_event("on_experiment_error"), + self._get_event("end_experiment"), + *args, + **kwargs + ) + + def trial(self, *args, **kwargs): + """Initialize a context manager that will call start/error/end events automatically""" + return _DelegateStartEnd( + self._get_event("new_trial"), + self._get_event("on_trial_error"), + self._get_event("end_trial"), + *args, + **kwargs + ) + + def broadcast(self, name, *args, **kwargs): + return self._get_event(name).broadcast(*args, **kwargs) + + def _get_event(self, key): + """Retrieve event delegate + + Will generate one if not defined already. + """ + delegate = self._events.get(key) + + if delegate is None: + delegate = EventDelegate(key) + delegate.manager = self + self._events[key] = delegate + + return delegate + + def register(self, ext): + """Register a new extensions + + Parameters + ---------- + ext: ``OrionExtension`` + object implementing :class:`OrionExtension` methods + + Returns + ------- + the number of calls that was registered + """ + registered_callbacks = 0 + for name, delegate in self._events.items(): + if hasattr(ext, name): + delegate.add(getattr(ext, name)) + registered_callbacks += 1 + + return registered_callbacks + + def unregister(self, ext): + """Remove an extensions if it was already registered""" + unregistered_callbacks = 0 + for name, delegate in self._events.items(): + if hasattr(ext, name): + delegate.remove(getattr(ext, name)) + unregistered_callbacks += 1 + + return unregistered_callbacks + + +class OrionExtension: + """Base orion extension interface you need to implement""" + + def on_extension_error(self, name, fun, exception, args): + """Called when an extension callbakc raise an exception + + Parameters + ---------- + fun: callable + handler that raised the error + + exception: + raised exception + + args: tuple + tuple of the arguments that were used + """ + return + + def on_trial_error( + self, trial, exception_type, exception_value, exception_traceback + ): + """Called when a error occur during the optimization process""" + return + + def new_trial(self, trial): + """Called when the trial starts with a new configuration""" + return + + def end_trial(self, trial): + """Called when the trial finished""" + return + + def on_experiment_error( + self, experiment, exception_type, exception_value, exception_traceback + ): + """Called when a error occur during the optimization process""" + return + + def start_experiment(self, experiment): + """Called at the begin of the optimization process before the worker starts""" + return + + def end_experiment(self, experiment): + """Called at the end of the optimization process after the worker exits""" + return diff --git a/tests/unittests/ext/test_extension.py b/tests/unittests/ext/test_extension.py new file mode 100644 index 000000000..3f217990e --- /dev/null +++ b/tests/unittests/ext/test_extension.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +"""Example usage and tests for :mod:`orion.client.experiment`.""" +from collections import defaultdict + +import pytest + +from orion.core.utils.exceptions import BrokenExperiment +from orion.testing import create_experiment + +config = dict( + name="supernaekei", + space={"x": "uniform(0, 200)"}, + metadata={ + "user": "tsirif", + "orion_version": "XYZ", + "VCS": { + "type": "git", + "is_dirty": False, + "HEAD_sha": "test", + "active_branch": None, + "diff_sha": "diff", + }, + }, + version=1, + max_trials=10, + max_broken=5, + working_dir="", + algorithms={"random": {"seed": 1}}, + producer={"strategy": "NoParallelStrategy"}, + refers=dict(root_id="supernaekei", parent_id=None, adapter=[]), +) + +base_trial = { + "experiment": 0, + "status": "new", # new, reserved, suspended, completed, broken + "worker": None, + "start_time": None, + "end_time": None, + "heartbeat": None, + "results": [], + "params": [], +} + + +class OrionExtensionTest: + """Base orion extension interface you need to implement""" + + def __init__(self) -> None: + self.calls = defaultdict(int) + + def on_experiment_error(self, *args, **kwargs): + self.calls["on_experiment_error"] += 1 + + def on_trial_error(self, *args, **kwargs): + self.calls["on_trial_error"] += 1 + + def start_experiment(self, *args, **kwargs): + self.calls["start_experiment"] += 1 + + def new_trial(self, *args, **kwargs): + self.calls["new_trial"] += 1 + + def end_trial(self, *args, **kwargs): + self.calls["end_trial"] += 1 + + def end_experiment(self, *args, **kwargs): + self.calls["end_experiment"] += 1 + + +def test_client_extension(): + ext = OrionExtensionTest() + with create_experiment(config, base_trial) as (cfg, experiment, client): + registered_callback = client.extensions.register(ext) + assert registered_callback == 6, "All ext callbacks got registered" + + def foo(x): + if len(client.fetch_trials()) > 5: + raise RuntimeError() + return [dict(name="result", type="objective", value=x * 2)] + + MAX_TRIALS = 10 + MAX_BROKEN = 5 + assert client.max_trials == MAX_TRIALS + + with pytest.raises(BrokenExperiment): + client.workon(foo, max_trials=MAX_TRIALS, max_broken=MAX_BROKEN) + + n_trials = len(experiment.fetch_trials_by_status("completed")) + n_broken = len(experiment.fetch_trials_by_status("broken")) + n_reserved = len(experiment.fetch_trials_by_status("reserved")) + + assert ( + ext.calls["new_trial"] == n_trials + n_broken - n_reserved + ), "all trials should have triggered callbacks" + assert ( + ext.calls["end_trial"] == n_trials + n_broken - n_reserved + ), "all trials should have triggered callbacks" + assert ( + ext.calls["on_trial_error"] == n_broken + ), "failed trial should be reported " + + assert ext.calls["start_experiment"] == 1, "experiment should have started" + assert ext.calls["end_experiment"] == 1, "experiment should have ended" + assert ext.calls["on_experiment_error"] == 1, "failed experiment " + + unregistered_callback = client.extensions.unregister(ext) + assert unregistered_callback == 6, "All ext callbacks got unregistered" + + +class BadOrionExtensionTest: + """Base orion extension interface you need to implement""" + + def __init__(self) -> None: + self.calls = defaultdict(int) + + def on_extension_error(self, name, fun, exception, args): + self.calls["on_extension_error"] += 1 + + def on_experiment_error(self, *args, **kwargs): + self.calls["on_experiment_error"] += 1 + + def on_trial_error(self, *args, **kwargs): + self.calls["on_trial_error"] += 1 + + def new_trial(self, *args, **kwargs): + raise RuntimeError() + + +def test_client_bad_extension(): + ext = BadOrionExtensionTest() + with create_experiment(config, base_trial) as (cfg, experiment, client): + registered_callback = client.extensions.register(ext) + assert registered_callback == 4, "All ext callbacks got registered" + + def foo(x): + return [dict(name="result", type="objective", value=x * 2)] + + MAX_TRIALS = 10 + MAX_BROKEN = 5 + assert client.max_trials == MAX_TRIALS + client.workon(foo, max_trials=MAX_TRIALS, max_broken=MAX_BROKEN) + + assert ext.calls["on_trial_error"] == 0, "Orion worked as expected" + assert ext.calls["on_experiment_error"] == 0, "Orion worked as expected" + assert ext.calls["on_extension_error"] == 9, "Extension error got reported" + + unregistered_callback = client.extensions.unregister(ext) + assert unregistered_callback == 4, "All ext callbacks got unregistered"