-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[Hybrid]: Decouple Kernel Block Size from KV Page Size #24486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 22 commits
05ef7f1
0d18668
bded2b4
90c14ab
613f4c6
edfdf8d
28e94eb
b1d3dcc
e10d70a
2ce97c4
0909efd
097c11c
0e6ae07
3fd0727
ff983af
e869bf0
3bb83b9
8a7c2b6
9620fe0
5fe1e95
ddbaebb
698b55e
e013093
1a52e56
bbe2200
df485c3
29f9d30
f70aefa
beee4d3
40d7b95
1710a7a
5820c10
adba4a5
279d1d0
5d328d2
7cb4fc3
585f2bf
413272b
a51673b
5691f12
d865f00
dd7bfc8
6a97abb
74e0ff1
c9231e8
248dbd5
fc0c633
caa8b93
6fad4dc
942052b
4694b97
a7f3d54
ec1ca20
66e7685
3e70aa4
8db8f3f
5c9f1ef
c40ebc6
ff3a7db
d419f0f
86e414c
10fabbb
fba9bea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -374,12 +374,22 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: | |
| block_size=model_config.max_model_len, | ||
| ).page_size_bytes | ||
|
|
||
| # some attention backends (e.g. FA) only support setting | ||
| # block size to multiple of 16, so let's suggest a value | ||
| # that would work (note: FA is currently not compatible | ||
| # with mamba layers, use FlashInfer instead). | ||
| attn_block_size = 16 * cdiv(mamba_page_size, | ||
| 16 * attn_page_size_1_token) | ||
| # Attention backend constraints: | ||
| # - FlashAttention (FA) requires block size to be multiple of 16 | ||
| # - MLA (Multi-head Latent Attention) requires larger alignment: | ||
| # * CUTLASS_MLA backend: 128-byte alignment | ||
| # * Other MLA backends: 64-byte alignment | ||
| if model_config.use_mla: | ||
| use_cutlass_mla = (envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA") | ||
| block_alignment_bytes = 128 if use_cutlass_mla else 64 | ||
| else: | ||
| block_alignment_bytes = 16 | ||
|
||
|
|
||
| # Calculate minimum attention block size that satisfies both: | ||
| # 1. Backend alignment requirements (block_alignment_bytes) | ||
| # 2. Mamba page size compatibility (attn_page_size >= mamba_page_size) | ||
| attn_block_size = block_alignment_bytes * cdiv( | ||
| mamba_page_size, block_alignment_bytes * attn_page_size_1_token) | ||
|
|
||
| # override attention block size if either (a) the | ||
| # user has not set it or (b) the user has set it | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -138,7 +138,11 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: | |
|
|
||
| # TODO(lucas): handle this more gracefully | ||
| # Note: model_config may be None during testing | ||
| if model_config is not None and model_config.use_mla: | ||
| # Note: block_size is initialized in | ||
| # HybridAttentionMambaModelConfig.verify_and_update_config | ||
zhiyuan1i marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # and doesn't need to be reinitialized here | ||
|
Comment on lines
121
to
124
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This statement is true for hybrid models only right? |
||
| if model_config is not None and model_config.use_mla \ | ||
| and cache_config.block_size is not None: | ||
| # If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, | ||
| # then we default to FlashMLA backend for non-blackwell GPUs, | ||
| # else we default to CutlassMLA. For each case, we force the | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,14 +2,15 @@ | |
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """Attention layer with FlashAttention.""" | ||
| from dataclasses import dataclass | ||
| from typing import Optional | ||
| from typing import Optional, Union | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| from vllm import _custom_ops as ops | ||
| from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, | ||
| AttentionMetadata, AttentionType, | ||
| MultipleOf, | ||
| is_quantized_kv_cache) | ||
| from vllm.attention.layer import Attention | ||
| from vllm.attention.ops.merge_attn_states import merge_attn_states | ||
|
|
@@ -49,6 +50,10 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: | |
| def get_supported_head_sizes(cls) -> list[int]: | ||
| return [32, 64, 96, 128, 160, 192, 224, 256] | ||
|
|
||
| @staticmethod | ||
| def get_supported_block_size() -> list[Union[int, MultipleOf]]: | ||
| return [MultipleOf(16)] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Technically FA3 would support |
||
|
|
||
| @classmethod | ||
| def validate_head_size(cls, head_size: int) -> None: | ||
| supported_head_sizes = cls.get_supported_head_sizes() | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,12 +2,13 @@ | |
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| import os | ||
| from typing import ClassVar, Optional | ||
| from typing import ClassVar, Optional, Union | ||
|
|
||
| import torch | ||
|
|
||
| import vllm._custom_ops as ops | ||
| from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, | ||
| MultipleOf, | ||
| is_quantized_kv_cache) | ||
| from vllm.logger import init_logger | ||
| from vllm.v1.attention.backends.mla.common import (MLACommonBackend, | ||
|
|
@@ -39,6 +40,10 @@ def get_impl_cls() -> type["CutlassMLAImpl"]: | |
| def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]: | ||
| return CutlassMLAMetadataBuilder | ||
|
|
||
| @staticmethod | ||
| def get_supported_block_size() -> list[Union[int, MultipleOf]]: | ||
| return [128] | ||
|
||
|
|
||
|
|
||
| class SM100Workspace: | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,7 +6,8 @@ | |
|
|
||
| import torch | ||
|
|
||
| from vllm.attention.backends.abstract import AttentionLayer, AttentionType | ||
| from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, | ||
| MultipleOf) | ||
| from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, | ||
| get_mla_metadata, | ||
| is_flashmla_supported) | ||
|
|
@@ -41,6 +42,10 @@ def get_builder_cls() -> type["FlashMLAMetadataBuilder"]: | |
| def get_impl_cls() -> type["FlashMLAImpl"]: | ||
| return FlashMLAImpl | ||
|
|
||
| @staticmethod | ||
| def get_supported_block_size() -> list[Union[int, MultipleOf]]: | ||
| return [64] | ||
|
||
|
|
||
|
|
||
| @dataclass | ||
| class FlashMLADecodeMetadata(MLACommonDecodeMetadata): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise there will be an assertion error if people set block_size=1 manually for backends that supports block_size 1 but haven't update this function yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought 16 would be a bit more appropriate,? @heheda12345 I looked it up carefully, and almost every one of them supports 16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes maybe all of them supports 16. But if people specify block_size 1, there will be some problem.