Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 124 additions & 78 deletions chimerapy/workerui/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,21 @@
from pathlib import Path
from typing import Any, Dict, Optional, Tuple

from fastapi import FastAPI
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.exceptions import HTTPException
from fastapi.responses import RedirectResponse
from fastapi.staticfiles import StaticFiles
from fastapi.websockets import WebSocket, WebSocketDisconnect, WebSocketState
from starlette.websockets import WebSocketState

from chimerapy.engine.worker import Worker
from chimerapy.workerui.models import WorkerConfig
from chimerapy.workerui.state_updater import StateUpdater
from chimerapy.workerui.utils import instantiate_worker
from chimerapy.workerui.worker_state_broadcaster import WorkerStateBroadcaster

STATIC_DIR = Path(__file__).parent / "build"


async def relay(
q: asyncio.Queue, ws: WebSocket, is_sentinel, signal: str = "update"
) -> None:
async def relay(q: asyncio.Queue, ws: WebSocket, is_sentinel) -> None:
"""Relay messages from the queue to the websocket."""
while True:
message = await q.get()
Expand All @@ -28,7 +28,7 @@ async def relay(
if is_sentinel(message): # Received Sentinel
break
try:
await ws.send_json({"signal": signal, "data": message})
await ws.send_json({"data": message})
except WebSocketDisconnect:
break

Expand All @@ -42,18 +42,25 @@ async def poll(ws: WebSocket) -> None:
break


STATIC_DIR = Path(__file__).parent / "build"
async def _connect_to_manager(worker_instance: Worker, config: WorkerConfig) -> bool:
"""Connect the worker to the manager."""
success = await worker_instance.async_connect(
port=config.port,
host=config.ip,
timeout=config.timeout,
method="zeroconf" if config.zeroconf else "ip",
)

return success


class ChimeraPyWorkerUI(FastAPI):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.worker_instance: Optional[Worker] = None
self.updates_broadcaster: Optional[
WorkerStateBroadcaster
] = WorkerStateBroadcaster()
self.state_updater = StateUpdater()
self._add_routes()
self._serve_static_files()
# self._serve_static_files()

def _serve_static_files(self):
if STATIC_DIR.exists():
Expand All @@ -65,90 +72,76 @@ def _serve_static_files(self):

def _add_routes(self):
self.add_api_route("/state", self._get_worker_state, methods=["GET"])
self.add_api_route("/connect", self._instantiate_worker, methods=["POST"])
self.add_api_route("/start", self._start_worker, methods=["POST"])
self.add_api_route("/connect", self._connect_worker, methods=["POST"])
self.add_api_route("/disconnect", self._disconnect_worker, methods=["POST"])
self.add_api_route("/shutdown", self._shutdown_worker, methods=["POST"])
self.add_websocket_route("/updates", self._handle_updates)
self.add_websocket_route("/updates", self._update_state)

async def _get_worker_state(self) -> Dict[str, Any]:
return (
state = (
self.worker_instance.state.to_dict(encode_json=False)
if self.worker_instance
else {}
)
if self.worker_instance:
state["connected_to_manager"] = self._is_worker_connected()
return state

async def _instantiate_worker(self, config: WorkerConfig) -> Dict[str, Any]:
if self.worker_instance is not None:
# Method Not Allowed
async def _connect_worker(self, config: WorkerConfig) -> Dict[str, Any]:
if self.worker_instance is None:
raise HTTPException(
status_code=405,
detail="Worker already instantiated. Please restart the server.",
)
try:
self.worker_instance = instantiate_worker(
name=config.name,
id=config.id or None,
wport=config.wport or 0,
delete_temp=config.delete_temp,
status_code=404,
detail="Worker not instantiated. Please instantiate the worker first.",
)
await self.worker_instance.aserve()

await self.worker_instance.async_connect(
port=config.port,
host=config.ip,
timeout=config.timeout,
method="zeroconf" if config.zeroconf else "ip",
)
print("Connected to manager.")
await self._initialize_updater()
except TimeoutError:
self.worker_instance = None
raise HTTPException( # noqa: B904
status_code=408, detail="Connection to manager timed out."
success = await _connect_to_manager(self.worker_instance, config)
if not success:
raise HTTPException(
status_code=500,
detail="Connection to manager failed. "
"Please retry when the manager is running.",
)
except Exception as e:
self.worker_instance = None
raise HTTPException(status_code=500, detail=str(e)) # noqa: B904

return self.worker_instance.state.to_dict(encode_json=False)
await self.state_updater.broadcast_state()

return await self._get_worker_state()

async def _disconnect_worker(self) -> Dict[str, Any]:
can, reason = await self._can_shutdown_or_disconnect()

async def _initialize_updater(self):
if not can:
raise HTTPException(status_code=409, detail=reason)

else:
assert self.worker_instance is not None
await self.worker_instance.async_deregister()
await self.state_updater.broadcast_state()
return await self._get_worker_state()

async def _start_worker(self, config: WorkerConfig) -> Dict[str, Any]:
if self.worker_instance is not None:
await self.updates_broadcaster.initialize(
state=self.worker_instance.state, eventbus=self.worker_instance.eventbus
# Method Not Allowed
raise HTTPException(
status_code=405,
detail="Worker already instantiated. Please restart the server.",
)

async def _handle_updates(self, ws: WebSocket) -> None:
await ws.accept()
self.worker_instance = instantiate_worker(
name=config.name,
id=config.id or None,
wport=config.wport or 0,
delete_temp=config.delete_temp,
)

if self.updates_broadcaster is None:
await ws.send_json(
{"signal": "error", "data": {"Worker not instantiated."}}
)
await self.worker_instance.aserve()
await self.state_updater.deinitialize()
await self.state_updater.initialize(self.worker_instance)
await self.state_updater.broadcast_state()

if self.updates_broadcaster is not None:
update_queue: asyncio.Queue = asyncio.Queue()
relay_task = asyncio.create_task(
relay(
q=update_queue,
ws=ws,
is_sentinel=lambda message: message is None,
signal="update",
)
)
poll_task = asyncio.create_task(poll(ws))

try:
done, pending = await asyncio.wait(
[relay_task, poll_task],
return_when=asyncio.FIRST_COMPLETED,
)
for task in pending:
task.cancel()
finally:
await self.updates_broadcaster.remove_client(update_queue)
await ws.close()

async def _can_shutdown_worker(self) -> Tuple[bool, str]:
return await self._get_worker_state()

async def _can_shutdown_or_disconnect(self) -> Tuple[bool, str]:
if self.worker_instance is None:
return False, "Worker not instantiated."
if len(self.worker_instance.state.nodes) > 0:
Expand All @@ -157,15 +150,68 @@ async def _can_shutdown_worker(self) -> Tuple[bool, str]:
return True, "Worker can be shutdown."

async def _shutdown_worker(self) -> Dict[str, Any]:
can, reason = await self._can_shutdown_worker()
can, reason = await self._can_shutdown_or_disconnect()
if not can:
raise HTTPException(status_code=409, detail=reason)
else:
assert self.worker_instance is not None
await self.state_updater.deinitialize()
await self.worker_instance.async_shutdown()
await self.state_updater.broadcast_state()
self.worker_instance = None
return {}

async def _connect_to_manager(self, config: WorkerConfig) -> Dict[str, Any]:
if self.worker_instance is None:
raise HTTPException(status_code=404, detail="Worker not instantiated.")

if self._is_worker_connected():
raise HTTPException(
status_code=409, detail="Worker already connected to manager."
)

try:
await self.worker_instance.async_connect(
port=config.port,
host=config.ip,
timeout=config.timeout,
method="zeroconf" if config.zeroconf else "ip",
)
return await self._get_worker_state()
except TimeoutError as e:
raise HTTPException(
status_code=408, detail="Connection to manager timed out."
) from e
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e

async def _update_state(self, websocket: WebSocket):
await websocket.accept()
update_queue: asyncio.Queue = asyncio.Queue()

relay_task = asyncio.create_task(
relay(update_queue, websocket, is_sentinel=lambda m: m is None)
)

poll_task = asyncio.create_task(poll(websocket))
await self.state_updater.add_client(update_queue)
try:
done, pending = await asyncio.wait(
[relay_task, poll_task],
return_when=asyncio.FIRST_COMPLETED,
)
for task in pending:
task.cancel()
finally:
await self.state_updater.remove_client(update_queue)
# await websocket.close()

def _is_worker_connected(self) -> bool:
if self.worker_instance is not None:
return self.worker_instance.http_client.connected_to_manager

return False


def create_worker_ui_app():
return ChimeraPyWorkerUI()
68 changes: 68 additions & 0 deletions chimerapy/workerui/state_updater.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from asyncio import Queue
from typing import Any, Dict, Optional

from chimerapy.engine.eventbus import EventBus, TypedObserver
from chimerapy.engine.worker import Worker


class StateUpdater:
def __init__(self):
self.clients = set()
self.observers: Dict[str, TypedObserver] = {}
self.eventbus: Optional[EventBus] = None
self.worker: Optional[Worker] = None

async def deinitialize(self):
if self.eventbus is not None:
for ob in self.observers.values():
await self.eventbus.aunsubscribe(ob)
self.observers = {}
self.eventbus = None
self.worker = None

async def initialize(self, worker: Worker):

self.eventbus = None
self.worker = None

self.eventbus = worker.eventbus
self.observers = {
"WorkerState.changed": TypedObserver(
"WorkerState.changed",
on_asend=self.on_worker_state_changed,
handle_event="drop",
)
}

for ob in self.observers.values():
await self.eventbus.asubscribe(ob)

self.worker = worker

async def on_worker_state_changed(self):
state = self.get_state()
for client in self.clients:
await client.put(state)

async def broadcast_state(self):
for client in self.clients:
await client.put(self.get_state())

async def add_client(self, client: Queue):
self.clients.add(client)
await client.put(self.get_state())

async def remove_client(self, client: Queue):
self.clients.discard(client)

async def enqueue_sentinel(self):
for client in self.clients:
await client.put(None)

def get_state(self) -> Dict[str, Any]:
if self.worker is None:
return {}

state = self.worker.state.to_dict(encode_json=False)
state["connected_to_manager"] = self.worker.http_client.connected_to_manager
return state
Loading