-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[Spec Decode] (1/2) Remove batch expansion #8839
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
c4c5dab
cb08091
8c10b11
44930fb
e64c61b
541b767
b6c1de3
5824b78
07aebc0
d6cb1cc
bcc1fe9
b036d06
b93694d
35750a6
741068a
71be340
cff6b0f
f4fb00b
b3e8691
238e5a0
5063c95
70662b0
878d2da
0e32744
7ee2998
d39c8a9
79ac29c
6f3388b
1425332
e5702a9
8e27664
3f3c222
2707422
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| import pytest | ||
| import torch | ||
|
|
||
| from vllm.sequence import ExecuteModelRequest | ||
| from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer | ||
| from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeScores | ||
| from vllm.spec_decode.MQA_scorer import MQAScorer | ||
| from vllm.worker.worker import Worker | ||
|
|
||
| from .utils import create_batch, create_worker | ||
|
|
||
|
|
||
| def create_proposal(batch_size: int, propose_len: int, vocab_size: int, | ||
| device: str) -> SpeculativeProposals: | ||
| proposal_probs = torch.rand((batch_size, propose_len, vocab_size), | ||
| device=device) | ||
| proposal_token_ids = torch.argmax(proposal_probs, dim=-1) | ||
| proposal_lens = torch.tensor([propose_len] * batch_size, device=device) | ||
| return SpeculativeProposals(proposal_token_ids, proposal_probs, | ||
| proposal_lens) | ||
|
|
||
|
|
||
| def assert_score_equal(score1: SpeculativeScores, | ||
| score2: SpeculativeScores) -> None: | ||
| assert torch.allclose(score1.probs, score2.probs) | ||
| assert torch.allclose(score1.logprobs, score2.logprobs) | ||
| assert torch.equal(score1.token_ids, score2.token_ids) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize('model_name', ['facebook/opt-125m']) | ||
| @pytest.mark.parametrize('batch_size', [1, 2, 4, 8, 16]) | ||
| @pytest.mark.parametrize('propose_len', [1, 3, 5]) | ||
| @pytest.mark.parametrize('device', ['cuda']) | ||
| def test_scoroer(model_name: str, batch_size: int, propose_len: int, | ||
| device: str) -> None: | ||
| """ | ||
| Compare the batch expansion scorer and mqa scorer return the same score | ||
| """ | ||
| seed = 0 | ||
| block_size = 32 | ||
| num_gpu_blocks = 2048 // block_size | ||
| scorer_worker = create_worker(Worker, model_name, block_size, | ||
| num_gpu_blocks, seed) | ||
| scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor = True | ||
| scorer_worker.model_runner.model.sampler.\ | ||
| should_modify_greedy_probs_inplace = True | ||
|
|
||
| vocab_size = scorer_worker.vocab_size | ||
| proposals = create_proposal(batch_size, propose_len, vocab_size, device) | ||
| seq_group_metadatalist, _, _ = create_batch(batch_size, | ||
| propose_len, | ||
| block_size=block_size, | ||
| num_gpu_blocks=num_gpu_blocks) | ||
| requests = ExecuteModelRequest(seq_group_metadatalist, | ||
| num_lookahead_slots=propose_len) | ||
|
|
||
| batch_expansion_scorer = BatchExpansionTop1Scorer(scorer_worker, device, | ||
| vocab_size) | ||
| batch_expansion_score = batch_expansion_scorer.score_proposals( | ||
| requests, proposals) | ||
|
|
||
| mqa_scorer = MQAScorer(scorer_worker, device, vocab_size) | ||
| mqa_score = mqa_scorer.score_proposals(requests, proposals) | ||
|
|
||
| assert_score_equal(batch_expansion_score, mqa_score) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -186,6 +186,12 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata): | |
| # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. | ||
| use_cuda_graph: bool | ||
|
|
||
| # Number of query tokens for each request in the batch. | ||
| # Currently, we require that all requests have the same number of query | ||
| # tokens during the decoding phase. When speculavie decoding is enabled, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: |
||
| # decode_query_len might be greater than 1. In all other cases, it is 1. | ||
| decode_query_len: Optional[int] = None | ||
|
|
||
| _cached_prefill_metadata: Optional[ | ||
| "BlocksparseFlashAttentionMetadata"] = None | ||
| _cached_decode_metadata: Optional[ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -245,8 +245,15 @@ class FlashAttentionMetadata(AttentionMetadata): | |
| # |-------------------- seq_len ---------------------| | ||
| # |-- query_len ---| | ||
|
|
||
| # Maximum query length in the batch. None for decoding. | ||
| # Maximum query length in the batch. | ||
| max_query_len: Optional[int] | ||
|
|
||
| # Number of query tokens for each request in the batch. | ||
| # Currently, we require that all requests have the same number of query | ||
| # tokens during the decoding phase. When speculavie decoding is enabled, | ||
| # decode_query_len might be greater than 1. In all other cases, it is 1. | ||
| decode_query_len: Optional[int] | ||
|
|
||
| # Maximum sequence length among prefill batch. 0 if there are decoding | ||
| # requests only. | ||
| max_prefill_seq_len: int | ||
|
|
@@ -303,6 +310,7 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: | |
| slot_mapping=self.slot_mapping[:self.num_prefill_tokens], | ||
| seq_lens=self.seq_lens[:self.num_prefills], | ||
| seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], | ||
| decode_query_len=0, | ||
| max_query_len=self.max_query_len, | ||
| max_prefill_seq_len=self.max_prefill_seq_len, | ||
| max_decode_seq_len=0, | ||
|
|
@@ -331,7 +339,8 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: | |
| slot_mapping=self.slot_mapping[self.num_prefill_tokens:], | ||
| seq_lens=None, | ||
| seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], | ||
| max_query_len=None, | ||
| decode_query_len=self.decode_query_len, | ||
| max_query_len=self.max_query_len, | ||
| max_prefill_seq_len=0, | ||
| max_decode_seq_len=self.max_decode_seq_len, | ||
| query_start_loc=None, | ||
|
|
@@ -441,9 +450,6 @@ def _add_seq_group( | |
| self.num_prefill_tokens += token_len | ||
| self.prefill_seq_lens.append(seq_len) | ||
| else: | ||
| assert query_len == 1, ( | ||
| "seq_len: {}, context_len: {}, query_len: {}".format( | ||
| seq_len, context_len, query_len)) | ||
| self.num_decode_tokens += query_len | ||
| self.curr_seq_lens.append(curr_seq_len) | ||
|
|
||
|
|
@@ -498,6 +504,11 @@ def build(self, seq_lens: List[int], query_lens: List[int], | |
| use_captured_graph = cuda_graph_pad_size != -1 | ||
|
|
||
| max_query_len = max(query_lens) | ||
| decode_query_lens = query_lens[self.num_prefills:] | ||
| if len(decode_query_lens) > 0: | ||
| decode_query_len = max(decode_query_lens) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for my knowledge even if this is > 1 does decode_query_len need to be same for all the sequences?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, it requires that all requests in the batch have the same number of deocde tokens. |
||
| else: | ||
| decode_query_len = 0 | ||
| max_prefill_seq_len = max(self.prefill_seq_lens, default=0) | ||
| max_decode_seq_len = max(self.curr_seq_lens, default=0) | ||
| num_decode_tokens = self.num_decode_tokens | ||
|
|
@@ -566,6 +577,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], | |
| seq_lens=seq_lens, | ||
| seq_lens_tensor=seq_lens_tensor, | ||
| max_query_len=max_query_len, | ||
| decode_query_len=decode_query_len, | ||
| max_prefill_seq_len=max_prefill_seq_len, | ||
| max_decode_seq_len=max_decode_seq_len, | ||
| query_start_loc=query_start_loc, | ||
|
|
@@ -762,8 +774,12 @@ def forward( | |
|
|
||
| if decode_meta := attn_metadata.decode_metadata: | ||
| # Decoding run. | ||
| _, num_head, head_dim = decode_query.shape | ||
| decode_query = decode_query.reshape(-1, | ||
| decode_meta.decode_query_len, | ||
| num_head, head_dim) | ||
| decode_output = torch.ops.vllm.flash_attn_with_kvcache( | ||
| decode_query.unsqueeze(1), | ||
| decode_query, | ||
| key_cache, | ||
| value_cache, | ||
| block_table=decode_meta.block_tables, | ||
|
|
@@ -772,13 +788,17 @@ def forward( | |
| causal=True, | ||
| alibi_slopes=self.alibi_slopes, | ||
| softcap=self.logits_soft_cap, | ||
| ).squeeze(1) | ||
| ) | ||
|
|
||
| if prefill_output is None: | ||
| assert decode_output is not None | ||
| return decode_output.view(num_decode_tokens, hidden_size) | ||
| if decode_output is None: | ||
| assert prefill_output is not None | ||
| return prefill_output.view(num_prefill_tokens, hidden_size) | ||
|
|
||
| assert decode_meta is not None | ||
| assert decode_meta.decode_query_len == 1 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have this assert here because this case (chunked-prefill) won't happen for speculative decoding -- is my understanding correct? please consider adding a comment for this assert explaining why it would work for speculative decoding where decode_query_len might be > 1
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, exactly, currently chunked prefill does not work with speculative decoding. Therefore, in the case of chunked prefill, we require decode_query_len to be 1. Will add a comment here. |
||
| decode_output = decode_output.squeeze(1) | ||
| output = torch.cat([prefill_output, decode_output], dim=0) | ||
| return output.view(num_tokens, hidden_size) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to specify num_speculative_tokens for here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test verifies the correctness of MLPSpeculator, which uses additional heads to make proposals. We don't need to specify num_speculative_tokens here because it will read the number of tokens from the model config here.