diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index 5af232cb6af6..618d500eb7c4 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse + from transformers import AutoTokenizer from vllm import LLM, SamplingParams @@ -69,6 +71,10 @@ def parse_args(): parser.add_argument("--model-dir", type=str, default=None) parser.add_argument("--eagle-dir", type=str, default=None) parser.add_argument("--custom-mm-prompts", action="store_true") + parser.add_argument( + "--enable-draft-probs", action=argparse.BooleanOptionalAction, default=True + ) + parser.add_argument("--request-id-prefix", type=str, default="") return parser.parse_args() @@ -110,6 +116,7 @@ def main(): "method": args.method, "model": eagle_dir, "num_speculative_tokens": args.num_spec_tokens, + "enable_draft_probs": args.enable_draft_probs, } elif args.method == "ngram": speculative_config = { diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 46e3a611c6d2..35cf0334dba3 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -28,6 +28,7 @@ def _create_proposer( method: str, num_speculative_tokens: int, speculative_token_tree: Optional[list[tuple[int]]] = None, + enable_probs: bool = True, ) -> EagleProposer: model_config = ModelConfig(model=model_dir, runner="generate", @@ -48,6 +49,7 @@ def _create_proposer( method=method, num_speculative_tokens=num_speculative_tokens, speculative_token_tree=spec_token_tree_str, + enable_draft_probs=enable_probs, ) vllm_config = VllmConfig( @@ -228,7 +230,9 @@ class _TargetModelStub(LlamaForCausalLM): @pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) @pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8]) -def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): +@pytest.mark.parametrize("enable_probs", [True, False]) +def test_propose_deterministic(method, attn_backend, num_speculative_tokens, + enable_probs, monkeypatch): monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend) @@ -256,7 +260,9 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): seq_lens = [seq_len_1, seq_len_2] # Create proposer first so we can use its actual hidden_size - proposer = _create_proposer("eagle", num_speculative_tokens) + proposer = _create_proposer("eagle", + num_speculative_tokens, + enable_probs=enable_probs) # Get the hidden_size from the proposer to ensure consistency hidden_size = proposer.hidden_size @@ -341,6 +347,9 @@ def create_deterministic_logits(token_ids): dtype=torch.int32, device=device) sampling_metadata = mock.MagicMock() + # Simulate mixed greedy and non-greedy requests + sampling_metadata.all_greedy = False + sampling_metadata.temperature = torch.tensor([-1, 0.7], device=device) if attn_backend == "FLASH_ATTN_VLLM_V1": attn_metadata_builder_cls, _ = get_attention_backend( @@ -366,33 +375,247 @@ def create_deterministic_logits(token_ids): proposer.runner.attn_groups.append([mock.MagicMock()]) proposer.runner.attn_groups[0][0].metadata_builder = attn_metadata_builder - result = proposer.propose(target_token_ids=target_token_ids, - target_positions=target_positions, - target_hidden_states=target_hidden_states, - next_token_ids=next_token_ids, - common_attn_metadata=common_attn_metadata, - sampling_metadata=sampling_metadata) + result, result_probs = proposer.propose( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + next_token_ids=next_token_ids, + common_attn_metadata=common_attn_metadata, + sampling_metadata=sampling_metadata) + + # Example for num_speculative_tokens=1: + # [[42], [60]] + # Example for num_speculative_tokens=3: + # [[42, 43, 44], [60, 61, 62]] + expected_tokens = torch.zeros((batch_size, num_speculative_tokens), + dtype=torch.int64, + device=device) + expected_probs = torch.zeros( + (batch_size, num_speculative_tokens, vocab_size), device=device) + for i in range(batch_size): + for j in range(num_speculative_tokens): + expected_tokens[i, j] = base_token_ids[i] + j + expected_probs[i, j, base_token_ids[i] + j] = 1.0 + # Verify all tokens match our expectations assert result.shape == (batch_size, num_speculative_tokens) + assert torch.equal(result, expected_tokens) - # Create expected tokens based on our token pattern - if num_speculative_tokens == 1: - # Example for num_speculative_tokens=1: - # [[42], [60]] - expected_tokens = torch.tensor( - [[base_token_ids[0]], [base_token_ids[1]]], device=device) + if enable_probs: + assert result_probs is not None + assert result_probs.shape == (batch_size, num_speculative_tokens, + vocab_size) + torch.testing.assert_close(result_probs, expected_probs) else: - # Example for num_speculative_tokens=3: - # [[42, 43, 44], [60, 61, 62]] - expected_tokens = torch.zeros((batch_size, num_speculative_tokens), - dtype=torch.int64, - device=device) - for i in range(batch_size): - for j in range(num_speculative_tokens): - expected_tokens[i, j] = base_token_ids[i] + j + assert result_probs is None - # Verify all tokens match our expectations - assert torch.equal(result, expected_tokens) + +@pytest.mark.parametrize("method", ["eagle", "eagle3"]) +@pytest.mark.parametrize("attn_backend", + get_attn_backend_list_based_on_platform()) +@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8]) +@pytest.mark.parametrize("enable_probs", [True, False]) +def test_propose_random(method, attn_backend, num_speculative_tokens, + enable_probs, monkeypatch): + + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend) + + if (attn_backend == "TRITON_ATTN_VLLM_V1" + and not current_platform.is_rocm()): + pytest.skip("TRITON_ATTN_VLLM_V1 does not support " + "multi-token eagle spec decode on current platform") + + if (attn_backend == "TREE_ATTN"): + pytest.skip("TREE_ATTN is tested separately in test_propose_tree" + "because it requires special input mocking.") + + if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm(): + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + + # Use GPU device + device = torch.device(current_platform.device_type) + + # Setup test parameters + batch_size = 2 + seq_len_1 = 5 + seq_len_2 = 3 + total_tokens = seq_len_1 + seq_len_2 + vocab_size = 3 + seq_lens = [seq_len_1, seq_len_2] + + # Create proposer first so we can use its actual hidden_size + proposer = _create_proposer("eagle", + num_speculative_tokens, + enable_probs=enable_probs) + # Get the hidden_size from the proposer to ensure consistency + hidden_size = proposer.hidden_size + + # We mock a model that returns constant logits + # Sequence 1: [P(0) = 0.5, P(1) = 0.3, P(2) = 0.2] * num_speculative_tokens + # Sequence 2: [P(0) = 0.2, P(1) = 0.4, P(2) = 0.4] * num_speculative_tokens + token_probs = torch.tensor([ + [0.5, 0.3, 0.2], + [0.2, 0.4, 0.4], + ], + device=device) + + def sample_once(): + # Skip loading the model and replace it with a mock directly + # Create the mock model with deterministic outputs + model_mock = mock.MagicMock() + + # Setup for model forward calls + forward_returns = [] + for i in range(num_speculative_tokens): + if i == 0: + # First call uses all tokens + h_logits = torch.zeros(total_tokens, + hidden_size, + device=device) + h_states = torch.zeros(total_tokens, + hidden_size, + device=device) + else: + # Subsequent calls use batch_size tokens + h_logits = torch.zeros(batch_size, hidden_size, device=device) + h_states = torch.zeros(batch_size, hidden_size, device=device) + forward_returns.append((h_logits, h_states)) + + model_mock.side_effect = forward_returns + + # Setup for compute_logits calls + logits_returns = [] + for i in range(num_speculative_tokens): + # Subtracting a constant doesn't change the logits + logits = torch.log(token_probs) - torch.randn( + (batch_size, 1), device=device) + logits_returns.append(logits) + + model_mock.compute_logits.side_effect = logits_returns + + # Assign the mock to the proposer + proposer.model = model_mock + + # Assign draft attn_layer_names since load_model is not invoked + proposer.attn_layer_names = ["layer.0"] + + # Create input tensors + batch_spec = BatchSpec( + seq_lens=seq_lens, + query_lens=seq_lens, + ) + + common_attn_metadata = create_common_attn_metadata( + batch_spec, + block_size=16, + device=device, + ) + + target_token_ids = torch.randint(0, + vocab_size, (total_tokens, ), + device=device) + target_positions = torch.cat([ + torch.arange(seq_len_1, device=device), + torch.arange(seq_len_2, device=device) + ]) + target_hidden_states = torch.randn(total_tokens, + hidden_size, + device=device) + next_token_ids = torch.randint(0, + vocab_size, (batch_size, ), + dtype=torch.int32, + device=device) + sampling_metadata = mock.MagicMock() + # Simulate mixed greedy and non-greedy requests + sampling_metadata.all_greedy = False + # Greedy sampling for seq 1, standard sampling for seq 2 + sampling_metadata.temperature = torch.tensor([-1, 0.7], device=device) + + if attn_backend == "FLASH_ATTN_VLLM_V1": + attn_metadata_builder_cls, _ = get_attention_backend( + _Backend.FLASH_ATTN_VLLM_V1) + elif attn_backend == "TRITON_ATTN_VLLM_V1": + attn_metadata_builder_cls, _ = get_attention_backend( + _Backend.TRITON_ATTN_VLLM_V1) + elif attn_backend == "TREE_ATTN": + attn_metadata_builder_cls, _ = get_attention_backend( + _Backend.TREE_ATTN) + else: + raise ValueError(f"Unsupported attention backend: {attn_backend}") + + attn_metadata_builder = attn_metadata_builder_cls( + kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), + layer_names=proposer.attn_layer_names, + vllm_config=proposer.vllm_config, + device=device, + ) + + # Mock runner for attention metadata building + proposer.runner = mock.MagicMock() + proposer.runner.attn_groups.append([mock.MagicMock()]) + proposer.runner.attn_groups[0][ + 0].metadata_builder = attn_metadata_builder + + result, result_prob = proposer.propose( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + next_token_ids=next_token_ids, + common_attn_metadata=common_attn_metadata, + sampling_metadata=sampling_metadata) + + return result, result_prob + + results = [] + result_probs = [] + + # Run N times and check distribution + N = 1000 + for _ in range(N): + result, result_prob = sample_once() + results.append(result) + result_probs.append(result_prob) + + # Count the number of times each token appears + counts = torch.zeros((batch_size, num_speculative_tokens, vocab_size), + device=device, + dtype=torch.int64) + for result in results: + assert result.shape == (batch_size, num_speculative_tokens) + counts.scatter_add_(2, result.unsqueeze(-1), + torch.ones_like(result.unsqueeze(-1))) + sample_dist = counts / len(results) + + token_probs_after_temp = torch.tensor( + [ + [1, 0, 0], + [0.1567, 0.4217, 0.4217] if enable_probs else + [0, 1, 0], # argmax tie-breaks on first occurrence + ], + device=device) + + # Verify that the observed distribution is within 4 standard deviations + std = torch.sqrt(token_probs_after_temp * (1 - token_probs_after_temp) / N) + assert torch.all(std <= 0.02), f"Bounds {std=} are too loose, increase N" + lower_bound = token_probs_after_temp - 4 * std + upper_bound = token_probs_after_temp + 4 * std + assert torch.all(sample_dist >= lower_bound.unsqueeze(1)), ( + f"Sampled too many unlikely tokens: {sample_dist} < {lower_bound}") + assert torch.all(sample_dist <= upper_bound.unsqueeze(1)), ( + f"Sampled too few likely tokens: {sample_dist} > {upper_bound}") + + if enable_probs: + for result_prob in result_probs: + assert result_prob is not None + assert result_prob.shape == (batch_size, num_speculative_tokens, + vocab_size) + # only check sequence 2, since sequence 1 is greedy, so the probs + # are allowed to be anything + assert torch.allclose(result_prob[1], + token_probs_after_temp[1].unsqueeze(0), + atol=1e-3) + else: + assert all(result_prob is None for result_prob in result_probs) @pytest.mark.parametrize( @@ -517,13 +740,15 @@ def create_deterministic_logits(token_ids, k: int): sampling_metadata = mock.MagicMock() # Propose draft tokens. - result = proposer.propose(target_token_ids=target_token_ids, - target_positions=target_positions, - target_hidden_states=target_hidden_states, - next_token_ids=next_token_ids, - common_attn_metadata=common_attn_metadata, - sampling_metadata=sampling_metadata) + result, draft_probs = proposer.propose( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + next_token_ids=next_token_ids, + common_attn_metadata=common_attn_metadata, + sampling_metadata=sampling_metadata) assert result.shape == (batch_size, num_speculative_tokens) + assert draft_probs is None # The tokens are expected to be consecutive integers starting # from the base token IDs. diff --git a/tests/v1/spec_decode/test_scheduling.py b/tests/v1/spec_decode/test_scheduling.py new file mode 100644 index 000000000000..8e5f1f8df529 --- /dev/null +++ b/tests/v1/spec_decode/test_scheduling.py @@ -0,0 +1,253 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import tempfile + +import pytest +import torch + +from tests.v1.worker.test_gpu_model_runner import _schedule_new_request +from vllm.config import VllmConfig +from vllm.distributed import (cleanup_dist_env_and_memory, + init_distributed_environment, + initialize_model_parallel) +from vllm.engine.arg_utils import EngineArgs +from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput +from vllm.v1.engine.core import get_kv_cache_config +from vllm.v1.worker.gpu_model_runner import GPUModelRunner + +model_dir = "meta-llama/Llama-3.1-8B-Instruct" +eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" + + +@pytest.fixture() +def should_do_global_cleanup_after_test(request) -> bool: + # So we can share the DraftModelProposer between tests + return False + + +@pytest.fixture(scope="class") +def monkeyclass(): + with pytest.MonkeyPatch.context() as mp: + yield mp + + +@pytest.fixture(scope="class") +def spec_decode_vllm_config_and_env_setup(monkeyclass: pytest.MonkeyPatch): + with monkeyclass.context() as m: + m.setenv("VLLM_USE_V1", "1") + vllm_config = EngineArgs(model=model_dir, + max_model_len=256, + cuda_graph_sizes=[1, 2, 4], + gpu_memory_utilization=0.8, + speculative_config={ + "model": eagle_dir, + "method": "eagle", + "num_speculative_tokens": 2, + }).create_engine_config() + temp_file = tempfile.mkstemp()[1] + init_distributed_environment( + world_size=1, + rank=0, + distributed_init_method=f"file://{temp_file}", + local_rank=0, + backend="nccl", + ) + initialize_model_parallel(1, 1) + yield vllm_config + cleanup_dist_env_and_memory() + + +@pytest.fixture(scope="class") +def mock_spec_decode_model_runner( + spec_decode_vllm_config_and_env_setup: VllmConfig): + model_runner = GPUModelRunner(spec_decode_vllm_config_and_env_setup, + torch.device("cuda")) + model_runner.load_model() + kv_cache_spec = model_runner.get_kv_cache_spec() + + kv_cache_config = get_kv_cache_config( + spec_decode_vllm_config_and_env_setup, kv_cache_spec, 1024**3) # 1GB + model_runner.initialize_kv_cache(kv_cache_config) + yield model_runner + + +class TestSpecDecodeScheduling: + + def test_spec_decode_partial_scheduling( + self, mock_spec_decode_model_runner: GPUModelRunner): + """Make sure we don't crash when the scheduler schedules only a subset + of the requests. + + Four iterations: + 1. Schedule both req1 (w/ 0 draft) and req2 (w/ 0 draft) + 2. Schedule only req1 (w/ 1 draft) + 3. Schedule both req1 (w/ 1 draft) and req2 (w/ 2 draft) + 4. Terminate req1 and req2 + """ + # Schedule both req1 and req2 on the first iteration + scheduler_output = _schedule_new_request("req1", "req2") + mock_spec_decode_model_runner.execute_model(scheduler_output) + + # Only schedule req1 on the second iteration + cached_req_data = CachedRequestData( + req_ids=["req1"], + resumed_from_preemption=[False], + new_token_ids=[[3]], + new_block_ids=[([], )], + num_computed_tokens=[3], + ) + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=cached_req_data, + num_scheduled_tokens={"req1": 2}, + total_num_scheduled_tokens=2, + scheduled_spec_decode_tokens={"req1": [1001]}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=[0], + finished_req_ids=set(), + free_encoder_mm_hashes=[], + structured_output_request_ids={}, + grammar_bitmask=None, + ) + mock_spec_decode_model_runner.execute_model(scheduler_output) + + # Schedule both req1 and req2 on the third iteration + cached_req_data = CachedRequestData( + req_ids=["req1", "req2"], + resumed_from_preemption=[False, False], + new_token_ids=[[10], [11]], + new_block_ids=[([], ), ([], )], + num_computed_tokens=[4, 3], + ) + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=cached_req_data, + num_scheduled_tokens={ + "req1": 2, + "req2": 3 + }, + total_num_scheduled_tokens=5, + scheduled_spec_decode_tokens={ + "req1": [1001], + "req2": [2001, 2002] + }, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=[0], + finished_req_ids=set(), + free_encoder_mm_hashes=[], + structured_output_request_ids={}, + grammar_bitmask=None, + ) + mock_spec_decode_model_runner.execute_model(scheduler_output) + + # Terminate both req1 and req2 + cached_req_data = CachedRequestData( + req_ids=[], + resumed_from_preemption=[], + new_token_ids=[], + new_block_ids=[], + num_computed_tokens=[], + ) + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=cached_req_data, + num_scheduled_tokens={}, + total_num_scheduled_tokens=0, + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=[0], + finished_req_ids={"req1", "req2"}, + free_encoder_mm_hashes=[], + structured_output_request_ids={}, + grammar_bitmask=None, + ) + mock_spec_decode_model_runner.execute_model(scheduler_output) + + def test_spec_decode_preemption_scheduling( + self, mock_spec_decode_model_runner: GPUModelRunner): + """Make sure we don't crash when the scheduler preempts a request. + + Four iterations: + 1. Schedule req1 (w/ 0 draft) and req2 (w/ 0 draft) + 2. Schedule req1 (w/ 1 draft) and preempt req2 + 3. Schedule req1 (w/ 1 draft) and resume req2 (w/ 2 draft) + 4. Terminate req1 and req2 + """ + # Schedule both req1 and req2 on the first iteration + scheduler_output = _schedule_new_request("req1", "req2") + mock_spec_decode_model_runner.execute_model(scheduler_output) + + # Only schedule req1 on the second iteration + cached_req_data = CachedRequestData( + req_ids=["req1"], + resumed_from_preemption=[False], + new_token_ids=[[3]], + new_block_ids=[([], )], + num_computed_tokens=[3], + ) + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=cached_req_data, + num_scheduled_tokens={"req1": 2}, + total_num_scheduled_tokens=2, + scheduled_spec_decode_tokens={"req1": [1001]}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=[0], + finished_req_ids=set(), + free_encoder_mm_hashes=[], + structured_output_request_ids={}, + grammar_bitmask=None, + ) + mock_spec_decode_model_runner.execute_model(scheduler_output) + + # Schedule both req1 and req2 on the third iteration + cached_req_data = CachedRequestData( + req_ids=["req1", "req2"], + resumed_from_preemption=[False, True], + new_token_ids=[[10], [11]], + new_block_ids=[([], ), ([0], )], + num_computed_tokens=[4, 0], + ) + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=cached_req_data, + num_scheduled_tokens={ + "req1": 2, + "req2": 6 + }, + total_num_scheduled_tokens=8, + scheduled_spec_decode_tokens={ + "req1": [1001], + "req2": [2001, 2002] + }, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=[0], + finished_req_ids=set(), + free_encoder_mm_hashes=[], + structured_output_request_ids={}, + grammar_bitmask=None, + ) + mock_spec_decode_model_runner.execute_model(scheduler_output) + + # Terminate both req1 and req2 + cached_req_data = CachedRequestData( + req_ids=[], + resumed_from_preemption=[], + new_token_ids=[], + new_block_ids=[], + num_computed_tokens=[], + ) + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=cached_req_data, + num_scheduled_tokens={}, + total_num_scheduled_tokens=0, + scheduled_spec_decode_tokens={}, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=[0], + finished_req_ids={"req1", "req2"}, + free_encoder_mm_hashes=[], + structured_output_request_ids={}, + grammar_bitmask=None, + ) + mock_spec_decode_model_runner.execute_model(scheduler_output) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 6d99029e404e..6bc39606686c 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -139,7 +139,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: total_num_scheduled_tokens=total_num_scheduled_tokens, scheduled_spec_decode_tokens={}, scheduled_encoder_inputs={}, - num_common_prefix_blocks=0, + num_common_prefix_blocks=[0], finished_req_ids=set(), free_encoder_mm_hashes=[], structured_output_request_ids={}, diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index e474420e3f04..8f20a329e619 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -1986,6 +1986,12 @@ class SpeculativeConfig: speculative_token_tree: Optional[str] = None """Specifies the tree structure for speculative token generation. """ + enable_draft_probs: bool = True + """Whether to use draft probs for speculative decoding. Using draft probs + always increases the acceptance rate but increases sampling overhead. + For small models and/or low temperatures payloads, it may be beneficial to + disable this. Disabling falls back to greedy sampling for the draft tokens. + """ # required configuration params passed from engine target_model_config: SkipValidation[ModelConfig] = None # type: ignore """The configuration of the target model.""" diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index bf25c91d8390..83a39d0f7194 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -54,8 +54,10 @@ def __init__( ): self.vllm_config = vllm_config self.speculative_config = vllm_config.speculative_config + assert self.speculative_config is not None self.draft_model_config = self.speculative_config.draft_model_config self.method = self.speculative_config.method + self.enable_draft_probs = self.speculative_config.enable_draft_probs self.runner = runner self.dtype = vllm_config.model_config.dtype @@ -158,7 +160,14 @@ def propose( common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, mm_embeds: Optional[list[torch.Tensor]] = None, - ) -> torch.Tensor: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """Execute the draft model to generate draft tokens. + + Returns: + draft_token_ids: [batch_size, num_spec_tokens] + draft_probs (optional): [batch_size, num_spec_tokens, + vocab_size] + """ num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 @@ -238,14 +247,19 @@ def propose( common_attn_metadata=common_attn_metadata, ) # [batch_size, num_tree_tokens] - return torch.cat(draft_token_ids_list, dim=1) + return torch.cat(draft_token_ids_list, dim=1), None - draft_token_ids = logits.argmax(dim=-1) + draft_token_ids, draft_probs = self._compute_probs_and_sample( + logits, sampling_metadata) # Early exit if there is only one draft token to be generated. if self.num_speculative_tokens == 1: - # [batch_size, 1] - return draft_token_ids.view(-1, 1) + return ( + # [batch_size, 1] + draft_token_ids.view(-1, 1), + # [batch_size, 1, vocab_size] + draft_probs.unsqueeze(1) if draft_probs is not None else None, + ) # TODO: Currently, MTP module released by deepseek only has # one layer. Adapt this code to support multiple layers once @@ -253,7 +267,10 @@ def propose( assert isinstance(attn_metadata, self.allowed_attn_types) # Generate the remaining draft tokens. + # Each tensor in the list has shape [batch_size]. draft_token_ids_list = [draft_token_ids] + # Each tensor in the list has shape [batch_size, vocab_size]. + draft_probs_list: list[Optional[torch.Tensor]] = [draft_probs] if self.use_cuda_graph and \ batch_size <= self.cudagraph_batch_sizes[-1]: @@ -331,12 +348,25 @@ def propose( hidden_states = hidden_states[:batch_size] logits = self.model.compute_logits(last_hidden_states[:batch_size], None) - draft_token_ids = logits.argmax(dim=-1) + # TODO(wenlong): get more than one token for tree attention + draft_token_ids, draft_probs = self._compute_probs_and_sample( + logits, sampling_metadata) draft_token_ids_list.append(draft_token_ids) + draft_probs_list.append(draft_probs) - # [batch_size, num_speculative_tokens] - draft_token_ids = torch.stack(draft_token_ids_list, dim=1) - return draft_token_ids + if any(draft_probs is None for draft_probs in draft_probs_list): + draft_probs_final = None + else: + assert all(draft_probs is not None + for draft_probs in draft_probs_list) + draft_probs_final = torch.stack(draft_probs_list, dim=1) + + return ( + # [batch_size, num_spec_tokens] + torch.stack(draft_token_ids_list, dim=1), + # [batch_size, num_spec_tokens, vocab_size] + draft_probs_final, + ) def propose_tree( self, @@ -605,7 +635,7 @@ def prepare_inputs( def load_model(self, target_model: nn.Module) -> None: draft_model_config = \ - self.vllm_config.speculative_config.draft_model_config + self.speculative_config.draft_model_config target_attn_layer_names = set( get_layers_from_vllm_config(self.vllm_config, Attention).keys()) @@ -645,7 +675,7 @@ def load_model(self, target_model: nn.Module) -> None: # share lm_head with the target model if needed # some model definition do not define lm_head explicitly # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM - if self.vllm_config.speculative_config.method != "eagle3" and \ + if self.speculative_config.method != "eagle3" and \ hasattr(target_language_model, "lm_head"): logger.info("Loading EAGLE LM head weights from the target model.") self.model.lm_head = target_language_model.lm_head @@ -664,12 +694,19 @@ def dummy_run( input_ids = self.input_ids[:num_tokens] inputs_embeds = None - self.model( + ret_hidden_states = self.model( input_ids=input_ids, positions=self.positions[:num_tokens], hidden_states=self.hidden_states[:num_tokens], inputs_embeds=inputs_embeds, ) + if self.method == "deepseek_mtp": + last_hidden_states = ret_hidden_states + else: + last_hidden_states, hidden_states = ret_hidden_states + logits = self.model.compute_logits(last_hidden_states, None) + temperature = torch.ones(num_tokens, device=logits.device) + _mixed_sample(logits, temperature) def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None: @@ -690,26 +727,34 @@ def validate_same_kv_cache_group(self, ]) ) == 1, "All eagle layers should belong to the same kv cache group" - -# NOTE(woosuk): Currently, the below code is not used and we always use argmax -# to sample the draft tokens. We will use this after we find a way to manage -# the draft prob tensor. -# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details. -# FIXME(woosuk): The logic here is duplicated with the main sampling code. -# We should refactor this to reuse the same sampling implementation. -def compute_probs_and_sample_next_token( - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, -) -> tuple[torch.Tensor, torch.Tensor]: - if sampling_metadata.all_greedy: - # For greedy requests, draft_probs is not used in rejection sampling. - # Therefore, we can just return the logits. - probs = logits - next_token_ids = logits.argmax(dim=-1) - return next_token_ids, probs - - is_greedy = sampling_metadata.temperature == -1 - temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature) + # FIXME(woosuk): The logic here is duplicated with the main sampling code. + # We should refactor this to reuse the same sampling implementation. + def _compute_probs_and_sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + if not self.enable_draft_probs: + return logits.argmax(dim=-1), None + elif sampling_metadata.all_greedy: + # For greedy requests, draft_probs is not used in rejection + # sampling. Therefore, we can just return the logits. + # We cannot return None because other future steps might + # contain requests that are not greedy. + probs = logits + next_token_ids = logits.argmax(dim=-1) + return next_token_ids, probs + return _mixed_sample(logits, sampling_metadata.temperature) + + +@torch.compile(dynamic=True, + backend=current_platform.simple_compile_backend, + mode="max-autotune-no-cudagraphs") +def _mixed_sample( + logits: torch.Tensor, + temperature: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + is_greedy = temperature == -1 + temperature = torch.where(is_greedy, 1.0, temperature) logits.div_(temperature.view(-1, 1)) probs = logits.softmax(dim=-1, dtype=torch.float32) @@ -721,14 +766,8 @@ def compute_probs_and_sample_next_token( # TODO(woosuk): Consider seeds. q = torch.empty_like(probs) q.exponential_() + q[is_greedy, :] = 1.0 # NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs # will be used later for rejection sampling. next_token_ids = probs.div(q).argmax(dim=-1).view(-1) - if not sampling_metadata.all_random: - greedy_token_ids = probs.argmax(dim=-1) - next_token_ids = torch.where( - is_greedy, - greedy_token_ids, - next_token_ids, - ) return next_token_ids, probs diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index bf9b16575e60..ac75087eabb7 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -46,6 +46,7 @@ class CachedRequestState: mrope_position_delta: Optional[int] = None lora_request: Optional[LoRARequest] = None + draft_probs: Optional[torch.Tensor] = None def __post_init__(self): self.num_prompt_tokens = len(self.prompt_token_ids) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 549c5dd2bbb2..0825cb162e02 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1701,9 +1701,11 @@ def _sample( # separate storage from the original `logits` tensor. Therefore, # it is safe to update `target_logits` in place. target_logits = logits[spec_decode_metadata.target_logits_indices] + + draft_probs = self._collect_draft_probs(spec_decode_metadata) output_token_ids = self.rejection_sampler( spec_decode_metadata, - None, # draft_probs + draft_probs, target_logits, bonus_token_ids, sampling_metadata, @@ -1973,16 +1975,18 @@ def execute_model( if self.speculative_config: assert spec_decode_common_attn_metadata is not None with record_function_or_nullcontext("Draft"): - self._draft_token_ids = self.propose_draft_token_ids( - scheduler_output, - valid_sampled_token_ids, - self.input_batch.sampling_metadata, - hidden_states, - sample_hidden_states, - aux_hidden_states, - spec_decode_metadata, - spec_decode_common_attn_metadata, - ) + (self._draft_token_ids, + draft_probs) = self.propose_draft_token_ids( + scheduler_output, + valid_sampled_token_ids, + self.input_batch.sampling_metadata, + hidden_states, + sample_hidden_states, + aux_hidden_states, + spec_decode_metadata, + spec_decode_common_attn_metadata, + ) + self._store_draft_probs(draft_probs) with record_function_or_nullcontext("EPLB"): self.eplb_step() @@ -2008,6 +2012,52 @@ def execute_model( async_output_copy_stream=self.async_output_copy_stream, ) + def _store_draft_probs(self, draft_probs: Optional[torch.Tensor]) -> None: + """Store the draft probs for future use, usually the next step. + + Args: + draft_probs: [num_reqs, num_draft_tokens, vocab_size]. + It is assumed that draft_probs are in the same order as the + requests in the input batch. + """ + if draft_probs is None: + return + for i, spec_prob in enumerate(draft_probs): + req_id = self.input_batch.req_ids[i] + self.requests[req_id].draft_probs = spec_prob + + def _collect_draft_probs( + self, spec_decode_metadata: SpecDecodeMetadata + ) -> Optional[torch.Tensor]: + """Collect the draft probs for the requests in the current input batch. + + Args: + spec_decode_metadata: The metadata for speculative decoding for + the current step. + + Returns: + draft_probs: None if no draft probs are available. + Otherwise, a packed sequence of draft probs with shape + [total_num_draft_tokens, vocab_size]. + """ + draft_probs_list: list[torch.Tensor] = [] + has_draft_probs: list[bool] = [] + for i, req_id in enumerate(self.input_batch.req_ids): + draft_length = spec_decode_metadata.num_draft_tokens[i] + if draft_length > 0: + draft_probs = self.requests[req_id].draft_probs + has_draft_probs.append(draft_probs is not None) + if draft_probs is not None: + # <= since not every draft token is necessarily scheduled + assert draft_length <= draft_probs.shape[0] + draft_probs_list.append(draft_probs[:draft_length]) + assert all(has_draft_probs) or not any(has_draft_probs), ( + "Some requests have draft logits while others do not.") + + draft_probs = (torch.cat(draft_probs_list, dim=0) + if len(draft_probs_list) > 0 else None) + return draft_probs + def take_draft_token_ids(self) -> Optional[DraftTokenIds]: if self._draft_token_ids is None: return None @@ -2029,12 +2079,19 @@ def propose_draft_token_ids( aux_hidden_states: Optional[torch.Tensor], spec_decode_metadata: Optional[SpecDecodeMetadata], common_attn_metadata: CommonAttentionMetadata, - ) -> Union[list[list[int]], torch.Tensor]: + ) -> tuple[Union[list[list[int]], torch.Tensor], Optional[torch.Tensor]]: + """Generate the draft for the next step. + + Returns: + draft_token_ids: [num_requests, num_draft_tokens] + draft_probs (optional): [num_requests, num_draft_tokens, vocab_size] + """ num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if self.speculative_config.method == "ngram": assert isinstance(self.drafter, NgramProposer) draft_token_ids = self.propose_ngram_draft_token_ids( sampled_token_ids) + draft_probs = None elif self.speculative_config.method == "medusa": assert isinstance(self.drafter, MedusaProposer) if sample_hidden_states.shape[0] == len(sampled_token_ids): @@ -2055,6 +2112,7 @@ def propose_draft_token_ids( target_hidden_states=hidden_states, sampling_metadata=sampling_metadata, ) + draft_probs = None elif self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) # TODO(woosuk): Refactor the loop. @@ -2114,7 +2172,7 @@ def propose_draft_token_ids( mm_embeds = self._gather_mm_embeddings(scheduler_output, shift_computed_tokens=1) - draft_token_ids = self.drafter.propose( + draft_token_ids, draft_probs = self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, @@ -2123,7 +2181,10 @@ def propose_draft_token_ids( common_attn_metadata=common_attn_metadata, mm_embeds=mm_embeds, ) - return draft_token_ids + else: + raise ValueError(f"Unsupported speculative decoding method: " + f"{self.speculative_config.method}") + return draft_token_ids, draft_probs def propose_ngram_draft_token_ids( self, @@ -2703,10 +2764,10 @@ def _dummy_sampler_run( draft_token_ids, self.device) num_tokens = sum(len(ids) for ids in draft_token_ids) - # draft_probs = torch.randn( - # num_tokens, logits.shape[-1], device=self.device, - # dtype=logits.dtype) - draft_probs = None + draft_probs = torch.randn(num_tokens, + logits.shape[-1], + device=self.device, + dtype=logits.dtype) target_logits = torch.randn(num_tokens, logits.shape[-1], device=self.device, @@ -2714,7 +2775,7 @@ def _dummy_sampler_run( # NOTE(woosuk): Here, we should use int32 because the sampler uses # int32 for bonus_token_ids. If the dtype mismatches, re-compilation # will occur at runtime. - bonus_token_ids = torch.zeros(num_reqs, + bonus_token_ids = torch.zeros((num_reqs, 1), device=self.device, dtype=torch.int32) self.rejection_sampler(