Skip to content

Commit 467aef5

Browse files
committed
test async mm fix
Signed-off-by: Nick Hill <[email protected]>
1 parent 5efba46 commit 467aef5

File tree

1 file changed

+19
-15
lines changed

1 file changed

+19
-15
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -422,9 +422,12 @@ def __init__(
422422
# cuda event to synchronize use of reused CPU tensors between steps
423423
# when async scheduling is enabled.
424424
self.prepare_inputs_event: torch.cuda.Event | None = None
425+
self.mm_preproc_event: torch.cuda.Event | None = None
425426
if self.use_async_scheduling:
426427
self.async_output_copy_stream = torch.cuda.Stream()
427428
self.prepare_inputs_event = torch.cuda.Event()
429+
if self.supports_mm_inputs or self.model_config.is_encoder_decoder:
430+
self.mm_preproc_event = torch.cuda.Event()
428431

429432
# self.cudagraph_batch_sizes sorts in ascending order.
430433
if (
@@ -2462,19 +2465,19 @@ def _bookkeeping_sync(
24622465
)
24632466

24642467
@contextmanager
2465-
def synchronize_input_prep(self):
2466-
if self.prepare_inputs_event is None:
2468+
def synchronize_async_cpu(self, event: torch.cuda.Event):
2469+
if event is None:
24672470
yield
24682471
return
24692472

24702473
# Ensure prior step has finished with reused CPU tensors.
24712474
# This is required in the async scheduling case because
24722475
# the CPU->GPU transfer happens async.
2473-
self.prepare_inputs_event.synchronize()
2476+
event.synchronize()
24742477
try:
24752478
yield
24762479
finally:
2477-
self.prepare_inputs_event.record()
2480+
event.record()
24782481

24792482
def _model_forward(
24802483
self,
@@ -2521,7 +2524,7 @@ def execute_model(
25212524
)
25222525
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
25232526
with record_function_or_nullcontext("gpu_model_runner: preprocess"):
2524-
with self.synchronize_input_prep():
2527+
with self.synchronize_async_cpu(self.prepare_inputs_event):
25252528
# Update persistent batch states.
25262529
self._update_states(scheduler_output)
25272530

@@ -2602,16 +2605,17 @@ def execute_model(
26022605
scheduler_output.total_num_scheduled_tokens
26032606
)
26042607

2605-
(
2606-
input_ids,
2607-
inputs_embeds,
2608-
positions,
2609-
intermediate_tensors,
2610-
model_kwargs,
2611-
ec_connector_output,
2612-
) = self._preprocess(
2613-
scheduler_output, num_input_tokens, intermediate_tensors
2614-
)
2608+
with self.synchronize_async_cpu(self.mm_preproc_event):
2609+
(
2610+
input_ids,
2611+
inputs_embeds,
2612+
positions,
2613+
intermediate_tensors,
2614+
model_kwargs,
2615+
ec_connector_output,
2616+
) = self._preprocess(
2617+
scheduler_output, num_input_tokens, intermediate_tensors
2618+
)
26152619

26162620
uniform_decode = (
26172621
max_num_scheduled_tokens == self.uniform_decode_query_len

0 commit comments

Comments
 (0)