Skip to content

Commit c61a56b

Browse files
tjtanaalulmer
authored andcommitted
[FEAT] [ROCm] [Embedding] Add encoder-only model support into ROCm Flash Attention to enable embedding models. (vllm-project#14664)
Signed-off-by: tjtanaa <[email protected]> Signed-off-by: Louis Ulmer <[email protected]>
1 parent 0494a75 commit c61a56b

File tree

7 files changed

+118
-50
lines changed

7 files changed

+118
-50
lines changed

CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,10 @@ set(VLLM_MOE_EXT_SRC
561561
"csrc/moe/moe_align_sum_kernels.cu"
562562
"csrc/moe/topk_softmax_kernels.cu")
563563

564+
if(VLLM_GPU_LANG STREQUAL "CUDA")
565+
list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/moe_wna16.cu")
566+
endif()
567+
564568
set_gencode_flags_for_srcs(
565569
SRCS "${VLLM_MOE_EXT_SRC}"
566570
CUDA_ARCHS "${CUDA_ARCHS}")

csrc/moe/torch_bindings.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
5252
"int moe_block_size, bool replicate_input, bool apply_weights)"
5353
" -> Tensor");
5454
// conditionally compiled so impl registration is in source file
55+
5556
#endif
5657
}
5758

tests/models/embedding/language/test_cls_models.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import torch
88
from transformers import AutoModelForSequenceClassification
99

10+
from vllm.platforms import current_platform
11+
1012

1113
@pytest.mark.parametrize(
1214
"model",
@@ -15,14 +17,21 @@
1517
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
1618
],
1719
)
18-
@pytest.mark.parametrize("dtype", ["float"])
20+
@pytest.mark.parametrize("dtype",
21+
["half"] if current_platform.is_rocm() else ["float"])
1922
def test_classification_models(
2023
hf_runner,
2124
vllm_runner,
2225
example_prompts,
2326
model: str,
2427
dtype: str,
28+
monkeypatch,
2529
) -> None:
30+
if current_platform.is_rocm():
31+
# ROCm Triton FA does not currently support sliding window attention
32+
# switch to use ROCm CK FA backend
33+
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
34+
2635
with vllm_runner(model, dtype=dtype) as vllm_model:
2736
vllm_outputs = vllm_model.classify(example_prompts)
2837

@@ -43,4 +52,8 @@ def print_model(model):
4352
hf_output = torch.tensor(hf_output)
4453
vllm_output = torch.tensor(vllm_output)
4554

46-
assert torch.allclose(hf_output, vllm_output, 1e-3)
55+
# the tolerance value of 1e-2 is selected based on the
56+
# half datatype tests in
57+
# tests/models/embedding/language/test_embedding.py
58+
assert torch.allclose(hf_output, vllm_output,
59+
1e-3 if dtype == "float" else 1e-2)

tests/models/embedding/language/test_embedding.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytest
77

88
from vllm.config import PoolerConfig
9+
from vllm.platforms import current_platform
910

1011
from ..utils import check_embeddings_close
1112

@@ -18,15 +19,15 @@
1819
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
1920
pytest.param("sentence-transformers/all-MiniLM-L12-v2"),
2021
pytest.param("intfloat/multilingual-e5-small"),
22+
pytest.param("Alibaba-NLP/gte-Qwen2-7B-instruct"),
2123
# [Decoder-only]
2224
pytest.param("BAAI/bge-multilingual-gemma2",
2325
marks=[pytest.mark.core_model]),
2426
pytest.param("intfloat/e5-mistral-7b-instruct",
2527
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
2628
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"),
27-
pytest.param("Alibaba-NLP/gte-Qwen2-7B-instruct"),
2829
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"),
29-
# [Encoder-decoder]
30+
# [Cross-Encoder]
3031
pytest.param("sentence-transformers/stsb-roberta-base-v2"),
3132
],
3233
)
@@ -37,11 +38,19 @@ def test_models(
3738
example_prompts,
3839
model,
3940
dtype: str,
41+
monkeypatch,
4042
) -> None:
43+
44+
if model == "BAAI/bge-multilingual-gemma2" and current_platform.is_rocm():
45+
# ROCm Triton FA does not currently support sliding window attention
46+
# switch to use ROCm CK FA backend
47+
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
48+
4149
vllm_extra_kwargs = {}
4250
if model == "ssmits/Qwen2-7B-Instruct-embed-base":
4351
vllm_extra_kwargs["override_pooler_config"] = \
4452
PoolerConfig(pooling_type="MEAN")
53+
4554
if model == "Alibaba-NLP/gte-Qwen2-1.5B-instruct":
4655
vllm_extra_kwargs["hf_overrides"] = {"is_causal": True}
4756

tests/models/embedding/language/test_gritlm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from ....utils import RemoteOpenAIServer
1616

1717
# GritLM embedding implementation is only supported by XFormers backend.
18-
pytest.mark.skipif(not importlib.util.find_spec("xformers"),
19-
reason="GritLM requires XFormers")
18+
pytestmark = pytest.mark.skipif(not importlib.util.find_spec("xformers"),
19+
reason="GritLM requires XFormers")
2020

2121
MODEL_NAME = "parasail-ai/GritLM-7B-vllm"
2222
MAX_MODEL_LEN = 4000

tests/models/embedding/vision_language/test_llava_next.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,27 @@
44
import torch.nn.functional as F
55
from transformers import AutoModelForVision2Seq
66

7+
from vllm.platforms import current_platform
8+
79
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
810
from ....utils import large_gpu_test
911
from ..utils import check_embeddings_close
1012

13+
# Llava Next embedding implementation is only supported by CUDA.
14+
# If run on ROCm, hf_model.model.resize_token_embeddings will
15+
# cause the following error:
16+
# RuntimeError: Calling torch.linalg.cholesky on a CUDA tensor
17+
# requires compiling PyTorch with MAGMA. Please use PyTorch
18+
# built with MAGMA support.
19+
# If run on CPU, hf_model.model.resize_token_embeddings will
20+
# cause the following error:
21+
# RuntimeError: Calling torch.linalg.cholesky on a CPU tensor
22+
# requires compiling PyTorch with LAPACK. Please use PyTorch
23+
# built with LAPACK support.
24+
pytestmark = pytest.mark.skipif(
25+
not current_platform.is_cuda(),
26+
reason="Llava Next model uses op that is only supported in CUDA")
27+
1128
llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n' # noqa: E501
1229

1330
HF_TEXT_PROMPTS = [

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 68 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
"""Attention layer ROCm GPUs."""
3+
import itertools
34
from dataclasses import dataclass
45
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
56

@@ -342,28 +343,27 @@ def _get_seq_len_block_table_args(
342343
Decoder attn -> select entirely decoder self-attention-related fields
343344
Encoder/decoder cross-attn -> select encoder sequence lengths
344345
Encoder attn -> select encoder sequence lengths fields
346+
Encoder-only attn -> select prefill sequence lengths with
347+
bidirectional attention
345348
346349
Arguments:
347350
348351
* attn_metadata: Attention metadata structure associated with attention op
349352
* attn_type: encoder attention, decoder self-attention,
350-
encoder/decoder cross-attention
353+
encoder/decoder cross-attention, encoder-only
351354
352355
Returns:
353356
354357
* Appropriate sequence-lengths tensors for query and key
355358
* Appropriate max sequence-length scalar
359+
* Causal masking flag
356360
'''
357361

358-
partial_prefix_sum = 0
359362
if attn_type == AttentionType.ENCODER:
360363
assert attn_metadata.encoder_seq_lens is not None
361364
assert attn_metadata.encoder_seq_lens_tensor is not None
362365
query_seq_start_loc = torch.tensor(
363-
[0] + [
364-
partial_prefix_sum := partial_prefix_sum + i
365-
for i in attn_metadata.encoder_seq_lens
366-
],
366+
list(itertools.accumulate([0] + attn_metadata.encoder_seq_lens)),
367367
device=attn_metadata.encoder_seq_lens_tensor.device,
368368
dtype=attn_metadata.encoder_seq_lens_tensor.dtype)
369369
causal_mask = False
@@ -372,16 +372,29 @@ def _get_seq_len_block_table_args(
372372
return (query_seq_start_loc, attn_metadata.max_encoder_seq_len,
373373
query_seq_start_loc, attn_metadata.max_encoder_seq_len,
374374
attn_metadata.encoder_seq_lens, causal_mask)
375+
376+
elif attn_type == AttentionType.ENCODER_ONLY:
377+
# For encoder-only models, we use the prefill sequence lengths
378+
assert attn_metadata.seq_lens is not None
379+
assert attn_metadata.seq_lens_tensor is not None
380+
query_seq_start_loc = torch.tensor(
381+
list(itertools.accumulate([0] + attn_metadata.seq_lens)),
382+
device=attn_metadata.seq_lens_tensor.device,
383+
dtype=attn_metadata.seq_lens_tensor.dtype)
384+
max_seq_len = attn_metadata.max_prefill_seq_len
385+
# Encoder-only models typically use bidirectional attention
386+
causal_mask = False
387+
388+
return (query_seq_start_loc, max_seq_len, query_seq_start_loc,
389+
max_seq_len, attn_metadata.seq_lens, causal_mask)
390+
375391
elif attn_type == AttentionType.DECODER:
376392
# Decoder self-attention
377393
# Choose max_seq_len based on whether we are in prompt_run
378394
assert attn_metadata.seq_lens is not None
379395
assert attn_metadata.seq_lens_tensor is not None
380396
query_seq_start_loc = torch.tensor(
381-
[0] + [
382-
partial_prefix_sum := partial_prefix_sum + i
383-
for i in attn_metadata.seq_lens
384-
],
397+
list(itertools.accumulate([0] + attn_metadata.seq_lens)),
385398
device=attn_metadata.seq_lens_tensor.device,
386399
dtype=attn_metadata.seq_lens_tensor.dtype)
387400
max_seq_len = attn_metadata.max_prefill_seq_len
@@ -393,21 +406,14 @@ def _get_seq_len_block_table_args(
393406
assert attn_metadata.seq_lens is not None
394407
assert attn_metadata.encoder_seq_lens_tensor is not None
395408
query_start_loc = torch.tensor(
396-
[0] + [
397-
partial_prefix_sum := partial_prefix_sum + i
398-
for i in attn_metadata.seq_lens
399-
],
409+
list(itertools.accumulate([0] + attn_metadata.seq_lens)),
400410
device=attn_metadata.encoder_seq_lens_tensor.device,
401411
dtype=attn_metadata.encoder_seq_lens_tensor.dtype)
402412

403-
partial_prefix_sum = 0
404413
assert attn_metadata.encoder_seq_lens is not None
405414
assert attn_metadata.seq_lens_tensor is not None
406415
key_seq_start_loc = torch.tensor(
407-
[0] + [
408-
partial_prefix_sum := partial_prefix_sum + i
409-
for i in attn_metadata.encoder_seq_lens
410-
],
416+
list(itertools.accumulate([0] + attn_metadata.encoder_seq_lens)),
411417
device=attn_metadata.seq_lens_tensor.device,
412418
dtype=attn_metadata.seq_lens_tensor.dtype)
413419
causal_mask = False
@@ -584,6 +590,8 @@ def forward(
584590
will match encoder sequence lengths, pass encoder sequence
585591
attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
586592
max_encoder_seq_len)
593+
* ENCODER_ONLY: bidirectional attention with no KV caching;
594+
use prefill sequence attributes
587595
588596
Args:
589597
query: shape = [num_tokens, num_heads * head_size]
@@ -608,7 +616,11 @@ def forward(
608616
else:
609617
assert value is None
610618

611-
if self.attn_type != AttentionType.ENCODER and kv_cache.numel() > 0:
619+
# Only update KV cache for decoder self-attention
620+
# and encoder-decoder cross-attention
621+
if self.attn_type not in [
622+
AttentionType.ENCODER, AttentionType.ENCODER_ONLY
623+
] and kv_cache.numel() > 0:
612624
key_cache, value_cache = PagedAttention.split_kv_cache(
613625
kv_cache, self.num_kv_heads, self.head_size)
614626

@@ -632,6 +644,9 @@ def forward(
632644

633645
if self.attn_type != AttentionType.ENCODER:
634646
num_prefill_tokens = attn_metadata.num_prefill_tokens
647+
elif self.attn_type == AttentionType.ENCODER_ONLY:
648+
# For encoder-only models, all tokens are processed in one go
649+
num_prefill_tokens = query.shape[0]
635650
else:
636651
assert attn_metadata.num_encoder_tokens is not None
637652
num_prefill_tokens = attn_metadata.num_encoder_tokens
@@ -642,8 +657,13 @@ def forward(
642657
# QKV for prefill.
643658
query = query[:num_prefill_tokens]
644659

660+
# For encoder-only and encoder models,
661+
# we process all tokens at once
662+
# For decoder and encoder-decoder,
663+
# we may need to limit key/value to prefill tokens
645664
if key is not None and value is not None \
646-
and self.attn_type != AttentionType.ENCODER_DECODER:
665+
and self.attn_type not in [AttentionType.ENCODER_DECODER,
666+
AttentionType.ENCODER_ONLY]:
647667
key = key[:num_prefill_tokens]
648668
value = value[:num_prefill_tokens]
649669

@@ -678,7 +698,7 @@ def forward(
678698
self.alibi_slopes,
679699
query.dtype,
680700
seq_lens,
681-
make_attn_mask=False) # type: ignore
701+
make_attn_mask=causal_mask) # type: ignore
682702
out, _ = self.attn_func(
683703
query,
684704
key,
@@ -703,7 +723,7 @@ def forward(
703723
self.alibi_slopes,
704724
query.dtype,
705725
attn_metadata.seq_lens,
706-
make_attn_mask=True) # type: ignore
726+
make_attn_mask=causal_mask) # type: ignore
707727
query = query.movedim(0, query.dim() - 2)
708728
key = key.movedim(0, key.dim() - 2)
709729
value = value.movedim(0, value.dim() - 2)
@@ -729,7 +749,7 @@ def forward(
729749
max_seqlen_q=prefill_meta.max_prefill_seq_len,
730750
max_seqlen_k=key_max_seq_len,
731751
softmax_scale=self.scale,
732-
causal=True,
752+
causal=causal_mask,
733753
window_size=self.sliding_window,
734754
alibi_slopes=self.alibi_slopes,
735755
softcap=self.logits_soft_cap,
@@ -742,25 +762,29 @@ def forward(
742762
else:
743763
output = out
744764
else:
745-
# prefix-enabled attention
746-
output[:num_prefill_tokens] = PagedAttention.forward_prefix(
747-
query,
748-
key,
749-
value,
750-
self.kv_cache_dtype,
751-
key_cache,
752-
value_cache,
753-
prefill_meta.block_tables,
754-
prefill_meta.query_start_loc,
755-
prefill_meta.seq_lens_tensor,
756-
prefill_meta.max_query_len,
757-
self.alibi_slopes,
758-
self.sliding_window[0],
759-
layer._k_scale,
760-
layer._v_scale,
761-
)
762-
763-
if decode_meta := attn_metadata.decode_metadata:
765+
# prefix-enabled attention -
766+
# not applicable for encoder-only models
767+
if self.attn_type != AttentionType.ENCODER_ONLY:
768+
output[:
769+
num_prefill_tokens] = PagedAttention.forward_prefix(
770+
query,
771+
key,
772+
value,
773+
self.kv_cache_dtype,
774+
key_cache,
775+
value_cache,
776+
prefill_meta.block_tables,
777+
prefill_meta.query_start_loc,
778+
prefill_meta.seq_lens_tensor,
779+
prefill_meta.max_query_len,
780+
self.alibi_slopes,
781+
self.sliding_window[0],
782+
layer._k_scale,
783+
layer._v_scale,
784+
)
785+
# Skip decode phase for encoder-only models
786+
if (decode_meta := attn_metadata.decode_metadata) and (
787+
self.attn_type != AttentionType.ENCODER_ONLY):
764788
# Decoding run.
765789
# Whether to use rocm custom paged attention or not
766790
num_seqs, num_heads, head_size = decode_query.shape
@@ -885,4 +909,4 @@ def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
885909
and (qtype == torch.half or qtype == torch.bfloat16)
886910
and (head_size == 64 or head_size == 128)
887911
and (block_size == 16 or block_size == 32)
888-
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)
912+
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)

0 commit comments

Comments
 (0)