diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index b12bf7b382d0..f678436dd05e 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -208,7 +208,7 @@ steps: - tests/spec_decode commands: - pytest -v -s spec_decode/e2e/test_multistep_correctness.py - - pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py + - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py - label: LoRA Test %N # 15min each mirror_hardwares: [amd] diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 3342a336a4ef..9d4932dd1f5b 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -434,7 +434,7 @@ def run_test_case(*, expected_penalization: List[bool], sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens=seq_lens if seq_lens else None, - query_lens=seq_lens if seq_lens else None, + query_lens=seq_lens if seq_lens else [1] * batch_size, device=device, pin_memory=is_pin_memory_available()) # the logits tensor is modified in-place by the sampler diff --git a/tests/spec_decode/e2e/test_integration.py b/tests/spec_decode/e2e/test_integration.py index 4a427d4c3e28..d04e312689bc 100644 --- a/tests/spec_decode/e2e/test_integration.py +++ b/tests/spec_decode/e2e/test_integration.py @@ -102,3 +102,47 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs, max_output_len=32, seed=seed, temperature=0.0) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model_name": MAIN_MODEL, + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 3, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", + [{ + "speculative_disable_mqa_scorer": True, + }]) +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, batch_size: int, + output_len: int, seed: int): + """Verify that ngram speculative decoding generates the same output + with batch expansion scorer and mqa scorer. + """ + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + max_output_len=output_len, + seed=seed, + temperature=0.0) diff --git a/tests/spec_decode/e2e/test_medusa_correctness.py b/tests/spec_decode/e2e/test_medusa_correctness.py index 8c90e147df23..0b36e712a11b 100644 --- a/tests/spec_decode/e2e/test_medusa_correctness.py +++ b/tests/spec_decode/e2e/test_medusa_correctness.py @@ -350,6 +350,55 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs, temperature=0.0) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_disable_by_batch_size": 4 + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", + [{ + "speculative_disable_mqa_scorer": True, + }]) +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, batch_size: int, + output_len: int, seed: int): + """Verify that speculative decoding generates the same output + with batch expansion scorer and mqa scorer. + """ + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + max_output_len=output_len, + seed=seed, + temperature=0.0) + + if __name__ == "__main__": import pytest pytest.main([__file__]) diff --git a/tests/spec_decode/e2e/test_mlp_correctness.py b/tests/spec_decode/e2e/test_mlp_correctness.py index 7f3180befaff..52b48a33c309 100644 --- a/tests/spec_decode/e2e/test_mlp_correctness.py +++ b/tests/spec_decode/e2e/test_mlp_correctness.py @@ -460,3 +460,46 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs, max_output_len=output_len, seed=seed, temperature=0.0) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model_name": MAIN_MODEL, + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + "speculative_model": SPEC_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", + [{ + "speculative_disable_mqa_scorer": True, + }]) +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, batch_size: int, + output_len: int, seed: int): + """Verify that speculative decoding generates the same output + with batch expansion scorer and mqa scorer. + """ + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + max_output_len=output_len, + seed=seed, + temperature=0.0) diff --git a/tests/spec_decode/e2e/test_ngram_correctness.py b/tests/spec_decode/e2e/test_ngram_correctness.py index 850114eb7f5a..586245938316 100644 --- a/tests/spec_decode/e2e/test_ngram_correctness.py +++ b/tests/spec_decode/e2e/test_ngram_correctness.py @@ -292,3 +292,49 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs, max_output_len=output_len, seed=seed, temperature=0.0) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model_name": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + "speculative_model": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", + [{ + "speculative_disable_mqa_scorer": True, + }]) +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_ngram_scorer(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, baseline_llm_kwargs, + test_llm_kwargs, batch_size: int, output_len: int, + seed: int): + """Verify that ngram speculative decoding generates the same output + with batch expansion scorer and mqa scorer. + """ + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + max_output_len=output_len, + seed=seed, + temperature=0.0) diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index 6fa386ffab12..e6f7f480eebb 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -173,7 +173,6 @@ def test_same_output_for_multi_step(): block_size, num_gpu_blocks, seed, - model_runner_cls=TP1DraftModelRunner, ) worker = create_worker( diff --git a/tests/spec_decode/test_scorer.py b/tests/spec_decode/test_scorer.py new file mode 100644 index 000000000000..5f703b03ab7f --- /dev/null +++ b/tests/spec_decode/test_scorer.py @@ -0,0 +1,65 @@ +import pytest +import torch + +from vllm.sequence import ExecuteModelRequest +from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer +from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeScores +from vllm.spec_decode.mqa_scorer import MQAScorer +from vllm.worker.worker import Worker + +from .utils import create_batch, create_worker + + +def create_proposal(batch_size: int, propose_len: int, vocab_size: int, + device: str) -> SpeculativeProposals: + proposal_probs = torch.rand((batch_size, propose_len, vocab_size), + device=device) + proposal_token_ids = torch.argmax(proposal_probs, dim=-1) + proposal_lens = torch.tensor([propose_len] * batch_size, device=device) + return SpeculativeProposals(proposal_token_ids, proposal_probs, + proposal_lens) + + +def assert_score_equal(score1: SpeculativeScores, + score2: SpeculativeScores) -> None: + assert torch.allclose(score1.probs, score2.probs) + assert torch.allclose(score1.logprobs, score2.logprobs) + assert torch.equal(score1.token_ids, score2.token_ids) + + +@pytest.mark.parametrize('model_name', ['facebook/opt-125m']) +@pytest.mark.parametrize('batch_size', [1, 2, 4, 8, 16]) +@pytest.mark.parametrize('propose_len', [1, 3, 5]) +@pytest.mark.parametrize('device', ['cuda']) +def test_scoroer(model_name: str, batch_size: int, propose_len: int, + device: str) -> None: + """ + Compare the batch expansion scorer and mqa scorer return the same score + """ + seed = 0 + block_size = 32 + num_gpu_blocks = 2048 // block_size + scorer_worker = create_worker(Worker, model_name, block_size, + num_gpu_blocks, seed) + scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor = True + scorer_worker.model_runner.model.sampler.\ + should_modify_greedy_probs_inplace = True + + vocab_size = scorer_worker.vocab_size + proposals = create_proposal(batch_size, propose_len, vocab_size, device) + seq_group_metadatalist, _, _ = create_batch(batch_size, + propose_len, + block_size=block_size, + num_gpu_blocks=num_gpu_blocks) + requests = ExecuteModelRequest(seq_group_metadatalist, + num_lookahead_slots=propose_len) + + batch_expansion_scorer = BatchExpansionTop1Scorer(scorer_worker, device, + vocab_size) + batch_expansion_score = batch_expansion_scorer.score_proposals( + requests, proposals) + + mqa_scorer = MQAScorer(scorer_worker, device, vocab_size) + mqa_score = mqa_scorer.score_proposals(requests, proposals) + + assert_score_equal(batch_expansion_score, mqa_score) diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 501d05756e01..e0b7b7d47f1f 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -63,10 +63,10 @@ def test_correctly_calls_draft_model(k: int, batch_size: int, @pytest.mark.parametrize("acceptance_sampler_method", ["rejection_sampler", "typical_acceptance_sampler"]) @torch.inference_mode() -def test_correctly_calls_target_model(k: int, batch_size: int, - acceptance_sampler_method: str): +def test_batch_expansion_correctly_calls_target_model( + k: int, batch_size: int, acceptance_sampler_method: str): """Verify SpecDecodeWorker calls the target model with correct - inputs. Everything else is mocked out. + inputs with batch expansion. Everything else is mocked out. """ draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False) target_worker = mock_worker(use_spec=False) @@ -82,7 +82,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int, target_worker, mock_spec_decode_sampler(acceptance_sampler_method), disable_logprobs=False, - metrics_collector=metrics_collector) + metrics_collector=metrics_collector, + disable_mqa_scorer=True) worker.init_device() vocab_size = 32_000 diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index f17e87288163..f683942a5854 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -131,19 +131,22 @@ def create_seq_group_metadata_from_prompts( for i, final_len in enumerate(final_prompt_lens) } - return [ - SequenceGroupMetadata( - request_id=str(i), - is_prompt=len(cont_token_ids) == 0, - seq_data={ - i: SequenceData.from_seqs(prompt_token_ids[:], - cont_token_ids[:]), - }, - sampling_params=SamplingParams(temperature=0.0, ), - block_tables={i: block_allocations[i][:]}, - ) for i, (prompt_token_ids, - cont_token_ids) in enumerate(zip(prompts, continuations)) - ] + seq_grou_metadata_list = [] + for i, (prompt_token_ids, + cont_token_ids) in enumerate(zip(prompts, continuations)): + data = SequenceData.from_seqs(prompt_token_ids, cont_token_ids) + data.update_num_computed_tokens( + len(prompt_token_ids) + len(cont_token_ids) - 1) + seq_data = {i: data} + seq_grou_metadata_list.append( + SequenceGroupMetadata( + request_id=str(i), + is_prompt=len(cont_token_ids) == 0, + seq_data=seq_data, + sampling_params=SamplingParams(temperature=0.0), + block_tables={i: block_allocations[i][:]}, + )) + return seq_grou_metadata_list def assert_logprobs_dict_allclose( diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 656cfd124ab4..57ac152d9edb 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -186,6 +186,12 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata): # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool + # Number of query tokens for each request in the batch. + # Currently, we require that all requests have the same number of query + # tokens during the decoding phase. When speculavie decoding is enabled, + # decode_query_len might be greater than 1. In all other cases, it is 1. + decode_query_len: Optional[int] = None + _cached_prefill_metadata: Optional[ "BlocksparseFlashAttentionMetadata"] = None _cached_decode_metadata: Optional[ diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 43ca6c9ff160..e27702336719 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -245,8 +245,15 @@ class FlashAttentionMetadata(AttentionMetadata): # |-------------------- seq_len ---------------------| # |-- query_len ---| - # Maximum query length in the batch. None for decoding. + # Maximum query length in the batch. max_query_len: Optional[int] + + # Number of query tokens for each request in the batch. + # Currently, we require that all requests have the same number of query + # tokens during the decoding phase. When speculavie decoding is enabled, + # decode_query_len might be greater than 1. In all other cases, it is 1. + decode_query_len: Optional[int] + # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. max_prefill_seq_len: int @@ -303,6 +310,7 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: slot_mapping=self.slot_mapping[:self.num_prefill_tokens], seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + decode_query_len=0, max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, @@ -331,7 +339,8 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: slot_mapping=self.slot_mapping[self.num_prefill_tokens:], seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], - max_query_len=None, + decode_query_len=self.decode_query_len, + max_query_len=self.max_query_len, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, query_start_loc=None, @@ -461,9 +470,6 @@ def _add_seq_group( self.num_prefill_tokens += token_len self.prefill_seq_lens.append(seq_len) else: - assert query_len == 1, ( - "seq_len: {}, context_len: {}, query_len: {}".format( - seq_len, context_len, query_len)) self.num_decode_tokens += query_len self.curr_seq_lens.append(curr_seq_len) @@ -518,6 +524,11 @@ def build(self, seq_lens: List[int], query_lens: List[int], use_captured_graph = cuda_graph_pad_size != -1 max_query_len = max(query_lens) + decode_query_lens = query_lens[self.num_prefills:] + if len(decode_query_lens) > 0: + decode_query_len = max(decode_query_lens) + else: + decode_query_len = 1 max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens @@ -586,6 +597,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, + decode_query_len=decode_query_len, max_prefill_seq_len=max_prefill_seq_len, max_decode_seq_len=max_decode_seq_len, query_start_loc=query_start_loc, @@ -786,8 +798,12 @@ def forward( if decode_meta := attn_metadata.decode_metadata: # Decoding run. + _, num_head, head_dim = decode_query.shape + decode_query = decode_query.reshape(-1, + decode_meta.decode_query_len, + num_head, head_dim) decode_output = torch.ops.vllm.flash_attn_with_kvcache( - decode_query.unsqueeze(1), + decode_query, key_cache, value_cache, block_table=decode_meta.block_tables, @@ -796,7 +812,7 @@ def forward( causal=True, alibi_slopes=self.alibi_slopes, softcap=self.logits_soft_cap, - ).squeeze(1) + ) if prefill_output is None: assert decode_output is not None @@ -804,5 +820,11 @@ def forward( if decode_output is None: assert prefill_output is not None return prefill_output.view(num_prefill_tokens, hidden_size) + + # Chunked prefill does not work with speculative decoding. + # Therefore, the query length for decode should be 1 in chunked prefill. + assert decode_meta is not None + assert decode_meta.decode_query_len == 1 + decode_output = decode_output.squeeze(1) output = torch.cat([prefill_output, decode_output], dim=0) return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index a64bf34596f9..96d37b99f201 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -595,7 +595,6 @@ def build(self, seq_lens: List[int], query_lens: List[int], device = self.runner.device use_captured_graph = cuda_graph_pad_size != -1 - max_query_len = max(query_lens) max_prefill_seq_len = max(self.prefill_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens @@ -634,7 +633,6 @@ def build(self, seq_lens: List[int], query_lens: List[int], dtype=torch.int, device=device, ) - assert max_query_len > 0, ("query_lens: {}".format(query_lens)) assert device is not None seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 5ee3c3b69cf3..fb5cd11ec033 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -116,9 +116,17 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool + # (batch_size,) A tensor of context lengths (tokens that are computed # so far). context_lens_tensor: Optional[torch.Tensor] + + # Number of query tokens for each request in the batch. + # Currently, we require that all requests have the same number of query + # tokens during the decoding phase. When speculavie decoding is enabled, + # decode_query_len might be greater than 1. In all other cases, it is 1. + decode_query_len: Optional[int] = None + _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 49fbb25f4547..2b8c373178ab 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -312,7 +312,8 @@ def graph_capture_get_metadata_for_batch( slot_mapping=self._graph_slot_mapping[:batch_size], seq_lens=None, seq_lens_tensor=self._graph_seq_lens[:batch_size], - max_query_len=None, + max_query_len=1, + decode_query_len=1, max_prefill_seq_len=0, max_decode_seq_len=self.runner.max_seq_len_to_capture, query_start_loc=None, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 143fa6ee7dea..a3f9ff64f8b8 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -118,6 +118,12 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # Maximum query length in the batch. None for decoding. max_query_len: Optional[int] = None + # Number of query tokens for each request in the batch. + # Currently, we require that all requests have the same number of query + # tokens during the decoding phase. When speculavie decoding is enabled, + # decode_query_len might be greater than 1. In all other cases, it is 1. + decode_query_len: Optional[int] = None + # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. diff --git a/vllm/config.py b/vllm/config.py index 3139c5a08bfb..1310c07ade48 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1116,6 +1116,7 @@ def maybe_create_spec_config( speculative_model_quantization: Optional[str], speculative_draft_tensor_parallel_size: Optional[int], num_speculative_tokens: Optional[int], + speculative_disable_mqa_scorer: Optional[bool], speculative_max_model_len: Optional[int], enable_chunked_prefill: bool, use_v2_block_manager: bool, @@ -1150,6 +1151,9 @@ def maybe_create_spec_config( num_speculative_tokens (Optional[int]): The number of speculative tokens, if provided. Will default to the number in the draft model config if present, otherwise is required. + speculative_disable_mqa_scorer (Optional[bool]): Disable the MQA + scorer for the speculative model and fall back to batch + expansion for scoring. speculative_max_model_len (Optional[int]): The maximum model len of the speculative model. Used when testing the ability to skip speculation for some sequences. @@ -1304,6 +1308,7 @@ def maybe_create_spec_config( draft_model_config, draft_parallel_config, num_speculative_tokens, + speculative_disable_mqa_scorer, speculative_disable_by_batch_size, ngram_prompt_lookup_max, ngram_prompt_lookup_min, @@ -1400,6 +1405,7 @@ def __init__( draft_model_config: ModelConfig, draft_parallel_config: ParallelConfig, num_speculative_tokens: int, + speculative_disable_mqa_scorer: Optional[bool], speculative_disable_by_batch_size: Optional[int], ngram_prompt_lookup_max: Optional[int], ngram_prompt_lookup_min: Optional[int], @@ -1446,6 +1452,7 @@ def __init__( self.draft_model_config = draft_model_config self.draft_parallel_config = draft_parallel_config self.num_speculative_tokens = num_speculative_tokens + self.speculative_disable_mqa_scorer = speculative_disable_mqa_scorer self.speculative_disable_by_batch_size = \ speculative_disable_by_batch_size self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0 diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 208766a18e99..64fa7360b95b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -162,6 +162,7 @@ class EngineArgs: speculative_model_quantization: Optional[str] = None speculative_draft_tensor_parallel_size: Optional[int] = None num_speculative_tokens: Optional[int] = None + speculative_disable_mqa_scorer: Optional[bool] = False speculative_max_model_len: Optional[int] = None speculative_disable_by_batch_size: Optional[int] = None ngram_prompt_lookup_max: Optional[int] = None @@ -640,6 +641,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.num_speculative_tokens, help='The number of speculative tokens to sample from ' 'the draft model in speculative decoding.') + parser.add_argument( + '--speculative-disable-mqa-scorer', + action='store_true', + help= + 'If set to True, the MQA scorer will be disabled in speculative ' + ' and fall back to batch expansion') parser.add_argument( '--speculative-draft-tensor-parallel-size', '-spec-draft-tp', @@ -970,6 +977,7 @@ def create_engine_config(self) -> EngineConfig: speculative_draft_tensor_parallel_size = \ self.speculative_draft_tensor_parallel_size, num_speculative_tokens=self.num_speculative_tokens, + speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer, speculative_disable_by_batch_size=self. speculative_disable_by_batch_size, speculative_max_model_len=self.speculative_max_model_len, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 3550759f85dd..d6258c6413d8 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1110,6 +1110,8 @@ def update_prefill_num_computed_tokens( update_prefill_num_computed_tokens(seq_group, seq_group_meta, len(output), is_first_step_output) + elif not is_async: + seq_group.update_num_computed_tokens(1) if outputs: for o in outputs: @@ -1133,8 +1135,16 @@ def update_prefill_num_computed_tokens( else: self.output_processor.process_prompt_logprob(seq_group, output) if seq_group_meta.do_sample: - self.output_processor.process_outputs( + output_token_num = self.output_processor.process_outputs( seq_group, output, is_async) + if self.speculative_config: + # We -1 here because we always + # (w/o speculative decoding) add the number of + # computed tokens by one in the decoding phase. + # Therefore, we remove that one token that + # is already added. + seq_group.update_num_computed_tokens(output_token_num - + 1) if seq_group.is_finished(): finished_now.append(i) @@ -1251,11 +1261,12 @@ def _advance_to_next_step( # decodes after the very first step. Therefore, # we skip the update to the num_computed_tokens # here. - pass + seq_group.update_num_computed_tokens(1) else: seq_group.update_num_computed_tokens( seq_group_metadata.token_chunk_size) - + else: + seq_group.update_num_computed_tokens(1) if seq_group_metadata.do_sample: assert len(sequence_group_outputs.samples) == 1, ( "Async output processor expects a single sample" @@ -1266,7 +1277,6 @@ def _advance_to_next_step( assert len(seq_group.seqs) == 1 seq = seq_group.seqs[0] seq.append_token_id(sample.output_token, sample.logprobs) - seq_group.update_num_computed_tokens(1) def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py index 50adaf4e5918..554880a3cc43 100644 --- a/vllm/engine/output_processor/interfaces.py +++ b/vllm/engine/output_processor/interfaces.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Callable, List +from typing import Callable, List, Optional from vllm.config import SchedulerConfig from vllm.core.scheduler import Scheduler @@ -58,10 +58,14 @@ def create_output_processor( @abstractmethod def process_outputs(self, sequence_group: SequenceGroup, outputs: List[SequenceGroupOutput], - is_async: bool) -> None: + is_async: bool) -> Optional[int]: """Process new token ids for the sequence group. Handles logic such as detokenization, stop checking, and freeing/forking sequences in the scheduler. + + Return the number of new tokens generated in the sequence group. + The returned value is optional because it is only used for + speculative decoding mqa scorer. """ pass diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 6dac3619580b..f35b1ba9c2bd 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -1,5 +1,5 @@ import functools -from typing import Callable, List +from typing import Callable, List, Optional from vllm.core.scheduler import Scheduler from vllm.engine.output_processor.interfaces import ( @@ -69,7 +69,7 @@ def _log_prompt_logprob_unsupported_warning_once(): def process_outputs(self, sequence_group: SequenceGroup, outputs: List[SequenceGroupOutput], - is_async: bool = False) -> None: + is_async: bool = False) -> Optional[int]: """Append new tokens in the outputs to sequences in the sequence group. This only supports sequence groups of size 1. It supports greater than @@ -84,6 +84,10 @@ def process_outputs(self, tokens from the previous step. If this is true, then no tokens need to be appended since it is already done externally (before the next schedule() call) + + Returns: + The number of tokens appended to the sequence. This is optional + because only speculative decode uses this return value. """ # Sequences can be in RUNNING or FINISHED_ABORTED state # once scheduled, as a sequence is moved to FINSIHED_ABORTED @@ -106,6 +110,7 @@ def process_outputs(self, # was already appended, so we only need to do the rest of the # postprocessor: Detokenization + stopping logic self._process_decode_and_stop(seq, sequence_group.sampling_params) + return None else: # Standard multi-step case @@ -121,8 +126,8 @@ def process_outputs(self, ] assert valid_samples - self._process_seq_outputs(seq, valid_samples, - sequence_group.sampling_params) + return self._process_seq_outputs(seq, valid_samples, + sequence_group.sampling_params) def _process_decode_and_stop(self, seq: Sequence, sampling_params: SamplingParams) -> None: @@ -140,7 +145,7 @@ def _process_decode_and_stop(self, seq: Sequence, def _process_seq_outputs(self, seq: Sequence, valid_samples: List[SequenceOutput], - sampling_params: SamplingParams) -> None: + sampling_params: SamplingParams) -> int: output_token_ids = [sample.output_token for sample in valid_samples] output_logprobs = [sample.logprobs for sample in valid_samples] @@ -148,7 +153,6 @@ def _process_seq_outputs(self, seq: Sequence, remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() + len(output_token_ids)) if remaining_tokens < 0: - valid_samples = valid_samples[:remaining_tokens] output_token_ids = output_token_ids[:remaining_tokens] # Truncate any tokens after EOS. This is required as spec decode @@ -162,7 +166,6 @@ def _process_seq_outputs(self, seq: Sequence, for i in range(len(output_token_ids)): if output_token_ids[i] == eos_token_id: output_token_ids = output_token_ids[:i + 1] - valid_samples = valid_samples[:i + 1] break # Incrementally append tokens to the sequence, as if we had only one new @@ -173,9 +176,9 @@ def _process_seq_outputs(self, seq: Sequence, token_id=output_token_id, logprobs=output_logprob, ) - seq.data.update_num_computed_tokens(1) self._process_decode_and_stop(seq, sampling_params) if seq.is_finished(): break + return len(output_token_ids) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 583bb02dcb5b..cfa857b8f960 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -912,7 +912,7 @@ def get_logprobs( sampling_metadata: SamplingMetadata, sample_results: SampleResultType, ) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]: - """Return sample lobprobs and prompt logprobs. + """Return sample logprobs and prompt logprobs. The logic consists of 3 parts. - Select indices to compute logprob from, ranks of token ids, and diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 97d36d31f2b1..ee02368bec8a 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -146,7 +146,7 @@ def __init__( def prepare( seq_group_metadata_list: List[SequenceGroupMetadata], seq_lens: List[int], - query_lens: Optional[List[int]], + query_lens: List[int], device: str, pin_memory: bool, generators: Optional[Dict[str, torch.Generator]] = None, @@ -194,7 +194,7 @@ def __repr__(self) -> str: def _prepare_seq_groups( seq_group_metadata_list: List[SequenceGroupMetadata], seq_lens: List[int], - query_lens: Optional[List[int]], + query_lens: List[int], device: str, generators: Optional[Dict[str, torch.Generator]] = None, cache: Optional[SamplingMetadataCache] = None, @@ -284,7 +284,8 @@ def _prepare_seq_groups( else: # Decode prompt_logprob_len = 0 - sample_len = len(seq_ids) if do_sample else 0 + query_len = query_lens[i] if query_lens is not None else 1 + sample_len = len(seq_ids) * query_len if do_sample else 0 if sampling_params.seed is not None and generators is not None: generator = generators.get(seq_group_metadata.request_id) @@ -440,14 +441,14 @@ def from_sampling_metadata( if seq_group.do_sample: sample_lens = len(seq_group.sample_indices) - assert sample_lens == len(seq_ids) - temperatures += [temperature] * len(seq_ids) - top_ps += [top_p] * len(seq_ids) - top_ks += [top_k] * len(seq_ids) - min_ps += [min_p] * len(seq_ids) - presence_penalties += [p] * len(seq_ids) - frequency_penalties += [f] * len(seq_ids) - repetition_penalties += [r] * len(seq_ids) + assert sample_lens >= len(seq_ids) + temperatures += [temperature] * sample_lens + top_ps += [top_p] * sample_lens + top_ks += [top_k] * sample_lens + min_ps += [min_p] * sample_lens + presence_penalties += [p] * sample_lens + frequency_penalties += [f] * sample_lens + repetition_penalties += [r] * sample_lens if do_penalties: for seq_group in sampling_metadata.seq_groups: diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 9eb8bbfc5407..59e71cc8deb4 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -12,7 +12,6 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len -from vllm.worker.worker_base import WorkerBase SeqId = int TargetSeqId = int @@ -36,12 +35,6 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): of topk/tree. """ - def __init__(self, scorer_worker: WorkerBase, device: str, - vocab_size: int): - self._scorer_worker = scorer_worker - self._device = device - self._vocab_size = vocab_size - @nvtx_range("BatchExpansionTop1Scorer.score_proposals") def score_proposals( self, diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index cf64af72a14a..71cba5dd25f6 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -94,8 +94,6 @@ def _update_sampling_metadata(self, sampling_metadata, num_seqs, assert seq_group.is_prompt is False # No prompt assert seq_group.prompt_logprob_indices == [] # No prompt assert seq_group.sample_indices == [i] # Simple - assert seq_group.seq_len is None # Decode - assert seq_group.query_len is None # Decode def _gpu_advance_step( self, model_input: ModelInputForGPUWithSamplingMetadata, diff --git a/vllm/spec_decode/interfaces.py b/vllm/spec_decode/interfaces.py index 11ab09f10c1f..029f56460f5c 100644 --- a/vllm/spec_decode/interfaces.py +++ b/vllm/spec_decode/interfaces.py @@ -5,6 +5,7 @@ import torch from vllm.sequence import ExecuteModelRequest +from vllm.worker.worker_base import WorkerBase @dataclass @@ -74,6 +75,12 @@ def get_spec_proposals( class SpeculativeScorer(ABC): + def __init__(self, scorer_worker: WorkerBase, device: str, + vocab_size: int): + self._scorer_worker = scorer_worker + self._device = device + self._vocab_size = vocab_size + @abstractmethod def score_proposals( self, diff --git a/vllm/spec_decode/mqa_scorer.py b/vllm/spec_decode/mqa_scorer.py new file mode 100644 index 000000000000..59f2a4191a8b --- /dev/null +++ b/vllm/spec_decode/mqa_scorer.py @@ -0,0 +1,80 @@ +from vllm.sequence import (ExecuteModelRequest, SequenceData, + SequenceGroupMetadata, get_all_seq_ids) +from vllm.spec_decode.interfaces import (SpeculativeProposals, + SpeculativeScorer, SpeculativeScores) + +SeqId = int +TargetSeqId = int + + +class MQAScorer(SpeculativeScorer): + + def score_proposals( + self, + execute_model_req: ExecuteModelRequest, + proposals: SpeculativeProposals, + ) -> SpeculativeScores: + target_seq_group_metadata_list = [] + target_seq_id_start = max( + get_all_seq_ids(execute_model_req.seq_group_metadata_list)) + 1 + all_proposal_tokens = proposals.proposal_token_ids.tolist() + for i, seq_group_metadata in enumerate( + execute_model_req.seq_group_metadata_list): + seq_data_dict = seq_group_metadata.seq_data + assert len(seq_data_dict) == 1 + seq_id = next(iter(seq_data_dict.keys())) + + seq_data: SequenceData = seq_data_dict[seq_id] + prompt_token_ids = seq_data.get_prompt_token_ids() + output_token_ids = seq_data.get_output_token_ids() + proposal_token_ids = all_proposal_tokens[i] + new_output_token_ids = [*output_token_ids, *proposal_token_ids] + + target_seq_id = target_seq_id_start + i + new_seq_data = SequenceData.from_seqs( + prompt_token_ids=prompt_token_ids, + output_token_ids=new_output_token_ids, + ) + new_seq_data.update_num_computed_tokens( + len(prompt_token_ids) + len(output_token_ids) - 1) + + # Ensure that the new sequence has at least one token + # because we only use mqa scorer in the decoding stage. + assert len(output_token_ids) >= 1 + new_seq_data_dict = {target_seq_id: new_seq_data} + + new_seq_group_metadata = SequenceGroupMetadata( + request_id=seq_group_metadata.request_id, + is_prompt=seq_group_metadata.is_prompt, + seq_data=new_seq_data_dict, + sampling_params=seq_group_metadata.sampling_params, + block_tables={ + target_seq_id: seq_group_metadata.block_tables[seq_id], + }, + lora_request=None, + token_chunk_size=1, + ) + target_seq_group_metadata_list.append(new_seq_group_metadata) + + target_sampler_output = self._scorer_worker.execute_model( + execute_model_req=execute_model_req.clone( + seq_group_metadata_list=target_seq_group_metadata_list)) + + target_sampler_output = target_sampler_output[0] + + bs, k = proposals.proposal_token_ids.shape + all_tokens = target_sampler_output.sampled_token_ids.reshape(bs, k + 1) + + all_probs = target_sampler_output.sampled_token_probs.reshape( + bs, k + 1, self._vocab_size) + all_logprobs = target_sampler_output.logprobs.reshape( + bs, k + 1, self._vocab_size) + + hidden_states = None + if target_sampler_output.hidden_states is not None: + hidden_states = target_sampler_output.hidden_states.reshape( + bs, (k + 1), -1) + return SpeculativeScores(probs=all_probs, + token_ids=all_tokens, + logprobs=all_logprobs, + hidden_states=hidden_states) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index dbf880a8f475..a67715290a51 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -1,6 +1,6 @@ from collections import defaultdict from functools import cached_property -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple, Type import torch @@ -24,6 +24,7 @@ from vllm.spec_decode.medusa_worker import MedusaWorker from vllm.spec_decode.metrics import AsyncMetricsCollector from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker +from vllm.spec_decode.mqa_scorer import MQAScorer from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.ngram_worker import NGramWorker from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase @@ -70,6 +71,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": spec_decode_worker = SpecDecodeWorker.create_worker( scorer_worker=target_worker, draft_worker_kwargs=draft_worker_kwargs, + disable_mqa_scorer=speculative_config.speculative_disable_mqa_scorer, disable_by_batch_size=speculative_config. speculative_disable_by_batch_size, draft_token_acceptance_method=speculative_config. @@ -116,6 +118,7 @@ def create_worker( cls, scorer_worker: Worker, draft_worker_kwargs: Dict[str, Any], + disable_mqa_scorer: bool, disable_by_batch_size: Optional[int], draft_token_acceptance_method: str, typical_acceptance_sampler_posterior_threshold: float, @@ -173,12 +176,43 @@ def create_worker( typical_acceptance_sampler_posterior_threshold, posterior_alpha=typical_acceptance_sampler_posterior_alpha, ) - logger.info("Configuring SpecDecodeWorker with sampler=%s", - type(spec_decode_sampler)) + logger.info( + "[Speculative Decoding] Configuring" + " SpecDecodeWorker with sampler=%s", type(spec_decode_sampler)) + + if not disable_mqa_scorer: + if scorer_worker.model_runner.attn_backend.get_name( + ) != "flash-attn": + disable_mqa_scorer = True + logger.info( + "[Speculative Decoding] Disabling MQA scorer as the " + "MQA is only available with flash attn backend.") + + if ngram_prompt_lookup_max > 0: + disable_mqa_scorer = True + logger.info( + "[Speculative Decoding] Disabling MQA scorer as the " + "NGramWorker does not support MQA scorer.") + + if "model_config" in draft_worker_kwargs and \ + draft_worker_kwargs["model_config"].max_model_len < \ + scorer_worker.model_config.max_model_len: + disable_mqa_scorer = True + logger.info( + "[Speculative Decoding] Disabling MQA scorer as the " + "draft model max_model_len is smaller than the target " + "model max_model_len.") + + if not scorer_worker.model_runner.model_config.enforce_eager: + disable_mqa_scorer = True + logger.info( + "[Speculative Decoding] Disabling MQA scorer as the " + "target model is not running in eager mode.") return SpecDecodeWorker( proposer_worker, scorer_worker, + disable_mqa_scorer=disable_mqa_scorer, disable_logprobs=disable_logprobs, disable_log_stats=disable_log_stats, disable_by_batch_size=disable_by_batch_size, @@ -190,6 +224,7 @@ def __init__( proposer_worker: ProposerWorkerBase, scorer_worker: WorkerBase, spec_decode_sampler: SpecDecodeBaseSampler, + disable_mqa_scorer: bool = False, disable_logprobs: bool = False, disable_log_stats: bool = False, metrics_collector: Optional[AsyncMetricsCollector] = None, @@ -211,6 +246,8 @@ def __init__( types of sampler namely RejectionSampler and TypicalAcceptanceSampler. 'spec_decode_sampler' is either an instance of RejectionSampler or TypicalAcceptanceSampler. + disable_mqa_scorer: If set to True, disable the MQA scorer and use + the BatchExpansionTop1Scorer instead. disable_logprobs: If set to True, token log probabilities will not be output in both the draft worker and the target worker. If set to False, log probabilities will be output by both. @@ -248,6 +285,7 @@ def __init__( self.token_id_dtype = self.spec_decode_sampler.token_id_dtype # Lazy initialization. self.scorer: SpeculativeScorer + self.disable_mqa_scorer = disable_mqa_scorer # Hidden states from target model to pass to proposer # in the subsequent step. @@ -270,10 +308,19 @@ def init_device(self) -> None: self._metrics.init_gpu_tensors(self.rank) self.spec_decode_sampler.init_gpu_tensors(self.rank) - self.scorer = BatchExpansionTop1Scorer( - scorer_worker=self.scorer_worker, - device=self.device, - vocab_size=self._vocab_size) + scorer_cls: Type[SpeculativeScorer] + if self.disable_mqa_scorer: + scorer_cls = BatchExpansionTop1Scorer + logger.info("[Speculative Decoding] Use batch " + "expansion for scoring proposals.") + else: + scorer_cls = MQAScorer + logger.info( + "[Speculative Decoding] Use MQA scorer for scoring proposals.") + + self.scorer = scorer_cls(scorer_worker=self.scorer_worker, + device=self.device, + vocab_size=self._vocab_size) self._configure_model_sampler_for_spec_decode() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 76c04ce66fc2..bd92abdb945d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -468,43 +468,26 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, # Compute context length (the number of tokens that are # already computed) and sequence length (total number of tokens). + seq_len = seq_data.get_len() if inter_data.is_prompt: context_len = seq_data.get_num_computed_tokens() - else: - # get_num_computed_tokens is incorrect for spec decoding. - # So, we should have a special logic here. - # TODO(sang): Fix it. + seq_len = min(seq_len, context_len + token_chunk_size) + elif self.runner.scheduler_config.is_multi_step or \ + self.runner.model_config.is_encoder_decoder_model: context_len = seq_len - 1 - seq_len = min(seq_len, context_len + token_chunk_size) + else: + context_len = seq_data.get_num_computed_tokens() # Compute tokens. - if inter_data.is_prompt: - tokens = seq_data.get_token_ids() - if context_len != 0 or seq_len < len(tokens): - tokens = tokens[context_len:seq_len] - else: - # Optimization. get_token_ids requires the entire copy of - # tokens. - tokens = seq_data.get_last_token_id() + tokens = seq_data.get_token_ids()[context_len:seq_len] inter_data.seq_lens[seq_idx] = seq_len inter_data.orig_seq_lens[seq_idx] = seq_len inter_data.context_lens[seq_idx] = context_len - - if isinstance(tokens, list): - inter_data.input_tokens[seq_idx].extend(tokens) - else: - inter_data.input_tokens[seq_idx].append(tokens) - - if (seq_len - context_len) == 1: - inter_data.input_positions[seq_idx].append(seq_len - 1) - else: - inter_data.input_positions[seq_idx].extend( - range(context_len, seq_len)) - - inter_data.query_lens[ - seq_idx] = seq_len - context_len if inter_data.is_prompt else 1 + inter_data.input_tokens[seq_idx].extend(tokens) + inter_data.input_positions[seq_idx].extend(range(context_len, seq_len)) + inter_data.query_lens[seq_idx] = seq_len - context_len if seq_data.mrope_position_delta is not None: if inter_data.mrope_input_positions is None: