Skip to content

Commit f3634c6

Browse files
committed
Add Orion Extension concept [OC-343]
1 parent da00153 commit f3634c6

File tree

3 files changed

+413
-48
lines changed

3 files changed

+413
-48
lines changed

src/orion/client/experiment.py

Lines changed: 65 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from orion.executor.base import Executor
2929
from orion.plotting.base import PlotAccessor
3030
from orion.storage.base import FailedUpdate
31+
from orion.ext.extensions import OrionExtensionManager
3132

3233
log = logging.getLogger(__name__)
3334

@@ -87,6 +88,7 @@ def __init__(self, experiment, producer, executor=None, heartbeat=None):
8788
**orion.core.config.worker.executor_configuration,
8889
)
8990
self.plot = PlotAccessor(self)
91+
self.extensions = OrionExtensionManager()
9092

9193
###
9294
# Attributes
@@ -753,66 +755,81 @@ def workon(
753755
self._experiment.max_trials = max_trials
754756
self._experiment.algorithms.algorithm.max_trials = max_trials
755757

756-
trials = self.executor.wait(
757-
self.executor.submit(
758-
self._optimize,
759-
fct,
760-
pool_size,
761-
max_trials_per_worker,
762-
max_broken,
763-
trial_arg,
764-
on_error,
765-
**kwargs,
758+
with self.extensions.experiment(self._experiment):
759+
trials = self.executor.wait(
760+
self.executor.submit(
761+
self._optimize,
762+
fct,
763+
pool_size,
764+
max_trials_per_worker,
765+
max_broken,
766+
trial_arg,
767+
on_error,
768+
**kwargs,
769+
)
770+
for _ in range(n_workers)
766771
)
767-
for _ in range(n_workers)
768-
)
769772

770773
return sum(trials)
771774

775+
def _optimize_trial(self, fct, trial, trial_arg, kwargs, worker_broken_trials, max_broken, on_error):
776+
kwargs.update(flatten(trial.params))
777+
778+
if trial_arg:
779+
kwargs[trial_arg] = trial
780+
781+
try:
782+
with self.extensions.trial(trial):
783+
results = self.executor.wait(
784+
[self.executor.submit(fct, **unflatten(kwargs))]
785+
)[0]
786+
self.observe(trial, results=results)
787+
except (KeyboardInterrupt, InvalidResult):
788+
raise
789+
except BaseException as e:
790+
if on_error is None or on_error(
791+
self, trial, e, worker_broken_trials
792+
):
793+
log.error(traceback.format_exc())
794+
worker_broken_trials += 1
795+
else:
796+
log.error(str(e))
797+
log.debug(traceback.format_exc())
798+
799+
if worker_broken_trials >= max_broken:
800+
raise BrokenExperiment(
801+
"Worker has reached broken trials threshold"
802+
)
803+
else:
804+
self.release(trial, status="broken")
805+
772806
def _optimize(
773807
self, fct, pool_size, max_trials, max_broken, trial_arg, on_error, **kwargs
774808
):
775809
worker_broken_trials = 0
776810
trials = 0
777811
kwargs = flatten(kwargs)
778812
max_trials = min(max_trials, self.max_trials)
813+
779814
while not self.is_done and trials - worker_broken_trials < max_trials:
780-
try:
781-
with self.suggest(pool_size=pool_size) as trial:
782-
783-
kwargs.update(flatten(trial.params))
784-
785-
if trial_arg:
786-
kwargs[trial_arg] = trial
787-
788-
try:
789-
results = self.executor.wait(
790-
[self.executor.submit(fct, **unflatten(kwargs))]
791-
)[0]
792-
self.observe(trial, results=results)
793-
except (KeyboardInterrupt, InvalidResult):
794-
raise
795-
except BaseException as e:
796-
if on_error is None or on_error(
797-
self, trial, e, worker_broken_trials
798-
):
799-
log.error(traceback.format_exc())
800-
worker_broken_trials += 1
801-
else:
802-
log.error(str(e))
803-
log.debug(traceback.format_exc())
804-
805-
if worker_broken_trials >= max_broken:
806-
raise BrokenExperiment(
807-
"Worker has reached broken trials threshold"
808-
)
809-
else:
810-
self.release(trial, status="broken")
811-
except CompletedExperiment as e:
812-
log.warning(e)
813-
break
814-
815-
trials += 1
815+
try:
816+
with self.suggest(pool_size=pool_size) as trial:
817+
818+
self._optimize_trial(
819+
fct,
820+
trial,
821+
trial_arg,
822+
kwargs,
823+
worker_broken_trials,
824+
max_broken,
825+
on_error
826+
)
827+
828+
except CompletedExperiment as e:
829+
log.warning(e)
830+
break
831+
832+
trials += 1
816833

817834
return trials
818835

src/orion/ext/extensions.py

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
"""Defines extension mechanism for third party to hook into Orion"""
2+
3+
4+
class EventDelegate:
5+
"""Allow extensions to listen to incoming events from Orion.
6+
Orion broadcasts events which trigger extensions callbacks.
7+
8+
Parameters
9+
----------
10+
name: str
11+
name of the event we are creating, this is useful for error reporting
12+
13+
deferred: bool
14+
if false events are triggered as soon as broadcast is called
15+
if true the events will need to be triggered manually
16+
"""
17+
def __init__(self, name, deferred=False) -> None:
18+
self.handlers = []
19+
self.deferred_calls = []
20+
self.name = name
21+
self.deferred = deferred
22+
self.bad_handlers = []
23+
self.manager = None
24+
25+
def remove(self, function) -> bool:
26+
"""Remove an event handler from the handler list"""
27+
try:
28+
self.handlers.remove(function)
29+
return True
30+
except ValueError:
31+
return False
32+
33+
def add(self, function):
34+
"""Add an event handler to our handler list"""
35+
self.handlers.append(function)
36+
37+
def broadcast(self, *args, **kwargs):
38+
"""Broadcast and event to all our handlers"""
39+
if not self.deferred:
40+
self._execute(args, kwargs)
41+
return
42+
43+
self.deferred_calls.append((args, kwargs))
44+
45+
def _execute(self, args, kwargs):
46+
for fun in self.handlers:
47+
try:
48+
fun(*args, **kwargs)
49+
except Exception as err:
50+
if self.manager:
51+
self.manager.on_extension_error.broadcast(self.name, fun, err, args=(args, kwargs))
52+
53+
def execute(self):
54+
"""Execute all our deferred handlers if any"""
55+
for args, kwargs in self.deferred_calls:
56+
self._execute(args, kwargs)
57+
58+
59+
class _DelegateStartEnd:
60+
def __init__(self, start, error, end, *args, **kwargs):
61+
self.args = args
62+
self.kwargs = kwargs
63+
self.start = start
64+
self.end = end
65+
self.error = error
66+
67+
def __enter__(self):
68+
self.start.broadcast(*self.args, **self.kwargs)
69+
return self
70+
71+
def __exit__(self, exception_type, exception_value, exception_traceback):
72+
self.end.broadcast(*self.args, **self.kwargs)
73+
74+
if exception_value is not None:
75+
self.error.broadcast(
76+
*self.args,
77+
exception_type,
78+
exception_value,
79+
exception_traceback,
80+
**self.kwargs
81+
)
82+
83+
84+
class OrionExtensionManager:
85+
"""Manages third party extensions for Orion"""
86+
87+
def __init__(self):
88+
self._events = {}
89+
self._get_event('on_extension_error')
90+
91+
# -- Trials
92+
self._get_event('new_trial')
93+
self._get_event('on_trial_error')
94+
self._get_event('end_trial')
95+
96+
# -- Experiments
97+
self._get_event('start_experiment')
98+
self._get_event('on_experiment_error')
99+
self._get_event('end_experiment')
100+
101+
def experiment(self, *args, **kwargs):
102+
"""Initialize a context manager that will call start/error/end events automatically"""
103+
return _DelegateStartEnd(
104+
self.start_experiment,
105+
self.on_experiment_error,
106+
self.end_experiment,
107+
*args,
108+
**kwargs
109+
)
110+
111+
def trial(self, *args, **kwargs):
112+
"""Initialize a context manager that will call start/error/end events automatically"""
113+
return _DelegateStartEnd(
114+
self.new_trial,
115+
self.on_trial_error,
116+
self.end_trial,
117+
*args,
118+
**kwargs
119+
)
120+
121+
def __getattr__(self, name):
122+
if name in self._events:
123+
return self._get_event(name)
124+
125+
def _get_event(self, key):
126+
"""Retrieve or generate a new event delegate"""
127+
delegate = self._events.get(key)
128+
129+
if delegate is None:
130+
delegate = EventDelegate(key)
131+
delegate.manager = self
132+
self._events[key] = delegate
133+
134+
return delegate
135+
136+
def register(self, ext):
137+
"""Register a new extensions
138+
139+
Parameters
140+
----------
141+
ext
142+
object implementing :class`OrionExtension` methods
143+
144+
Returns
145+
-------
146+
the number of calls that was registered
147+
"""
148+
registered_callbacks = 0
149+
for name, delegate in self._events.items():
150+
if hasattr(ext, name):
151+
delegate.add(getattr(ext, name))
152+
registered_callbacks += 1
153+
154+
return registered_callbacks
155+
156+
def unregister(self, ext):
157+
"""Remove an extensions if it was already registered"""
158+
unregistered_callbacks = 0
159+
for name, delegate in self._events.items():
160+
if hasattr(ext, name):
161+
delegate.remove(getattr(ext, name))
162+
unregistered_callbacks += 1
163+
164+
return unregistered_callbacks
165+
166+
167+
class OrionExtension:
168+
"""Base orion extension interface you need to implement"""
169+
170+
def on_extension_error(self, name, fun, exception, args):
171+
"""Called when an extension callbakc raise an exception
172+
173+
Parameters
174+
----------
175+
fun: callable
176+
handler that raised the error
177+
178+
exception:
179+
raised exception
180+
181+
args: tuple
182+
tuple of the arguments that were used
183+
"""
184+
return
185+
186+
def on_trial_error(self, trial, exception_type, exception_value, exception_traceback):
187+
"""Called when a error occur during the optimization process"""
188+
return
189+
190+
def new_trial(self, trial):
191+
"""Called when the trial starts with a new configuration"""
192+
return
193+
194+
def end_trial(self, trial):
195+
"""Called when the trial finished"""
196+
return
197+
198+
def on_experiment_error(self, experiment, exception_type, exception_value, exception_traceback):
199+
"""Called when a error occur during the optimization process"""
200+
return
201+
202+
def start_experiment(self, experiment):
203+
"""Called at the begin of the optimization process before the worker starts"""
204+
return
205+
206+
def end_experiment(self, experiment):
207+
"""Called at the end of the optimization process after the worker exits"""
208+
return
209+

0 commit comments

Comments
 (0)