Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 63 additions & 3 deletions vllm_ascend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,55 @@ def update_default_aclgraph_sizes(vllm_config: VllmConfig) -> None:
update_cudagraph_capture_sizes(vllm_config,
new_cudagraph_capture_sizes)

# modify the default capture_sizes for num_speculative_tokens >= 1 scenario.
# this is mainly because in the scenario where MTP is superimposed with Full Graph, the FIA operator needs to perform
# padding operations to adapt to its actual_seq_lengths parameter. The padding operation will
# expand each request to the maximum request count under MTP, therefore the input shape must be
# equal to a multiple of the MTP layer count (k+1). Assuming k=2, capture_sizes = [3, 6, 9, 15, 18, ...].
# Consequently, it is necessary to modify the default captured graph shape of Full Graph to
# accommodate this requirement of the FIA operator.
# TODO: It is more appropriate to place the initialization of shape capture for the fullgraph of the FIA
# operator adapted for MTP in the vLLM community. Therefore, this section will be removed
# migrated to the vLLM community.
from vllm.config.compilation import CUDAGraphMode
aclgraph_mode = vllm_config.compilation_config.cudagraph_mode
if vllm_config.speculative_config is not None and \
aclgraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
target_sizes = (num_speculative_tokens + 1) * max_num_seqs
original_sizes, vllm_config.compilation_config.cudagraph_capture_sizes = \
vllm_config.compilation_config.cudagraph_capture_sizes, None
assert len(original_sizes) > 0

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lack of clear error messages when assertions fail

assert max_num_seqs > 0
assert num_speculative_tokens > 0
if num_speculative_tokens > 1:
if original_sizes[0] < (num_speculative_tokens + 1) * max_num_seqs:
new_original_sizes = sorted(
set(
list(range(1, min(10, max_num_seqs + 1), 2)) +

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardcoded range parameters (1, 10, 2 / 8, 4) lack explanation.

list(range(8, max_num_seqs + 1, 4))))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The case where max_num_seqs < 8 was not considered.

enlarged_sizes = [(num_speculative_tokens + 1) * sizes
for sizes in new_original_sizes]
if enlarged_sizes[-1] < target_sizes:
enlarged_sizes.append(target_sizes)
update_cudagraph_capture_sizes(vllm_config, enlarged_sizes)
logger.info(
"Adjusted ACL full graphs: %s → %s for speculative decoding",
original_sizes, enlarged_sizes)
else:
vllm_config.compilation_config.cudagraph_capture_sizes = original_sizes
if num_speculative_tokens == 1:
padding_sizes = original_sizes.copy()
if padding_sizes[-1] < target_sizes:
padding_sizes.append(target_sizes)
update_cudagraph_capture_sizes(vllm_config, padding_sizes)
logger.info(
"Adjusted ACL full graphs: %s → %s for speculative decoding",
original_sizes, padding_sizes)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comparison before and after the change is not intuitive enough; descriptions of the sizes before and after should be provided.

else:
vllm_config.compilation_config.cudagraph_capture_sizes = original_sizes


def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
"""Update ACL graph capture sizes based on hardware limitations"""
Expand Down Expand Up @@ -564,20 +613,31 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:

# default or defined cudagraph_capture_sizes may not consider num_speculative_tokens>1 scenario
# the maximum size cudagraph_capture_sizes[0] should be greater or equal than
# (num_speculative_tokens+1)*max_num_seqs, otherwise draft model will run in eager mode
# (num_speculative_tokens+1) * max_num_seqs, otherwise draft model will run in eager mode
if vllm_config.speculative_config is not None and \
vllm_config.speculative_config.num_speculative_tokens > 1:
num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
original_sizes, compilation_config.cudagraph_capture_sizes = \
compilation_config.cudagraph_capture_sizes, None
new_original_sizes = sorted(
set(
list(range(1, min(10, max_num_seqs + 1), 2)) +
list(range(8, max_num_seqs + 1, 4))))
step = (len(new_original_sizes) - 1) / (max_num_batch_sizes - 1)
indices = [round(i * step) for i in range(max_num_batch_sizes)]
indices[0], indices[-1] = 0, len(new_original_sizes) - 1
new_sampled_sizes = [new_original_sizes[i] for i in indices]
target_sizes = (num_speculative_tokens + 1) * max_num_seqs
assert len(original_sizes) > 0
if original_sizes[0] < (num_speculative_tokens + 1) * max_num_seqs:
enlarged_sizes = [(num_speculative_tokens + 1) * size
for size in original_sizes]
for size in new_sampled_sizes]
if enlarged_sizes[-1] < target_sizes:
enlarged_sizes[-1] = target_sizes
update_cudagraph_capture_sizes(vllm_config, enlarged_sizes)
logger.info(
"Adjusted ACL graphs: %s → %s for speculative decoding",
"Adjusted PieceWise ACL graphs: %s → %s for speculative decoding",
original_sizes, enlarged_sizes)
else:
compilation_config.cudagraph_capture_sizes = original_sizes
Expand Down
10 changes: 0 additions & 10 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -4010,16 +4010,6 @@ def _capture_model(self):
and x >= self.uniform_decode_query_len
]
compilation_cases_decode = sorted(decode_cudagraph_batch_sizes)
# TODO: refactor this when vLLM supports mtp>1
if not all(x % self.uniform_decode_query_len == 0
for x in decode_cudagraph_batch_sizes):
raise ValueError(
"In the MTP fullgraph scenario, each graph size must be an integer multiple of "
f"(num_speculative_tokens + 1): {self.uniform_decode_query_len}. "
f"Please modify the cudagraph_capture_sizes variable to be integer multiple of {self.uniform_decode_query_len}, "
f"while ensuring the maximum cudagraph_capture_sizes does not exceed max_num_seqs * (num_speculative_tokens + 1): {max_num_tokens}. "
"For example, with MTP=2 and max_num_seqs=16, we recommend setting cudagraph_capture_sizes to [48]."
)
self._capture_aclgraphs(
compilation_cases=compilation_cases_decode,
aclgraph_runtime_mode=CUDAGraphMode.FULL,
Expand Down
Loading