99import time
1010import traceback
1111import weakref
12+ from collections import deque
1213from collections .abc import Callable
13- from concurrent .futures import Future , ThreadPoolExecutor
14+ from concurrent .futures import Future
1415from dataclasses import dataclass
1516from enum import Enum , auto
1617from functools import cached_property , partial
5455logger = init_logger (__name__ )
5556
5657
58+ class FutureWrapper (Future ):
59+ def __init__ (self , futures_queue : deque [tuple ["FutureWrapper" , Callable ]]):
60+ self .futures_queue = futures_queue
61+ super ().__init__ ()
62+
63+ def result (self , timeout = None ):
64+ if timeout is not None :
65+ raise RuntimeError ("timeout not implemented" )
66+ # Drain any futures ahead of us in the queue.
67+ while not self .done ():
68+ future , get_response = self .futures_queue .pop ()
69+ future .update_with_response (get_response )
70+ return super ().result ()
71+
72+ def update_with_response (self , get_response : Callable ):
73+ try :
74+ response = get_response ()
75+ self .set_result (response )
76+ except Exception as e :
77+ self .set_exception (e )
78+
79+
5780class MultiprocExecutor (Executor ):
5881 supports_pp : bool = True
5982
@@ -64,7 +87,6 @@ def _init_executor(self) -> None:
6487 self .is_failed = False
6588 self .shutdown_event = threading .Event ()
6689 self .failure_callback : FailureCallback | None = None
67- self .io_thread_pool : ThreadPoolExecutor | None = None
6890
6991 self .world_size = self .parallel_config .world_size
7092 tensor_parallel_size = self .parallel_config .tensor_parallel_size
@@ -132,12 +154,7 @@ def _init_executor(self) -> None:
132154 uw .death_writer .close ()
133155 self ._ensure_worker_termination ([uw .proc for uw in unready_workers ])
134156
135- # Note: must use only 1 IO thread to keep dequeue sequence
136- # from the response queue.
137- # _async_aggregate_workers_output also assumes a single IO thread.
138- self .io_thread_pool = ThreadPoolExecutor (
139- max_workers = 1 , thread_name_prefix = "mp_exec_io"
140- )
157+ self .futures_queue = deque [tuple [FutureWrapper , Callable ]]()
141158
142159 self .output_rank = self ._get_output_rank ()
143160 self .has_connector = self .vllm_config .kv_transfer_config is not None
@@ -195,14 +212,13 @@ def _execute_with_aggregation(
195212 ) -> ModelRunnerOutput | None | Future [ModelRunnerOutput | None ]:
196213 if not self .has_connector :
197214 # get output only from a single worker (output_rank)
198- ( output ,) = self .collective_rpc (
215+ return self .collective_rpc (
199216 method ,
200217 args = args ,
201218 unique_reply_rank = self .output_rank ,
202219 non_block = non_block ,
203220 timeout = envs .VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS ,
204221 )
205- return output
206222
207223 # get output from all workers
208224 outputs = self .collective_rpc (
@@ -223,20 +239,21 @@ def execute_dummy_batch(self) -> None:
223239
224240 def take_draft_token_ids (self ) -> DraftTokenIds | None :
225241 # OPTIMIZATION: Get output only from a single worker (output_rank)
226- outputs = self .collective_rpc (
242+ return self .collective_rpc (
227243 "take_draft_token_ids" , unique_reply_rank = self .output_rank
228244 )
229- return outputs [0 ]
230245
231- def collective_rpc (
246+ def collective_rpc ( # type: ignore[override]
232247 self ,
233248 method : str | Callable ,
234249 timeout : float | None = None ,
235250 args : tuple = (),
236251 kwargs : dict | None = None ,
237252 non_block : bool = False ,
238253 unique_reply_rank : int | None = None ,
239- ) -> list [Any ]:
254+ ) -> Any | list [Any ] | Future [Any | list [Any ]]:
255+ """Returns single result if unique_reply_rank is provided, otherwise list."""
256+
240257 if self .is_failed :
241258 raise RuntimeError ("Executor failed." )
242259
@@ -246,63 +263,52 @@ def collective_rpc(
246263 # NOTE: If the args are heterogeneous, then we pack them into a list,
247264 # and unpack them in the method of every worker, because every worker
248265 # knows their own rank.
249- try :
250- if isinstance (method , str ):
251- send_method = method
252- else :
253- send_method = cloudpickle .dumps (
254- method , protocol = pickle .HIGHEST_PROTOCOL
255- )
256- self .rpc_broadcast_mq .enqueue (
257- (send_method , args , kwargs , unique_reply_rank )
258- )
259266
260- workers = (
261- (self .workers [unique_reply_rank ],)
262- if unique_reply_rank is not None
263- else self .workers
264- )
265- responses = []
267+ if isinstance (method , str ):
268+ send_method = method
269+ else :
270+ send_method = cloudpickle .dumps (method , protocol = pickle .HIGHEST_PROTOCOL )
271+ self .rpc_broadcast_mq .enqueue ((send_method , args , kwargs , unique_reply_rank ))
266272
267- def get_response (
268- w : WorkerProcHandle ,
269- dequeue_timeout : float | None = None ,
270- cancel_event : threading .Event | None = None ,
271- ):
272- status , result = w .worker_response_mq .dequeue (
273- timeout = dequeue_timeout , cancel = cancel_event
274- )
273+ workers = (
274+ (self .workers [unique_reply_rank ],)
275+ if unique_reply_rank is not None
276+ else self .workers
277+ )
275278
276- if status != WorkerProc .ResponseStatus .SUCCESS :
277- raise RuntimeError (
278- f"Worker failed with error '{ result } ', please check the"
279- " stack trace above for the root cause"
280- )
281- return result
279+ shutdown_event = self .shutdown_event
282280
281+ def get_response ():
282+ responses = []
283283 for w in workers :
284284 dequeue_timeout = (
285285 None if deadline is None else (deadline - time .monotonic ())
286286 )
287-
288- if self .io_thread_pool is not None :
289- # We must consume worker_response_mq from a single thread.
290- result = self .io_thread_pool .submit ( # type: ignore
291- get_response , w , dequeue_timeout , self .shutdown_event
287+ try :
288+ status , result = w .worker_response_mq .dequeue (
289+ timeout = dequeue_timeout , cancel = shutdown_event
292290 )
293- if not non_block :
294- result = result .result ()
295- elif not non_block :
296- result = get_response (w , dequeue_timeout , self .shutdown_event )
297- else :
291+ except TimeoutError as e :
292+ raise TimeoutError (f"RPC call to { method } timed out." ) from e
293+ if status != WorkerProc .ResponseStatus .SUCCESS :
298294 raise RuntimeError (
299- "non_block can only be used when max_concurrent_batches > 1"
295+ f"Worker failed with error '{ result } ', please check the"
296+ " stack trace above for the root cause"
300297 )
301298 responses .append (result )
299+ return responses [0 ] if unique_reply_rank is not None else responses
300+
301+ if non_block :
302+ future = FutureWrapper (self .futures_queue )
303+ self .futures_queue .appendleft ((future , get_response ))
304+ return future
305+
306+ # First drain any pending futures in the queue.
307+ while self .futures_queue :
308+ future , get_fut_response = self .futures_queue .pop ()
309+ future .update_with_response (get_fut_response )
302310
303- return responses
304- except TimeoutError as e :
305- raise TimeoutError (f"RPC call to { method } timed out." ) from e
311+ return get_response ()
306312
307313 @staticmethod
308314 def _ensure_worker_termination (worker_procs : list [BaseProcess ]):
@@ -348,9 +354,6 @@ def shutdown(self):
348354 self ._ensure_worker_termination ([w .proc for w in workers ])
349355
350356 self .shutdown_event .set ()
351- if self .io_thread_pool is not None :
352- self .io_thread_pool .shutdown (wait = False , cancel_futures = True )
353- del self .io_thread_pool
354357
355358 self .rpc_broadcast_mq = None
356359
0 commit comments