Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/routines/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1837,7 +1837,7 @@ def run_backend_wrapper(backend):
return_lse=False,
)
elif backend == "trtllm-native":
return flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
return flashinfer.mla.trtllm_batch_decode_with_kv_cache_mla(
query=q.unsqueeze(1),
kv_cache=kv_cache.unsqueeze(1),
workspace_buffer=workspace_buffer,
Expand Down
6 changes: 6 additions & 0 deletions docs/api/attention.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ XQA
:toctree: ../generated

xqa
xqa_mla

flashinfer.prefill
==================
Expand Down Expand Up @@ -98,6 +99,11 @@ and `DeepSeek-R1 <https://arxiv.org/abs/2501.12948>`_).
PageAttention for MLA
---------------------

.. autosummary::
:toctree: ../generated

trtllm_batch_decode_with_kv_cache_mla

.. autoclass:: BatchMLAPagedAttentionWrapper
:members:

Expand Down
338 changes: 7 additions & 331 deletions flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
import torch

from .api_logging import flashinfer_api
from .xqa import xqa, xqa_mla

## NOTE: MLA functions have been moved to mla.py, but we keep the aliases here for backward compatibility.
from .mla import (
trtllm_batch_decode_with_kv_cache_mla as trtllm_batch_decode_with_kv_cache_mla,
xqa_batch_decode_with_kv_cache_mla as xqa_batch_decode_with_kv_cache_mla,
)
from .xqa import xqa, xqa_mla as xqa_mla
Comment on lines +27 to +31
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The as clauses in these imports are redundant. You can simplify them for better readability.

Suggested change
from .mla import (
trtllm_batch_decode_with_kv_cache_mla as trtllm_batch_decode_with_kv_cache_mla,
xqa_batch_decode_with_kv_cache_mla as xqa_batch_decode_with_kv_cache_mla,
)
from .xqa import xqa, xqa_mla as xqa_mla
from .mla import (
trtllm_batch_decode_with_kv_cache_mla,
xqa_batch_decode_with_kv_cache_mla,
)
from .xqa import xqa, xqa_mla

Comment on lines +27 to +31
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The aliases for the imported functions are redundant. You can simplify these imports for better readability.

Suggested change
from .mla import (
trtllm_batch_decode_with_kv_cache_mla as trtllm_batch_decode_with_kv_cache_mla,
xqa_batch_decode_with_kv_cache_mla as xqa_batch_decode_with_kv_cache_mla,
)
from .xqa import xqa, xqa_mla as xqa_mla
from .mla import (
trtllm_batch_decode_with_kv_cache_mla,
xqa_batch_decode_with_kv_cache_mla,
)
from .xqa import xqa, xqa_mla

from .cudnn import cudnn_batch_decode_with_kv_cache as cudnn_batch_decode_with_kv_cache
from .jit import (
gen_batch_decode_mla_module,
Expand Down Expand Up @@ -2499,336 +2505,6 @@ def xqa_batch_decode_with_kv_cache(
return out


def _check_trtllm_gen_mla_shape(
query,
kv_cache,
qk_nope_head_dim,
kv_lora_rank,
qk_rope_head_dim,
sparse_mla_top_k,
page_table,
page_size,
):
if query.ndim != 4:
raise ValueError(f"Expected query.ndim == 4, got {query.ndim}")
if kv_cache.ndim != 4:
raise ValueError(f"Expected kv_cache.ndim == 4, got {kv_cache.ndim}")
if qk_nope_head_dim != 128:
raise ValueError(f"Expected qk_nope_head_dim == 128, got {qk_nope_head_dim}")
if kv_lora_rank != 512:
raise ValueError(f"Expected kv_lora_rank == 512, got {kv_lora_rank}")
if qk_rope_head_dim != 64:
raise ValueError(f"Expected qk_rope_head_dim == 64, got {qk_rope_head_dim}")

B_q, Q_len, H, D_q = query.shape
D_ckv = kv_cache.shape[3]
# if H != 128:
# raise ValueError(f"Expected 128 heads for query, got {H}")
# todo(Yingyi): should we check num_heads == 128? Is this deepseek only?
if D_q != D_ckv or D_q != 576:
raise ValueError(
f"Expected head dim 576 for query and kv_cache, got {D_q} and {D_ckv}"
)

if sparse_mla_top_k > 0:
page_table_shape = page_table.shape
if page_table_shape != (B_q, Q_len, sparse_mla_top_k):
raise ValueError(
f"Expected page_table.shape == (B_q, Q_len, sparse_mla_top_k), got {page_table_shape}"
)
else:
B_block_table, block_num = page_table.shape
block_size = page_size
if B_q != B_block_table:
raise ValueError(
f"Expected batch size {B_q} for query and block_table, got {B_q} and {B_block_table}"
)
if block_num % (128 / block_size) != 0:
raise ValueError(
f"Expected block_num % (128 / block_size) == 0, got {block_num=} and {block_size=}"
)


@flashinfer_api
def trtllm_batch_decode_with_kv_cache_mla(
query: torch.Tensor,
kv_cache: torch.Tensor,
workspace_buffer: torch.Tensor,
qk_nope_head_dim: int,
kv_lora_rank: int,
qk_rope_head_dim: int,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
max_seq_len: int,
sparse_mla_top_k: int = 0,
out: Optional[torch.Tensor] = None,
bmm1_scale: Union[float, torch.Tensor] = 1.0,
bmm2_scale: Union[float, torch.Tensor] = 1.0,
sinks: Optional[List[torch.Tensor]] = None,
enable_pdl: bool = None,
backend: str = "auto",
) -> torch.Tensor:
"""
Parameters:
query: [batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length.
kv_cache: [num_pages, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache
workspace_buffer: [num_semaphores, 4], used for multi_block mode. Must be initialized to 0 for its first use.
qk_nope_head_dim: qk_nope_head_dim, must be 128
kv_lora_rank: kv_lora_rank, must be 512
qk_rope_head_dim: qk_rope_head_dim, must be 64
sparse_mla_top_k: sparse MLA top k, must be 0 for non-sparse MLA.
block_tables: page_table of kv cache, [batch_size, num_pages]
seq_lens: query_len
max_seq_len: max sequence length for kv_cache
out: output tensor, if not provided, will be allocated internally
bmm1_scale: fused scale for mla bmm1 input.
when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32.
bmm2_scale: fused scale for mla bmm2 input.
when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32.
sinks: additional value per head in the denominator of the softmax.
backend : str = "auto"
The implementation backend, could be ``auto``/``xqa`` or ``trtllm-gen``. Defaults to ``auto``.
When set to ``auto``, the backend will be chosen based on the device architecture and kernel availability.
For sm_100 and sm_103 (blackwell architecture), ``auto`` will choose ``trtllm-gen`` backend.
For sm_120 (blackwell architecture), ``auto`` will choose ``xqa`` backend.

Note:
In MLA, the actual BMM1 and BMM2 scales applied would be fused as:
bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5)
bmm2_scale = v_scale * o_scale
or,
bmm1_scale = torch.Tensor([q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5))
bmm2_scale = torch.Tensor([v_scale * o_scale])

The two scale factors should be static constant for cuda graph capture.
Either (bmm1_scale, bmm2_scale) or (bmm1_scale_log2_tensor, bmm2_scale_tensor) should be provided.

For static constant scale factors, the scale factors should be provided as float.
- (bmm1_scale, bmm2_scale)
For on-device fused scale tensors, which could dynamically change, the scale factors should be provided as torch.Tensor.
- (bmm1_scale_log2_tensor, bmm2_scale_tensor)
- Currently, only fp8 tensor core operation supports this mode.
When both are provided, the dynamic scale factor tensors will be used.
"""
if backend == "auto":
backend = (
"trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa"
)
if isinstance(bmm1_scale, torch.Tensor):
assert bmm1_scale.dtype == torch.float32
bmm1_scale = bmm1_scale * log2e
if isinstance(bmm2_scale, torch.Tensor):
assert bmm2_scale.dtype == torch.float32
if backend == "xqa":
if (
get_compute_capability(query.device)[0] != 12
or query.dtype != torch.float8_e4m3fn
or kv_cache.dtype != torch.float8_e4m3fn
):
raise ValueError(
f"XQA MLA only supports fp8 operation on SM120 GPUs, got {query.dtype} and {kv_cache.dtype}"
)
if sinks is not None:
raise ValueError("XQA MLA does not support sinks")
if query.size(1) != 1:
raise ValueError(
f"XQA MLA only supports q_len_per_request == 1, got {query.size(1)}"
)
return xqa_batch_decode_with_kv_cache_mla(
query,
kv_cache,
workspace_buffer,
qk_nope_head_dim,
kv_lora_rank,
qk_rope_head_dim,
block_tables,
seq_lens,
max_seq_len,
out,
bmm1_scale,
bmm2_scale,
sinks,
enable_pdl,
)
elif backend == "trtllm-gen":
enable_pdl = (
device_support_pdl(query.device) if enable_pdl is None else enable_pdl
)
run_func = get_trtllm_gen_fmha_module().trtllm_paged_attention_decode
sm_count = get_device_sm_count(query.device)

block_size = kv_cache.size(-2)
if (
block_size != 32 and block_size != 64
): # todo(Yingyi): add support for more block sizes?
raise ValueError(f"Supported block_size are 32 and 64, got {block_size}")

_check_trtllm_gen_mla_shape(
query,
kv_cache,
qk_nope_head_dim,
kv_lora_rank,
qk_rope_head_dim,
sparse_mla_top_k,
block_tables,
block_size,
)

if out is None:
out_shape = query.shape[:-1] + (kv_lora_rank,)
out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device)
else:
batch_size, _, num_q_heads, _ = query.shape
check_shape_dtype_device(
out,
[batch_size, num_q_heads, kv_lora_rank],
torch.bfloat16,
query.device,
"out",
)

run_func(
out,
None, # fp4 output not supported in wrapper api yet.
query,
kv_cache,
kv_cache,
workspace_buffer,
block_tables,
seq_lens,
max_seq_len,
bmm1_scale,
bmm2_scale,
-1, # o_sf_scale
-1, # o_sf_vec_size
0, # o_sf_start_index
-1, # window_left
sparse_mla_top_k,
sm_count,
enable_pdl,
workspace_buffer.numel() * workspace_buffer.element_size(),
sinks,
)

return out
else:
raise ValueError(f"Backend {backend} not supported")


@flashinfer_api
def xqa_batch_decode_with_kv_cache_mla(
query: torch.Tensor,
kv_cache: torch.Tensor,
workspace_buffer: torch.Tensor,
qk_nope_head_dim: int,
kv_lora_rank: int,
qk_rope_head_dim: int,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
max_seq_len: int,
out: Optional[torch.Tensor] = None,
bmm1_scale: Union[float, torch.Tensor] = 1.0,
bmm2_scale: Union[float, torch.Tensor] = 1.0,
sinks: Optional[List[torch.Tensor]] = None,
enable_pdl: bool = None,
) -> torch.Tensor:
"""
Parameters:
query: [batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length.
kv_cache: [num_pages, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache
workspace_buffer: torch.Tensor. Must be initialized to 0 for its first use.
qk_nope_head_dim: qk_nope_head_dim, must be 128
kv_lora_rank: kv_lora_rank, must be 512
qk_rope_head_dim: qk_rope_head_dim, must be 64
block_tables: page_table of kv cache, [batch_size, num_pages]
seq_lens: query_len
max_seq_len: max sequence length for kv_cache
out: output tensor, if not provided, will be allocated internally
bmm1_scale: fused scale for mla bmm1 input. Can be a float or a torch.Tensor.
bmm2_scale: fused scale for mla bmm2 input. Can be a float or a torch.Tensor.
sinks: additional value per head in the denominator of the softmax.

Note:
In MLA, the actual BMM1 and BMM2 scales applied would be fused as:
bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5)
bmm2_scale = v_scale * o_scale

The two scale factors should be static constant for cuda graph capture.
Either (bmm1_scale, bmm2_scale) or (bmm1_scale_log2_tensor, bmm2_scale_tensor) should be provided.

For static constant scale factors, the scale factors should be provided as float.
- (bmm1_scale, bmm2_scale)
For on-device fused scale tensors, which could dynamically change, the scale factors should be provided as torch.Tensor.
- (bmm1_scale_log2_tensor, bmm2_scale_tensor)
- Currently, only fp8 tensor core operation supports this mode.
When both are provided, the dynamic scale factor tensors will be used.
"""
enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl
sm_count = get_device_sm_count(query.device)

block_size = kv_cache.size(-2)
q_len_per_request = query.size(1)
if q_len_per_request != 1:
raise ValueError(
f"XQA MLA only supports q_len_per_request == 1, got {q_len_per_request}"
)
if query.dtype != torch.float8_e4m3fn or kv_cache.dtype != torch.float8_e4m3fn:
raise ValueError(
f"XQA MLA only supports fp8 tensor core operation, got {query.dtype} and {kv_cache.dtype}"
)
if sinks is not None:
raise ValueError("XQA MLA does not support sinks")

_check_trtllm_gen_mla_shape(
query,
kv_cache,
qk_nope_head_dim,
kv_lora_rank,
qk_rope_head_dim,
0, # sparse_mla_top_k
block_tables,
block_size,
)

if out is None:
out_shape = query.shape[:-1] + (kv_lora_rank,)
out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device)
else:
batch_size, _, num_q_heads, _ = query.shape
check_shape_dtype_device(
out,
[batch_size, num_q_heads, kv_lora_rank],
torch.bfloat16,
query.device,
"out",
)

workspace_u8 = workspace_buffer.view(torch.uint8)
semaphore = workspace_u8[: 8 * 1024 * 1024] # reserve 8MB for semaphore
scratch = workspace_u8[8 * 1024 * 1024 :]
# This can not be replaced by kv_cache.transpose(1, 2) because the stride is not the same
kv_cache_new = kv_cache.squeeze(1).unsqueeze(2)
seq_lens_new = seq_lens.unsqueeze(1)

xqa_mla(
query,
kv_cache_new,
kv_cache_new,
block_tables,
seq_lens_new,
out,
scratch,
semaphore,
block_size,
q_scale=bmm1_scale,
kv_scale=bmm2_scale,
sm_count=sm_count,
enable_pdl=enable_pdl,
)

return out


def fast_decode_plan(
self,
indptr: torch.Tensor,
Expand Down
Loading