3434
3535import uvicorn
3636import uvicorn .server
37- from fastapi import Depends , FastAPI , HTTPException , Request , Response
37+ from fastapi import Depends , FastAPI , HTTPException , Request , Response , status
3838from fastapi .responses import JSONResponse , StreamingResponse
3939from fastapi .security import APIKeyHeader
4040from starlette .formparsers import MultiPartParser
@@ -337,7 +337,7 @@ async def handle_request(self, request, request_type) -> StreamingResponse:
337337 self .server .data_streamer (response_queue , data_available = event ),
338338 self .server ._callback_runner .trigger_event ,
339339 EventTypes .ON_RESPONSE .value ,
340- litserver = self .server ,
340+ litserve = self .server ,
341341 )
342342
343343 return StreamingResponse (response_generator )
@@ -357,6 +357,7 @@ def __init__(
357357 timeout : Union [float , bool ] = 30 ,
358358 healthcheck_path : str = "/health" ,
359359 info_path : str = "/info" ,
360+ shutdown_path : str = "/shutdown" ,
360361 model_metadata : Optional [dict ] = None ,
361362 spec : Optional [LitSpec ] = None ,
362363 max_payload_size = None ,
@@ -370,6 +371,7 @@ def __init__(
370371 stream : bool = False ,
371372 api_path : Optional [str ] = None ,
372373 loop : Optional [Union [str , LitLoop ]] = None ,
374+ uvicorn_graceful_timeout : int = 30 ,
373375 ):
374376 """Initialize a LitServer instance for high-performance model inference.
375377
@@ -394,6 +396,9 @@ def __init__(
394396
395397 info_path (str, optional):
396398 Server info endpoint path showing metadata and configuration. Defaults to "/info".
399+
400+ shutdown_path (str, optional):
401+ Server shutdown endpoint path that terminates and cleans up all worker and server processes. Defaults to "/shutdown".
397402
398403 model_metadata (dict, optional):
399404 Model metadata displayed at info endpoint (e.g., {"version": "1.0"}). Defaults to None.
@@ -419,6 +424,10 @@ def __init__(
419424 max_batch_size, batch_timeout, stream, spec, api_path, loop:
420425 **Deprecated**: Configure these in your LitAPI implementation instead.
421426
427+ uvicorn_graceful_timeout (int, optional):
428+ Timeout in seconds for Uvicorn to gracefully shut down its workers.
429+ Defaults to 30.
430+
422431 Example:
423432 >>> # Basic
424433 >>> server = LitServer(MyLitAPI())
@@ -441,7 +450,7 @@ def __init__(
441450 "Old usage:\n "
442451 " server = LitServer(api, max_batch_size=N, batch_timeout=T, ...)\n \n "
443452 "New usage:\n "
444- " api = LitAPI(max_batch_size=N, batch_timeout=T, ...)\n "
453+ f " api = LitAPI(max_batch_size=N, batch_timeout=T, ...)\n "
445454 " server = LitServer(api, ...)" ,
446455 DeprecationWarning ,
447456 stacklevel = 2 ,
@@ -507,6 +516,7 @@ def __init__(
507516
508517 self .healthcheck_path = healthcheck_path
509518 self .info_path = info_path
519+ self ._shutdown_path = shutdown_path
510520 self .track_requests = track_requests
511521 self .timeout = timeout
512522 self .litapi_connector .set_request_timeout (timeout )
@@ -531,6 +541,12 @@ def __init__(
531541 self .use_zmq = fast_queue
532542 self .transport_config = None
533543 self .litapi_request_queues = {}
544+ self .inference_workers : List [mp .Process ] = []
545+ self .uvicorn_workers : List [Union [mp .Process , threading .Thread ]] = []
546+ self .manager : Optional [mp .Manager ] = None
547+ self ._shutdown_event : Optional [mp .Event ] = None
548+ self ._server_port : Optional [int ] = None
549+ self .uvicorn_graceful_timeout = uvicorn_graceful_timeout
534550
535551 accelerator = self ._connector .accelerator
536552 devices = self ._connector .devices
@@ -542,8 +558,8 @@ def __init__(
542558 device_list = range (devices )
543559 self .devices = [self .device_identifiers (accelerator , device ) for device in device_list ]
544560
545- self .inference_workers = self .devices * self .workers_per_device
546- self .transport_config = TransportConfig (transport_config = "zmq" if self .use_zmq else "mp" )
561+ self .inference_workers_config = self .devices * self .workers_per_device
562+ self .transport_config = TransportConfig (transport_type = "zmq" if self .use_zmq else "mp" )
547563 self .register_endpoints ()
548564
549565 def launch_inference_worker (self , lit_api : LitAPI ):
@@ -557,7 +573,7 @@ def launch_inference_worker(self, lit_api: LitAPI):
557573
558574 process_list = []
559575 endpoint = lit_api .api_path .split ("/" )[- 1 ]
560- for worker_id , device in enumerate (self .inference_workers ):
576+ for worker_id , device in enumerate (self .inference_workers_config ):
561577 if len (device ) == 1 :
562578 device = device [0 ]
563579
@@ -578,6 +594,7 @@ def launch_inference_worker(self, lit_api: LitAPI):
578594 )
579595 process .start ()
580596 process_list .append (process )
597+ self .inference_workers .extend (process_list )
581598 return process_list
582599
583600 @asynccontextmanager
@@ -698,6 +715,15 @@ def register_endpoints(self):
698715 response_type = Response
699716 self ._register_api_endpoints (lit_api , request_type , response_type )
700717
718+ # Register the new shutdown endpoint
719+ @self .app .post (self ._shutdown_path , status_code = status .HTTP_200_OK , dependencies = [Depends (self .setup_auth ())])
720+ async def shutdown_endpoint ():
721+ logger .info ("Received shutdown request via /shutdown endpoint. Signaling main process." )
722+ if self ._shutdown_event :
723+ self ._shutdown_event .set ()
724+ return Response (content = "Server is initiating graceful shutdown." , status_code = status .HTTP_200_OK )
725+
726+
701727 def _get_request_queue (self , api_path : str ):
702728 return self .litapi_request_queues [api_path ]
703729
@@ -770,19 +796,105 @@ def verify_worker_status(self):
770796 logger .debug ("One or more workers are ready to serve requests" )
771797
772798 def _init_manager (self , num_api_servers : int ):
773- manager = self .transport_config .manager = mp .Manager ()
799+ self .manager = mp .Manager ()
800+ self .transport_config .manager = self .manager
774801 self .transport_config .num_consumers = num_api_servers
775- self .workers_setup_status = manager .dict ()
802+ self .workers_setup_status = self .manager .dict ()
803+ self ._shutdown_event = self .manager .Event ()
776804
777805 # create request queues for each unique lit_api api_path
778806 for lit_api in self .litapi_connector :
779- self .litapi_request_queues [lit_api .api_path ] = manager .Queue ()
807+ self .litapi_request_queues [lit_api .api_path ] = self . manager .Queue ()
780808
781809 if self ._logger_connector ._loggers :
782- self .logger_queue = manager .Queue ()
810+ self .logger_queue = self . manager .Queue ()
783811 self ._logger_connector .run (self )
784812 self ._transport = create_transport_from_config (self .transport_config )
785- return manager
813+ return self .manager
814+
815+ def _perform_graceful_shutdown (self ):
816+ """Encapsulates the graceful shutdown logic for LitServe."""
817+ logger .info ("Starting graceful shutdown of LitServe components." )
818+
819+ # 1. Close the message transport first to stop new messages
820+ if self ._transport :
821+ logger .info ("Closing message transport..." )
822+ self ._transport .close ()
823+
824+ # 2. Terminate and join inference worker processes
825+ if self .inference_workers :
826+ logger .info (f"Terminating { len (self .inference_workers )} inference workers..." )
827+ for i , iw in enumerate (self .inference_workers ):
828+ worker_pid = iw .pid
829+ logger .info (f"Worker { i } (PID: { worker_pid } ): Checking status BEFORE termination attempt. is_alive(): { iw .is_alive ()} " )
830+
831+ if iw .is_alive ():
832+ try :
833+ iw .terminate ()
834+ logger .info (f"Worker { i } (PID: { worker_pid } ): Sent SIGTERM. Joining..." )
835+ iw .join (timeout = 5 )
836+
837+ logger .info (f"Worker { i } (PID: { worker_pid } ): Status AFTER SIGTERM & JOIN. is_alive(): { iw .is_alive ()} " )
838+ if iw .is_alive ():
839+ logger .warning (f"Worker { i } (PID: { worker_pid } ): Did not terminate gracefully. Forcibly killing (SIGKILL)." )
840+ iw .kill ()
841+ logger .info (f"Worker { i } (PID: { worker_pid } ): Status AFTER SIGKILL. is_alive(): { iw .is_alive ()} " )
842+ else :
843+ logger .info (f"Worker { i } (PID: { worker_pid } ): Terminated gracefully." )
844+ except Exception as e :
845+ logger .error (f"Error during termination of worker { i } (PID: { worker_pid } ): { e } " )
846+ else :
847+ logger .info (f"Worker { i } (PID: { worker_pid } ): Was already not alive before termination attempt." )
848+ logger .info ("Inference workers termination loop completed." )
849+
850+ # 3. Terminate Uvicorn API server workers tracked by LitServe (the master processes/threads)
851+ if self .uvicorn_workers :
852+ logger .info (f"Terminating { len (self .uvicorn_workers )} Uvicorn API server master processes/threads..." )
853+ for i , uw in enumerate (self .uvicorn_workers ):
854+ logger .info (f"PID: { uw .pid } ; { type (uw )} " )
855+ uvicorn_runner_pid = uw .pid
856+ if uvicorn_runner_pid :
857+ log_prefix = f"Uvicorn Master { 'Process' if uvicorn_runner_pid else 'Thread' } { i } "
858+ log_prefix += f" (PID: { uvicorn_runner_pid } )"
859+
860+ if isinstance (uw , threading .Thread ):
861+ logger .info (f"{ log_prefix } : will terminate with the main process." )
862+ else :
863+ logger .info (f"{ log_prefix } : Checking status BEFORE termination. is_alive(): { uw .is_alive ()} " )
864+ if uw .is_alive ():
865+ try :
866+ uw .terminate ()
867+ logger .info (f"{ log_prefix } : Sent SIGTERM. Joining (timeout={ self .uvicorn_graceful_timeout } s)..." )
868+ uw .join (timeout = self .uvicorn_graceful_timeout )
869+ logger .info (f"{ log_prefix } : Status AFTER SIGTERM & JOIN. is_alive(): { uw .is_alive ()} " )
870+ if uw .is_alive ():
871+ logger .warning (f"{ log_prefix } : Did not terminate gracefully. Forcibly killing." )
872+ uw .kill ()
873+ logger .info (f"{ log_prefix } : Status AFTER SIGKILL. is_alive(): { uw .is_alive ()} " )
874+ else :
875+ logger .info (f"{ log_prefix } : Terminated gracefully." )
876+ except Exception as e :
877+ logger .error (f"Error during termination of { log_prefix } : { e } " )
878+ else :
879+ logger .info (f"{ log_prefix } : Already not alive." )
880+ logger .info ("Uvicorn API server master processes/threads termination loop completed." )
881+
882+
883+ # 5. Shut down the multiprocessing manager
884+ if self .manager :
885+ logger .info ("Shutting down multiprocessing manager..." )
886+ self .manager .shutdown ()
887+ logger .info ("Multiprocessing manager shut down." )
888+
889+ logger .info ("All LitServe components gracefully shut down." )
890+
891+ # 6. Exit the main process
892+ def exit_process ():
893+ time .sleep (0.5 )
894+ logger .info ("Exiting main LitServe process." )
895+ os ._exit (0 )
896+
897+ threading .Thread (target = exit_process ).start ()
786898
787899 def run (
788900 self ,
@@ -824,21 +936,6 @@ def run(
824936 **kwargs:
825937 Additional uvicorn server options (ssl_keyfile, ssl_certfile, etc.). See uvicorn docs.
826938
827- Example:
828- >>> server.run() # Basic
829-
830- >>> server.run( # Production
831- ... port=8080,
832- ... num_api_servers=4,
833- ... log_level="warning"
834- ... )
835-
836- >>> server.run( # Development
837- ... log_level="debug",
838- ... pretty_logs=True,
839- ... generate_client_file=True
840- ... )
841-
842939 """
843940 if generate_client_file :
844941 LitServer .generate_client_file (port = port )
@@ -859,9 +956,11 @@ def run(
859956 configure_logging (log_level , use_rich = pretty_logs )
860957 config = uvicorn .Config (app = self .app , host = host , port = port , log_level = log_level , ** kwargs )
861958 sockets = [config .bind_socket ()]
959+
960+ self ._server_port = port
862961
863962 if num_api_servers is None :
864- num_api_servers = len (self .inference_workers )
963+ num_api_servers = len (self .inference_workers_config )
865964
866965 if num_api_servers < 1 :
867966 raise ValueError ("num_api_servers must be greater than 0" )
@@ -874,34 +973,26 @@ def run(
874973 elif api_server_worker_type is None :
875974 api_server_worker_type = "process"
876975
877- manager = self ._init_manager (num_api_servers )
976+ self ._init_manager (num_api_servers )
878977 self ._logger_connector .run (self )
879- inference_workers = []
880978 for lit_api in self .litapi_connector :
881- _inference_workers = self .launch_inference_worker (lit_api )
882- inference_workers .extend (_inference_workers )
979+ self .launch_inference_worker (lit_api )
883980
884981 self .verify_worker_status ()
885982 try :
886- uvicorn_workers = self ._start_server (
983+ self . uvicorn_workers = self ._start_server (
887984 port , num_api_servers , log_level , sockets , api_server_worker_type , ** kwargs
888985 )
889986 print (f"Swagger UI is available at http://0.0.0.0:{ port } /docs" )
890- # On Linux, kill signal will be captured by uvicorn.
891- # => They will join and raise a KeyboardInterrupt, allowing to Shutdown server.
892- for i , uw in enumerate (uvicorn_workers ):
893- uw : Union [Process , Thread ]
894- if isinstance (uw , Process ):
895- print (f"Uvicorn worker { i } : [{ uw .pid } ]" )
896- uw .join ()
987+
988+ while not self ._shutdown_event .is_set ():
989+ time .sleep (0.1 )
990+
991+ except KeyboardInterrupt :
992+ logger .info ("KeyboardInterrupt received. Initiating graceful shutdown." )
897993 finally :
898- print ("Shutting down LitServe" )
899- self ._transport .close ()
900- for iw in inference_workers :
901- iw : Process
902- iw .terminate ()
903- iw .join ()
904- manager .shutdown ()
994+ self ._perform_graceful_shutdown ()
995+
905996
906997 def _prepare_app_run (self , app : FastAPI ):
907998 # Add middleware to count active requests
@@ -920,17 +1011,17 @@ def _start_server(self, port, num_uvicorn_servers, log_level, sockets, uvicorn_w
9201011 app : FastAPI = copy .copy (self .app )
9211012
9221013 self ._prepare_app_run (app )
923- config = uvicorn .Config (app = app , host = "0.0.0.0" , port = port , log_level = log_level , ** kwargs )
924- if sys . platform == "win32" and num_uvicorn_servers > 1 :
925- logger . debug ( "Enable Windows explicit socket sharing..." )
926- # We make sure sockets is listening...
927- # It prevents further [WinError 10022]
928- for sock in sockets :
929- sock . listen ( config . backlog )
930- # We add worker to say unicorn to use a shared socket (win32)
931- # https://github.com/encode/uvicorn/pull/802
932- config . workers = num_uvicorn_servers
933- server = uvicorn .Server (config = config )
1014+ uvicorn_config = uvicorn .Config (
1015+ app = app ,
1016+ host = "0.0.0.0" ,
1017+ port = port ,
1018+ log_level = log_level ,
1019+ workers = num_uvicorn_servers if uvicorn_worker_type == "process" else 1 ,
1020+ timeout_graceful_shutdown = self . uvicorn_graceful_timeout ,
1021+ ** kwargs
1022+ )
1023+
1024+ server = uvicorn .Server (config = uvicorn_config )
9341025 if uvicorn_worker_type == "process" :
9351026 ctx = mp .get_context ("fork" )
9361027 w = ctx .Process (target = server .run , args = (sockets ,))
@@ -947,4 +1038,4 @@ def setup_auth(self):
9471038 return self .lit_api .authorize
9481039 if LIT_SERVER_API_KEY :
9491040 return api_key_auth
950- return no_auth
1041+ return no_auth
0 commit comments