|
22 | 22 | compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens, |
23 | 23 | get_seq_len_block_table_args, is_all_cross_attn_metadata_set, |
24 | 24 | is_all_encoder_attn_metadata_set, is_block_tables_empty) |
25 | | -from vllm.fa_utils import get_flash_attn_version |
26 | 25 | from vllm.logger import init_logger |
27 | 26 | from vllm.multimodal import MultiModalPlaceholderMap |
28 | 27 | from vllm.utils import async_tensor_h2d, make_tensor_with_pad |
29 | 28 | from vllm.vllm_flash_attn import (flash_attn_varlen_func, |
30 | 29 | flash_attn_with_kvcache) |
| 30 | +from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8, |
| 31 | + get_flash_attn_version) |
31 | 32 |
|
32 | 33 | if TYPE_CHECKING: |
33 | 34 | from vllm.worker.model_runner import (ModelInputForGPUBuilder, |
@@ -632,10 +633,13 @@ def __init__( |
632 | 633 | self.kv_cache_dtype = kv_cache_dtype |
633 | 634 | self.vllm_flash_attn_version = get_flash_attn_version( |
634 | 635 | requires_alibi=self.alibi_slopes is not None) |
635 | | - if (is_quantized_kv_cache(self.kv_cache_dtype) |
636 | | - and self.vllm_flash_attn_version != 3): |
| 636 | + if is_quantized_kv_cache(self.kv_cache_dtype) and ( |
| 637 | + not self.kv_cache_dtype.startswith("fp8") |
| 638 | + or not flash_attn_supports_fp8()): |
637 | 639 | raise NotImplementedError( |
638 | | - "Only FlashAttention3 supports FP8 KV cache") |
| 640 | + f"FlashAttention does not support {self.kv_cache_dtype} " |
| 641 | + "kv-cache on this device " |
| 642 | + f"(FA supports fp8 = {flash_attn_supports_fp8()}).") |
639 | 643 | if logits_soft_cap is None: |
640 | 644 | # In flash-attn, setting logits_soft_cap as 0 means no soft cap. |
641 | 645 | logits_soft_cap = 0 |
@@ -704,6 +708,10 @@ def forward( |
704 | 708 | logits_soft_cap: Optional[float] = self.logits_soft_cap |
705 | 709 | fp8_attention = kv_cache_dtype.startswith("fp8") |
706 | 710 |
|
| 711 | + if fp8_attention and not flash_attn_supports_fp8(): |
| 712 | + raise NotImplementedError( |
| 713 | + "FlashAttention does not support FP8 kv-cache on this device.") |
| 714 | + |
707 | 715 | if kv_cache.numel() > 0: |
708 | 716 | key_cache = kv_cache[0] |
709 | 717 | value_cache = kv_cache[1] |
|
0 commit comments