Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,7 @@ def __init__(
self.timeout = timeout
self.litapi_connector.set_request_timeout(timeout)
self.app = FastAPI(lifespan=self.lifespan, openapi_url="" if disable_openapi_url else "/openapi.json")

self.app.response_queue_id = None
self.response_buffer = {}
# gzip does not play nicely with streaming, see https://github.com/tiangolo/fastapi/discussions/8448
Expand Down Expand Up @@ -791,6 +792,8 @@ def __init__(
self.inference_workers_config = self.devices * self.workers_per_device
self.transport_config = TransportConfig(transport_config="zmq" if self.use_zmq else "mp")
self.register_endpoints()
# register middleware
self._register_middleware()

def launch_inference_worker(self, lit_api: LitAPI):
specs = [lit_api.spec] if lit_api.spec else []
Expand Down Expand Up @@ -947,6 +950,7 @@ async def shutdown_endpoint():

def register_endpoints(self):
self._register_internal_endpoints()

for lit_api in self.litapi_connector:
decode_request_signature = inspect.signature(lit_api.decode_request)
encode_response_signature = inspect.signature(lit_api.encode_response)
Expand Down Expand Up @@ -987,9 +991,6 @@ async def endpoint_handler(request: request_type) -> response_type:
# Handle specs
self._register_spec_endpoints(lit_api)

# Register middleware
self._register_middleware()

def _register_spec_endpoints(self, lit_api: LitAPI):
specs = [lit_api.spec] if lit_api.spec else []
for spec in specs:
Expand Down
10 changes: 10 additions & 0 deletions tests/unit/test_middlewares.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,13 @@ def test_middlewares_inputs():

with pytest.raises(ValueError, match="middlewares must be a list of tuples"):
ls.LitServer(ls.test_examples.SimpleLitAPI(), middlewares=(RequestIdMiddleware, {"length": 5}))


def test_middleware_multiple_initialization():
api1 = ls.test_examples.SimpleLitAPI(api_path="/api1")
api2 = ls.test_examples.SimpleLitAPI(api_path="/api2")
api3 = ls.test_examples.SimpleLitAPI(api_path="/api3")
api4 = ls.test_examples.SimpleLitAPI(api_path="/api4")

server = ls.LitServer([api1, api2, api3, api4])
assert len(server.app.user_middleware) == 1, "Default/User middleware should be initialized only once"
Loading