-
Notifications
You must be signed in to change notification settings - Fork 699
[0.11.0][Bug Fix] Fixes ngram spec decode bug introduced by vllm #3817
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Icey <[email protected]>
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request aims to fix a bug in the n-gram speculative decoding by refactoring to a batched approach, aligning with an upstream vLLM change. However, the implementation introduces a critical bug where it uses stale sequence lengths when proposing draft tokens, which will likely prevent n-gram matching from working correctly. I've provided a suggestion to fix this by tracking and using the updated sequence lengths, which aligns with the correct implementation in upstream vLLM.
| valid_ngram_requests = [] | ||
| for i, sampled_ids in enumerate(valid_sampled_token_ids): | ||
| num_sampled_ids = len(sampled_ids) | ||
| if not num_sampled_ids: | ||
| # Skip speculative decoding. | ||
| draft_token_ids.append([]) | ||
| continue | ||
|
|
||
| # Skip requests that require top-p, top-k, etc. | ||
| req_id = self.runner.input_batch.req_ids[i] | ||
| if req_id in self.runner.input_batch.spec_decode_unsupported_reqs: | ||
| draft_token_ids.append([]) | ||
| continue | ||
|
|
||
| # Add sampled_token_ids to token_ids_cpu. | ||
| num_tokens = self.runner.input_batch.num_tokens_no_spec[i] | ||
| if num_tokens >= self.runner.input_batch.max_model_len: | ||
| # Skip requests that have already reached the max model length. | ||
| continue | ||
|
|
||
| start_idx = self.runner.input_batch.num_tokens_no_spec[i] | ||
| end_idx = start_idx + num_sampled_ids | ||
| self.runner.input_batch.token_ids_cpu[ | ||
| i, start_idx:end_idx] = sampled_ids | ||
| drafter_output = self.propose( | ||
| self.runner.input_batch.token_ids_cpu[i, :end_idx]) | ||
| if drafter_output is None or len(drafter_output) == 0: | ||
| draft_token_ids.append([]) | ||
| else: | ||
| draft_token_ids.append(drafter_output.tolist()) | ||
| return draft_token_ids | ||
|
|
||
| valid_ngram_requests.append(i) | ||
|
|
||
| draft_token_ids = self.batch_propose( | ||
| len(valid_sampled_token_ids), | ||
| valid_ngram_requests, | ||
| self.runner.input_batch.num_tokens_no_spec, | ||
| self.runner.input_batch.token_ids_cpu, | ||
| ) | ||
|
|
||
| return draft_token_ids |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The batch_propose method is called with self.runner.input_batch.num_tokens_no_spec, which contains the sequence lengths before the newly sampled tokens are appended. This causes the propose method to look for n-grams in an outdated sequence, missing the most recent token. This will likely cause the n-gram matching to fail and prevent any speculative tokens from being proposed.
The fix is to calculate the new sequence lengths after appending the sampled tokens and pass these new lengths to batch_propose. This aligns with the corrected implementation in upstream vLLM.
valid_ngram_requests = []
new_num_tokens = self.runner.input_batch.num_tokens_no_spec.copy()
for i, sampled_ids in enumerate(valid_sampled_token_ids):
num_sampled_ids = len(sampled_ids)
if not num_sampled_ids:
continue
req_id = self.runner.input_batch.req_ids[i]
if req_id in self.runner.input_batch.spec_decode_unsupported_reqs:
continue
num_tokens = self.runner.input_batch.num_tokens_no_spec[i]
if num_tokens >= self.runner.input_batch.max_model_len:
# Skip requests that have already reached the max model length.
continue
start_idx = self.runner.input_batch.num_tokens_no_spec[i]
end_idx = start_idx + num_sampled_ids
self.runner.input_batch.token_ids_cpu[
i, start_idx:end_idx] = sampled_ids
new_num_tokens[i] = end_idx
valid_ngram_requests.append(i)
draft_token_ids = self.batch_propose(
len(valid_sampled_token_ids),
valid_ngram_requests,
new_num_tokens,
self.runner.input_batch.token_ids_cpu,
)
return draft_token_ids|
is this a cherry-pick? If yes, please mention it in commit message or title |
This PR fixed the issue introduced upstream, but ngram still has unresolved accuracy issues, so I'm unsure if this PR can be merged. |
What this PR does / why we need it?
[Bug Fix] Fixes ngram spec decode bug introduced by vllm vllm-project/vllm#24986
Does this PR introduce any user-facing change?
N/A
How was this patch tested?