Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 1f045c7

Browse files
rkooo567Robert Shaw
authored andcommitted
[Test] Test multiple attn backend for chunked prefill. (vllm-project#4023)
1 parent e8e00d2 commit 1f045c7

File tree

4 files changed

+13
-23
lines changed

4 files changed

+13
-23
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,13 @@ steps:
1212
command: pytest -v -s async_engine
1313

1414
- label: Basic Correctness Test
15-
command: pytest -v -s basic_correctness
15+
commands:
16+
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py
17+
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py
18+
- VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_basic_correctness.py
19+
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
20+
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
21+
- VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_chunked_prefill.py
1622

1723
- label: Core Test
1824
command: pytest -v -s core

tests/basic_correctness/test_basic_correctness.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
"""
55
import pytest
66

7-
from vllm.attention.selector import VLLM_ATTENTION_BACKEND
8-
97
MODELS = [
108
"facebook/opt-125m",
119
"meta-llama/Llama-2-7b-hf",
@@ -16,7 +14,6 @@
1614
@pytest.mark.parametrize("dtype", ["half"])
1715
@pytest.mark.parametrize("max_tokens", [5])
1816
@pytest.mark.parametrize("enforce_eager", [False, True])
19-
@pytest.mark.parametrize("attn_backend", ["XFORMERS", "FLASH_ATTN"])
2017
def test_models(
2118
hf_runner,
2219
vllm_runner,
@@ -25,10 +22,7 @@ def test_models(
2522
dtype: str,
2623
max_tokens: int,
2724
enforce_eager: bool,
28-
attn_backend: str,
29-
monkeypatch,
3025
) -> None:
31-
monkeypatch.setenv(VLLM_ATTENTION_BACKEND, attn_backend)
3226
hf_model = hf_runner(model, dtype=dtype)
3327
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
3428
del hf_model

tests/basic_correctness/test_chunked_prefill.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,6 @@ def test_models(
3434
enforce_eager: bool,
3535
tensor_parallel_size: int,
3636
) -> None:
37-
if (tensor_parallel_size == 2 and chunked_prefill_token_size != 16
38-
and not enforce_eager):
39-
pytest.skip(f"Skip {chunked_prefill_token_size=} and {enforce_eager=} "
40-
"for high TP to save testing time.")
4137
max_num_seqs = min(chunked_prefill_token_size, 256)
4238
enable_chunked_prefill = False
4339
max_num_batched_tokens = None

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def __init__(
162162
# AMD Radeon 7900 series (gfx1100) currently does not support
163163
# xFormers nor FlashAttention. As a temporary workaround, we use
164164
# naive PyTorch implementation of attention.
165-
self.attn_fuc = _naive_attention()
165+
self.attn_fuc = _naive_attention
166166
logger.debug("Using naive attention in ROCmBackend")
167167
elif self.use_triton_flash_attn:
168168
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
@@ -334,26 +334,21 @@ def _naive_attention(
334334
prompt_lens: List[int],
335335
scale: float,
336336
) -> torch.Tensor:
337-
num_tokens = query.shape[0]
338337
output = torch.empty_like(query)
339338
start = 0
340339
for _, prompt_len in enumerate(prompt_lens):
341340
end = start + prompt_len
342341
out = _naive_masked_attention(
343-
query[None, start:end],
344-
key[None, start:end],
345-
value[None, start:end],
342+
query[start:end],
343+
key[start:end],
344+
value[start:end],
346345
scale,
347346
)
348347
# TODO(woosuk): Unnecessary copy. Optimize.
349348
output[start:end].copy_(out)
350349
start += prompt_len
351350

352-
# Using view got RuntimeError: view size is not compatible
353-
# with input tensor's size and stride (at least one
354-
# dimension spans across two contiguous subspaces).
355-
# Use reshape instead.
356-
return output.reshape(num_tokens, -1)
351+
return output
357352

358353

359354
def _naive_masked_attention(
@@ -362,14 +357,13 @@ def _naive_masked_attention(
362357
value: torch.Tensor,
363358
scale: float,
364359
) -> torch.Tensor:
365-
seq_len, _, _ = query.shape
360+
seq_len, head_size, head_dim = query.shape
366361
attn_mask = torch.triu(torch.ones(seq_len,
367362
seq_len,
368363
dtype=query.dtype,
369364
device=query.device),
370365
diagonal=1)
371366
attn_mask = attn_mask * torch.finfo(query.dtype).min
372-
373367
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
374368
attn_weights = attn_weights + attn_mask.float()
375369
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)

0 commit comments

Comments
 (0)