88import weakref
99from abc import ABC , abstractmethod
1010from concurrent .futures import Future
11+ from dataclasses import dataclass
1112from threading import Thread
1213from typing import Any , Dict , List , Optional , Type , Union
1314
@@ -169,6 +170,31 @@ def add_lora(self, lora_request: LoRARequest) -> None:
169170 self .engine_core .add_lora (lora_request )
170171
171172
173+ @dataclass
174+ class BackgroundResources :
175+ """Used as a finalizer for clean shutdown, avoiding
176+ circular reference back to the client object."""
177+
178+ ctx : Union [zmq .Context , zmq .asyncio .Context ] = None
179+ output_socket : Union [zmq .Socket , zmq .asyncio .Socket ] = None
180+ input_socket : Union [zmq .Socket , zmq .asyncio .Socket ] = None
181+ proc_handle : Optional [BackgroundProcHandle ] = None
182+
183+ def __call__ (self ):
184+ """Clean up background resources."""
185+
186+ if self .proc_handle is not None :
187+ self .proc_handle .shutdown ()
188+ # ZMQ context termination can hang if the sockets
189+ # aren't explicitly closed first.
190+ if self .output_socket is not None :
191+ self .output_socket .close (linger = 0 )
192+ if self .input_socket is not None :
193+ self .input_socket .close (linger = 0 )
194+ if self .ctx is not None :
195+ self .ctx .destroy (linger = 0 )
196+
197+
172198class MPClient (EngineCoreClient ):
173199 """
174200 MPClient: base client for multi-proc EngineCore.
@@ -212,21 +238,22 @@ def sigusr1_handler(signum, frame):
212238 zmq .asyncio .Context () # type: ignore[attr-defined]
213239 if asyncio_mode else zmq .Context ()) # type: ignore[attr-defined]
214240
215- # Note(rob): shutdown function cannot be a bound method,
216- # else the gc cannot collect the object.
217- self ._finalizer = weakref .finalize (self , lambda x : x .destroy (linger = 0 ),
218- self .ctx )
241+ # This will ensure resources created so far are closed
242+ # when the client is garbage collected, even if an
243+ # exception is raised mid-construction.
244+ resources = BackgroundResources (ctx = self .ctx )
245+ self ._finalizer = weakref .finalize (self , resources )
219246
220247 # Paths and sockets for IPC.
221248 output_path = get_open_zmq_ipc_path ()
222249 input_path = get_open_zmq_ipc_path ()
223- self .output_socket = make_zmq_socket (self .ctx , output_path ,
224- zmq .constants .PULL )
225- self .input_socket = make_zmq_socket (self .ctx , input_path ,
226- zmq .constants .PUSH )
250+ resources .output_socket = make_zmq_socket (self .ctx , output_path ,
251+ zmq .constants .PULL )
252+ resources .input_socket = make_zmq_socket (self .ctx , input_path ,
253+ zmq .constants .PUSH )
227254
228255 # Start EngineCore in background process.
229- self .proc_handle = BackgroundProcHandle (
256+ resources .proc_handle = BackgroundProcHandle (
230257 input_path = input_path ,
231258 output_path = output_path ,
232259 process_name = "EngineCore" ,
@@ -237,13 +264,11 @@ def sigusr1_handler(signum, frame):
237264 "log_stats" : log_stats ,
238265 })
239266
267+ self .output_socket = resources .output_socket
268+ self .input_socket = resources .input_socket
240269 self .utility_results : Dict [int , AnyFuture ] = {}
241270
242271 def shutdown (self ):
243- """Clean up background resources."""
244- if hasattr (self , "proc_handle" ):
245- self .proc_handle .shutdown ()
246-
247272 self ._finalizer ()
248273
249274
0 commit comments