Skip to content
10 changes: 3 additions & 7 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,13 +522,9 @@ def prepare_next_token_ids_padded(
)

# Generate a mask for all valid tokens within those requests
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
valid_mask = torch.ones_like(valid_sampled_token_ids_gpu, dtype=torch.bool)
else:
valid_mask = (valid_sampled_token_ids_gpu != -1) & (
valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size
)
valid_mask = (valid_sampled_token_ids_gpu != -1) & (
valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size
)
Comment on lines 553 to 555
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This simplification is a great improvement. By removing the special case for max_gen_len == 1, the code is now more robust. The previous logic didn't account for discarded requests when max_gen_len == 1, which could lead to using an invalid token ID of -1. This unified approach correctly handles all cases.


# Count the number of valid tokens in each request
valid_sampled_tokens_count = valid_mask.sum(dim=1)
Expand Down