Skip to content

Commit ade9b2f

Browse files
committed
adding shutdown endpoint part 2
1 parent 9f1b4f3 commit ade9b2f

File tree

1 file changed

+149
-58
lines changed

1 file changed

+149
-58
lines changed

src/litserve/server.py

Lines changed: 149 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
import uvicorn
3636
import uvicorn.server
37-
from fastapi import Depends, FastAPI, HTTPException, Request, Response
37+
from fastapi import Depends, FastAPI, HTTPException, Request, Response, status
3838
from fastapi.responses import JSONResponse, StreamingResponse
3939
from fastapi.security import APIKeyHeader
4040
from starlette.formparsers import MultiPartParser
@@ -337,7 +337,7 @@ async def handle_request(self, request, request_type) -> StreamingResponse:
337337
self.server.data_streamer(response_queue, data_available=event),
338338
self.server._callback_runner.trigger_event,
339339
EventTypes.ON_RESPONSE.value,
340-
litserver=self.server,
340+
litserve=self.server,
341341
)
342342

343343
return StreamingResponse(response_generator)
@@ -357,6 +357,7 @@ def __init__(
357357
timeout: Union[float, bool] = 30,
358358
healthcheck_path: str = "/health",
359359
info_path: str = "/info",
360+
shutdown_path: str = "/shutdown",
360361
model_metadata: Optional[dict] = None,
361362
spec: Optional[LitSpec] = None,
362363
max_payload_size=None,
@@ -370,6 +371,7 @@ def __init__(
370371
stream: bool = False,
371372
api_path: Optional[str] = None,
372373
loop: Optional[Union[str, LitLoop]] = None,
374+
uvicorn_graceful_timeout: int = 30,
373375
):
374376
"""Initialize a LitServer instance for high-performance model inference.
375377
@@ -394,6 +396,9 @@ def __init__(
394396
395397
info_path (str, optional):
396398
Server info endpoint path showing metadata and configuration. Defaults to "/info".
399+
400+
shutdown_path (str, optional):
401+
Server shutdown endpoint path that terminates and cleans up all worker and server processes. Defaults to "/shutdown".
397402
398403
model_metadata (dict, optional):
399404
Model metadata displayed at info endpoint (e.g., {"version": "1.0"}). Defaults to None.
@@ -419,6 +424,10 @@ def __init__(
419424
max_batch_size, batch_timeout, stream, spec, api_path, loop:
420425
**Deprecated**: Configure these in your LitAPI implementation instead.
421426
427+
uvicorn_graceful_timeout (int, optional):
428+
Timeout in seconds for Uvicorn to gracefully shut down its workers.
429+
Defaults to 30.
430+
422431
Example:
423432
>>> # Basic
424433
>>> server = LitServer(MyLitAPI())
@@ -441,7 +450,7 @@ def __init__(
441450
"Old usage:\n"
442451
" server = LitServer(api, max_batch_size=N, batch_timeout=T, ...)\n\n"
443452
"New usage:\n"
444-
" api = LitAPI(max_batch_size=N, batch_timeout=T, ...)\n"
453+
f" api = LitAPI(max_batch_size=N, batch_timeout=T, ...)\n"
445454
" server = LitServer(api, ...)",
446455
DeprecationWarning,
447456
stacklevel=2,
@@ -507,6 +516,7 @@ def __init__(
507516

508517
self.healthcheck_path = healthcheck_path
509518
self.info_path = info_path
519+
self._shutdown_path = shutdown_path
510520
self.track_requests = track_requests
511521
self.timeout = timeout
512522
self.litapi_connector.set_request_timeout(timeout)
@@ -531,6 +541,12 @@ def __init__(
531541
self.use_zmq = fast_queue
532542
self.transport_config = None
533543
self.litapi_request_queues = {}
544+
self.inference_workers: List[mp.Process] = []
545+
self.uvicorn_workers: List[Union[mp.Process, threading.Thread]] = []
546+
self.manager: Optional[mp.Manager] = None
547+
self._shutdown_event: Optional[mp.Event] = None
548+
self._server_port: Optional[int] = None
549+
self.uvicorn_graceful_timeout = uvicorn_graceful_timeout
534550

535551
accelerator = self._connector.accelerator
536552
devices = self._connector.devices
@@ -542,8 +558,8 @@ def __init__(
542558
device_list = range(devices)
543559
self.devices = [self.device_identifiers(accelerator, device) for device in device_list]
544560

545-
self.inference_workers = self.devices * self.workers_per_device
546-
self.transport_config = TransportConfig(transport_config="zmq" if self.use_zmq else "mp")
561+
self.inference_workers_config = self.devices * self.workers_per_device
562+
self.transport_config = TransportConfig(transport_type="zmq" if self.use_zmq else "mp")
547563
self.register_endpoints()
548564

549565
def launch_inference_worker(self, lit_api: LitAPI):
@@ -557,7 +573,7 @@ def launch_inference_worker(self, lit_api: LitAPI):
557573

558574
process_list = []
559575
endpoint = lit_api.api_path.split("/")[-1]
560-
for worker_id, device in enumerate(self.inference_workers):
576+
for worker_id, device in enumerate(self.inference_workers_config):
561577
if len(device) == 1:
562578
device = device[0]
563579

@@ -578,6 +594,7 @@ def launch_inference_worker(self, lit_api: LitAPI):
578594
)
579595
process.start()
580596
process_list.append(process)
597+
self.inference_workers.extend(process_list)
581598
return process_list
582599

583600
@asynccontextmanager
@@ -698,6 +715,15 @@ def register_endpoints(self):
698715
response_type = Response
699716
self._register_api_endpoints(lit_api, request_type, response_type)
700717

718+
# Register the new shutdown endpoint
719+
@self.app.post(self._shutdown_path, status_code=status.HTTP_200_OK, dependencies=[Depends(self.setup_auth())])
720+
async def shutdown_endpoint():
721+
logger.info("Received shutdown request via /shutdown endpoint. Signaling main process.")
722+
if self._shutdown_event:
723+
self._shutdown_event.set()
724+
return Response(content="Server is initiating graceful shutdown.", status_code=status.HTTP_200_OK)
725+
726+
701727
def _get_request_queue(self, api_path: str):
702728
return self.litapi_request_queues[api_path]
703729

@@ -770,19 +796,105 @@ def verify_worker_status(self):
770796
logger.debug("One or more workers are ready to serve requests")
771797

772798
def _init_manager(self, num_api_servers: int):
773-
manager = self.transport_config.manager = mp.Manager()
799+
self.manager = mp.Manager()
800+
self.transport_config.manager = self.manager
774801
self.transport_config.num_consumers = num_api_servers
775-
self.workers_setup_status = manager.dict()
802+
self.workers_setup_status = self.manager.dict()
803+
self._shutdown_event = self.manager.Event()
776804

777805
# create request queues for each unique lit_api api_path
778806
for lit_api in self.litapi_connector:
779-
self.litapi_request_queues[lit_api.api_path] = manager.Queue()
807+
self.litapi_request_queues[lit_api.api_path] = self.manager.Queue()
780808

781809
if self._logger_connector._loggers:
782-
self.logger_queue = manager.Queue()
810+
self.logger_queue = self.manager.Queue()
783811
self._logger_connector.run(self)
784812
self._transport = create_transport_from_config(self.transport_config)
785-
return manager
813+
return self.manager
814+
815+
def _perform_graceful_shutdown(self):
816+
"""Encapsulates the graceful shutdown logic for LitServe."""
817+
logger.info("Starting graceful shutdown of LitServe components.")
818+
819+
# 1. Close the message transport first to stop new messages
820+
if self._transport:
821+
logger.info("Closing message transport...")
822+
self._transport.close()
823+
824+
# 2. Terminate and join inference worker processes
825+
if self.inference_workers:
826+
logger.info(f"Terminating {len(self.inference_workers)} inference workers...")
827+
for i, iw in enumerate(self.inference_workers):
828+
worker_pid = iw.pid
829+
logger.info(f"Worker {i} (PID: {worker_pid}): Checking status BEFORE termination attempt. is_alive(): {iw.is_alive()}")
830+
831+
if iw.is_alive():
832+
try:
833+
iw.terminate()
834+
logger.info(f"Worker {i} (PID: {worker_pid}): Sent SIGTERM. Joining...")
835+
iw.join(timeout=5)
836+
837+
logger.info(f"Worker {i} (PID: {worker_pid}): Status AFTER SIGTERM & JOIN. is_alive(): {iw.is_alive()}")
838+
if iw.is_alive():
839+
logger.warning(f"Worker {i} (PID: {worker_pid}): Did not terminate gracefully. Forcibly killing (SIGKILL).")
840+
iw.kill()
841+
logger.info(f"Worker {i} (PID: {worker_pid}): Status AFTER SIGKILL. is_alive(): {iw.is_alive()}")
842+
else:
843+
logger.info(f"Worker {i} (PID: {worker_pid}): Terminated gracefully.")
844+
except Exception as e:
845+
logger.error(f"Error during termination of worker {i} (PID: {worker_pid}): {e}")
846+
else:
847+
logger.info(f"Worker {i} (PID: {worker_pid}): Was already not alive before termination attempt.")
848+
logger.info("Inference workers termination loop completed.")
849+
850+
# 3. Terminate Uvicorn API server workers tracked by LitServe (the master processes/threads)
851+
if self.uvicorn_workers:
852+
logger.info(f"Terminating {len(self.uvicorn_workers)} Uvicorn API server master processes/threads...")
853+
for i, uw in enumerate(self.uvicorn_workers):
854+
logger.info(f"PID: {uw.pid}; {type(uw)}")
855+
uvicorn_runner_pid = uw.pid
856+
if uvicorn_runner_pid:
857+
log_prefix = f"Uvicorn Master {'Process' if uvicorn_runner_pid else 'Thread'} {i}"
858+
log_prefix += f" (PID: {uvicorn_runner_pid})"
859+
860+
if isinstance(uw, threading.Thread):
861+
logger.info(f"{log_prefix}: will terminate with the main process.")
862+
else:
863+
logger.info(f"{log_prefix}: Checking status BEFORE termination. is_alive(): {uw.is_alive()}")
864+
if uw.is_alive():
865+
try:
866+
uw.terminate()
867+
logger.info(f"{log_prefix}: Sent SIGTERM. Joining (timeout={self.uvicorn_graceful_timeout}s)...")
868+
uw.join(timeout=self.uvicorn_graceful_timeout)
869+
logger.info(f"{log_prefix}: Status AFTER SIGTERM & JOIN. is_alive(): {uw.is_alive()}")
870+
if uw.is_alive():
871+
logger.warning(f"{log_prefix}: Did not terminate gracefully. Forcibly killing.")
872+
uw.kill()
873+
logger.info(f"{log_prefix}: Status AFTER SIGKILL. is_alive(): {uw.is_alive()}")
874+
else:
875+
logger.info(f"{log_prefix}: Terminated gracefully.")
876+
except Exception as e:
877+
logger.error(f"Error during termination of {log_prefix}: {e}")
878+
else:
879+
logger.info(f"{log_prefix}: Already not alive.")
880+
logger.info("Uvicorn API server master processes/threads termination loop completed.")
881+
882+
883+
# 5. Shut down the multiprocessing manager
884+
if self.manager:
885+
logger.info("Shutting down multiprocessing manager...")
886+
self.manager.shutdown()
887+
logger.info("Multiprocessing manager shut down.")
888+
889+
logger.info("All LitServe components gracefully shut down.")
890+
891+
# 6. Exit the main process
892+
def exit_process():
893+
time.sleep(0.5)
894+
logger.info("Exiting main LitServe process.")
895+
os._exit(0)
896+
897+
threading.Thread(target=exit_process).start()
786898

787899
def run(
788900
self,
@@ -824,21 +936,6 @@ def run(
824936
**kwargs:
825937
Additional uvicorn server options (ssl_keyfile, ssl_certfile, etc.). See uvicorn docs.
826938
827-
Example:
828-
>>> server.run() # Basic
829-
830-
>>> server.run( # Production
831-
... port=8080,
832-
... num_api_servers=4,
833-
... log_level="warning"
834-
... )
835-
836-
>>> server.run( # Development
837-
... log_level="debug",
838-
... pretty_logs=True,
839-
... generate_client_file=True
840-
... )
841-
842939
"""
843940
if generate_client_file:
844941
LitServer.generate_client_file(port=port)
@@ -859,9 +956,11 @@ def run(
859956
configure_logging(log_level, use_rich=pretty_logs)
860957
config = uvicorn.Config(app=self.app, host=host, port=port, log_level=log_level, **kwargs)
861958
sockets = [config.bind_socket()]
959+
960+
self._server_port = port
862961

863962
if num_api_servers is None:
864-
num_api_servers = len(self.inference_workers)
963+
num_api_servers = len(self.inference_workers_config)
865964

866965
if num_api_servers < 1:
867966
raise ValueError("num_api_servers must be greater than 0")
@@ -874,34 +973,26 @@ def run(
874973
elif api_server_worker_type is None:
875974
api_server_worker_type = "process"
876975

877-
manager = self._init_manager(num_api_servers)
976+
self._init_manager(num_api_servers)
878977
self._logger_connector.run(self)
879-
inference_workers = []
880978
for lit_api in self.litapi_connector:
881-
_inference_workers = self.launch_inference_worker(lit_api)
882-
inference_workers.extend(_inference_workers)
979+
self.launch_inference_worker(lit_api)
883980

884981
self.verify_worker_status()
885982
try:
886-
uvicorn_workers = self._start_server(
983+
self.uvicorn_workers = self._start_server(
887984
port, num_api_servers, log_level, sockets, api_server_worker_type, **kwargs
888985
)
889986
print(f"Swagger UI is available at http://0.0.0.0:{port}/docs")
890-
# On Linux, kill signal will be captured by uvicorn.
891-
# => They will join and raise a KeyboardInterrupt, allowing to Shutdown server.
892-
for i, uw in enumerate(uvicorn_workers):
893-
uw: Union[Process, Thread]
894-
if isinstance(uw, Process):
895-
print(f"Uvicorn worker {i} : [{uw.pid}]")
896-
uw.join()
987+
988+
while not self._shutdown_event.is_set():
989+
time.sleep(0.1)
990+
991+
except KeyboardInterrupt:
992+
logger.info("KeyboardInterrupt received. Initiating graceful shutdown.")
897993
finally:
898-
print("Shutting down LitServe")
899-
self._transport.close()
900-
for iw in inference_workers:
901-
iw: Process
902-
iw.terminate()
903-
iw.join()
904-
manager.shutdown()
994+
self._perform_graceful_shutdown()
995+
905996

906997
def _prepare_app_run(self, app: FastAPI):
907998
# Add middleware to count active requests
@@ -920,17 +1011,17 @@ def _start_server(self, port, num_uvicorn_servers, log_level, sockets, uvicorn_w
9201011
app: FastAPI = copy.copy(self.app)
9211012

9221013
self._prepare_app_run(app)
923-
config = uvicorn.Config(app=app, host="0.0.0.0", port=port, log_level=log_level, **kwargs)
924-
if sys.platform == "win32" and num_uvicorn_servers > 1:
925-
logger.debug("Enable Windows explicit socket sharing...")
926-
# We make sure sockets is listening...
927-
# It prevents further [WinError 10022]
928-
for sock in sockets:
929-
sock.listen(config.backlog)
930-
# We add worker to say unicorn to use a shared socket (win32)
931-
# https://github.com/encode/uvicorn/pull/802
932-
config.workers = num_uvicorn_servers
933-
server = uvicorn.Server(config=config)
1014+
uvicorn_config = uvicorn.Config(
1015+
app=app,
1016+
host="0.0.0.0",
1017+
port=port,
1018+
log_level=log_level,
1019+
workers=num_uvicorn_servers if uvicorn_worker_type == "process" else 1,
1020+
timeout_graceful_shutdown=self.uvicorn_graceful_timeout,
1021+
**kwargs
1022+
)
1023+
1024+
server = uvicorn.Server(config=uvicorn_config)
9341025
if uvicorn_worker_type == "process":
9351026
ctx = mp.get_context("fork")
9361027
w = ctx.Process(target=server.run, args=(sockets,))
@@ -947,4 +1038,4 @@ def setup_auth(self):
9471038
return self.lit_api.authorize
9481039
if LIT_SERVER_API_KEY:
9491040
return api_key_auth
950-
return no_auth
1041+
return no_auth

0 commit comments

Comments
 (0)