Skip to content

Commit c4d5755

Browse files
mgoinchoprahetarth
authored andcommitted
[CI Failure] Fix fp8 kv cache on <SM90 (vllm-project#25396)
Signed-off-by: mgoin <[email protected]>
1 parent 3af5c27 commit c4d5755

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

vllm/platforms/cuda.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,9 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
286286
TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501
287287
XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501
288288

289+
use_fp8_kv_cache = (kv_cache_dtype is not None
290+
and kv_cache_dtype.startswith("fp8"))
291+
289292
if selected_backend == _Backend.FLASHINFER:
290293
logger.info_once("Using FlashInfer backend on V1 engine.")
291294
if cls.has_device_capability(100):
@@ -334,10 +337,11 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
334337

335338
# FlashAttention is the default for SM 8.0+ GPUs
336339
if cls.has_device_capability(80):
337-
if has_sink and not cls.is_device_capability(90):
340+
if (has_sink or
341+
use_fp8_kv_cache) and not cls.is_device_capability(90):
338342
logger.info_once("Using Triton backend on V1 engine.")
339343
return TRITON_ATTN_VLLM_V1
340-
if is_default_backend_supported := is_attn_backend_supported(
344+
elif is_default_backend_supported := is_attn_backend_supported(
341345
FLASH_ATTN_V1, head_size, dtype,
342346
allow_import_error=False):
343347
logger.info_once("Using Flash Attention backend on "

0 commit comments

Comments
 (0)