Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,3 +974,18 @@ def get_client_text_logprob_generations(
return [(text_generations, text,
(None if x.logprobs is None else x.logprobs.top_logprobs))
for completion in completions for x in completion.choices]


def get_attn_backend_list_based_on_platform() -> list[str]:
if current_platform.is_cuda():
return ["FLASH_ATTN_VLLM_V1", "TRITON_ATTN_VLLM_V1"]
elif current_platform.is_rocm():
attn_backend_list = ["TRITON_ATTN_VLLM_V1"]
try:
attn_backend_list.append("FLASH_ATTN_VLLM_V1")
except Exception:
print("Skip FLASH_ATTN_VLLM_V1 on ROCm as aiter is not installed")

return attn_backend_list
else:
raise ValueError("Unsupported platform")
9 changes: 9 additions & 0 deletions tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import pytest
import torch

from tests.utils import get_attn_backend_list_based_on_platform
from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.platforms import current_platform


@pytest.fixture
Expand Down Expand Up @@ -114,11 +116,14 @@ def test_ngram_correctness(
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
],
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle"])
@pytest.mark.parametrize("attn_backend",
get_attn_backend_list_based_on_platform())
def test_eagle_correctness(
monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_setup: tuple[str, str, str, int],
attn_backend: str,
):
'''
Compare the outputs of a original LLM and a speculative LLM
Expand All @@ -127,6 +132,10 @@ def test_eagle_correctness(
'''
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)

if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
m.setenv("VLLM_ROCM_USE_AITER", "1")
method, model_name, spec_model_name, tp_size = model_setup

ref_llm = LLM(model=model_name,
Expand Down
45 changes: 36 additions & 9 deletions tests/v1/spec_decode/test_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
import torch

from tests.utils import get_attn_backend_list_based_on_platform
from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata,
create_standard_kv_cache_spec,
Expand Down Expand Up @@ -125,17 +126,27 @@ def test_prepare_inputs():
assert torch.equal(token_indices, expected_token_indices)


@pytest.mark.parametrize("method,proposer_helper", [
("eagle", lambda k: _create_proposer("eagle", k)),
("eagle3", lambda k: _create_proposer("eagle3", k)),
])
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
@pytest.mark.parametrize("attn_backend",
get_attn_backend_list_based_on_platform())
@pytest.mark.parametrize("pp_size", [1, 2])
@pytest.mark.parametrize("use_distinct_embed_tokens", [True, False])
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
@mock.patch('vllm.v1.spec_decode.eagle.get_model')
def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
proposer_helper, pp_size, use_distinct_embed_tokens):
attn_backend, pp_size, use_distinct_embed_tokens,
monkeypatch):

monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)

if attn_backend == "TRITON_ATTN_VLLM_V1" and current_platform.is_cuda():
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
"multi-token eagle spec decode on CUDA")

if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

# Setup draft model mock
mock_model = mock.MagicMock()
if use_distinct_embed_tokens:
Expand Down Expand Up @@ -182,7 +193,7 @@ class _TargetModelStub(LlamaForCausalLM):
target_model.lm_head = mock.MagicMock()

# Create proposer using the helper function
proposer = proposer_helper(k=8)
proposer = _create_proposer(method, k=8)

# Call the method under test
proposer.load_model(target_model)
Expand All @@ -206,8 +217,17 @@ class _TargetModelStub(LlamaForCausalLM):
target_model.model.embed_tokens


@pytest.mark.parametrize("method", ["eagle", "eagle3"])
@pytest.mark.parametrize("attn_backend",
get_attn_backend_list_based_on_platform())
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
def test_propose(num_speculative_tokens):
def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):

monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)

if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

# Use GPU device
device = torch.device(current_platform.device_type)

Expand Down Expand Up @@ -306,8 +326,15 @@ def create_deterministic_logits(token_ids):
device=device)
sampling_metadata = mock.MagicMock()

attn_metadata_builder_cls, _ = get_attention_backend(
_Backend.FLASH_ATTN_VLLM_V1)
if attn_backend == "FLASH_ATTN_VLLM_V1":
attn_metadata_builder_cls, _ = get_attention_backend(
_Backend.FLASH_ATTN_VLLM_V1)
elif attn_backend == "TRITON_ATTN_VLLM_V1":
attn_metadata_builder_cls, _ = get_attention_backend(
_Backend.TRITON_ATTN_VLLM_V1)
else:
raise ValueError(f"Unsupported attention backend: {attn_backend}")

attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
vllm_config=proposer.vllm_config,
Expand Down
54 changes: 30 additions & 24 deletions tests/v1/spec_decode/test_max_len.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

import pytest

from tests.utils import get_attn_backend_list_based_on_platform
from vllm import LLM, SamplingParams
from vllm.platforms import current_platform

_PROMPTS = [
"1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1",
Expand All @@ -14,36 +16,40 @@


@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
def test_ngram_max_len(
monkeypatch: pytest.MonkeyPatch,
num_speculative_tokens: int,
):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

llm = LLM(
model="facebook/opt-125m",
max_model_len=100,
enforce_eager=True, # For faster initialization.
speculative_config={
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": num_speculative_tokens,
},
)
sampling_params = SamplingParams(max_tokens=100, ignore_eos=True)
llm.generate(_PROMPTS, sampling_params)
def test_ngram_max_len(num_speculative_tokens: int):
llm = LLM(
model="facebook/opt-125m",
max_model_len=100,
enforce_eager=True, # For faster initialization.
speculative_config={
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": num_speculative_tokens,
},
)
sampling_params = SamplingParams(max_tokens=100, ignore_eos=True)
llm.generate(_PROMPTS, sampling_params)


@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
def test_eagle_max_len(
monkeypatch: pytest.MonkeyPatch,
num_speculative_tokens: int,
):
@pytest.mark.parametrize("attn_backend",
get_attn_backend_list_based_on_platform())
def test_eagle_max_len(monkeypatch: pytest.MonkeyPatch,
num_speculative_tokens: int, attn_backend: str):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)

if attn_backend == "TRITON_ATTN_VLLM_V1" and current_platform.is_cuda(
):
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
"multi-token eagle spec decode on CUDA")

if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
m.setenv("VLLM_ROCM_USE_AITER", "1")

llm = LLM(
model="meta-llama/Meta-Llama-3-8B-Instruct",
enforce_eager=True, # For faster initialization.
Expand Down
12 changes: 8 additions & 4 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
Expand Down Expand Up @@ -154,10 +155,13 @@ def propose(
# one layer. Adapt this code to support multiple layers once
# there's a multi-layer MTP module.

# Currently FlashAttention is the only backend that supports
# multi-token eagle spec decode. This is because the code below
# makes assumptions about attn_metadata attributes available.
assert isinstance(attn_metadata, FlashAttentionMetadata)
if not current_platform.is_rocm():
# On ROCm, both AiterFlashAttention and TritonAttention
# support multi-token eagle spec decode.
# Currently FlashAttention is the only backend that supports
# multi-token eagle spec decode. This is because the code below
# makes assumptions about attn_metadata attributes available.
assert isinstance(attn_metadata, FlashAttentionMetadata)

# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
Expand Down