diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index 184c30891eca..f01a3613d85b 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json + from transformers import AutoTokenizer from vllm import LLM, SamplingParams @@ -68,6 +70,7 @@ def parse_args(): parser.add_argument("--model-dir", type=str, default=None) parser.add_argument("--eagle-dir", type=str, default=None) parser.add_argument("--custom-mm-prompts", action="store_true") + parser.add_argument("--compilation-config", type=str, default="") return parser.parse_args() @@ -132,6 +135,9 @@ def main(): max_model_len=16384, limit_mm_per_prompt={"image": 5}, disable_chunked_mm_input=True, + compilation_config=( + json.loads(args.compilation_config) if args.compilation_config else None + ), ) sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index dde95fbe590b..eb7f0be1f884 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from __future__ import annotations +import math import random from typing import Any, Union @@ -14,6 +15,7 @@ from vllm.assets.image import VLM_IMAGES_DIR from vllm.distributed import cleanup_dist_env_and_memory from vllm.platforms import current_platform +from vllm.v1.metrics.reader import Counter, Metric, Vector def get_test_prompts(mm_enabled: bool): @@ -67,6 +69,23 @@ def get_test_prompts(mm_enabled: bool): return prompts +def get_acceptance_rate(metrics: list[Metric]): + num_drafts = num_accepted = 0 + acceptance_counts = [0] * 3 + for metric in metrics: + if metric.name == "vllm:spec_decode_num_drafts": + assert isinstance(metric, Counter) + num_drafts += metric.value + elif metric.name == "vllm:spec_decode_num_accepted_tokens": + assert isinstance(metric, Counter) + num_accepted += metric.value + elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos": + assert isinstance(metric, Vector) + for pos in range(len(metric.values)): + acceptance_counts[pos] += metric.values[pos] + return 1.0 * num_accepted / num_drafts + 1 + + @pytest.fixture def sampling_config(): return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False) @@ -220,3 +239,95 @@ def test_eagle_correctness( del spec_llm torch.cuda.empty_cache() cleanup_dist_env_and_memory() + + +@pytest.mark.parametrize( + "model_setup", + [ + ("eagle", "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), + ], +) +def test_full_vs_piecewise_cudagraph( + monkeypatch: pytest.MonkeyPatch, + sampling_config: SamplingParams, + model_setup: tuple[str, str, str, int], +): + test_prompts = get_test_prompts(mm_enabled=False) + ''' + Compare the eagle speculative decoding outputs and acceptance + rate should match between piecewise and full cudagraph mode + model_setup: (method, model_name, eagle_model_name, tp_size) + ''' + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + m.setenv("VLLM_GPU_MEMORY_UTILIZATION", "0.8") + if current_platform.is_rocm(): + m.setenv("VLLM_ROCM_USE_AITER", "1") + else: + m.setenv("VLLM_FLASH_ATTN_VERSION", "3") + + method, model_name, spec_model_name, tp_size = model_setup + + spec_llm_piecewise_cudagraph = LLM( + model=model_name, + trust_remote_code=True, + tensor_parallel_size=tp_size, + speculative_config={ + "method": method, + "model": spec_model_name, + "num_speculative_tokens": 3, + "max_model_len": 2048, + }, + max_model_len=2048, + compilation_config={"full_cuda_graph": False}, + disable_log_stats=False, + ) + piecewise_cudagraph_outputs = spec_llm_piecewise_cudagraph.chat( + test_prompts, sampling_config) + piecewise_cudagraph_metrics = spec_llm_piecewise_cudagraph.get_metrics( + ) + del spec_llm_piecewise_cudagraph + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + + spec_llm_full_cudagraph = LLM( + model=model_name, + trust_remote_code=True, + tensor_parallel_size=tp_size, + speculative_config={ + "method": method, + "model": spec_model_name, + "num_speculative_tokens": 3, + "max_model_len": 2048, + }, + max_model_len=2048, + compilation_config={"full_cuda_graph": True}, + disable_log_stats=False, + ) + full_cudagraph_outputs = spec_llm_full_cudagraph.chat( + test_prompts, sampling_config) + full_cudagraph_metrics = spec_llm_full_cudagraph.get_metrics() + matches = 0 + misses = 0 + for piecewise, full in zip(piecewise_cudagraph_outputs, + full_cudagraph_outputs): + if piecewise.outputs[0].text == full.outputs[0].text: + matches += 1 + else: + misses += 1 + print( + f"piecewise_cudagraph_output: {piecewise.outputs[0].text}") + print(f"full_cudagraph_output: {full.outputs[0].text}") + + # Heuristic: expect at least 66% of the prompts to match exactly between + # piecewise and full cudagraph mode, and acceptance rate to be within + # 0.1 atol + # Upon failure, inspect the outputs to check for inaccuracy. + assert matches > int(0.66 * len(piecewise_cudagraph_outputs)) + piecewise_acceptance = get_acceptance_rate(piecewise_cudagraph_metrics) + full_acceptance = get_acceptance_rate(full_cudagraph_metrics) + assert math.isclose(piecewise_acceptance, full_acceptance, abs_tol=0.1) + del spec_llm_full_cudagraph + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index 5d10e9e26082..38d44182f276 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -228,6 +228,7 @@ def build_for_drafting( self, common_attn_metadata: CommonAttentionMetadata, draft_index: int, + fast_build: bool = True, ) -> TreeAttentionMetadata: # Cache the original tree attention bias. orig_tree_attn_bias = self.tree_attn_bias @@ -243,7 +244,9 @@ def build_for_drafting( start:end].contiguous() # Build attention bias. - attn_metadata = self.build(0, common_attn_metadata, fast_build=True) + attn_metadata = self.build(0, + common_attn_metadata, + fast_build=fast_build) # Reset the tree attention bias to the original value. self.tree_attn_bias = orig_tree_attn_bias diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 91eb84245ac0..dfe33a87707c 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -36,7 +36,7 @@ class CommonAttentionMetadata: """ Per-batch attention metadata, shared across layers and backends. AttentionMetadataBuilder instances use it to construct per-layer metadata. - + For many of the tensors we keep both GPU and CPU versions. """ @@ -76,7 +76,7 @@ def slice_query_start_locs( request_slice: slice, ) -> torch.Tensor: """ - Creates a new query_start_loc that corresponds to the requests in + Creates a new query_start_loc that corresponds to the requests in request_slice. Note: This function creates a new tensor to hold the new query_start_locs. @@ -90,7 +90,7 @@ def _make_metadata_with_slice( ubatch_slice: UbatchSlice, attn_metadata: CommonAttentionMetadata) -> CommonAttentionMetadata: """ - This function creates a new CommonAttentionMetadata that corresponds to + This function creates a new CommonAttentionMetadata that corresponds to the requests included in ubatch_slice """ @@ -138,7 +138,7 @@ def split_attn_metadata( common_attn_metadata: CommonAttentionMetadata, ) -> list[CommonAttentionMetadata]: """ - Creates a new CommonAttentionMetadata instance that corresponds to the + Creates a new CommonAttentionMetadata instance that corresponds to the requests for each UbatchSlice in ubatch_slices. Note: This function does not modify common_attn_metadata @@ -189,7 +189,7 @@ def build(self, """ Central method that builds attention metadata. Some builders (MLA) require reorder_batch to be called prior to build. - + Args: common_prefix_len: The length of the common prefix of the batch. common_attn_metadata: The common attention metadata. @@ -220,10 +220,11 @@ def build_for_drafting( self, common_attn_metadata: CommonAttentionMetadata, draft_index: int, + fast_build: bool = True, ) -> M: """ Build attention metadata for draft model. Uses build by default. - + Args: common_attn_metadata: The common attention metadata. draft_index: The index of the current draft operation. @@ -234,7 +235,7 @@ def build_for_drafting( """ return self.build(common_prefix_len=0, common_attn_metadata=common_attn_metadata, - fast_build=True) + fast_build=fast_build) def use_cascade_attention( self, @@ -629,7 +630,7 @@ def reorder_batch_to_split_decodes_and_prefills( """ Reorders the batch to split into prefill and decode requests; places all requests with <= decode_threshold tokens at the front of the batch. - + Returns: True if the batch was modified, False otherwise. """ diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index a8a160a0f995..9f270e731a99 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import ast from dataclasses import replace -from typing import Optional +from typing import Any, Optional import numpy as np import torch @@ -67,10 +67,12 @@ def __init__( self.use_cuda_graph = (self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not self.vllm_config.model_config.enforce_eager) + self.use_full_cuda_graph = ( + self.use_cuda_graph + and vllm_config.compilation_config.full_cuda_graph) self.cudagraph_batch_sizes = list( reversed( self.vllm_config.compilation_config.cudagraph_capture_sizes)) - # persistent buffers for cuda graph self.input_ids = torch.zeros(self.max_num_tokens, dtype=torch.int32, @@ -120,6 +122,8 @@ def __init__( device=device, dtype=torch.int32, ).repeat(max_batch_size, 1) + # attention metadata captured in full cudagraph mode + self.attn_metadata_cudagraph = None def propose( self, @@ -157,7 +161,8 @@ def propose( # FIXME: need to consider multiple kv_cache_groups attn_metadata = self.runner.attn_groups[0][0].metadata_builder\ .build_for_drafting(common_attn_metadata=common_attn_metadata, - draft_index=0) + draft_index=0, + fast_build=not self.use_full_cuda_graph) # At this moment, we assume all eagle layers belong to the same KV # cache group, thus using the same attention metadata. @@ -185,6 +190,18 @@ def propose( inputs_embeds = None input_ids = self.input_ids[:num_input_tokens] + if (self.use_full_cuda_graph + and num_tokens <= self.cudagraph_batch_sizes[-1]): + assert self.attn_metadata_cudagraph + self.attn_metadata_cudagraph.seq_lens[:batch_size] = ( + attn_metadata.seq_lens) + self.attn_metadata_cudagraph.slot_mapping[:num_tokens] = ( + attn_metadata.slot_mapping) + self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = ( + attn_metadata.query_start_loc) + self.attn_metadata_cudagraph.block_table[:batch_size] = ( + attn_metadata.block_table) + with set_forward_context(per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens): @@ -245,11 +262,17 @@ def propose( if self.use_cuda_graph and \ batch_size <= self.cudagraph_batch_sizes[-1]: input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) + if self.use_full_cuda_graph: + assert self.attn_metadata_cudagraph + self.attn_metadata_cudagraph.block_table[:batch_size] = ( + attn_metadata.block_table) + attn_metadata = self.attn_metadata_cudagraph else: input_batch_size = batch_size attn_metadata.num_actual_tokens = batch_size attn_metadata.max_query_len = 1 - attn_metadata.query_start_loc = self.arange[:batch_size + 1] + attn_metadata.query_start_loc[:batch_size + + 1] = self.arange[:batch_size + 1] for _ in range(self.num_speculative_tokens - 1): # Update the inputs. # cast to int32 is crucial when eagle model is compiled. @@ -277,20 +300,27 @@ def propose( self.max_model_len) # For the requests that exceed the max model length, we set the # sequence length to 1 to minimize their overheads in attention. - attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) + attn_metadata.seq_lens[:batch_size].masked_fill_( + exceeds_max_model_len, 1) # Compute the slot mapping. block_numbers = clamped_positions // self.block_size block_ids = attn_metadata.block_table.gather( dim=1, index=block_numbers.view(-1, 1)) block_ids = block_ids.view(-1) - attn_metadata.slot_mapping = (block_ids * self.block_size + - clamped_positions % self.block_size) + slot_mapping = (block_ids * self.block_size + + clamped_positions % self.block_size) + if self.use_full_cuda_graph: + attn_metadata.slot_mapping[:batch_size] = slot_mapping + else: + # In eager mode attention, slot_mapping's shape is used to + # determine the number of actual tokens. + attn_metadata.slot_mapping = slot_mapping # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the # padding tokens. - attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len, - PADDING_SLOT_ID) + attn_metadata.slot_mapping[:batch_size].masked_fill_( + exceeds_max_model_len, PADDING_SLOT_ID) # copy inputs to buffer for cudagraph self.input_ids[:batch_size] = input_ids @@ -642,8 +672,14 @@ def load_model(self, target_model: nn.Module) -> None: def dummy_run( self, num_tokens: int, + attn_metadata: Optional[dict[str, Any]] = None, ) -> None: - with set_forward_context(None, self.vllm_config, + if attn_metadata is not None and self.attn_metadata_cudagraph is None: + # attn_metadata is shared across all draft layers + self.attn_metadata_cudagraph = attn_metadata[ + self.attn_layer_names[0]] + with set_forward_context(attn_metadata, + self.vllm_config, num_tokens=num_tokens): if self.is_multimodal_model: input_ids = None diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a03e860a91c7..e1a6e37ebb4c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2320,7 +2320,7 @@ def _dummy_run( if self.speculative_config and self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) - self.drafter.dummy_run(num_tokens) + self.drafter.dummy_run(num_tokens, attn_metadata) # This is necessary to avoid blocking DP. # For dummy runs, we typically skip EPLB since we don't have any real