Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/offline_inference/spec_decode.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import json

from transformers import AutoTokenizer

from vllm import LLM, SamplingParams
Expand Down Expand Up @@ -68,6 +70,7 @@ def parse_args():
parser.add_argument("--model-dir", type=str, default=None)
parser.add_argument("--eagle-dir", type=str, default=None)
parser.add_argument("--custom-mm-prompts", action="store_true")
parser.add_argument("--compilation-config", type=str, default="")
return parser.parse_args()


Expand Down Expand Up @@ -132,6 +135,9 @@ def main():
max_model_len=16384,
limit_mm_per_prompt={"image": 5},
disable_chunked_mm_input=True,
compilation_config=(
json.loads(args.compilation_config) if args.compilation_config else None
),
)

sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
Expand Down
111 changes: 111 additions & 0 deletions tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations

import math
import random
from typing import Any, Union

Expand All @@ -14,6 +15,7 @@
from vllm.assets.image import VLM_IMAGES_DIR
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.platforms import current_platform
from vllm.v1.metrics.reader import Counter, Metric, Vector


def get_test_prompts(mm_enabled: bool):
Expand Down Expand Up @@ -67,6 +69,23 @@ def get_test_prompts(mm_enabled: bool):
return prompts


def get_acceptance_rate(metrics: list[Metric]):
num_drafts = num_accepted = 0
acceptance_counts = [0] * 3
for metric in metrics:
if metric.name == "vllm:spec_decode_num_drafts":
assert isinstance(metric, Counter)
num_drafts += metric.value
elif metric.name == "vllm:spec_decode_num_accepted_tokens":
assert isinstance(metric, Counter)
num_accepted += metric.value
elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
assert isinstance(metric, Vector)
for pos in range(len(metric.values)):
acceptance_counts[pos] += metric.values[pos]
return 1.0 * num_accepted / num_drafts + 1


@pytest.fixture
def sampling_config():
return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)
Expand Down Expand Up @@ -220,3 +239,95 @@ def test_eagle_correctness(
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()


@pytest.mark.parametrize(
"model_setup",
[
("eagle", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1),
],
)
def test_full_vs_piecewise_cudagraph(
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
model_setup: tuple[str, str, str, int],
):
test_prompts = get_test_prompts(mm_enabled=False)
'''
Compare the eagle speculative decoding outputs and acceptance
rate should match between piecewise and full cudagraph mode
model_setup: (method, model_name, eagle_model_name, tp_size)
'''
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
m.setenv("VLLM_GPU_MEMORY_UTILIZATION", "0.8")
if current_platform.is_rocm():
m.setenv("VLLM_ROCM_USE_AITER", "1")
else:
m.setenv("VLLM_FLASH_ATTN_VERSION", "3")

method, model_name, spec_model_name, tp_size = model_setup

spec_llm_piecewise_cudagraph = LLM(
model=model_name,
trust_remote_code=True,
tensor_parallel_size=tp_size,
speculative_config={
"method": method,
"model": spec_model_name,
"num_speculative_tokens": 3,
"max_model_len": 2048,
},
max_model_len=2048,
compilation_config={"full_cuda_graph": False},
disable_log_stats=False,
)
piecewise_cudagraph_outputs = spec_llm_piecewise_cudagraph.chat(
test_prompts, sampling_config)
piecewise_cudagraph_metrics = spec_llm_piecewise_cudagraph.get_metrics(
)
del spec_llm_piecewise_cudagraph
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()

spec_llm_full_cudagraph = LLM(
model=model_name,
trust_remote_code=True,
tensor_parallel_size=tp_size,
speculative_config={
"method": method,
"model": spec_model_name,
"num_speculative_tokens": 3,
"max_model_len": 2048,
},
max_model_len=2048,
compilation_config={"full_cuda_graph": True},
disable_log_stats=False,
)
full_cudagraph_outputs = spec_llm_full_cudagraph.chat(
test_prompts, sampling_config)
full_cudagraph_metrics = spec_llm_full_cudagraph.get_metrics()
matches = 0
misses = 0
for piecewise, full in zip(piecewise_cudagraph_outputs,
full_cudagraph_outputs):
if piecewise.outputs[0].text == full.outputs[0].text:
matches += 1
else:
misses += 1
print(
f"piecewise_cudagraph_output: {piecewise.outputs[0].text}")
print(f"full_cudagraph_output: {full.outputs[0].text}")

# Heuristic: expect at least 66% of the prompts to match exactly between
# piecewise and full cudagraph mode, and acceptance rate to be within
# 0.1 atol
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(piecewise_cudagraph_outputs))
piecewise_acceptance = get_acceptance_rate(piecewise_cudagraph_metrics)
full_acceptance = get_acceptance_rate(full_cudagraph_metrics)
assert math.isclose(piecewise_acceptance, full_acceptance, abs_tol=0.1)
del spec_llm_full_cudagraph
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
5 changes: 4 additions & 1 deletion vllm/v1/attention/backends/tree_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def build_for_drafting(
self,
common_attn_metadata: CommonAttentionMetadata,
draft_index: int,
fast_build: bool = True,
) -> TreeAttentionMetadata:
# Cache the original tree attention bias.
orig_tree_attn_bias = self.tree_attn_bias
Expand All @@ -243,7 +244,9 @@ def build_for_drafting(
start:end].contiguous()

# Build attention bias.
attn_metadata = self.build(0, common_attn_metadata, fast_build=True)
attn_metadata = self.build(0,
common_attn_metadata,
fast_build=fast_build)

# Reset the tree attention bias to the original value.
self.tree_attn_bias = orig_tree_attn_bias
Expand Down
17 changes: 9 additions & 8 deletions vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class CommonAttentionMetadata:
"""
Per-batch attention metadata, shared across layers and backends.
AttentionMetadataBuilder instances use it to construct per-layer metadata.

For many of the tensors we keep both GPU and CPU versions.
"""

Expand Down Expand Up @@ -76,7 +76,7 @@ def slice_query_start_locs(
request_slice: slice,
) -> torch.Tensor:
"""
Creates a new query_start_loc that corresponds to the requests in
Creates a new query_start_loc that corresponds to the requests in
request_slice.

Note: This function creates a new tensor to hold the new query_start_locs.
Expand All @@ -90,7 +90,7 @@ def _make_metadata_with_slice(
ubatch_slice: UbatchSlice,
attn_metadata: CommonAttentionMetadata) -> CommonAttentionMetadata:
"""
This function creates a new CommonAttentionMetadata that corresponds to
This function creates a new CommonAttentionMetadata that corresponds to
the requests included in ubatch_slice
"""

Expand Down Expand Up @@ -138,7 +138,7 @@ def split_attn_metadata(
common_attn_metadata: CommonAttentionMetadata,
) -> list[CommonAttentionMetadata]:
"""
Creates a new CommonAttentionMetadata instance that corresponds to the
Creates a new CommonAttentionMetadata instance that corresponds to the
requests for each UbatchSlice in ubatch_slices.

Note: This function does not modify common_attn_metadata
Expand Down Expand Up @@ -189,7 +189,7 @@ def build(self,
"""
Central method that builds attention metadata.
Some builders (MLA) require reorder_batch to be called prior to build.

Args:
common_prefix_len: The length of the common prefix of the batch.
common_attn_metadata: The common attention metadata.
Expand Down Expand Up @@ -220,10 +220,11 @@ def build_for_drafting(
self,
common_attn_metadata: CommonAttentionMetadata,
draft_index: int,
fast_build: bool = True,
) -> M:
"""
Build attention metadata for draft model. Uses build by default.

Args:
common_attn_metadata: The common attention metadata.
draft_index: The index of the current draft operation.
Expand All @@ -234,7 +235,7 @@ def build_for_drafting(
"""
return self.build(common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
fast_build=True)
fast_build=fast_build)

def use_cascade_attention(
self,
Expand Down Expand Up @@ -629,7 +630,7 @@ def reorder_batch_to_split_decodes_and_prefills(
"""
Reorders the batch to split into prefill and decode requests; places all
requests with <= decode_threshold tokens at the front of the batch.

Returns:
True if the batch was modified, False otherwise.
"""
Expand Down
56 changes: 46 additions & 10 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
from dataclasses import replace
from typing import Optional
from typing import Any, Optional

import numpy as np
import torch
Expand Down Expand Up @@ -67,10 +67,12 @@ def __init__(
self.use_cuda_graph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE and
not self.vllm_config.model_config.enforce_eager)
self.use_full_cuda_graph = (
self.use_cuda_graph
and vllm_config.compilation_config.full_cuda_graph)
self.cudagraph_batch_sizes = list(
reversed(
self.vllm_config.compilation_config.cudagraph_capture_sizes))

# persistent buffers for cuda graph
self.input_ids = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
Expand Down Expand Up @@ -120,6 +122,8 @@ def __init__(
device=device,
dtype=torch.int32,
).repeat(max_batch_size, 1)
# attention metadata captured in full cudagraph mode
self.attn_metadata_cudagraph = None

def propose(
self,
Expand Down Expand Up @@ -157,7 +161,8 @@ def propose(
# FIXME: need to consider multiple kv_cache_groups
attn_metadata = self.runner.attn_groups[0][0].metadata_builder\
.build_for_drafting(common_attn_metadata=common_attn_metadata,
draft_index=0)
draft_index=0,
fast_build=not self.use_full_cuda_graph)

# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
Expand Down Expand Up @@ -185,6 +190,18 @@ def propose(
inputs_embeds = None
input_ids = self.input_ids[:num_input_tokens]

if (self.use_full_cuda_graph
and num_tokens <= self.cudagraph_batch_sizes[-1]):
assert self.attn_metadata_cudagraph
self.attn_metadata_cudagraph.seq_lens[:batch_size] = (
attn_metadata.seq_lens)
self.attn_metadata_cudagraph.slot_mapping[:num_tokens] = (
attn_metadata.slot_mapping)
self.attn_metadata_cudagraph.query_start_loc[:batch_size + 1] = (
attn_metadata.query_start_loc)
self.attn_metadata_cudagraph.block_table[:batch_size] = (
attn_metadata.block_table)

with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
Expand Down Expand Up @@ -245,11 +262,17 @@ def propose(
if self.use_cuda_graph and \
batch_size <= self.cudagraph_batch_sizes[-1]:
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
if self.use_full_cuda_graph:
assert self.attn_metadata_cudagraph
self.attn_metadata_cudagraph.block_table[:batch_size] = (
attn_metadata.block_table)
attn_metadata = self.attn_metadata_cudagraph
else:
input_batch_size = batch_size
attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
attn_metadata.query_start_loc[:batch_size +
1] = self.arange[:batch_size + 1]
for _ in range(self.num_speculative_tokens - 1):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
Expand Down Expand Up @@ -277,20 +300,27 @@ def propose(
self.max_model_len)
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
attn_metadata.seq_lens[:batch_size].masked_fill_(
exceeds_max_model_len, 1)

# Compute the slot mapping.
block_numbers = clamped_positions // self.block_size
block_ids = attn_metadata.block_table.gather(
dim=1, index=block_numbers.view(-1, 1))
block_ids = block_ids.view(-1)
attn_metadata.slot_mapping = (block_ids * self.block_size +
clamped_positions % self.block_size)
slot_mapping = (block_ids * self.block_size +
clamped_positions % self.block_size)
if self.use_full_cuda_graph:
attn_metadata.slot_mapping[:batch_size] = slot_mapping
else:
# In eager mode attention, slot_mapping's shape is used to
# determine the number of actual tokens.
attn_metadata.slot_mapping = slot_mapping
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len,
PADDING_SLOT_ID)
attn_metadata.slot_mapping[:batch_size].masked_fill_(
exceeds_max_model_len, PADDING_SLOT_ID)

# copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids
Expand Down Expand Up @@ -642,8 +672,14 @@ def load_model(self, target_model: nn.Module) -> None:
def dummy_run(
self,
num_tokens: int,
attn_metadata: Optional[dict[str, Any]] = None,
) -> None:
with set_forward_context(None, self.vllm_config,
if attn_metadata is not None and self.attn_metadata_cudagraph is None:
# attn_metadata is shared across all draft layers
self.attn_metadata_cudagraph = attn_metadata[
self.attn_layer_names[0]]
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_tokens):
if self.is_multimodal_model:
input_ids = None
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2320,7 +2320,7 @@ def _dummy_run(

if self.speculative_config and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
self.drafter.dummy_run(num_tokens)
self.drafter.dummy_run(num_tokens, attn_metadata)

# This is necessary to avoid blocking DP.
# For dummy runs, we typically skip EPLB since we don't have any real
Expand Down