Skip to content

Commit eb5f3f9

Browse files
committed
Add Orion Extension concept [OC-343]
1 parent da00153 commit eb5f3f9

File tree

1 file changed

+95
-0
lines changed

1 file changed

+95
-0
lines changed

src/orion/ext/__init__.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
"""Defines extension mechanism for third party to hook into Orion"""
2+
3+
4+
class EventDelegate:
5+
def __init__(self, name, deferred=False, parent=None) -> None:
6+
self.handlers = []
7+
self.deferred_calls = []
8+
self.name = name
9+
self.parent = parent
10+
self.deferred = deferred
11+
event_manager.add(name, self)
12+
self.bad_handlers = []
13+
14+
def remove(self, function) -> bool:
15+
try:
16+
self.handlers.remove(function)
17+
return True
18+
except ValueError:
19+
return False
20+
21+
def add(self, function):
22+
self.handlers.append(function)
23+
24+
def broadcast(self, *args, **kwargs):
25+
if not self.deferred:
26+
self._execute(args, kwargs)
27+
return
28+
29+
self.deferred_calls.append((args, kwargs))
30+
31+
def _execute(self, args, kwargs):
32+
for fun in self.handlers:
33+
try:
34+
fun(*args, _parent=self.parent, **kwargs)
35+
except Exception as err:
36+
event_manager.broadcast(self.name, fun, err, args=(args, kwargs))
37+
38+
def execute(self):
39+
self.bad_handlers = []
40+
41+
for args, kwargs in self.deferred_calls:
42+
self._execute(args, kwargs)
43+
44+
45+
46+
class OrionExtensionManager:
47+
"""Manages third party extensions for Orion"""
48+
49+
def __init__(self):
50+
self._events = {}
51+
52+
self._get_event('error')
53+
self._get_event('start_experiment')
54+
self._get_event('new_trial')
55+
self._get_event('end_trial')
56+
self._get_event('end_experiment')
57+
58+
59+
def _get_event(self, key):
60+
delegate = self._events.get(key)
61+
62+
if delegate is None:
63+
delegate = EventDelegate(key)
64+
self._events[key] = delegate
65+
66+
return delegate
67+
68+
def register(self, ext):
69+
"""Register a new extensions"""
70+
for name, delegate in self._events.items():
71+
if hasattr(ext, name):
72+
delegate.add(getattr(ext, name))
73+
74+
def unregister(self, ext):
75+
"""Remove an extensions if it was already registered"""
76+
for name, delegate in self._events.items():
77+
if hasattr(ext, name):
78+
delegate.remove(getattr(ext, name))
79+
80+
81+
class OrionExtension:
82+
"""Base orion extension interface you need to implement"""
83+
84+
def error(self, *args, **kwargs):
85+
return
86+
87+
def start_experiment(self, *args, **kwargs):
88+
return
89+
90+
def new_trial(self, *args, **kwargs):
91+
return
92+
93+
def end_experiment(self, *args, **kwargs):
94+
return
95+

0 commit comments

Comments
 (0)