Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,19 @@ def stateless_init_device_torch_dist_pg(
def device_count(cls) -> int:
return cuda_device_count_stateless()

@classmethod
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
fp8_attention = kv_cache_dtype.startswith("fp8")
will_use_fa = (not envs.is_set("VLLM_ATTENTION_BACKEND")
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
supported = False
if cls.is_device_capability(100):
supported = True
elif fp8_attention and will_use_fa:
from vllm.attention.utils.fa_utils import flash_attn_supports_fp8
supported = flash_attn_supports_fp8()
return supported


# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
Expand Down Expand Up @@ -583,19 +596,6 @@ def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
" not found. Assuming no NVLink available.")
return False

@classmethod
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
fp8_attention = kv_cache_dtype.startswith("fp8")
will_use_fa = (not envs.is_set("VLLM_ATTENTION_BACKEND")
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
supported = False
if cls.is_device_capability(100):
supported = True
elif fp8_attention and will_use_fa:
from vllm.attention.utils.fa_utils import flash_attn_supports_fp8
supported = flash_attn_supports_fp8()
return supported


# Autodetect either NVML-enabled or non-NVML platform
# based on whether NVML is available.
Expand Down