Skip to content

Commit c765f0b

Browse files
authored
[FlashInfer] Avoid FlashInfer block_size 16 + head_size 256 on blackwell (#27994)
Signed-off-by: Chen Zhang <[email protected]>
1 parent 002b07c commit c765f0b

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

vllm/model_executor/models/config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import vllm.envs as envs
88
from vllm.logger import init_logger
99
from vllm.model_executor.models import ModelRegistry
10+
from vllm.platforms import current_platform
1011
from vllm.utils.math_utils import cdiv, round_up
1112
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
1213
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentionSpec
@@ -356,6 +357,17 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
356357
).page_size_bytes
357358
else:
358359
kernel_block_alignment_size = 16
360+
if (
361+
current_platform.is_device_capability(100)
362+
and model_config.get_head_size() == 256
363+
and (
364+
envs.VLLM_ATTENTION_BACKEND is None
365+
or envs.VLLM_ATTENTION_BACKEND == "FLASHINFER"
366+
)
367+
):
368+
# https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that`
369+
# head size 256 and block size 16 is not supported on blackwell.
370+
kernel_block_alignment_size = 32
359371
attn_page_size_1_token = FullAttentionSpec(
360372
block_size=1,
361373
num_kv_heads=model_config.get_num_kv_heads(parallel_config),

vllm/v1/attention/backends/flashinfer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,15 @@ def __init__(
402402
)
403403
self.paged_kv_last_page_len_np = self.paged_kv_last_page_len_cpu.numpy()
404404

405+
if self.head_dim == 256 and current_platform.is_device_capability(100):
406+
# https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that
407+
# head size 256 and block size 16 is not supported on blackwell.
408+
assert kv_cache_spec.block_size != 16, (
409+
"There is a bug in FlashInfer "
410+
"block_size 16 head size 256 support. Please avoid this combination by "
411+
"passing --block-size 32 or --block-size 64."
412+
)
413+
405414
def _get_workspace_buffer(self):
406415
if self._workspace_buffer is None:
407416
buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE

0 commit comments

Comments
 (0)