Skip to content

Commit 3c65afe

Browse files
committed
[PerfFix] Avoid separate thread for MP executor shm spin
Signed-off-by: Nick Hill <[email protected]>
1 parent 3758757 commit 3c65afe

File tree

8 files changed

+127
-110
lines changed

8 files changed

+127
-110
lines changed

tests/v1/executor/test_executor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import asyncio
55
import os
66
from collections.abc import Callable
7+
from concurrent.futures import Future
78
from typing import Any
89

910
import pytest
@@ -27,7 +28,7 @@ def collective_rpc(
2728
kwargs: dict | None = None,
2829
non_block: bool = False,
2930
unique_reply_rank: int | None = None,
30-
) -> list[Any]:
31+
) -> Any | list[Any] | Future[Any | list[Any]]:
3132
# Drop marker to show that this was run
3233
with open(".marker", "w"):
3334
...

vllm/distributed/kv_transfer/kv_connector/utils.py

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -221,39 +221,24 @@ def update_finished_set(
221221

222222
def async_aggregate(
223223
self,
224-
output_futures: Sequence[Future[ModelRunnerOutput | None]],
224+
output_future: Future[Sequence[ModelRunnerOutput | None]],
225225
output_rank: int = 0,
226226
) -> Future[ModelRunnerOutput | None]:
227-
"""Takes a list of futures and returns a single future which resolves
228-
to the respective list of outputs."""
227+
"""Takes a future that resolves to a list of outputs and returns a future
228+
which resolves to a single aggregated output."""
229229
result_future: Future[ModelRunnerOutput | None] = Future()
230230

231-
outputs: list[ModelRunnerOutput | None] = [None] * len(output_futures)
232-
remaining = len(output_futures)
233-
234-
def make_callback(idx):
235-
def callback(fut):
236-
if result_future.done():
237-
return
238-
239-
try:
240-
outputs[idx] = fut.result()
241-
except CancelledError:
242-
result_future.cancel()
243-
except Exception as e:
244-
result_future.set_exception(e)
245-
246-
# this check assumes io_thread_pool uses a single thread
247-
nonlocal remaining
248-
remaining -= 1
249-
if not remaining:
250-
result_future.set_result(self.aggregate(outputs, output_rank))
251-
252-
return callback
253-
254-
for i, output_future in enumerate(output_futures):
255-
output_future.add_done_callback(make_callback(i))
256-
231+
def callback(fut):
232+
if result_future.done():
233+
return
234+
try:
235+
result_future.set_result(self.aggregate(fut.result(), output_rank))
236+
except CancelledError:
237+
result_future.cancel()
238+
except Exception as e:
239+
result_future.set_exception(e)
240+
241+
output_future.add_done_callback(callback)
257242
return result_future
258243

259244

vllm/v1/executor/abstract.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def collective_rpc(
171171
args: tuple = (),
172172
kwargs: dict | None = None,
173173
non_block: Literal[True] = True,
174-
) -> list[Future[_R]]:
174+
) -> Future[list[_R]]:
175175
pass
176176

177177
@abstractmethod
@@ -219,7 +219,7 @@ def sample_tokens(
219219

220220
def sample_tokens(
221221
self, grammar_output: GrammarOutput | None, non_block: bool = False
222-
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
222+
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
223223
output = self.collective_rpc( # type: ignore[call-overload]
224224
"sample_tokens", args=(grammar_output,), non_block=non_block
225225
)

vllm/v1/executor/multiproc_executor.py

Lines changed: 64 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
import time
1010
import traceback
1111
import weakref
12+
from collections import deque
1213
from collections.abc import Callable
13-
from concurrent.futures import Future, ThreadPoolExecutor
14+
from concurrent.futures import Future
1415
from dataclasses import dataclass
1516
from enum import Enum, auto
1617
from functools import cached_property, partial
@@ -54,6 +55,28 @@
5455
logger = init_logger(__name__)
5556

5657

58+
class FutureWrapper(Future):
59+
def __init__(self, futures_queue: deque[tuple["FutureWrapper", Callable]]):
60+
self.futures_queue = futures_queue
61+
super().__init__()
62+
63+
def result(self, timeout=None):
64+
if timeout is not None:
65+
raise RuntimeError("timeout not implemented")
66+
# Drain any futures ahead of us in the queue.
67+
while not self.done():
68+
future, get_response = self.futures_queue.pop()
69+
future.update_with_response(get_response)
70+
return super().result()
71+
72+
def update_with_response(self, get_response: Callable):
73+
try:
74+
response = get_response()
75+
self.set_result(response)
76+
except Exception as e:
77+
self.set_exception(e)
78+
79+
5780
class MultiprocExecutor(Executor):
5881
supports_pp: bool = True
5982

@@ -64,7 +87,6 @@ def _init_executor(self) -> None:
6487
self.is_failed = False
6588
self.shutdown_event = threading.Event()
6689
self.failure_callback: FailureCallback | None = None
67-
self.io_thread_pool: ThreadPoolExecutor | None = None
6890

6991
self.world_size = self.parallel_config.world_size
7092
tensor_parallel_size = self.parallel_config.tensor_parallel_size
@@ -132,12 +154,7 @@ def _init_executor(self) -> None:
132154
uw.death_writer.close()
133155
self._ensure_worker_termination([uw.proc for uw in unready_workers])
134156

135-
# Note: must use only 1 IO thread to keep dequeue sequence
136-
# from the response queue.
137-
# _async_aggregate_workers_output also assumes a single IO thread.
138-
self.io_thread_pool = ThreadPoolExecutor(
139-
max_workers=1, thread_name_prefix="mp_exec_io"
140-
)
157+
self.futures_queue = deque[tuple[FutureWrapper, Callable]]()
141158

142159
self.output_rank = self._get_output_rank()
143160
self.has_connector = self.vllm_config.kv_transfer_config is not None
@@ -195,14 +212,13 @@ def _execute_with_aggregation(
195212
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
196213
if not self.has_connector:
197214
# get output only from a single worker (output_rank)
198-
(output,) = self.collective_rpc(
215+
return self.collective_rpc(
199216
method,
200217
args=args,
201218
unique_reply_rank=self.output_rank,
202219
non_block=non_block,
203220
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
204221
)
205-
return output
206222

207223
# get output from all workers
208224
outputs = self.collective_rpc(
@@ -223,20 +239,21 @@ def execute_dummy_batch(self) -> None:
223239

224240
def take_draft_token_ids(self) -> DraftTokenIds | None:
225241
# OPTIMIZATION: Get output only from a single worker (output_rank)
226-
outputs = self.collective_rpc(
242+
return self.collective_rpc(
227243
"take_draft_token_ids", unique_reply_rank=self.output_rank
228244
)
229-
return outputs[0]
230245

231-
def collective_rpc(
246+
def collective_rpc( # type: ignore[override]
232247
self,
233248
method: str | Callable,
234249
timeout: float | None = None,
235250
args: tuple = (),
236251
kwargs: dict | None = None,
237252
non_block: bool = False,
238253
unique_reply_rank: int | None = None,
239-
) -> list[Any]:
254+
) -> Any | list[Any] | Future[Any | list[Any]]:
255+
"""Returns single result if unique_reply_rank is provided, otherwise list."""
256+
240257
if self.is_failed:
241258
raise RuntimeError("Executor failed.")
242259

@@ -246,63 +263,52 @@ def collective_rpc(
246263
# NOTE: If the args are heterogeneous, then we pack them into a list,
247264
# and unpack them in the method of every worker, because every worker
248265
# knows their own rank.
249-
try:
250-
if isinstance(method, str):
251-
send_method = method
252-
else:
253-
send_method = cloudpickle.dumps(
254-
method, protocol=pickle.HIGHEST_PROTOCOL
255-
)
256-
self.rpc_broadcast_mq.enqueue(
257-
(send_method, args, kwargs, unique_reply_rank)
258-
)
259266

260-
workers = (
261-
(self.workers[unique_reply_rank],)
262-
if unique_reply_rank is not None
263-
else self.workers
264-
)
265-
responses = []
267+
if isinstance(method, str):
268+
send_method = method
269+
else:
270+
send_method = cloudpickle.dumps(method, protocol=pickle.HIGHEST_PROTOCOL)
271+
self.rpc_broadcast_mq.enqueue((send_method, args, kwargs, unique_reply_rank))
266272

267-
def get_response(
268-
w: WorkerProcHandle,
269-
dequeue_timeout: float | None = None,
270-
cancel_event: threading.Event | None = None,
271-
):
272-
status, result = w.worker_response_mq.dequeue(
273-
timeout=dequeue_timeout, cancel=cancel_event
274-
)
273+
workers = (
274+
(self.workers[unique_reply_rank],)
275+
if unique_reply_rank is not None
276+
else self.workers
277+
)
275278

276-
if status != WorkerProc.ResponseStatus.SUCCESS:
277-
raise RuntimeError(
278-
f"Worker failed with error '{result}', please check the"
279-
" stack trace above for the root cause"
280-
)
281-
return result
279+
shutdown_event = self.shutdown_event
282280

281+
def get_response():
282+
responses = []
283283
for w in workers:
284284
dequeue_timeout = (
285285
None if deadline is None else (deadline - time.monotonic())
286286
)
287-
288-
if self.io_thread_pool is not None:
289-
# We must consume worker_response_mq from a single thread.
290-
result = self.io_thread_pool.submit( # type: ignore
291-
get_response, w, dequeue_timeout, self.shutdown_event
287+
try:
288+
status, result = w.worker_response_mq.dequeue(
289+
timeout=dequeue_timeout, cancel=shutdown_event
292290
)
293-
if not non_block:
294-
result = result.result()
295-
elif not non_block:
296-
result = get_response(w, dequeue_timeout, self.shutdown_event)
297-
else:
291+
except TimeoutError as e:
292+
raise TimeoutError(f"RPC call to {method} timed out.") from e
293+
if status != WorkerProc.ResponseStatus.SUCCESS:
298294
raise RuntimeError(
299-
"non_block can only be used when max_concurrent_batches > 1"
295+
f"Worker failed with error '{result}', please check the"
296+
" stack trace above for the root cause"
300297
)
301298
responses.append(result)
299+
return responses[0] if unique_reply_rank is not None else responses
300+
301+
if non_block:
302+
future = FutureWrapper(self.futures_queue)
303+
self.futures_queue.appendleft((future, get_response))
304+
return future
305+
306+
# First drain any pending futures in the queue.
307+
while self.futures_queue:
308+
future, get_fut_response = self.futures_queue.pop()
309+
future.update_with_response(get_fut_response)
302310

303-
return responses
304-
except TimeoutError as e:
305-
raise TimeoutError(f"RPC call to {method} timed out.") from e
311+
return get_response()
306312

307313
@staticmethod
308314
def _ensure_worker_termination(worker_procs: list[BaseProcess]):
@@ -348,9 +354,6 @@ def shutdown(self):
348354
self._ensure_worker_termination([w.proc for w in workers])
349355

350356
self.shutdown_event.set()
351-
if self.io_thread_pool is not None:
352-
self.io_thread_pool.shutdown(wait=False, cancel_futures=True)
353-
del self.io_thread_pool
354357

355358
self.rpc_broadcast_mq = None
356359

vllm/v1/executor/ray_executor.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -441,20 +441,19 @@ def sample_tokens( # type: ignore[override]
441441
assert self.kv_output_aggregator is not None
442442
if not non_block:
443443
# Block and get results from all workers
444-
outputs = [ref.get() for ref in refs]
445-
return self.kv_output_aggregator.aggregate(outputs)
444+
return self.kv_output_aggregator.aggregate(ray.get(refs))
446445

447446
# Return a future that will aggregate outputs from all workers
448447
return FutureWrapper(refs, self.kv_output_aggregator)
449448

450-
def collective_rpc(
449+
def collective_rpc( # type: ignore[override]
451450
self,
452451
method: str | Callable,
453452
timeout: float | None = None,
454453
args: tuple = (),
455454
kwargs: dict[str, Any] | None = None,
456455
non_block: bool = False,
457-
) -> list[Any]:
456+
) -> list[Any] | Future[list[Any]]:
458457
"""Runs the given method on all workers."""
459458
sent_method = method if isinstance(method, str) else cloudpickle.dumps(method)
460459
del method
@@ -470,7 +469,7 @@ def collective_rpc(
470469

471470
# Get the results of the ray workers.
472471
if non_block:
473-
return [FutureWrapper((output,)) for output in ray_worker_outputs]
472+
return FutureWrapper(ray_worker_outputs)
474473

475474
return ray.get(ray_worker_outputs, timeout=timeout)
476475

vllm/v1/executor/ray_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,19 +141,19 @@ class FutureWrapper(Future):
141141
the result() call. If not only the first worker's output is returned.
142142
"""
143143

144-
def __init__(self, refs, aggregator: KVOutputAggregator | None = None):
144+
def __init__(self, ref, aggregator: KVOutputAggregator | None = None):
145145
super().__init__()
146-
self.refs = refs
146+
self.ref = ref
147147
self.aggregator = aggregator
148148

149149
def result(self, timeout=None):
150150
if timeout is not None:
151151
raise NotImplementedError("timeout is not supported")
152152

153+
outputs = ray.get(self.ref, timeout=timeout)
153154
if self.aggregator is None:
154-
return self.refs[0].get()
155+
return outputs[0]
155156

156-
outputs = [ref.get() for ref in self.refs]
157157
return self.aggregator.aggregate(outputs, output_rank=0)
158158

159159

0 commit comments

Comments
 (0)