Skip to content

Commit 994c667

Browse files
LucasWilkinsonerictang000
authored andcommitted
[V1] Enable V1 Fp8 cache for FA3 in the oracle (vllm-project#15191)
Signed-off-by: Lucas Wilkinson <[email protected]> Signed-off-by: Lucas Wilkinson <[email protected]>
1 parent 9d3d0c1 commit 994c667

File tree

9 files changed

+45
-23
lines changed

9 files changed

+45
-23
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
/vllm/_version.py
33

44
# vllm-flash-attn built from source
5-
vllm/vllm_flash_attn/
5+
vllm/vllm_flash_attn/*
6+
!vllm/vllm_flash_attn/fa_utils.py
67

78
# Byte-compiled / optimized / DLL files
89
__pycache__/

vllm/attention/backends/flash_attn.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@
2222
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
2323
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
2424
is_all_encoder_attn_metadata_set, is_block_tables_empty)
25-
from vllm.fa_utils import get_flash_attn_version
2625
from vllm.logger import init_logger
2726
from vllm.multimodal import MultiModalPlaceholderMap
2827
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
2928
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
3029
flash_attn_with_kvcache)
30+
from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8,
31+
get_flash_attn_version)
3132

3233
if TYPE_CHECKING:
3334
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
@@ -632,10 +633,13 @@ def __init__(
632633
self.kv_cache_dtype = kv_cache_dtype
633634
self.vllm_flash_attn_version = get_flash_attn_version(
634635
requires_alibi=self.alibi_slopes is not None)
635-
if (is_quantized_kv_cache(self.kv_cache_dtype)
636-
and self.vllm_flash_attn_version != 3):
636+
if is_quantized_kv_cache(self.kv_cache_dtype) and (
637+
not self.kv_cache_dtype.startswith("fp8")
638+
or not flash_attn_supports_fp8()):
637639
raise NotImplementedError(
638-
"Only FlashAttention3 supports FP8 KV cache")
640+
f"FlashAttention does not support {self.kv_cache_dtype} "
641+
"kv-cache on this device "
642+
f"(FA supports fp8 = {flash_attn_supports_fp8()}).")
639643
if logits_soft_cap is None:
640644
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
641645
logits_soft_cap = 0
@@ -704,6 +708,10 @@ def forward(
704708
logits_soft_cap: Optional[float] = self.logits_soft_cap
705709
fp8_attention = kv_cache_dtype.startswith("fp8")
706710

711+
if fp8_attention and not flash_attn_supports_fp8():
712+
raise NotImplementedError(
713+
"FlashAttention does not support FP8 kv-cache on this device.")
714+
707715
if kv_cache.numel() > 0:
708716
key_cache = kv_cache[0]
709717
value_cache = kv_cache[1]

vllm/attention/backends/mla/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,6 @@
205205
compute_slot_mapping_start_idx,
206206
is_block_tables_empty)
207207
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
208-
from vllm.fa_utils import get_flash_attn_version
209208
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
210209
LinearBase, RowParallelLinear,
211210
UnquantizedLinearMethod)
@@ -214,6 +213,7 @@
214213
from vllm.multimodal import MultiModalPlaceholderMap
215214
from vllm.platforms import current_platform
216215
from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down
216+
from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
217217

218218
try:
219219
from vllm.vllm_flash_attn import flash_attn_varlen_func

vllm/config.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,10 +1157,6 @@ def _verify_cache_dtype(self) -> None:
11571157
if self.cache_dtype == "auto":
11581158
pass
11591159
elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
1160-
if envs.VLLM_USE_V1:
1161-
raise NotImplementedError(
1162-
"V1 does not yet support fp8 KV cache. "
1163-
"Set VLLM_USE_V1=0 to enable fp8 kv cache.")
11641160
logger.info(
11651161
"Using fp8 data type to store kv cache. It reduces the GPU "
11661162
"memory footprint and boosts the performance. "

vllm/engine/arg_utils.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1562,9 +1562,20 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
15621562

15631563
# No Fp8 KV cache so far.
15641564
if self.kv_cache_dtype != "auto":
1565-
_raise_or_fallback(feature_name="--kv-cache-dtype",
1566-
recommend_to_remove=False)
1567-
return False
1565+
fp8_attention = self.kv_cache_dtype.startswith("fp8")
1566+
will_use_fa = (
1567+
current_platform.is_cuda()
1568+
and not envs.is_set("VLLM_ATTENTION_BACKEND")
1569+
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
1570+
supported = False
1571+
if fp8_attention and will_use_fa:
1572+
from vllm.vllm_flash_attn.fa_utils import (
1573+
flash_attn_supports_fp8)
1574+
supported = flash_attn_supports_fp8()
1575+
if not supported:
1576+
_raise_or_fallback(feature_name="--kv-cache-dtype",
1577+
recommend_to_remove=False)
1578+
return False
15681579

15691580
# No Prompt Adapter so far.
15701581
if self.enable_prompt_adapter:

vllm/platforms/cuda.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# import custom ops, trigger op registration
1515
import vllm._C # noqa
1616
import vllm.envs as envs
17-
from vllm.fa_utils import get_flash_attn_version
1817
from vllm.logger import init_logger
1918
from vllm.utils import import_pynvml
2019

@@ -258,7 +257,7 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
258257
try:
259258
import vllm.vllm_flash_attn # noqa: F401
260259
from vllm.attention.backends.flash_attn import ( # noqa: F401
261-
FlashAttentionBackend)
260+
FlashAttentionBackend, flash_attn_supports_fp8)
262261

263262
supported_sizes = \
264263
FlashAttentionBackend.get_supported_head_sizes()
@@ -269,10 +268,9 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
269268
target_backend = _Backend.XFORMERS
270269
fp8_kv_cache = (kv_cache_dtype is not None
271270
and kv_cache_dtype.startswith("fp8"))
272-
if (fp8_kv_cache and get_flash_attn_version() != 3):
271+
if (fp8_kv_cache and not flash_attn_supports_fp8()):
273272
logger.info(
274-
"Cannot use FlashAttention-2 backend for FP8 KV cache."
275-
)
273+
"Cannot use FlashAttention backend for FP8 KV cache.")
276274
logger.warning(
277275
"Please use FlashInfer backend with FP8 KV Cache for "
278276
"better performance by setting environment variable "

vllm/v1/attention/backends/flash_attn.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111
AttentionMetadata, AttentionType,
1212
is_quantized_kv_cache)
1313
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
14-
from vllm.fa_utils import get_flash_attn_version
1514
from vllm.logger import init_logger
1615
from vllm.platforms import current_platform
1716
from vllm.utils import cdiv
17+
from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8,
18+
get_flash_attn_version)
1819

1920
if TYPE_CHECKING:
2021
from vllm.v1.core.sched.output import SchedulerOutput
@@ -182,9 +183,6 @@ def __init__(
182183
else:
183184
self.sliding_window = (sliding_window - 1, 0)
184185
self.kv_cache_dtype = kv_cache_dtype
185-
if is_quantized_kv_cache(self.kv_cache_dtype):
186-
raise NotImplementedError(
187-
"FlashAttention V1 with FP8 KV cache not yet supported")
188186
if logits_soft_cap is None:
189187
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
190188
logits_soft_cap = 0
@@ -206,6 +204,10 @@ def __init__(
206204
"are not implemented for "
207205
"FlashAttentionImpl")
208206
self.vllm_flash_attn_version = get_flash_attn_version()
207+
if is_quantized_kv_cache(self.kv_cache_dtype) \
208+
and not flash_attn_supports_fp8():
209+
raise NotImplementedError(
210+
"FlashAttention does not support fp8 kv-cache on this device.")
209211

210212
def forward(
211213
self,

vllm/v1/attention/backends/mla/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,14 +196,14 @@
196196
AttentionMetadata,
197197
MLAAttentionImpl)
198198
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
199-
from vllm.fa_utils import get_flash_attn_version
200199
from vllm.logger import init_logger
201200
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
202201
LinearBase, RowParallelLinear,
203202
UnquantizedLinearMethod)
204203
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
205204
from vllm.platforms import current_platform
206205
from vllm.utils import cdiv, round_down
206+
from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
207207

208208
try:
209209
from vllm.vllm_flash_attn import flash_attn_varlen_func

vllm/fa_utils.py renamed to vllm/vllm_flash_attn/fa_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,9 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
4646
return fa_version
4747
except (ImportError, AssertionError):
4848
return None
49+
50+
51+
def flash_attn_supports_fp8() -> bool:
52+
from vllm.platforms import current_platform
53+
return get_flash_attn_version() == 3 and \
54+
current_platform.get_device_capability().major == 9

0 commit comments

Comments
 (0)