@@ -452,6 +452,55 @@ def update_default_aclgraph_sizes(vllm_config: VllmConfig) -> None:
452452 update_cudagraph_capture_sizes (vllm_config ,
453453 new_cudagraph_capture_sizes )
454454
455+ # modify the default capture_sizes for num_speculative_tokens >= 1 scenario.
456+ # this is mainly because in the scenario where MTP is superimposed with Full Graph, the FIA operator needs to perform
457+ # padding operations to adapt to its actual_seq_lengths parameter. The padding operation will
458+ # expand each request to the maximum request count under MTP, therefore the input shape must be
459+ # equal to a multiple of the MTP layer count (k+1). Assuming k=2, capture_sizes = [3, 6, 9, 15, 18, ...].
460+ # Consequently, it is necessary to modify the default captured graph shape of Full Graph to
461+ # accommodate this requirement of the FIA operator.
462+ # TODO: It is more appropriate to place the initialization of shape capture for the fullgraph of the FIA
463+ # operator adapted for MTP in the vLLM community. Therefore, this section will be removed
464+ # migrated to the vLLM community.
465+ from vllm .config .compilation import CUDAGraphMode
466+ aclgraph_mode = vllm_config .compilation_config .cudagraph_mode
467+ if vllm_config .speculative_config is not None and \
468+ aclgraph_mode == CUDAGraphMode .FULL_DECODE_ONLY :
469+ num_speculative_tokens = vllm_config .speculative_config .num_speculative_tokens
470+ max_num_seqs = vllm_config .scheduler_config .max_num_seqs
471+ target_sizes = (num_speculative_tokens + 1 ) * max_num_seqs
472+ original_sizes , vllm_config .compilation_config .cudagraph_capture_sizes = \
473+ vllm_config .compilation_config .cudagraph_capture_sizes , None
474+ assert len (original_sizes ) > 0
475+ assert max_num_seqs > 0
476+ assert num_speculative_tokens > 0
477+ if num_speculative_tokens > 1 :
478+ if original_sizes [0 ] < (num_speculative_tokens + 1 ) * max_num_seqs :
479+ new_original_sizes = sorted (
480+ set (
481+ list (range (1 , min (10 , max_num_seqs + 1 ), 2 )) +
482+ list (range (8 , max_num_seqs + 1 , 4 ))))
483+ enlarged_sizes = [(num_speculative_tokens + 1 ) * sizes
484+ for sizes in new_original_sizes ]
485+ if enlarged_sizes [- 1 ] < target_sizes :
486+ enlarged_sizes .append (target_sizes )
487+ update_cudagraph_capture_sizes (vllm_config , enlarged_sizes )
488+ logger .info (
489+ "Adjusted ACL full graphs: %s → %s for speculative decoding" ,
490+ original_sizes , enlarged_sizes )
491+ else :
492+ vllm_config .compilation_config .cudagraph_capture_sizes = original_sizes
493+ if num_speculative_tokens == 1 :
494+ padding_sizes = original_sizes .copy ()
495+ if padding_sizes [- 1 ] < target_sizes :
496+ padding_sizes .append (target_sizes )
497+ update_cudagraph_capture_sizes (vllm_config , padding_sizes )
498+ logger .info (
499+ "Adjusted ACL full graphs: %s → %s for speculative decoding" ,
500+ original_sizes , padding_sizes )
501+ else :
502+ vllm_config .compilation_config .cudagraph_capture_sizes = original_sizes
503+
455504
456505def update_aclgraph_sizes (vllm_config : VllmConfig ) -> None :
457506 """Update ACL graph capture sizes based on hardware limitations"""
@@ -571,13 +620,24 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
571620 max_num_seqs = vllm_config .scheduler_config .max_num_seqs
572621 original_sizes , compilation_config .cudagraph_capture_sizes = \
573622 compilation_config .cudagraph_capture_sizes , None
623+ new_original_sizes = sorted (
624+ set (
625+ list (range (1 , min (10 , max_num_seqs + 1 ), 2 )) +
626+ list (range (8 , max_num_seqs + 1 , 4 ))))
627+ step = (len (new_original_sizes ) - 1 ) / (max_num_batch_sizes - 1 )
628+ indices = [round (i * step ) for i in range (max_num_batch_sizes )]
629+ indices [0 ], indices [- 1 ] = 0 , len (new_original_sizes ) - 1
630+ new_sampled_sizes = [new_original_sizes [i ] for i in indices ]
631+ target_sizes = (num_speculative_tokens + 1 ) * max_num_seqs
574632 assert len (original_sizes ) > 0
575633 if original_sizes [0 ] < (num_speculative_tokens + 1 ) * max_num_seqs :
576634 enlarged_sizes = [(num_speculative_tokens + 1 ) * size
577- for size in original_sizes ]
635+ for size in new_sampled_sizes ]
636+ if enlarged_sizes [- 1 ] < target_sizes :
637+ enlarged_sizes [- 1 ] = target_sizes
578638 update_cudagraph_capture_sizes (vllm_config , enlarged_sizes )
579639 logger .info (
580- "Adjusted ACL graphs: %s → %s for speculative decoding" ,
640+ "Adjusted PieceWise ACL graphs: %s → %s for speculative decoding" ,
581641 original_sizes , enlarged_sizes )
582642 else :
583643 compilation_config .cudagraph_capture_sizes = original_sizes
0 commit comments