-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[Hybrid]: Decouple Kernel Block Size from KV Page Size #24486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 18 commits
05ef7f1
0d18668
bded2b4
90c14ab
613f4c6
edfdf8d
28e94eb
b1d3dcc
e10d70a
2ce97c4
0909efd
097c11c
0e6ae07
3fd0727
ff983af
e869bf0
3bb83b9
8a7c2b6
9620fe0
5fe1e95
ddbaebb
698b55e
e013093
1a52e56
bbe2200
df485c3
29f9d30
f70aefa
beee4d3
40d7b95
1710a7a
5820c10
adba4a5
279d1d0
5d328d2
7cb4fc3
585f2bf
413272b
a51673b
5691f12
d865f00
dd7bfc8
6a97abb
74e0ff1
c9231e8
248dbd5
fc0c633
caa8b93
6fad4dc
942052b
4694b97
a7f3d54
ec1ca20
66e7685
3e70aa4
8db8f3f
5c9f1ef
c40ebc6
ff3a7db
d419f0f
86e414c
10fabbb
fba9bea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -139,7 +139,11 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: | |
|
|
||
| # TODO(lucas): handle this more gracefully | ||
| # Note: model_config may be None during testing | ||
| if model_config is not None and model_config.use_mla: | ||
| # Note: block_size is initialized in | ||
| # HybridAttentionMambaModelConfig.verify_and_update_config | ||
zhiyuan1i marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # and doesn't need to be reinitialized here | ||
|
Comment on lines
121
to
124
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This statement is true for hybrid models only right? |
||
| if model_config is not None and model_config.use_mla \ | ||
| and cache_config.block_size is not None: | ||
| # If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, | ||
| # then we default to FlashMLA backend for non-blackwell GPUs, | ||
| # else we default to CutlassMLA. For each case, we force the | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,14 +2,15 @@ | |
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """Attention layer with FlashAttention.""" | ||
| from dataclasses import dataclass | ||
| from typing import Optional | ||
| from typing import Optional, Union | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| from vllm import _custom_ops as ops | ||
| from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, | ||
| AttentionMetadata, AttentionType, | ||
| MultipleOf, | ||
| is_quantized_kv_cache) | ||
| from vllm.attention.layer import Attention | ||
| from vllm.attention.ops.merge_attn_states import merge_attn_states | ||
|
|
@@ -49,6 +50,10 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: | |
| def get_supported_head_sizes(cls) -> list[int]: | ||
| return [32, 64, 96, 128, 160, 192, 224, 256] | ||
|
|
||
| @staticmethod | ||
| def get_supported_block_size() -> list[Union[int, MultipleOf]]: | ||
| return [MultipleOf(16)] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Technically FA3 would support |
||
|
|
||
| @classmethod | ||
| def validate_head_size(cls, head_size: int) -> None: | ||
| supported_head_sizes = cls.get_supported_head_sizes() | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,12 +2,13 @@ | |
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| import os | ||
| from typing import ClassVar, Optional | ||
| from typing import ClassVar, Optional, Union | ||
|
|
||
| import torch | ||
|
|
||
| import vllm._custom_ops as ops | ||
| from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, | ||
| MultipleOf, | ||
| is_quantized_kv_cache) | ||
| from vllm.logger import init_logger | ||
| from vllm.v1.attention.backends.mla.common import (MLACommonBackend, | ||
|
|
@@ -39,6 +40,10 @@ def get_impl_cls() -> type["CutlassMLAImpl"]: | |
| def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]: | ||
| return CutlassMLAMetadataBuilder | ||
|
|
||
| @staticmethod | ||
| def get_supported_block_size() -> list[Union[int, MultipleOf]]: | ||
| return [128] | ||
|
||
|
|
||
|
|
||
| class SM100Workspace: | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,7 +6,8 @@ | |
|
|
||
| import torch | ||
|
|
||
| from vllm.attention.backends.abstract import AttentionLayer, AttentionType | ||
| from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, | ||
| MultipleOf) | ||
| from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, | ||
| get_mla_metadata, | ||
| is_flashmla_supported) | ||
|
|
@@ -42,6 +43,10 @@ def get_builder_cls() -> type["FlashMLAMetadataBuilder"]: | |
| def get_impl_cls() -> type["FlashMLAImpl"]: | ||
| return FlashMLAImpl | ||
|
|
||
| @staticmethod | ||
| def get_supported_block_size() -> list[Union[int, MultipleOf]]: | ||
| return [64] | ||
|
||
|
|
||
|
|
||
| @dataclass | ||
| class FlashMLADecodeMetadata(MLACommonDecodeMetadata): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've checked this pr on vllm-ascend, everything goes well except for the value of
block_alignment_bytes, we need to specify another value according to our attn backend.So could we get
block_alignment_bytesform thecurrent_platform? I know this is exactly related to attn backend, instead of platform. Maybe we could set theblock_alignment_bytesdetails in platform.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ATTN backend in config.py had not been initialized at that time. This was discussed offline with @heheda12345 previously. However, we could obtain supported alignment_bytes from the attn class, so I will give it a try.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm OK with a hardcode value in this PR. @MengqingCao you can make better abstraction of this part in a follow-up PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I'll do this in the follow-up pr then