@@ -422,9 +422,12 @@ def __init__(
422422 # cuda event to synchronize use of reused CPU tensors between steps
423423 # when async scheduling is enabled.
424424 self .prepare_inputs_event : torch .cuda .Event | None = None
425+ self .mm_preproc_event : torch .cuda .Event | None = None
425426 if self .use_async_scheduling :
426427 self .async_output_copy_stream = torch .cuda .Stream ()
427428 self .prepare_inputs_event = torch .cuda .Event ()
429+ if self .supports_mm_inputs or self .model_config .is_encoder_decoder :
430+ self .mm_preproc_event = torch .cuda .Event ()
428431
429432 # self.cudagraph_batch_sizes sorts in ascending order.
430433 if (
@@ -2462,19 +2465,19 @@ def _bookkeeping_sync(
24622465 )
24632466
24642467 @contextmanager
2465- def synchronize_input_prep (self ):
2466- if self . prepare_inputs_event is None :
2468+ def synchronize_async_cpu (self , event : torch . cuda . Event ):
2469+ if event is None :
24672470 yield
24682471 return
24692472
24702473 # Ensure prior step has finished with reused CPU tensors.
24712474 # This is required in the async scheduling case because
24722475 # the CPU->GPU transfer happens async.
2473- self . prepare_inputs_event .synchronize ()
2476+ event .synchronize ()
24742477 try :
24752478 yield
24762479 finally :
2477- self . prepare_inputs_event .record ()
2480+ event .record ()
24782481
24792482 def _model_forward (
24802483 self ,
@@ -2521,7 +2524,7 @@ def execute_model(
25212524 )
25222525 num_scheduled_tokens = scheduler_output .total_num_scheduled_tokens
25232526 with record_function_or_nullcontext ("gpu_model_runner: preprocess" ):
2524- with self .synchronize_input_prep ( ):
2527+ with self .synchronize_async_cpu ( self . prepare_inputs_event ):
25252528 # Update persistent batch states.
25262529 self ._update_states (scheduler_output )
25272530
@@ -2602,16 +2605,17 @@ def execute_model(
26022605 scheduler_output .total_num_scheduled_tokens
26032606 )
26042607
2605- (
2606- input_ids ,
2607- inputs_embeds ,
2608- positions ,
2609- intermediate_tensors ,
2610- model_kwargs ,
2611- ec_connector_output ,
2612- ) = self ._preprocess (
2613- scheduler_output , num_input_tokens , intermediate_tensors
2614- )
2608+ with self .synchronize_async_cpu (self .mm_preproc_event ):
2609+ (
2610+ input_ids ,
2611+ inputs_embeds ,
2612+ positions ,
2613+ intermediate_tensors ,
2614+ model_kwargs ,
2615+ ec_connector_output ,
2616+ ) = self ._preprocess (
2617+ scheduler_output , num_input_tokens , intermediate_tensors
2618+ )
26152619
26162620 uniform_decode = (
26172621 max_num_scheduled_tokens == self .uniform_decode_query_len
0 commit comments