-
Notifications
You must be signed in to change notification settings - Fork 624
[main] Support MTP shape with ACLgraph #4523
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for MTP (Multi-Token Prediction) with ACLgraph by adjusting the graph capture sizes to be multiples of (num_speculative_tokens + 1). This is a necessary change for the FIA operator in FULL_DECODE_ONLY graph mode. The changes also remove a runtime check that is now obsolete.
The overall approach is correct, but I've found some issues in the implementation of the size adjustment logic in vllm_ascend/utils.py.
- In
update_default_aclgraph_sizes, the logic fornum_speculative_tokens == 1is inconsistent with the stated requirements, and a condition check fornum_speculative_tokens > 1uses an incorrect index. I've suggested a unified and corrected implementation for this block. - In
update_aclgraph_sizes, there's a similar incorrect condition check.
I've provided detailed comments and suggestions to address these points. Addressing these will make the implementation more robust and correct.
| if num_speculative_tokens > 1: | ||
| if original_sizes[0] < (num_speculative_tokens + 1) * max_num_seqs: | ||
| new_original_sizes = sorted(set(list(range(1, min(10, max_num_seqs + 1), 2)) + list(range(8, max_num_seqs + 1, 4)))) | ||
| enlarged_sizes = [(num_speculative_tokens + 1) * sizes for sizes in new_original_sizes] | ||
| if enlarged_sizes[-1] < target_sizes: | ||
| enlarged_sizes.append(target_sizes) | ||
| update_cudagraph_capture_sizes(vllm_config, enlarged_sizes) | ||
| logger.info( | ||
| "Adjusted ACL full graphs: %s → %s for speculative decoding", | ||
| original_sizes, enlarged_sizes) | ||
| else: | ||
| vllm_config.compilation_config.cudagraph_capture_sizes = original_sizes | ||
| if num_speculative_tokens == 1: | ||
| padding_sizes = original_sizes.copy() | ||
| if padding_sizes[-1] < target_sizes: | ||
| padding_sizes.append(target_sizes) | ||
| update_cudagraph_capture_sizes(vllm_config, padding_sizes) | ||
| logger.info( | ||
| "Adjusted ACL full graphs: %s → %s for speculative decoding", | ||
| original_sizes, padding_sizes) | ||
| else: | ||
| vllm_config.compilation_config.cudagraph_capture_sizes = original_sizes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a couple of issues in this logic block:
- At line 478,
original_sizes[0]is used in the condition. This seems incorrect as it checks the smallest capture size. It should probably beoriginal_sizes[-1]to check if the largest capture size is sufficient, similar to the logic at line 491. - The logic for
num_speculative_tokens == 1(lines 489-498) does not ensure that capture sizes are multiples ofnum_speculative_tokens + 1(which is 2). This contradicts the comment on lines 456-461 which states this is a requirement for the FIA operator. The logic fornum_speculative_tokens > 1correctly enforces this multiplication.
To address these issues and improve clarity, the logic for all num_speculative_tokens > 0 can be unified. Here is a suggested implementation:
if num_speculative_tokens > 0:
# The check should be against the largest original size. Also, the logic is unified
# for all num_speculative_tokens > 0 to enforce multiples of (num_speculative_tokens + 1).
if not original_sizes or original_sizes[-1] < target_sizes:
new_original_sizes = sorted(set(list(range(1, min(10, max_num_seqs + 1), 2)) + list(range(8, max_num_seqs + 1, 4))))
enlarged_sizes = [(num_speculative_tokens + 1) * size for size in new_original_sizes]
if not enlarged_sizes or enlarged_sizes[-1] < target_sizes:
enlarged_sizes.append(target_sizes)
final_sizes = sorted(list(set(enlarged_sizes)))
update_cudagraph_capture_sizes(vllm_config, final_sizes)
logger.info(
"Adjusted ACL full graphs: %s → %s for speculative decoding",
original_sizes, final_sizes)
else:
vllm_config.compilation_config.cudagraph_capture_sizes = original_sizes1bf7915 to
07a3814
Compare
Signed-off-by: lilinsiman <[email protected]>
07a3814 to
25fef14
Compare
|
Please check vllm-project/vllm#28315 |
| target_sizes = (num_speculative_tokens + 1) * max_num_seqs | ||
| original_sizes, vllm_config.compilation_config.cudagraph_capture_sizes = \ | ||
| vllm_config.compilation_config.cudagraph_capture_sizes, None | ||
| assert len(original_sizes) > 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lack of clear error messages when assertions fail
| if original_sizes[0] < (num_speculative_tokens + 1) * max_num_seqs: | ||
| new_original_sizes = sorted( | ||
| set( | ||
| list(range(1, min(10, max_num_seqs + 1), 2)) + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hardcoded range parameters (1, 10, 2 / 8, 4) lack explanation.
| new_original_sizes = sorted( | ||
| set( | ||
| list(range(1, min(10, max_num_seqs + 1), 2)) + | ||
| list(range(8, max_num_seqs + 1, 4)))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The case where max_num_seqs < 8 was not considered.
| update_cudagraph_capture_sizes(vllm_config, padding_sizes) | ||
| logger.info( | ||
| "Adjusted ACL full graphs: %s → %s for speculative decoding", | ||
| original_sizes, padding_sizes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comparison before and after the change is not intuitive enough; descriptions of the sizes before and after should be provided.
What this PR does / why we need it?
Supports overlaying MTP with all default shapes of ACLgraph
Does this PR introduce any user-facing change?
no
How was this patch tested?
ut