diff --git a/src/litserve/callbacks/base.py b/src/litserve/callbacks/base.py index 7a1dc7d0..9144bb7e 100644 --- a/src/litserve/callbacks/base.py +++ b/src/litserve/callbacks/base.py @@ -1,13 +1,12 @@ -import dataclasses import logging from abc import ABC +from enum import Enum from typing import List, Union logger = logging.getLogger(__name__) -@dataclasses.dataclass -class EventTypes: +class EventTypes(Enum): BEFORE_SETUP = "on_before_setup" AFTER_SETUP = "on_after_setup" BEFORE_DECODE_REQUEST = "on_before_decode_request" diff --git a/src/litserve/callbacks/defaults/metric_callback.py b/src/litserve/callbacks/defaults/metric_callback.py index f0065849..8782b615 100644 --- a/src/litserve/callbacks/defaults/metric_callback.py +++ b/src/litserve/callbacks/defaults/metric_callback.py @@ -1,23 +1,18 @@ import time import typing -from logging import getLogger from litserve.callbacks.base import Callback if typing.TYPE_CHECKING: from litserve import LitAPI -logger = getLogger(__name__) - class PredictionTimeLogger(Callback): def on_before_predict(self, lit_api: "LitAPI"): - t0 = time.perf_counter() - self._start_time = t0 + self._start_time = time.perf_counter() def on_after_predict(self, lit_api: "LitAPI"): - t1 = time.perf_counter() - elapsed = t1 - self._start_time + elapsed = time.perf_counter() - self._start_time print(f"Prediction took {elapsed:.2f} seconds", flush=True) diff --git a/src/litserve/loops/loops.py b/src/litserve/loops/loops.py index 3959c471..5fbd3978 100644 --- a/src/litserve/loops/loops.py +++ b/src/litserve/loops/loops.py @@ -71,7 +71,7 @@ def inference_worker( callback_runner: CallbackRunner, loop: Union[str, _BaseLoop], ): - callback_runner.trigger_event(EventTypes.BEFORE_SETUP, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.BEFORE_SETUP.value, lit_api=lit_api) try: lit_api.setup(device) except Exception: @@ -79,7 +79,7 @@ def inference_worker( workers_setup_status[worker_id] = WorkerSetupStatus.ERROR return lit_api.device = device - callback_runner.trigger_event(EventTypes.AFTER_SETUP, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.AFTER_SETUP.value, lit_api=lit_api) print(f"Setup complete for worker {worker_id}.") diff --git a/src/litserve/loops/simple_loops.py b/src/litserve/loops/simple_loops.py index e254835e..06323a51 100644 --- a/src/litserve/loops/simple_loops.py +++ b/src/litserve/loops/simple_loops.py @@ -68,29 +68,29 @@ def run_single_loop( if hasattr(lit_spec, "populate_context"): lit_spec.populate_context(context, x_enc) - callback_runner.trigger_event(EventTypes.BEFORE_DECODE_REQUEST, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.BEFORE_DECODE_REQUEST.value, lit_api=lit_api) x = _inject_context( context, lit_api.decode_request, x_enc, ) - callback_runner.trigger_event(EventTypes.AFTER_DECODE_REQUEST, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.AFTER_DECODE_REQUEST.value, lit_api=lit_api) - callback_runner.trigger_event(EventTypes.BEFORE_PREDICT, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.BEFORE_PREDICT.value, lit_api=lit_api) y = _inject_context( context, lit_api.predict, x, ) - callback_runner.trigger_event(EventTypes.AFTER_PREDICT, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.AFTER_PREDICT.value, lit_api=lit_api) - callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE.value, lit_api=lit_api) y_enc = _inject_context( context, lit_api.encode_response, y, ) - callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE.value, lit_api=lit_api) self.put_response( transport=transport, response_queue_id=response_queue_id, @@ -135,29 +135,29 @@ async def _process_single_request( if hasattr(lit_spec, "populate_context"): lit_spec.populate_context(context, x_enc) - callback_runner.trigger_event(EventTypes.BEFORE_DECODE_REQUEST, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.BEFORE_DECODE_REQUEST.value, lit_api=lit_api) x = await _async_inject_context( context, lit_api.decode_request, x_enc, ) - callback_runner.trigger_event(EventTypes.AFTER_DECODE_REQUEST, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.AFTER_DECODE_REQUEST.value, lit_api=lit_api) - callback_runner.trigger_event(EventTypes.BEFORE_PREDICT, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.BEFORE_PREDICT.value, lit_api=lit_api) y = await _async_inject_context( context, lit_api.predict, x, ) - callback_runner.trigger_event(EventTypes.AFTER_PREDICT, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.AFTER_PREDICT.value, lit_api=lit_api) - callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE.value, lit_api=lit_api) y_enc = await _async_inject_context( context, lit_api.encode_response, y, ) - callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE.value, lit_api=lit_api) self.put_response( transport=transport, response_queue_id=response_queue_id, @@ -306,7 +306,7 @@ def run_batched_loop( for input, context in zip(inputs, contexts): lit_spec.populate_context(context, input) - callback_runner.trigger_event(EventTypes.BEFORE_DECODE_REQUEST, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.BEFORE_DECODE_REQUEST.value, lit_api=lit_api) x = [ _inject_context( context, @@ -315,13 +315,13 @@ def run_batched_loop( ) for input, context in zip(inputs, contexts) ] - callback_runner.trigger_event(EventTypes.AFTER_DECODE_REQUEST, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.AFTER_DECODE_REQUEST.value, lit_api=lit_api) x = lit_api.batch(x) - callback_runner.trigger_event(EventTypes.BEFORE_PREDICT, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.BEFORE_PREDICT.value, lit_api=lit_api) y = _inject_context(contexts, lit_api.predict, x) - callback_runner.trigger_event(EventTypes.AFTER_PREDICT, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.AFTER_PREDICT.value, lit_api=lit_api) outputs = lit_api.unbatch(y) @@ -332,12 +332,12 @@ def run_batched_loop( ) raise HTTPException(500, "Batch size mismatch") - callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE.value, lit_api=lit_api) y_enc_list = [] for response_queue_id, y, uid, context in zip(response_queue_ids, outputs, uids, contexts): y_enc = _inject_context(context, lit_api.encode_response, y) y_enc_list.append((response_queue_id, uid, y_enc)) - callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE.value, lit_api=lit_api) for response_queue_id, uid, y_enc in y_enc_list: self.put_response(transport, response_queue_id, uid, y_enc, LitAPIStatus.OK) diff --git a/src/litserve/loops/streaming_loops.py b/src/litserve/loops/streaming_loops.py index 64032571..de3e295d 100644 --- a/src/litserve/loops/streaming_loops.py +++ b/src/litserve/loops/streaming_loops.py @@ -68,15 +68,15 @@ def run_streaming_loop( x_enc, ) - callback_runner.trigger_event(EventTypes.BEFORE_PREDICT, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.BEFORE_PREDICT.value, lit_api=lit_api) y_gen = _inject_context( context, lit_api.predict, x, ) - callback_runner.trigger_event(EventTypes.AFTER_PREDICT, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.AFTER_PREDICT.value, lit_api=lit_api) - callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE.value, lit_api=lit_api) y_enc_gen = _inject_context( context, lit_api.encode_response, @@ -87,8 +87,8 @@ def run_streaming_loop( self.put_response(transport, response_queue_id, uid, y_enc, LitAPIStatus.OK) self.put_response(transport, response_queue_id, uid, "", LitAPIStatus.FINISH_STREAMING) - callback_runner.trigger_event(EventTypes.AFTER_PREDICT, lit_api=lit_api) - callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.AFTER_PREDICT.value, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE.value, lit_api=lit_api) except HTTPException as e: self.put_response( @@ -123,23 +123,23 @@ async def _process_streaming_request( if hasattr(lit_spec, "populate_context"): lit_spec.populate_context(context, x_enc) - callback_runner.trigger_event(EventTypes.BEFORE_DECODE_REQUEST, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.BEFORE_DECODE_REQUEST.value, lit_api=lit_api) x = await _async_inject_context( context, lit_api.decode_request, x_enc, ) - callback_runner.trigger_event(EventTypes.AFTER_DECODE_REQUEST, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.AFTER_DECODE_REQUEST.value, lit_api=lit_api) - callback_runner.trigger_event(EventTypes.BEFORE_PREDICT, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.BEFORE_PREDICT.value, lit_api=lit_api) y_gen = await _async_inject_context( context, lit_api.predict, x, ) - callback_runner.trigger_event(EventTypes.AFTER_PREDICT, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.AFTER_PREDICT.value, lit_api=lit_api) - callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE.value, lit_api=lit_api) # When using async, predict should return an async generator # and encode_response should handle async generators @@ -158,7 +158,7 @@ async def _process_streaming_request( self.put_response(transport, response_queue_id, uid, y_enc, LitAPIStatus.OK) self.put_response(transport, response_queue_id, uid, "", LitAPIStatus.FINISH_STREAMING) - callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE.value, lit_api=lit_api) except HTTPException as e: self.put_response( @@ -282,7 +282,7 @@ def run_batched_streaming_loop( for input, context in zip(inputs, contexts): lit_spec.populate_context(context, input) - callback_runner.trigger_event(EventTypes.BEFORE_DECODE_REQUEST, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.BEFORE_DECODE_REQUEST.value, lit_api=lit_api) x = [ _inject_context( context, @@ -291,19 +291,19 @@ def run_batched_streaming_loop( ) for input, context in zip(inputs, contexts) ] - callback_runner.trigger_event(EventTypes.AFTER_DECODE_REQUEST, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.AFTER_DECODE_REQUEST.value, lit_api=lit_api) x = lit_api.batch(x) - callback_runner.trigger_event(EventTypes.BEFORE_PREDICT, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.BEFORE_PREDICT.value, lit_api=lit_api) y_iter = _inject_context(contexts, lit_api.predict, x) - callback_runner.trigger_event(EventTypes.AFTER_PREDICT, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.AFTER_PREDICT.value, lit_api=lit_api) unbatched_iter = lit_api.unbatch(y_iter) - callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE.value, lit_api=lit_api) y_enc_iter = _inject_context(contexts, lit_api.encode_response, unbatched_iter) - callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE, lit_api=lit_api) + callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE.value, lit_api=lit_api) # y_enc_iter -> [[response-1, response-2], [response-1, response-2]] for y_batch in y_enc_iter: diff --git a/src/litserve/server.py b/src/litserve/server.py index 168268ae..29017687 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -369,7 +369,7 @@ async def lifespan(self, app: FastAPI): try: yield finally: - self._callback_runner.trigger_event(EventTypes.ON_SERVER_END, litserver=self) + self._callback_runner.trigger_event(EventTypes.ON_SERVER_END.value, litserver=self) # Cancel the task task.cancel() @@ -412,7 +412,7 @@ def active_requests(self): def register_endpoints(self): """Register endpoint routes for the FastAPI app and setup middlewares.""" - self._callback_runner.trigger_event(EventTypes.ON_SERVER_START, litserver=self) + self._callback_runner.trigger_event(EventTypes.ON_SERVER_START.value, litserver=self) workers_ready = False @self.app.get("/", dependencies=[Depends(self.setup_auth())]) @@ -449,7 +449,7 @@ async def info(request: Request) -> Response: async def predict(request: self.request_type) -> self.response_type: self._callback_runner.trigger_event( - EventTypes.ON_REQUEST, + EventTypes.ON_REQUEST.value, active_requests=self.active_requests, litserver=self, ) @@ -478,12 +478,12 @@ async def predict(request: self.request_type) -> self.response_type: if status == LitAPIStatus.ERROR: logger.error("Error in request: %s", response) raise HTTPException(status_code=500) - self._callback_runner.trigger_event(EventTypes.ON_RESPONSE, litserver=self) + self._callback_runner.trigger_event(EventTypes.ON_RESPONSE.value, litserver=self) return response async def stream_predict(request: self.request_type) -> self.response_type: self._callback_runner.trigger_event( - EventTypes.ON_REQUEST, + EventTypes.ON_REQUEST.value, active_requests=self.active_requests, litserver=self, ) @@ -502,7 +502,7 @@ async def stream_predict(request: self.request_type) -> self.response_type: response = call_after_stream( self.data_streamer(q, data_available=event), self._callback_runner.trigger_event, - EventTypes.ON_RESPONSE, + EventTypes.ON_RESPONSE.value, litserver=self, ) return StreamingResponse(response) diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index a374ef4c..27740c88 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -58,8 +58,8 @@ def test_metric_logger(capfd): cb_runner = CallbackRunner() cb_runner._add_callbacks(cb) assert cb_runner._callbacks == [cb], "Callback not added to runner" - cb_runner.trigger_event(EventTypes.BEFORE_PREDICT, lit_api=None) - cb_runner.trigger_event(EventTypes.AFTER_PREDICT, lit_api=None) + cb_runner.trigger_event(EventTypes.BEFORE_PREDICT.value, lit_api=None) + cb_runner.trigger_event(EventTypes.AFTER_PREDICT.value, lit_api=None) captured = capfd.readouterr() pattern = r"Prediction took \d+\.\d{2} seconds"