Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,3 +986,19 @@ def has_module_attribute(module_name, attribute_name):
return hasattr(module, attribute_name)
except ImportError:
return False


def get_attn_backend_list_based_on_platform() -> list[str]:
if current_platform.is_cuda():
return ["FLASH_ATTN_VLLM_V1", "TRITON_ATTN_VLLM_V1"]
elif current_platform.is_rocm():
attn_backend_list = ["TRITON_ATTN_VLLM_V1"]
try:
import aiter # noqa: F401
attn_backend_list.append("FLASH_ATTN_VLLM_V1")
except Exception:
print("Skip FLASH_ATTN_VLLM_V1 on ROCm as aiter is not installed")

return attn_backend_list
else:
raise ValueError("Unsupported platform")
15 changes: 15 additions & 0 deletions tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
import pytest
import torch

from tests.utils import get_attn_backend_list_based_on_platform
from vllm import LLM, SamplingParams
from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.platforms import current_platform


def get_test_prompts(mm_enabled: bool):
Expand Down Expand Up @@ -141,11 +143,14 @@ def test_ngram_correctness(
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
],
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm"])
@pytest.mark.parametrize("attn_backend",
get_attn_backend_list_based_on_platform())
def test_eagle_correctness(
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
model_setup: tuple[str, str, str, int],
mm_enabled: bool,
attn_backend: str,
):
# Generate test prompts inside the function instead of using fixture
test_prompts = get_test_prompts(mm_enabled)
Expand All @@ -156,6 +161,16 @@ def test_eagle_correctness(
'''
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)

if (attn_backend == "TRITON_ATTN_VLLM_V1"
and not current_platform.is_rocm()):
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
"multi-token eagle spec decode on current platform")

if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
m.setenv("VLLM_ROCM_USE_AITER", "1")

method, model_name, spec_model_name, tp_size = model_setup

ref_llm = LLM(model=model_name,
Expand Down
51 changes: 42 additions & 9 deletions tests/v1/spec_decode/test_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
import torch

from tests.utils import get_attn_backend_list_based_on_platform
from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata,
create_standard_kv_cache_spec,
Expand Down Expand Up @@ -120,17 +121,28 @@ def test_prepare_inputs():
assert torch.equal(token_indices, expected_token_indices)


@pytest.mark.parametrize("method,proposer_helper", [
("eagle", lambda k: _create_proposer("eagle", k)),
("eagle3", lambda k: _create_proposer("eagle3", k)),
])
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
@pytest.mark.parametrize("attn_backend",
get_attn_backend_list_based_on_platform())
@pytest.mark.parametrize("pp_size", [1, 2])
@pytest.mark.parametrize("use_distinct_embed_tokens", [True, False])
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
@mock.patch('vllm.v1.spec_decode.eagle.get_model')
def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
proposer_helper, pp_size, use_distinct_embed_tokens):
attn_backend, pp_size, use_distinct_embed_tokens,
monkeypatch):

monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)

if (attn_backend == "TRITON_ATTN_VLLM_V1"
and not current_platform.is_rocm()):
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
"multi-token eagle spec decode on current platform")

if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

# Setup draft model mock
mock_model = mock.MagicMock()
if use_distinct_embed_tokens:
Expand Down Expand Up @@ -177,7 +189,7 @@ class _TargetModelStub(LlamaForCausalLM):
target_model.lm_head = mock.MagicMock()

# Create proposer using the helper function
proposer = proposer_helper(k=8)
proposer = _create_proposer(method, k=8)

# Call the method under test
proposer.load_model(target_model)
Expand All @@ -201,8 +213,22 @@ class _TargetModelStub(LlamaForCausalLM):
target_model.model.embed_tokens


@pytest.mark.parametrize("method", ["eagle", "eagle3"])
@pytest.mark.parametrize("attn_backend",
get_attn_backend_list_based_on_platform())
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
def test_propose(num_speculative_tokens):
def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):

monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)

if (attn_backend == "TRITON_ATTN_VLLM_V1"
and not current_platform.is_rocm()):
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
"multi-token eagle spec decode on current platform")

if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

# Use GPU device
device = torch.device(current_platform.device_type)

Expand Down Expand Up @@ -301,8 +327,15 @@ def create_deterministic_logits(token_ids):
device=device)
sampling_metadata = mock.MagicMock()

attn_metadata_builder_cls, _ = get_attention_backend(
_Backend.FLASH_ATTN_VLLM_V1)
if attn_backend == "FLASH_ATTN_VLLM_V1":
attn_metadata_builder_cls, _ = get_attention_backend(
_Backend.FLASH_ATTN_VLLM_V1)
elif attn_backend == "TRITON_ATTN_VLLM_V1":
attn_metadata_builder_cls, _ = get_attention_backend(
_Backend.TRITON_ATTN_VLLM_V1)
else:
raise ValueError(f"Unsupported attention backend: {attn_backend}")

attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
layer_names=proposer.attn_layer_names,
Expand Down
54 changes: 30 additions & 24 deletions tests/v1/spec_decode/test_max_len.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

import pytest

from tests.utils import get_attn_backend_list_based_on_platform
from vllm import LLM, SamplingParams
from vllm.platforms import current_platform

_PROMPTS = [
"1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1",
Expand All @@ -14,36 +16,40 @@


@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
def test_ngram_max_len(
monkeypatch: pytest.MonkeyPatch,
num_speculative_tokens: int,
):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

llm = LLM(
model="facebook/opt-125m",
max_model_len=100,
enforce_eager=True, # For faster initialization.
speculative_config={
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": num_speculative_tokens,
},
)
sampling_params = SamplingParams(max_tokens=100, ignore_eos=True)
llm.generate(_PROMPTS, sampling_params)
def test_ngram_max_len(num_speculative_tokens: int):
llm = LLM(
model="facebook/opt-125m",
max_model_len=100,
enforce_eager=True, # For faster initialization.
speculative_config={
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": num_speculative_tokens,
},
)
sampling_params = SamplingParams(max_tokens=100, ignore_eos=True)
llm.generate(_PROMPTS, sampling_params)


@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
def test_eagle_max_len(
monkeypatch: pytest.MonkeyPatch,
num_speculative_tokens: int,
):
@pytest.mark.parametrize("attn_backend",
get_attn_backend_list_based_on_platform())
def test_eagle_max_len(monkeypatch: pytest.MonkeyPatch,
num_speculative_tokens: int, attn_backend: str):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)

if (attn_backend == "TRITON_ATTN_VLLM_V1"
and not current_platform.is_rocm()):
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
"multi-token eagle spec decode on current platform")

if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
m.setenv("VLLM_ROCM_USE_AITER", "1")

llm = LLM(
model="meta-llama/Meta-Llama-3-8B-Instruct",
enforce_eager=True, # For faster initialization.
Expand Down
20 changes: 16 additions & 4 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.rocm_aiter_fa import (
AiterFlashAttentionMetadata)
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.sample.metadata import SamplingMetadata
Expand Down Expand Up @@ -178,10 +182,18 @@ def propose(
# one layer. Adapt this code to support multiple layers once
# there's a multi-layer MTP module.

# Currently FlashAttention is the only backend that supports
# multi-token eagle spec decode. This is because the code below
# makes assumptions about attn_metadata attributes available.
assert isinstance(attn_metadata, FlashAttentionMetadata)
# On ROCm, both AiterFlashAttention and TritonAttention
# support multi-token eagle spec decode.
if current_platform.is_rocm():
assert isinstance(
attn_metadata,
(TritonAttentionMetadata, AiterFlashAttentionMetadata,
FlashAttentionMetadata))
else:
# Currently FlashAttention is the only backend that supports
# multi-token eagle spec decode. This is because the code below
# makes assumptions about attn_metadata attributes available.
assert isinstance(attn_metadata, FlashAttentionMetadata)

# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
Expand Down