diff --git a/src/litserve/server.py b/src/litserve/server.py index 29017687..ad0f16ed 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -130,6 +130,7 @@ def __init__( api_path: str = "/predict", healthcheck_path: str = "/health", info_path: str = "/info", + shutdown_path: str = "/shutdown", model_metadata: Optional[dict] = None, stream: bool = False, spec: Optional[LitSpec] = None, @@ -154,6 +155,7 @@ def __init__( api_path: URL path for the prediction endpoint. healthcheck_path: URL path for the health check endpoint. info_path: URL path for the server and model information endpoint. + shutdown_path: URL path for the server shutdown endpoint. model_metadata: Metadata about the model, shown at the info endpoint. stream: Whether to enable streaming responses. spec: Specification for the API, such as OpenAISpec or custom specs. @@ -220,6 +222,9 @@ def __init__( "info_path must start with '/'. Please provide a valid api path like '/info', '/details', or '/v1/info'" ) + if not shutdown_path.startswith("/"): + raise ValueError("shutdown_path must start with '/'. Please provide a valid api path like '/shutdown'") + try: json.dumps(model_metadata) except (TypeError, ValueError): @@ -243,6 +248,7 @@ def __init__( self.api_path = api_path self.healthcheck_path = healthcheck_path self.info_path = info_path + self.shutdown_path = shutdown_path self.track_requests = track_requests self.timeout = timeout lit_api.stream = stream @@ -447,6 +453,16 @@ async def info(request: Request) -> Response: } ) + @self.app.post(self.shutdown_path, dependencies=[Depends(self.setup_auth())]) + async def shutdown(request: Request): + server = self.app.state.server + print("Initiating shutdown...") + if server.should_exit: + return Response(content="Shutdown already in progress", status_code=400) + server.should_exit = True + + return Response(content="Server has been shutdown") + async def predict(request: self.request_type) -> self.response_type: self._callback_runner.trigger_event( EventTypes.ON_REQUEST.value, @@ -617,12 +633,14 @@ def run( print(f"Uvicorn worker {i} : [{uw.pid}]") uw.join() finally: - print("Shutting down LitServe") self._transport.close() + print("Transport closed") for iw in inference_workers: iw: Process + print(f"Terminating worker [PID {iw.pid}]") iw.terminate() iw.join() + print("Shutting down LitServe") manager.shutdown() def _prepare_app_run(self, app: FastAPI): @@ -651,6 +669,7 @@ def _start_server(self, port, num_uvicorn_servers, log_level, sockets, uvicorn_w # https://github.com/encode/uvicorn/pull/802 config.workers = num_uvicorn_servers server = uvicorn.Server(config=config) + self.app.state.server = server if uvicorn_worker_type == "process": ctx = mp.get_context("fork") w = ctx.Process(target=server.run, args=(sockets,)) diff --git a/tests/test_simple.py b/tests/test_simple.py index 8ab1934c..05e1e5b2 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -15,6 +15,7 @@ import time from concurrent.futures import ThreadPoolExecutor from contextlib import ExitStack +from types import SimpleNamespace import numpy as np import pytest @@ -147,6 +148,30 @@ def test_workers_health_with_custom_health_method(use_zmq): assert response.text == "ok" +def test_shutdown_endpoint(): + server = LitServer( + SlowSetupLitAPI(), + accelerator="cpu", + shutdown_path="/shutdown", + devices=1, + workers_per_device=1, + ) + + server.app.state.server = SimpleNamespace(should_exit=False) + + with wrap_litserve_start(server) as server, TestClient(server.app) as client: + response = client.post("/shutdown") + assert response.status_code == 200 + assert "shutdown" in response.text.lower() + time.sleep(0.5) + assert server.app.state.server.should_exit is True, "Server should be marked for shutdown" + + time.sleep(1) + response = client.post("/shutdown") + assert response.status_code == 400 + assert "shutdown already" in response.text.lower() + + def make_load_request(server, outputs): with TestClient(server.app) as client: for i in range(100):