@@ -636,7 +636,7 @@ def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests:
636636 self .model_inputs ["not_need_stop" ][0 ] = True
637637 self .model_inputs ["seq_lens_this_time" ] = self .seq_lens_this_time_buffer
638638
639- def _initialize_forward_meta (self , step_use_cudagraph : bool = False ):
639+ def _initialize_forward_meta (self , step_use_cudagraph : bool = False , is_dummy_run : bool = False , substep : int = 0 ):
640640 """
641641 Initialize forward meta and attention meta data
642642 """
@@ -672,7 +672,12 @@ def _initialize_forward_meta(self, step_use_cudagraph: bool = False):
672672 for attn_backend in self .attn_backends :
673673 attn_backend .init_attention_metadata (self .forward_meta )
674674
675- self .forward_meta .step_use_cudagraph = step_use_cudagraph and self .draft_model_use_cudagraph
675+ # Notes(liuzichang):
676+ # 1. CUDA Graph capture sizes must be recorded in descending order (large → small).
677+ # 2. In multi-step execution, only the first step should be captured.
678+ self .forward_meta .step_use_cudagraph = (
679+ step_use_cudagraph and self .draft_model_use_cudagraph and not (substep > 0 and is_dummy_run )
680+ )
676681
677682 def exist_prefill (self ):
678683 """
@@ -774,7 +779,7 @@ def _post_process(self, sampled_token_ids):
774779 self .model_inputs ["step_idx" ],
775780 )
776781
777- def _propose (self , step_use_cudagraph : bool = False ):
782+ def _propose (self , step_use_cudagraph : bool = False , is_dummy_run = False ):
778783 """
779784 Main process for MTP inference.
780785 Args:
@@ -827,7 +832,9 @@ def _propose(self, step_use_cudagraph: bool = False):
827832 self .model_inputs ["output_padding_offset" ].copy_ (output_padding_offset , False )
828833
829834 # Initialize forward meta data
830- self ._initialize_forward_meta (step_use_cudagraph = step_use_cudagraph )
835+ self ._initialize_forward_meta (
836+ step_use_cudagraph = step_use_cudagraph , is_dummy_run = is_dummy_run , substep = substep
837+ )
831838 self .forward_meta .batch_id_per_token .copy_ (batch_id_per_token , False )
832839
833840 # Padding inputs for cuda graph
@@ -852,9 +859,10 @@ def _propose(self, step_use_cudagraph: bool = False):
852859 top_p_normalized_logprobs = self .model_inputs ["top_p_normalized_logprobs" ],
853860 share_inputs = self .model_inputs ,
854861 )
855-
862+ # Note(liuzichang):
863+ # paddle.clone would raise error 700 in cudaGraph mode
856864 if self .num_model_steps > 1 :
857- self .last_seq_lens_this_time = paddle . clone (self .model_inputs ["seq_lens_this_time" ])
865+ self .last_seq_lens_this_time . copy_ (self .model_inputs ["seq_lens_this_time" ], False )
858866
859867 model_output = self .model (
860868 ids_remove_padding = self .model_inputs ["ids_remove_padding" ],
@@ -1017,10 +1025,12 @@ def _extend_draft_token_with_ngram_match(self):
10171025 self .target_model_inputs ["draft_tokens" ][:] = draft_tokens .cuda ()
10181026 self .target_model_inputs ["seq_lens_this_time" ][:] = seq_lens_this_time .cuda ()
10191027
1020- def _run_impl (self , full_hidden_states : paddle .Tensor , step_use_cudagraph : bool = False ):
1028+ def _run_impl (
1029+ self , full_hidden_states : paddle .Tensor , step_use_cudagraph : bool = False , is_dummy_run : bool = False
1030+ ):
10211031 """Execute Draft Model"""
10221032 self ._prepare_inputs (full_hidden_states )
1023- self ._propose (step_use_cudagraph = step_use_cudagraph )
1033+ self ._propose (step_use_cudagraph = step_use_cudagraph , is_dummy_run = is_dummy_run )
10241034 self ._update_status ()
10251035 if self .hybrid_mode :
10261036 self ._extend_draft_token_with_ngram_match ()
0 commit comments