@@ -636,7 +636,7 @@ def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests:
636636 self .model_inputs ["not_need_stop" ][0 ] = True
637637 self .model_inputs ["seq_lens_this_time" ] = self .seq_lens_this_time_buffer
638638
639- def _initialize_forward_meta (self , step_use_cudagraph : bool = False ):
639+ def _initialize_forward_meta (self , step_use_cudagraph : bool = False , is_dummy_run : bool = False , substep : int = 0 ):
640640 """
641641 Initialize forward meta and attention meta data
642642 """
@@ -672,7 +672,12 @@ def _initialize_forward_meta(self, step_use_cudagraph: bool = False):
672672 for attn_backend in self .attn_backends :
673673 attn_backend .init_attention_metadata (self .forward_meta )
674674
675- self .forward_meta .step_use_cudagraph = step_use_cudagraph and self .draft_model_use_cudagraph
675+ # Notes(liuzichang):
676+ # 1. CUDA Graph capture sizes must be recorded in descending order (large → small).
677+ # 2. In multi-step execution, only the first step should be captured.
678+ self .forward_meta .step_use_cudagraph = (
679+ step_use_cudagraph and self .draft_model_use_cudagraph and not (substep > 0 and is_dummy_run )
680+ )
676681
677682 def exist_prefill (self ):
678683 """
@@ -827,7 +832,9 @@ def _propose(self, step_use_cudagraph: bool = False):
827832 self .model_inputs ["output_padding_offset" ].copy_ (output_padding_offset , False )
828833
829834 # Initialize forward meta data
830- self ._initialize_forward_meta (step_use_cudagraph = step_use_cudagraph )
835+ self ._initialize_forward_meta (
836+ step_use_cudagraph = step_use_cudagraph , is_dummy_run = is_dummy_run , substep = substep
837+ )
831838 self .forward_meta .batch_id_per_token .copy_ (batch_id_per_token , False )
832839
833840 # Padding inputs for cuda graph
@@ -852,9 +859,10 @@ def _propose(self, step_use_cudagraph: bool = False):
852859 top_p_normalized_logprobs = self .model_inputs ["top_p_normalized_logprobs" ],
853860 share_inputs = self .model_inputs ,
854861 )
855-
862+ # Note(liuzichang):
863+ # paddle.clone would raise error 700 in cudaGraph mode
856864 if self .num_model_steps > 1 :
857- self .last_seq_lens_this_time = paddle . clone (self .model_inputs ["seq_lens_this_time" ])
865+ self .last_seq_lens_this_time . copy_ (self .model_inputs ["seq_lens_this_time" ], False )
858866
859867 model_output = self .model (
860868 ids_remove_padding = self .model_inputs ["ids_remove_padding" ],
0 commit comments