diff --git a/chimerapy/workerui/server.py b/chimerapy/workerui/server.py index a817417..a041b04 100644 --- a/chimerapy/workerui/server.py +++ b/chimerapy/workerui/server.py @@ -3,21 +3,21 @@ from pathlib import Path from typing import Any, Dict, Optional, Tuple -from fastapi import FastAPI +from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.exceptions import HTTPException from fastapi.responses import RedirectResponse from fastapi.staticfiles import StaticFiles -from fastapi.websockets import WebSocket, WebSocketDisconnect, WebSocketState +from starlette.websockets import WebSocketState from chimerapy.engine.worker import Worker from chimerapy.workerui.models import WorkerConfig +from chimerapy.workerui.state_updater import StateUpdater from chimerapy.workerui.utils import instantiate_worker -from chimerapy.workerui.worker_state_broadcaster import WorkerStateBroadcaster + +STATIC_DIR = Path(__file__).parent / "build" -async def relay( - q: asyncio.Queue, ws: WebSocket, is_sentinel, signal: str = "update" -) -> None: +async def relay(q: asyncio.Queue, ws: WebSocket, is_sentinel) -> None: """Relay messages from the queue to the websocket.""" while True: message = await q.get() @@ -28,7 +28,7 @@ async def relay( if is_sentinel(message): # Received Sentinel break try: - await ws.send_json({"signal": signal, "data": message}) + await ws.send_json({"data": message}) except WebSocketDisconnect: break @@ -42,18 +42,25 @@ async def poll(ws: WebSocket) -> None: break -STATIC_DIR = Path(__file__).parent / "build" +async def _connect_to_manager(worker_instance: Worker, config: WorkerConfig) -> bool: + """Connect the worker to the manager.""" + success = await worker_instance.async_connect( + port=config.port, + host=config.ip, + timeout=config.timeout, + method="zeroconf" if config.zeroconf else "ip", + ) + + return success class ChimeraPyWorkerUI(FastAPI): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.worker_instance: Optional[Worker] = None - self.updates_broadcaster: Optional[ - WorkerStateBroadcaster - ] = WorkerStateBroadcaster() + self.state_updater = StateUpdater() self._add_routes() - self._serve_static_files() + # self._serve_static_files() def _serve_static_files(self): if STATIC_DIR.exists(): @@ -65,90 +72,76 @@ def _serve_static_files(self): def _add_routes(self): self.add_api_route("/state", self._get_worker_state, methods=["GET"]) - self.add_api_route("/connect", self._instantiate_worker, methods=["POST"]) + self.add_api_route("/start", self._start_worker, methods=["POST"]) + self.add_api_route("/connect", self._connect_worker, methods=["POST"]) + self.add_api_route("/disconnect", self._disconnect_worker, methods=["POST"]) self.add_api_route("/shutdown", self._shutdown_worker, methods=["POST"]) - self.add_websocket_route("/updates", self._handle_updates) + self.add_websocket_route("/updates", self._update_state) async def _get_worker_state(self) -> Dict[str, Any]: - return ( + state = ( self.worker_instance.state.to_dict(encode_json=False) if self.worker_instance else {} ) + if self.worker_instance: + state["connected_to_manager"] = self._is_worker_connected() + return state - async def _instantiate_worker(self, config: WorkerConfig) -> Dict[str, Any]: - if self.worker_instance is not None: - # Method Not Allowed + async def _connect_worker(self, config: WorkerConfig) -> Dict[str, Any]: + if self.worker_instance is None: raise HTTPException( - status_code=405, - detail="Worker already instantiated. Please restart the server.", - ) - try: - self.worker_instance = instantiate_worker( - name=config.name, - id=config.id or None, - wport=config.wport or 0, - delete_temp=config.delete_temp, + status_code=404, + detail="Worker not instantiated. Please instantiate the worker first.", ) - await self.worker_instance.aserve() - await self.worker_instance.async_connect( - port=config.port, - host=config.ip, - timeout=config.timeout, - method="zeroconf" if config.zeroconf else "ip", - ) - print("Connected to manager.") - await self._initialize_updater() - except TimeoutError: - self.worker_instance = None - raise HTTPException( # noqa: B904 - status_code=408, detail="Connection to manager timed out." + success = await _connect_to_manager(self.worker_instance, config) + if not success: + raise HTTPException( + status_code=500, + detail="Connection to manager failed. " + "Please retry when the manager is running.", ) - except Exception as e: - self.worker_instance = None - raise HTTPException(status_code=500, detail=str(e)) # noqa: B904 - return self.worker_instance.state.to_dict(encode_json=False) + await self.state_updater.broadcast_state() + + return await self._get_worker_state() + + async def _disconnect_worker(self) -> Dict[str, Any]: + can, reason = await self._can_shutdown_or_disconnect() - async def _initialize_updater(self): + if not can: + raise HTTPException(status_code=409, detail=reason) + + else: + assert self.worker_instance is not None + await self.worker_instance.async_deregister() + await self.state_updater.broadcast_state() + return await self._get_worker_state() + + async def _start_worker(self, config: WorkerConfig) -> Dict[str, Any]: if self.worker_instance is not None: - await self.updates_broadcaster.initialize( - state=self.worker_instance.state, eventbus=self.worker_instance.eventbus + # Method Not Allowed + raise HTTPException( + status_code=405, + detail="Worker already instantiated. Please restart the server.", ) - async def _handle_updates(self, ws: WebSocket) -> None: - await ws.accept() + self.worker_instance = instantiate_worker( + name=config.name, + id=config.id or None, + wport=config.wport or 0, + delete_temp=config.delete_temp, + ) - if self.updates_broadcaster is None: - await ws.send_json( - {"signal": "error", "data": {"Worker not instantiated."}} - ) + await self.worker_instance.aserve() + await self.state_updater.deinitialize() + await self.state_updater.initialize(self.worker_instance) + await self.state_updater.broadcast_state() - if self.updates_broadcaster is not None: - update_queue: asyncio.Queue = asyncio.Queue() - relay_task = asyncio.create_task( - relay( - q=update_queue, - ws=ws, - is_sentinel=lambda message: message is None, - signal="update", - ) - ) - poll_task = asyncio.create_task(poll(ws)) - - try: - done, pending = await asyncio.wait( - [relay_task, poll_task], - return_when=asyncio.FIRST_COMPLETED, - ) - for task in pending: - task.cancel() - finally: - await self.updates_broadcaster.remove_client(update_queue) - await ws.close() - - async def _can_shutdown_worker(self) -> Tuple[bool, str]: + return await self._get_worker_state() + + async def _can_shutdown_or_disconnect(self) -> Tuple[bool, str]: if self.worker_instance is None: return False, "Worker not instantiated." if len(self.worker_instance.state.nodes) > 0: @@ -157,15 +150,68 @@ async def _can_shutdown_worker(self) -> Tuple[bool, str]: return True, "Worker can be shutdown." async def _shutdown_worker(self) -> Dict[str, Any]: - can, reason = await self._can_shutdown_worker() + can, reason = await self._can_shutdown_or_disconnect() if not can: raise HTTPException(status_code=409, detail=reason) else: assert self.worker_instance is not None + await self.state_updater.deinitialize() await self.worker_instance.async_shutdown() + await self.state_updater.broadcast_state() self.worker_instance = None return {} + async def _connect_to_manager(self, config: WorkerConfig) -> Dict[str, Any]: + if self.worker_instance is None: + raise HTTPException(status_code=404, detail="Worker not instantiated.") + + if self._is_worker_connected(): + raise HTTPException( + status_code=409, detail="Worker already connected to manager." + ) + + try: + await self.worker_instance.async_connect( + port=config.port, + host=config.ip, + timeout=config.timeout, + method="zeroconf" if config.zeroconf else "ip", + ) + return await self._get_worker_state() + except TimeoutError as e: + raise HTTPException( + status_code=408, detail="Connection to manager timed out." + ) from e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + async def _update_state(self, websocket: WebSocket): + await websocket.accept() + update_queue: asyncio.Queue = asyncio.Queue() + + relay_task = asyncio.create_task( + relay(update_queue, websocket, is_sentinel=lambda m: m is None) + ) + + poll_task = asyncio.create_task(poll(websocket)) + await self.state_updater.add_client(update_queue) + try: + done, pending = await asyncio.wait( + [relay_task, poll_task], + return_when=asyncio.FIRST_COMPLETED, + ) + for task in pending: + task.cancel() + finally: + await self.state_updater.remove_client(update_queue) + # await websocket.close() + + def _is_worker_connected(self) -> bool: + if self.worker_instance is not None: + return self.worker_instance.http_client.connected_to_manager + + return False + def create_worker_ui_app(): return ChimeraPyWorkerUI() diff --git a/chimerapy/workerui/state_updater.py b/chimerapy/workerui/state_updater.py new file mode 100644 index 0000000..9de644a --- /dev/null +++ b/chimerapy/workerui/state_updater.py @@ -0,0 +1,68 @@ +from asyncio import Queue +from typing import Any, Dict, Optional + +from chimerapy.engine.eventbus import EventBus, TypedObserver +from chimerapy.engine.worker import Worker + + +class StateUpdater: + def __init__(self): + self.clients = set() + self.observers: Dict[str, TypedObserver] = {} + self.eventbus: Optional[EventBus] = None + self.worker: Optional[Worker] = None + + async def deinitialize(self): + if self.eventbus is not None: + for ob in self.observers.values(): + await self.eventbus.aunsubscribe(ob) + self.observers = {} + self.eventbus = None + self.worker = None + + async def initialize(self, worker: Worker): + + self.eventbus = None + self.worker = None + + self.eventbus = worker.eventbus + self.observers = { + "WorkerState.changed": TypedObserver( + "WorkerState.changed", + on_asend=self.on_worker_state_changed, + handle_event="drop", + ) + } + + for ob in self.observers.values(): + await self.eventbus.asubscribe(ob) + + self.worker = worker + + async def on_worker_state_changed(self): + state = self.get_state() + for client in self.clients: + await client.put(state) + + async def broadcast_state(self): + for client in self.clients: + await client.put(self.get_state()) + + async def add_client(self, client: Queue): + self.clients.add(client) + await client.put(self.get_state()) + + async def remove_client(self, client: Queue): + self.clients.discard(client) + + async def enqueue_sentinel(self): + for client in self.clients: + await client.put(None) + + def get_state(self) -> Dict[str, Any]: + if self.worker is None: + return {} + + state = self.worker.state.to_dict(encode_json=False) + state["connected_to_manager"] = self.worker.http_client.connected_to_manager + return state diff --git a/chimerapy/workerui/tests/test_server.py b/chimerapy/workerui/tests/test_server.py new file mode 100644 index 0000000..9be73dd --- /dev/null +++ b/chimerapy/workerui/tests/test_server.py @@ -0,0 +1,56 @@ +import asyncio +import tempfile + +import pytest +from fastapi.testclient import TestClient + +from chimerapy.engine.manager import Manager +from chimerapy.workerui.server import ChimeraPyWorkerUI +from chimerapy.workerui.tests.base_test import BaseTest + + +class TestServer(BaseTest): + @pytest.fixture(scope="class") + def event_loop(self): + return asyncio.get_event_loop() + + @pytest.fixture(scope="class") + def anyio_backend(self): + return "asyncio" + + @pytest.fixture(scope="class") + async def test_client_and_app(self, anyio_backend): + app = ChimeraPyWorkerUI() + client = TestClient(app) + yield client, app + + @pytest.mark.anyio + async def test_routes(self, test_client_and_app): + test_client, app = test_client_and_app + app_json = {"Content-Type": "application/json"} + req_json = { + "name": "worker1", + "id": "worker1", + "wport": 56403, + "zeroconf": True, + "timeout": 20, + "ip": "", + "port": 0, + "delete_temp": True, + } + response = test_client.post("/start", headers=app_json, json=req_json) + assert response.status_code == 200 + assert response.json() == await app._get_worker_state() + response = test_client.post("/start", headers=app_json, json=req_json) + assert response.status_code == 405 + + # Create Manager + m = Manager(logdir=tempfile.mkdtemp(), port=0) + await m.aserve() + await m.async_zeroconf(enable=True) + + # Connect Worker via Zeroconf + response = test_client.post("/connect", headers=app_json, json=req_json) + assert response.status_code == 200 + assert response.json() == await app._get_worker_state() + assert response.json()["connected_to_manager"] diff --git a/chimerapy/workerui/worker_state_broadcaster.py b/chimerapy/workerui/worker_state_broadcaster.py deleted file mode 100644 index 22fb469..0000000 --- a/chimerapy/workerui/worker_state_broadcaster.py +++ /dev/null @@ -1,48 +0,0 @@ -import asyncio -from typing import Dict, Optional - -from chimerapy.engine.eventbus import EventBus, TypedObserver -from chimerapy.engine.states import WorkerState - - -class WorkerStateBroadcaster: - def __init__(self): - self.clients = set() - self.observers: Dict[str, TypedObserver] = {} - self.eventbus: Optional[EventBus] = None - self.state: Optional[WorkerState] = None - - async def initialize(self, eventbus: EventBus, state: WorkerState): - if self.eventbus is not None: - for ob in self.observers.values(): - await self.eventbus.aunsubscribe(ob) - - self.state = None - - self.observers = { - "WorkerState.changed": TypedObserver( - "WorkerState.changed", - on_asend=self.on_state_changed, - handle_event="drop", - ) - } - - self.eventbus = eventbus - self.state = state - - for ob in self.observers.values(): - await self.eventbus.asubscribe(ob) - - async def on_state_changed(self): - for client in self.clients: - await client.put(self.state.to_dict(encode_json=False)) - - async def add_client(self, q: asyncio.Queue): - self.clients.add(q) - - async def remove_client(self, q: asyncio.Queue): - self.clients.discard(q) - - async def enqueue_sentinel(self): - for client in self.clients: - await client.put(None) diff --git a/web/src/lib/Client.ts b/web/src/lib/Client.ts index 8866c4e..e7c7ad9 100644 --- a/web/src/lib/Client.ts +++ b/web/src/lib/Client.ts @@ -17,6 +17,18 @@ export class WorkerClient { return response; } + async start(config: WorkerConfig): Promise> { + const response = await this._fetch('/start', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify(config) + }); + + return response; + } + async connect(config: WorkerConfig): Promise> { const response = await this._fetch('/connect', { method: 'POST', @@ -35,6 +47,13 @@ export class WorkerClient { return response; } + async disconnect(): Promise> { + const response = await this._fetch('/disconnect', { + method: 'POST' + }); + return response; + } + async _fetch(prefix: string, options: RequestInit): Promise> { const res = await fetch(this.url + prefix, options); if (res.ok) { diff --git a/web/src/lib/ReadableWebSocketStore.ts b/web/src/lib/ReadableWebSocketStore.ts new file mode 100644 index 0000000..05e140e --- /dev/null +++ b/web/src/lib/ReadableWebSocketStore.ts @@ -0,0 +1,97 @@ +import { writable } from 'svelte/store'; +import type { Readable, Subscriber } from 'svelte/store'; + +const reopenTimeouts = [2000, 5000, 10000, 30000, 60000]; + +export default function readableWebSocketStore( + path: string, + initialValue: T | null, + mapper: (data: any) => T, + origin: string | null = null +): Readable { + const { subscribe, update } = writable(initialValue); + const subscribers = new Set>(); + let reopenCount = 0; + let openPromise: Promise | null = null; + let reopenTimeoutHandler: ReturnType | null = null; + let socket: WebSocket | null = null; + + function reopenTimeout() { + const n = reopenCount; + reopenCount++; + return reopenTimeouts[n >= reopenTimeouts.length - 1 ? reopenTimeouts.length - 1 : n]; + } + + const open = () => { + if (reopenTimeoutHandler) { + clearTimeout(reopenTimeoutHandler); + reopenTimeoutHandler = null; + } + + if (openPromise) { + return openPromise; + } + + socket = new WebSocket(`ws://${origin || window.location.host}${path}`); + socket.onmessage = (event) => { + const data = JSON.parse(event.data); + update((value) => { + const newValue = mapper(data); + if (newValue !== value) { + return newValue; + } + return value; + }); + }; + + socket.onclose = (event) => { + reopen(); + }; + + openPromise = new Promise((resolve, reject) => { + if (socket) { + socket.onerror = (event) => { + update(() => null); + reject(event); + openPromise = null; + }; + + socket.onopen = (event) => { + reopenCount = 0; + resolve(); + openPromise = null; + }; + } + }); + + return openPromise; + }; + + const reopen = () => { + close(); + if (subscribers.size > 0) { + reopenTimeoutHandler = setTimeout(() => open(), reopenTimeout()); + } + }; + + const close = () => { + if (socket) { + socket.close(); + socket = null; + } + }; + + return { + subscribe: (subscriber) => { + open(); + subscribers.add(subscriber); + subscribe(subscriber); + return () => { + subscribers.delete(subscriber); + if (subscribers.size === 0) { + close(); + } + }; + } + }; +} diff --git a/web/src/lib/Utils.ts b/web/src/lib/Utils.ts index 340c3a0..4d68733 100644 --- a/web/src/lib/Utils.ts +++ b/web/src/lib/Utils.ts @@ -4,9 +4,11 @@ const isValidIP = (ip: string): boolean => { const ipRegex = /^(\d{1,3}\.){3}\d{1,3}$/; return ipRegex.test(ip) || ip === 'localhost'; }; -export const isValidWorkerCreationConfig = (config: WorkerConfig): boolean => { - const isFalsy = (value: any): boolean => - value === '' || value === null || value === undefined || value === 0; + +const isFalsy = (value: any): boolean => + value === '' || value === null || value === undefined || value === 0; + +export const isValidWorkerConnectionConfig = (config: WorkerConfig): boolean => { if (isFalsy(config.name)) return false; if (!config.zeroconf) { @@ -15,3 +17,8 @@ export const isValidWorkerCreationConfig = (config: WorkerConfig): boolean => { return true; }; + +export const isValidWorkerCreationConfig = (config: WorkerConfig): boolean => { + if (isFalsy(config.name)) return false; + return true; +}; diff --git a/web/src/lib/models.ts b/web/src/lib/models.ts index 3777cfe..e26e9cb 100644 --- a/web/src/lib/models.ts +++ b/web/src/lib/models.ts @@ -55,4 +55,6 @@ export interface WorkerState { ip: string; port: number; tempfolder: string; + + connected_to_manager: boolean; } diff --git a/web/src/lib/stores.ts b/web/src/lib/stores.ts new file mode 100644 index 0000000..283be06 --- /dev/null +++ b/web/src/lib/stores.ts @@ -0,0 +1,17 @@ +import type { WorkerState } from './models'; +import readableWebSocketStore from './ReadableWebSocketStore'; +import { dev } from '$app/environment'; + +const stores = new Map(); +export function populateStores() { + const workerStateStore = readableWebSocketStore( + dev ? '/api/updates' : '/updates', + null, + (payload) => payload.data + ); + stores.set('worker', workerStateStore); +} + +export function getStore(name: string): T { + return stores.get(name); +} diff --git a/web/src/routes/+layout.ts b/web/src/routes/+layout.ts index 6d63253..5a6d815 100644 --- a/web/src/routes/+layout.ts +++ b/web/src/routes/+layout.ts @@ -1,2 +1,10 @@ // This can be false if you're using a fallback (i.e. SPA mode) +export const ssr = false; export const prerender = true; +import type { LayoutLoad } from './$types'; +import { populateStores } from '$lib/stores'; + +export const load: LayoutLoad = ({ fetch }) => { + populateStores(); + return {}; +}; diff --git a/web/src/routes/+page.svelte b/web/src/routes/+page.svelte index 8f08948..6ba8fe9 100644 --- a/web/src/routes/+page.svelte +++ b/web/src/routes/+page.svelte @@ -2,19 +2,16 @@ import { onMount } from 'svelte'; import { workerClient } from '$lib/services'; import type { WorkerConfig } from '../lib/models'; - import { isValidWorkerCreationConfig } from '$lib/Utils'; + import { isValidWorkerCreationConfig, isValidWorkerConnectionConfig } from '$lib/Utils'; import Alert from '$lib/Modal/Alert.svelte'; - import { - Table, - TableBody, - TableBodyCell, - TableBodyRow, - TableHead, - TableHeadCell - } from 'flowbite-svelte'; + import { getStore } from '../lib/stores'; + + import WorkerStarter from './Components/WorkerStarter.svelte'; + import WorkerDetails from './Components/WorkerDetails.svelte'; let errorDisplay; let canCreateWorker: boolean = false; + let canConnectWorker: boolean = false; let workerConfig: WorkerConfig = { name: '', wport: 0, @@ -25,71 +22,20 @@ ip: '', timeout: 20 }; - - let connecting: boolean = false; - let connected: boolean = false; - let workerState = {}; + let started: boolean = false; + + const workerStateStore = getStore('worker'); onMount(async () => { - (await workerClient.getWorkerState()) - .map((state) => { - workerState = state; - connected = state.name !== undefined; - }) - .mapError((e) => { - // Display error - errorDisplay.display({ - title: 'Error Getting Worker State', - content: e - }); - }); + // Get worker state + (await workerClient.getWorkerState()).map((state) => { + workerState = state; + }); }); - async function handleStartAndConnect() { - connecting = true; - (await workerClient.connect(workerConfig)) - .map((state) => { - workerState = state; - connecting = false; - connected = true; - }) - .mapError((e) => { - // Display error - errorDisplay.display({ - title: 'Error Starting and Connecting Worker', - content: e - }); - connecting = false; - }); - } - - async function handleShutdown() { - (await workerClient.shutdown()) - .map((state) => { - workerState = state; - connected = false; - }) - .mapError((e) => { - // Display error - errorDisplay.display({ - title: 'Error Shutting down Worker', - content: e - }); - }); - } - $: { - canCreateWorker = isValidWorkerCreationConfig(workerConfig); - } - - function getNodes(workerState) { - if (workerState.nodes === undefined) { - return 'Unknown'; - } - return Object.values(workerState.nodes) - .map((node) => `${node.name}(${node.fsm})`) - .join(', '); + started = $workerStateStore?.name !== undefined; } @@ -100,152 +46,16 @@
- {#if connected} -

Worker Connected

- - - Name - ID - URL - Nodes - - - - {workerState.name} - {workerState.id} - {`http://${workerState.ip}:${workerState.port}`} - {getNodes(workerState)} - - -
-
- -
+ {#if !started} + {:else} -

Connect Worker

- -
-
- - -
-
- - -
-
- - -
- {#if !workerConfig.zeroconf} -
- - -
-
- - -
- {/if} -
- - -
-
- - - - -
-
- -
-
+ {/if}

Worker State

-
{JSON.stringify(workerState, null, 2)}
+
{JSON.stringify($workerStateStore || workerState, null, 2)}
diff --git a/web/src/routes/Components/WorkerDetails.svelte b/web/src/routes/Components/WorkerDetails.svelte new file mode 100644 index 0000000..cad43b5 --- /dev/null +++ b/web/src/routes/Components/WorkerDetails.svelte @@ -0,0 +1,221 @@ + + +

Worker {connected ? 'Connected' : 'Started'}

+ + + Name + ID + URL + Nodes + + + + {$workerStateStore?.name} + {$workerStateStore?.id} + {`http://${$workerStateStore?.ip}:${$workerStateStore?.port}`} + {getNodes($workerStateStore || {})} + + +
+ +{#if !connected} + {#if !workerConfig.zeroconf} +
+ + +
+
+ + +
+ {/if} +
+ + +
+
+ + +
+{/if} + +
+ {#if !connected} + + {:else} + + {/if} + + +
+ + diff --git a/web/src/routes/Components/WorkerStarter.svelte b/web/src/routes/Components/WorkerStarter.svelte new file mode 100644 index 0000000..86cb72c --- /dev/null +++ b/web/src/routes/Components/WorkerStarter.svelte @@ -0,0 +1,83 @@ + + +

Start Worker

+
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+ +
+ + diff --git a/web/vite.config.ts b/web/vite.config.ts index cdf5f47..206bdc3 100644 --- a/web/vite.config.ts +++ b/web/vite.config.ts @@ -12,6 +12,7 @@ export default defineConfig({ target: 'http://localhost:8000', changeOrigin: true, secure: false, + ws: true, rewrite: (path) => path.replace(/^\/api/, '') } }