Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/litserve/callbacks/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import dataclasses
import logging
from abc import ABC
from enum import Enum
from typing import List, Union

logger = logging.getLogger(__name__)


@dataclasses.dataclass
class EventTypes:
class EventTypes(Enum):
BEFORE_SETUP = "on_before_setup"
AFTER_SETUP = "on_after_setup"
BEFORE_DECODE_REQUEST = "on_before_decode_request"
Expand Down
9 changes: 2 additions & 7 deletions src/litserve/callbacks/defaults/metric_callback.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,18 @@
import time
import typing
from logging import getLogger

from litserve.callbacks.base import Callback

if typing.TYPE_CHECKING:
from litserve import LitAPI

logger = getLogger(__name__)


class PredictionTimeLogger(Callback):
def on_before_predict(self, lit_api: "LitAPI"):
t0 = time.perf_counter()
self._start_time = t0
self._start_time = time.perf_counter()

def on_after_predict(self, lit_api: "LitAPI"):
t1 = time.perf_counter()
elapsed = t1 - self._start_time
elapsed = time.perf_counter() - self._start_time
print(f"Prediction took {elapsed:.2f} seconds", flush=True)


Expand Down
4 changes: 2 additions & 2 deletions src/litserve/loops/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,15 @@ def inference_worker(
callback_runner: CallbackRunner,
loop: Union[str, _BaseLoop],
):
callback_runner.trigger_event(EventTypes.BEFORE_SETUP, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.BEFORE_SETUP.value, lit_api=lit_api)
try:
lit_api.setup(device)
except Exception:
logger.exception(f"Error setting up worker {worker_id}.")
workers_setup_status[worker_id] = WorkerSetupStatus.ERROR
return
lit_api.device = device
callback_runner.trigger_event(EventTypes.AFTER_SETUP, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.AFTER_SETUP.value, lit_api=lit_api)

print(f"Setup complete for worker {worker_id}.")

Expand Down
36 changes: 18 additions & 18 deletions src/litserve/loops/simple_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,29 +68,29 @@ def run_single_loop(
if hasattr(lit_spec, "populate_context"):
lit_spec.populate_context(context, x_enc)

callback_runner.trigger_event(EventTypes.BEFORE_DECODE_REQUEST, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.BEFORE_DECODE_REQUEST.value, lit_api=lit_api)
x = _inject_context(
context,
lit_api.decode_request,
x_enc,
)
callback_runner.trigger_event(EventTypes.AFTER_DECODE_REQUEST, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.AFTER_DECODE_REQUEST.value, lit_api=lit_api)

callback_runner.trigger_event(EventTypes.BEFORE_PREDICT, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.BEFORE_PREDICT.value, lit_api=lit_api)
y = _inject_context(
context,
lit_api.predict,
x,
)
callback_runner.trigger_event(EventTypes.AFTER_PREDICT, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.AFTER_PREDICT.value, lit_api=lit_api)

callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE.value, lit_api=lit_api)
y_enc = _inject_context(
context,
lit_api.encode_response,
y,
)
callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE.value, lit_api=lit_api)
self.put_response(
transport=transport,
response_queue_id=response_queue_id,
Expand Down Expand Up @@ -135,29 +135,29 @@ async def _process_single_request(
if hasattr(lit_spec, "populate_context"):
lit_spec.populate_context(context, x_enc)

callback_runner.trigger_event(EventTypes.BEFORE_DECODE_REQUEST, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.BEFORE_DECODE_REQUEST.value, lit_api=lit_api)
x = await _async_inject_context(
context,
lit_api.decode_request,
x_enc,
)
callback_runner.trigger_event(EventTypes.AFTER_DECODE_REQUEST, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.AFTER_DECODE_REQUEST.value, lit_api=lit_api)

callback_runner.trigger_event(EventTypes.BEFORE_PREDICT, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.BEFORE_PREDICT.value, lit_api=lit_api)
y = await _async_inject_context(
context,
lit_api.predict,
x,
)
callback_runner.trigger_event(EventTypes.AFTER_PREDICT, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.AFTER_PREDICT.value, lit_api=lit_api)

callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE.value, lit_api=lit_api)
y_enc = await _async_inject_context(
context,
lit_api.encode_response,
y,
)
callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE.value, lit_api=lit_api)
self.put_response(
transport=transport,
response_queue_id=response_queue_id,
Expand Down Expand Up @@ -306,7 +306,7 @@ def run_batched_loop(
for input, context in zip(inputs, contexts):
lit_spec.populate_context(context, input)

callback_runner.trigger_event(EventTypes.BEFORE_DECODE_REQUEST, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.BEFORE_DECODE_REQUEST.value, lit_api=lit_api)
x = [
_inject_context(
context,
Expand All @@ -315,13 +315,13 @@ def run_batched_loop(
)
for input, context in zip(inputs, contexts)
]
callback_runner.trigger_event(EventTypes.AFTER_DECODE_REQUEST, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.AFTER_DECODE_REQUEST.value, lit_api=lit_api)

x = lit_api.batch(x)

callback_runner.trigger_event(EventTypes.BEFORE_PREDICT, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.BEFORE_PREDICT.value, lit_api=lit_api)
y = _inject_context(contexts, lit_api.predict, x)
callback_runner.trigger_event(EventTypes.AFTER_PREDICT, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.AFTER_PREDICT.value, lit_api=lit_api)

outputs = lit_api.unbatch(y)

Expand All @@ -332,12 +332,12 @@ def run_batched_loop(
)
raise HTTPException(500, "Batch size mismatch")

callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE.value, lit_api=lit_api)
y_enc_list = []
for response_queue_id, y, uid, context in zip(response_queue_ids, outputs, uids, contexts):
y_enc = _inject_context(context, lit_api.encode_response, y)
y_enc_list.append((response_queue_id, uid, y_enc))
callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE.value, lit_api=lit_api)

for response_queue_id, uid, y_enc in y_enc_list:
self.put_response(transport, response_queue_id, uid, y_enc, LitAPIStatus.OK)
Expand Down
34 changes: 17 additions & 17 deletions src/litserve/loops/streaming_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ def run_streaming_loop(
x_enc,
)

callback_runner.trigger_event(EventTypes.BEFORE_PREDICT, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.BEFORE_PREDICT.value, lit_api=lit_api)
y_gen = _inject_context(
context,
lit_api.predict,
x,
)
callback_runner.trigger_event(EventTypes.AFTER_PREDICT, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.AFTER_PREDICT.value, lit_api=lit_api)

callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE.value, lit_api=lit_api)
y_enc_gen = _inject_context(
context,
lit_api.encode_response,
Expand All @@ -87,8 +87,8 @@ def run_streaming_loop(
self.put_response(transport, response_queue_id, uid, y_enc, LitAPIStatus.OK)
self.put_response(transport, response_queue_id, uid, "", LitAPIStatus.FINISH_STREAMING)

callback_runner.trigger_event(EventTypes.AFTER_PREDICT, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.AFTER_PREDICT.value, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE.value, lit_api=lit_api)

except HTTPException as e:
self.put_response(
Expand Down Expand Up @@ -123,23 +123,23 @@ async def _process_streaming_request(
if hasattr(lit_spec, "populate_context"):
lit_spec.populate_context(context, x_enc)

callback_runner.trigger_event(EventTypes.BEFORE_DECODE_REQUEST, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.BEFORE_DECODE_REQUEST.value, lit_api=lit_api)
x = await _async_inject_context(
context,
lit_api.decode_request,
x_enc,
)
callback_runner.trigger_event(EventTypes.AFTER_DECODE_REQUEST, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.AFTER_DECODE_REQUEST.value, lit_api=lit_api)

callback_runner.trigger_event(EventTypes.BEFORE_PREDICT, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.BEFORE_PREDICT.value, lit_api=lit_api)
y_gen = await _async_inject_context(
context,
lit_api.predict,
x,
)
callback_runner.trigger_event(EventTypes.AFTER_PREDICT, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.AFTER_PREDICT.value, lit_api=lit_api)

callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE.value, lit_api=lit_api)

# When using async, predict should return an async generator
# and encode_response should handle async generators
Expand All @@ -158,7 +158,7 @@ async def _process_streaming_request(
self.put_response(transport, response_queue_id, uid, y_enc, LitAPIStatus.OK)

self.put_response(transport, response_queue_id, uid, "", LitAPIStatus.FINISH_STREAMING)
callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE.value, lit_api=lit_api)

except HTTPException as e:
self.put_response(
Expand Down Expand Up @@ -282,7 +282,7 @@ def run_batched_streaming_loop(
for input, context in zip(inputs, contexts):
lit_spec.populate_context(context, input)

callback_runner.trigger_event(EventTypes.BEFORE_DECODE_REQUEST, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.BEFORE_DECODE_REQUEST.value, lit_api=lit_api)
x = [
_inject_context(
context,
Expand All @@ -291,19 +291,19 @@ def run_batched_streaming_loop(
)
for input, context in zip(inputs, contexts)
]
callback_runner.trigger_event(EventTypes.AFTER_DECODE_REQUEST, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.AFTER_DECODE_REQUEST.value, lit_api=lit_api)

x = lit_api.batch(x)

callback_runner.trigger_event(EventTypes.BEFORE_PREDICT, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.BEFORE_PREDICT.value, lit_api=lit_api)
y_iter = _inject_context(contexts, lit_api.predict, x)
callback_runner.trigger_event(EventTypes.AFTER_PREDICT, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.AFTER_PREDICT.value, lit_api=lit_api)

unbatched_iter = lit_api.unbatch(y_iter)

callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.BEFORE_ENCODE_RESPONSE.value, lit_api=lit_api)
y_enc_iter = _inject_context(contexts, lit_api.encode_response, unbatched_iter)
callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE.value, lit_api=lit_api)

# y_enc_iter -> [[response-1, response-2], [response-1, response-2]]
for y_batch in y_enc_iter:
Expand Down
12 changes: 6 additions & 6 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ async def lifespan(self, app: FastAPI):
try:
yield
finally:
self._callback_runner.trigger_event(EventTypes.ON_SERVER_END, litserver=self)
self._callback_runner.trigger_event(EventTypes.ON_SERVER_END.value, litserver=self)

# Cancel the task
task.cancel()
Expand Down Expand Up @@ -412,7 +412,7 @@ def active_requests(self):

def register_endpoints(self):
"""Register endpoint routes for the FastAPI app and setup middlewares."""
self._callback_runner.trigger_event(EventTypes.ON_SERVER_START, litserver=self)
self._callback_runner.trigger_event(EventTypes.ON_SERVER_START.value, litserver=self)
workers_ready = False

@self.app.get("/", dependencies=[Depends(self.setup_auth())])
Expand Down Expand Up @@ -449,7 +449,7 @@ async def info(request: Request) -> Response:

async def predict(request: self.request_type) -> self.response_type:
self._callback_runner.trigger_event(
EventTypes.ON_REQUEST,
EventTypes.ON_REQUEST.value,
active_requests=self.active_requests,
litserver=self,
)
Expand Down Expand Up @@ -478,12 +478,12 @@ async def predict(request: self.request_type) -> self.response_type:
if status == LitAPIStatus.ERROR:
logger.error("Error in request: %s", response)
raise HTTPException(status_code=500)
self._callback_runner.trigger_event(EventTypes.ON_RESPONSE, litserver=self)
self._callback_runner.trigger_event(EventTypes.ON_RESPONSE.value, litserver=self)
return response

async def stream_predict(request: self.request_type) -> self.response_type:
self._callback_runner.trigger_event(
EventTypes.ON_REQUEST,
EventTypes.ON_REQUEST.value,
active_requests=self.active_requests,
litserver=self,
)
Expand All @@ -502,7 +502,7 @@ async def stream_predict(request: self.request_type) -> self.response_type:
response = call_after_stream(
self.data_streamer(q, data_available=event),
self._callback_runner.trigger_event,
EventTypes.ON_RESPONSE,
EventTypes.ON_RESPONSE.value,
litserver=self,
)
return StreamingResponse(response)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def test_metric_logger(capfd):
cb_runner = CallbackRunner()
cb_runner._add_callbacks(cb)
assert cb_runner._callbacks == [cb], "Callback not added to runner"
cb_runner.trigger_event(EventTypes.BEFORE_PREDICT, lit_api=None)
cb_runner.trigger_event(EventTypes.AFTER_PREDICT, lit_api=None)
cb_runner.trigger_event(EventTypes.BEFORE_PREDICT.value, lit_api=None)
cb_runner.trigger_event(EventTypes.AFTER_PREDICT.value, lit_api=None)

captured = capfd.readouterr()
pattern = r"Prediction took \d+\.\d{2} seconds"
Expand Down