Skip to content

Commit 07a3814

Browse files
committed
[main] Support MTP shape with ACLgraph
Signed-off-by: lilinsiman <[email protected]>
1 parent 755b635 commit 07a3814

File tree

2 files changed

+62
-12
lines changed

2 files changed

+62
-12
lines changed

vllm_ascend/utils.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,55 @@ def update_default_aclgraph_sizes(vllm_config: VllmConfig) -> None:
452452
update_cudagraph_capture_sizes(vllm_config,
453453
new_cudagraph_capture_sizes)
454454

455+
# modify the default capture_sizes for num_speculative_tokens >= 1 scenario.
456+
# this is mainly because in the scenario where MTP is superimposed with Full Graph, the FIA operator needs to perform
457+
# padding operations to adapt to its actual_seq_lengths parameter. The padding operation will
458+
# expand each request to the maximum request count under MTP, therefore the input shape must be
459+
# equal to a multiple of the MTP layer count (k+1). Assuming k=2, capture_sizes = [3, 6, 9, 15, 18, ...].
460+
# Consequently, it is necessary to modify the default captured graph shape of Full Graph to
461+
# accommodate this requirement of the FIA operator.
462+
# TODO: It is more appropriate to place the initialization of shape capture for the fullgraph of the FIA
463+
# operator adapted for MTP in the vLLM community. Therefore, this section will be removed
464+
# migrated to the vLLM community.
465+
from vllm.config.compilation import CUDAGraphMode
466+
aclgraph_mode = vllm_config.compilation_config.cudagraph_mode
467+
if vllm_config.speculative_config is not None and \
468+
aclgraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
469+
num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens
470+
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
471+
target_sizes = (num_speculative_tokens + 1) * max_num_seqs
472+
original_sizes, vllm_config.compilation_config.cudagraph_capture_sizes = \
473+
vllm_config.compilation_config.cudagraph_capture_sizes, None
474+
assert len(original_sizes) > 0
475+
assert max_num_seqs > 0
476+
assert num_speculative_tokens > 0
477+
if num_speculative_tokens > 1:
478+
if original_sizes[0] < (num_speculative_tokens + 1) * max_num_seqs:
479+
new_original_sizes = sorted(
480+
set(
481+
list(range(1, min(10, max_num_seqs + 1), 2)) +
482+
list(range(8, max_num_seqs + 1, 4))))
483+
enlarged_sizes = [(num_speculative_tokens + 1) * sizes
484+
for sizes in new_original_sizes]
485+
if enlarged_sizes[-1] < target_sizes:
486+
enlarged_sizes.append(target_sizes)
487+
update_cudagraph_capture_sizes(vllm_config, enlarged_sizes)
488+
logger.info(
489+
"Adjusted ACL full graphs: %s → %s for speculative decoding",
490+
original_sizes, enlarged_sizes)
491+
else:
492+
vllm_config.compilation_config.cudagraph_capture_sizes = original_sizes
493+
if num_speculative_tokens == 1:
494+
padding_sizes = original_sizes.copy()
495+
if padding_sizes[-1] < target_sizes:
496+
padding_sizes.append(target_sizes)
497+
update_cudagraph_capture_sizes(vllm_config, padding_sizes)
498+
logger.info(
499+
"Adjusted ACL full graphs: %s → %s for speculative decoding",
500+
original_sizes, padding_sizes)
501+
else:
502+
vllm_config.compilation_config.cudagraph_capture_sizes = original_sizes
503+
455504

456505
def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
457506
"""Update ACL graph capture sizes based on hardware limitations"""
@@ -571,13 +620,24 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
571620
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
572621
original_sizes, compilation_config.cudagraph_capture_sizes = \
573622
compilation_config.cudagraph_capture_sizes, None
623+
new_original_sizes = sorted(
624+
set(
625+
list(range(1, min(10, max_num_seqs + 1), 2)) +
626+
list(range(8, max_num_seqs + 1, 4))))
627+
step = (len(new_original_sizes) - 1) / (max_num_batch_sizes - 1)
628+
indices = [round(i * step) for i in range(max_num_batch_sizes)]
629+
indices[0], indices[-1] = 0, len(new_original_sizes) - 1
630+
new_sampled_sizes = [new_original_sizes[i] for i in indices]
631+
target_sizes = (num_speculative_tokens + 1) * max_num_seqs
574632
assert len(original_sizes) > 0
575633
if original_sizes[0] < (num_speculative_tokens + 1) * max_num_seqs:
576634
enlarged_sizes = [(num_speculative_tokens + 1) * size
577-
for size in original_sizes]
635+
for size in new_sampled_sizes]
636+
if enlarged_sizes[-1] < target_sizes:
637+
enlarged_sizes[-1] = target_sizes
578638
update_cudagraph_capture_sizes(vllm_config, enlarged_sizes)
579639
logger.info(
580-
"Adjusted ACL graphs: %s → %s for speculative decoding",
640+
"Adjusted PieceWise ACL graphs: %s → %s for speculative decoding",
581641
original_sizes, enlarged_sizes)
582642
else:
583643
compilation_config.cudagraph_capture_sizes = original_sizes

vllm_ascend/worker/model_runner_v1.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4010,16 +4010,6 @@ def _capture_model(self):
40104010
and x >= self.uniform_decode_query_len
40114011
]
40124012
compilation_cases_decode = sorted(decode_cudagraph_batch_sizes)
4013-
# TODO: refactor this when vLLM supports mtp>1
4014-
if not all(x % self.uniform_decode_query_len == 0
4015-
for x in decode_cudagraph_batch_sizes):
4016-
raise ValueError(
4017-
"In the MTP fullgraph scenario, each graph size must be an integer multiple of "
4018-
f"(num_speculative_tokens + 1): {self.uniform_decode_query_len}. "
4019-
f"Please modify the cudagraph_capture_sizes variable to be integer multiple of {self.uniform_decode_query_len}, "
4020-
f"while ensuring the maximum cudagraph_capture_sizes does not exceed max_num_seqs * (num_speculative_tokens + 1): {max_num_tokens}. "
4021-
"For example, with MTP=2 and max_num_seqs=16, we recommend setting cudagraph_capture_sizes to [48]."
4022-
)
40234013
self._capture_aclgraphs(
40244014
compilation_cases=compilation_cases_decode,
40254015
aclgraph_runtime_mode=CUDAGraphMode.FULL,

0 commit comments

Comments
 (0)