Skip to content

Commit a7f53d4

Browse files
tjtanaaamd-xiaoyu12
authored andcommitted
[ROCm] [V1] [SpecDec] Enable Speculative Decoding on ROCm V1 Engine (vllm-project#21496)
Signed-off-by: tjtanaa <[email protected]> Signed-off-by: Xiao Yu <[email protected]>
1 parent 748c9d8 commit a7f53d4

File tree

6 files changed

+128
-41
lines changed

6 files changed

+128
-41
lines changed

tests/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -986,3 +986,19 @@ def has_module_attribute(module_name, attribute_name):
986986
return hasattr(module, attribute_name)
987987
except ImportError:
988988
return False
989+
990+
991+
def get_attn_backend_list_based_on_platform() -> list[str]:
992+
if current_platform.is_cuda():
993+
return ["FLASH_ATTN_VLLM_V1", "TRITON_ATTN_VLLM_V1", "TREE_ATTN"]
994+
elif current_platform.is_rocm():
995+
attn_backend_list = ["TRITON_ATTN_VLLM_V1"]
996+
try:
997+
import aiter # noqa: F401
998+
attn_backend_list.append("FLASH_ATTN_VLLM_V1")
999+
except Exception:
1000+
print("Skip FLASH_ATTN_VLLM_V1 on ROCm as aiter is not installed")
1001+
1002+
return attn_backend_list
1003+
else:
1004+
raise ValueError("Unsupported platform")

tests/v1/attention/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig,
1212
LoadConfig, ModelConfig, ModelDType, ParallelConfig,
1313
SchedulerConfig, VllmConfig)
14-
from vllm.platforms import _Backend
14+
from vllm.platforms import _Backend, current_platform
1515
from vllm.utils import resolve_obj_by_qualname
1616
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
1717
from vllm.v1.kv_cache_interface import FullAttentionSpec
@@ -119,7 +119,10 @@ def get_attention_backend(backend_name: _Backend):
119119
"""
120120
backend_map = {
121121
_Backend.FLASH_ATTN_VLLM_V1:
122-
"vllm.v1.attention.backends.flash_attn.FlashAttentionBackend",
122+
("vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
123+
if current_platform.is_cuda() else
124+
"vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
125+
),
123126
_Backend.FLASHINFER_VLLM_V1:
124127
"vllm.v1.attention.backends.flashinfer.FlashInferBackend",
125128
_Backend.FLEX_ATTENTION:

tests/v1/e2e/test_spec_decode.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
import pytest
99
import torch
1010

11+
from tests.utils import get_attn_backend_list_based_on_platform
1112
from vllm import LLM, SamplingParams
1213
from vllm.assets.base import VLLM_S3_BUCKET_URL
1314
from vllm.assets.image import VLM_IMAGES_DIR
1415
from vllm.distributed import cleanup_dist_env_and_memory
16+
from vllm.platforms import current_platform
1517

1618

1719
def get_test_prompts(mm_enabled: bool):
@@ -141,11 +143,14 @@ def test_ngram_correctness(
141143
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
142144
],
143145
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm"])
146+
@pytest.mark.parametrize("attn_backend",
147+
get_attn_backend_list_based_on_platform())
144148
def test_eagle_correctness(
145149
monkeypatch: pytest.MonkeyPatch,
146150
sampling_config: SamplingParams,
147151
model_setup: tuple[str, str, str, int],
148152
mm_enabled: bool,
153+
attn_backend: str,
149154
):
150155
# Generate test prompts inside the function instead of using fixture
151156
test_prompts = get_test_prompts(mm_enabled)
@@ -156,6 +161,16 @@ def test_eagle_correctness(
156161
'''
157162
with monkeypatch.context() as m:
158163
m.setenv("VLLM_USE_V1", "1")
164+
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
165+
166+
if (attn_backend == "TRITON_ATTN_VLLM_V1"
167+
and not current_platform.is_rocm()):
168+
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
169+
"multi-token eagle spec decode on current platform")
170+
171+
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
172+
m.setenv("VLLM_ROCM_USE_AITER", "1")
173+
159174
method, model_name, spec_model_name, tp_size = model_setup
160175

161176
ref_llm = LLM(model=model_name,

tests/v1/spec_decode/test_eagle.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytest
77
import torch
88

9+
from tests.utils import get_attn_backend_list_based_on_platform
910
from tests.v1.attention.utils import (BatchSpec, _Backend,
1011
create_common_attn_metadata,
1112
create_standard_kv_cache_spec,
@@ -120,17 +121,28 @@ def test_prepare_inputs():
120121
assert torch.equal(token_indices, expected_token_indices)
121122

122123

123-
@pytest.mark.parametrize("method,proposer_helper", [
124-
("eagle", lambda k: _create_proposer("eagle", k)),
125-
("eagle3", lambda k: _create_proposer("eagle3", k)),
126-
])
124+
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
125+
@pytest.mark.parametrize("attn_backend",
126+
get_attn_backend_list_based_on_platform())
127127
@pytest.mark.parametrize("pp_size", [1, 2])
128128
@pytest.mark.parametrize("use_distinct_embed_tokens", [True, False])
129129
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
130130
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
131131
@mock.patch('vllm.v1.spec_decode.eagle.get_model')
132132
def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
133-
proposer_helper, pp_size, use_distinct_embed_tokens):
133+
attn_backend, pp_size, use_distinct_embed_tokens,
134+
monkeypatch):
135+
136+
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
137+
138+
if (attn_backend == "TRITON_ATTN_VLLM_V1"
139+
and not current_platform.is_rocm()):
140+
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
141+
"multi-token eagle spec decode on current platform")
142+
143+
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
144+
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
145+
134146
# Setup draft model mock
135147
mock_model = mock.MagicMock()
136148
if use_distinct_embed_tokens:
@@ -177,7 +189,7 @@ class _TargetModelStub(LlamaForCausalLM):
177189
target_model.lm_head = mock.MagicMock()
178190

179191
# Create proposer using the helper function
180-
proposer = proposer_helper(k=8)
192+
proposer = _create_proposer(method, k=8)
181193

182194
# Call the method under test
183195
proposer.load_model(target_model)
@@ -201,10 +213,22 @@ class _TargetModelStub(LlamaForCausalLM):
201213
target_model.model.embed_tokens
202214

203215

216+
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
217+
@pytest.mark.parametrize("attn_backend",
218+
get_attn_backend_list_based_on_platform())
204219
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
205-
@pytest.mark.parametrize("backend",
206-
[_Backend.FLASH_ATTN_VLLM_V1, _Backend.TREE_ATTN])
207-
def test_propose(num_speculative_tokens, backend):
220+
def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
221+
222+
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
223+
224+
if (attn_backend == "TRITON_ATTN_VLLM_V1"
225+
and not current_platform.is_rocm()):
226+
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
227+
"multi-token eagle spec decode on current platform")
228+
229+
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
230+
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
231+
208232
# Use GPU device
209233
device = torch.device(current_platform.device_type)
210234

@@ -303,7 +327,18 @@ def create_deterministic_logits(token_ids):
303327
device=device)
304328
sampling_metadata = mock.MagicMock()
305329

306-
attn_metadata_builder_cls, _ = get_attention_backend(backend)
330+
if attn_backend == "FLASH_ATTN_VLLM_V1":
331+
attn_metadata_builder_cls, _ = get_attention_backend(
332+
_Backend.FLASH_ATTN_VLLM_V1)
333+
elif attn_backend == "TRITON_ATTN_VLLM_V1":
334+
attn_metadata_builder_cls, _ = get_attention_backend(
335+
_Backend.TRITON_ATTN_VLLM_V1)
336+
elif attn_backend == "TREE_ATTN":
337+
attn_metadata_builder_cls, _ = get_attention_backend(
338+
_Backend.TREE_ATTN)
339+
else:
340+
raise ValueError(f"Unsupported attention backend: {attn_backend}")
341+
307342
attn_metadata_builder = attn_metadata_builder_cls(
308343
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
309344
layer_names=proposer.attn_layer_names,

tests/v1/spec_decode/test_max_len.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44

55
import pytest
66

7+
from tests.utils import get_attn_backend_list_based_on_platform
78
from vllm import LLM, SamplingParams
9+
from vllm.platforms import current_platform
810

911
_PROMPTS = [
1012
"1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1",
@@ -14,36 +16,40 @@
1416

1517

1618
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
17-
def test_ngram_max_len(
18-
monkeypatch: pytest.MonkeyPatch,
19-
num_speculative_tokens: int,
20-
):
21-
with monkeypatch.context() as m:
22-
m.setenv("VLLM_USE_V1", "1")
23-
24-
llm = LLM(
25-
model="facebook/opt-125m",
26-
max_model_len=100,
27-
enforce_eager=True, # For faster initialization.
28-
speculative_config={
29-
"method": "ngram",
30-
"prompt_lookup_max": 5,
31-
"prompt_lookup_min": 3,
32-
"num_speculative_tokens": num_speculative_tokens,
33-
},
34-
)
35-
sampling_params = SamplingParams(max_tokens=100, ignore_eos=True)
36-
llm.generate(_PROMPTS, sampling_params)
19+
def test_ngram_max_len(num_speculative_tokens: int):
20+
llm = LLM(
21+
model="facebook/opt-125m",
22+
max_model_len=100,
23+
enforce_eager=True, # For faster initialization.
24+
speculative_config={
25+
"method": "ngram",
26+
"prompt_lookup_max": 5,
27+
"prompt_lookup_min": 3,
28+
"num_speculative_tokens": num_speculative_tokens,
29+
},
30+
)
31+
sampling_params = SamplingParams(max_tokens=100, ignore_eos=True)
32+
llm.generate(_PROMPTS, sampling_params)
3733

3834

3935
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10])
40-
def test_eagle_max_len(
41-
monkeypatch: pytest.MonkeyPatch,
42-
num_speculative_tokens: int,
43-
):
36+
@pytest.mark.parametrize("attn_backend",
37+
get_attn_backend_list_based_on_platform())
38+
def test_eagle_max_len(monkeypatch: pytest.MonkeyPatch,
39+
num_speculative_tokens: int, attn_backend: str):
4440
with monkeypatch.context() as m:
4541
m.setenv("VLLM_USE_V1", "1")
4642

43+
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
44+
45+
if (attn_backend == "TRITON_ATTN_VLLM_V1"
46+
and not current_platform.is_rocm()):
47+
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
48+
"multi-token eagle spec decode on current platform")
49+
50+
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
51+
m.setenv("VLLM_ROCM_USE_AITER", "1")
52+
4753
llm = LLM(
4854
model="meta-llama/Meta-Llama-3-8B-Instruct",
4955
enforce_eager=True, # For faster initialization.

vllm/v1/spec_decode/eagle.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,14 @@
1717
from vllm.model_executor.model_loader import get_model
1818
from vllm.model_executor.models import supports_multimodal
1919
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
20+
from vllm.platforms import current_platform
2021
from vllm.utils import is_pin_memory_available
2122
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
23+
from vllm.v1.attention.backends.rocm_aiter_fa import (
24+
AiterFlashAttentionMetadata)
2225
from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata,
2326
TreeAttentionMetadataBuilder)
27+
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
2428
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
2529
from vllm.v1.kv_cache_interface import KVCacheConfig
2630
from vllm.v1.sample.metadata import SamplingMetadata
@@ -230,11 +234,19 @@ def propose(
230234
# one layer. Adapt this code to support multiple layers once
231235
# there's a multi-layer MTP module.
232236

233-
# Currently, only FlashAttention and TreeAttention support multi-token
234-
# eagle spec decode. This is because the code below
235-
# makes assumptions about attn_metadata attributes available.
236-
assert isinstance(attn_metadata,
237-
(FlashAttentionMetadata, TreeAttentionMetadata))
237+
# On ROCm, both AiterFlashAttention and TritonAttention
238+
# support multi-token eagle spec decode.
239+
if current_platform.is_rocm():
240+
assert isinstance(
241+
attn_metadata,
242+
(TritonAttentionMetadata, AiterFlashAttentionMetadata,
243+
FlashAttentionMetadata))
244+
else:
245+
# Currently, only FlashAttention and TreeAttention support
246+
# multi-token eagle spec decode. This is because the code below
247+
# makes assumptions about attn_metadata attributes available.
248+
assert isinstance(attn_metadata,
249+
(FlashAttentionMetadata, TreeAttentionMetadata))
238250

239251
# Generate the remaining draft tokens.
240252
draft_token_ids_list = [draft_token_ids]

0 commit comments

Comments
 (0)