Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,7 @@ def __init__(
self.timeout = timeout
self.litapi_connector.set_request_timeout(timeout)
self.app = FastAPI(lifespan=self.lifespan, openapi_url="" if disable_openapi_url else "/openapi.json")

self.app.response_queue_id = None
self.response_buffer = {}
# gzip does not play nicely with streaming, see https://github.com/tiangolo/fastapi/discussions/8448
Expand Down Expand Up @@ -791,6 +792,8 @@ def __init__(
self.inference_workers_config = self.devices * self.workers_per_device
self.transport_config = TransportConfig(transport_config="zmq" if self.use_zmq else "mp")
self.register_endpoints()
# register middleware
self._register_middleware()

def launch_inference_worker(self, lit_api: LitAPI):
specs = [lit_api.spec] if lit_api.spec else []
Expand Down Expand Up @@ -947,6 +950,7 @@ async def shutdown_endpoint():

def register_endpoints(self):
self._register_internal_endpoints()

for lit_api in self.litapi_connector:
decode_request_signature = inspect.signature(lit_api.decode_request)
encode_response_signature = inspect.signature(lit_api.encode_response)
Expand Down Expand Up @@ -987,9 +991,6 @@ async def endpoint_handler(request: request_type) -> response_type:
# Handle specs
self._register_spec_endpoints(lit_api)

# Register middleware
self._register_middleware()

def _register_spec_endpoints(self, lit_api: LitAPI):
specs = [lit_api.spec] if lit_api.spec else []
for spec in specs:
Expand Down
10 changes: 10 additions & 0 deletions tests/unit/test_middlewares.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,13 @@ def test_middlewares_inputs():

with pytest.raises(ValueError, match="middlewares must be a list of tuples"):
ls.LitServer(ls.test_examples.SimpleLitAPI(), middlewares=(RequestIdMiddleware, {"length": 5}))


def test_middleware_multiple_initialization():
api1 = ls.test_examples.SimpleLitAPI(api_path="/api1")
api2 = ls.test_examples.SimpleLitAPI(api_path="/api2")
api3 = ls.test_examples.SimpleLitAPI(api_path="/api3")
api4 = ls.test_examples.SimpleLitAPI(api_path="/api4")

server = ls.LitServer([api1, api2, api3, api4])
assert len(server.app.user_middleware) == 1, "Each middleware should be initialized only once for `n` LitAPIs"
Loading