diff --git a/benchmarks/routines/attention.py b/benchmarks/routines/attention.py index e88b176f13..96b7a8f58b 100644 --- a/benchmarks/routines/attention.py +++ b/benchmarks/routines/attention.py @@ -1837,7 +1837,7 @@ def run_backend_wrapper(backend): return_lse=False, ) elif backend == "trtllm-native": - return flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( + return flashinfer.mla.trtllm_batch_decode_with_kv_cache_mla( query=q.unsqueeze(1), kv_cache=kv_cache.unsqueeze(1), workspace_buffer=workspace_buffer, diff --git a/docs/api/attention.rst b/docs/api/attention.rst index bb65664c83..eff9160787 100644 --- a/docs/api/attention.rst +++ b/docs/api/attention.rst @@ -47,6 +47,7 @@ XQA :toctree: ../generated xqa + xqa_mla flashinfer.prefill ================== @@ -98,6 +99,11 @@ and `DeepSeek-R1 `_). PageAttention for MLA --------------------- +.. autosummary:: + :toctree: ../generated + + trtllm_batch_decode_with_kv_cache_mla + .. autoclass:: BatchMLAPagedAttentionWrapper :members: diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 0765f933df..53adcf41a8 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -22,7 +22,13 @@ import torch from .api_logging import flashinfer_api -from .xqa import xqa, xqa_mla + +## NOTE: MLA functions have been moved to mla.py, but we keep the aliases here for backward compatibility. +from .mla import ( + trtllm_batch_decode_with_kv_cache_mla as trtllm_batch_decode_with_kv_cache_mla, + xqa_batch_decode_with_kv_cache_mla as xqa_batch_decode_with_kv_cache_mla, +) +from .xqa import xqa, xqa_mla as xqa_mla from .cudnn import cudnn_batch_decode_with_kv_cache as cudnn_batch_decode_with_kv_cache from .jit import ( gen_batch_decode_mla_module, @@ -2499,336 +2505,6 @@ def xqa_batch_decode_with_kv_cache( return out -def _check_trtllm_gen_mla_shape( - query, - kv_cache, - qk_nope_head_dim, - kv_lora_rank, - qk_rope_head_dim, - sparse_mla_top_k, - page_table, - page_size, -): - if query.ndim != 4: - raise ValueError(f"Expected query.ndim == 4, got {query.ndim}") - if kv_cache.ndim != 4: - raise ValueError(f"Expected kv_cache.ndim == 4, got {kv_cache.ndim}") - if qk_nope_head_dim != 128: - raise ValueError(f"Expected qk_nope_head_dim == 128, got {qk_nope_head_dim}") - if kv_lora_rank != 512: - raise ValueError(f"Expected kv_lora_rank == 512, got {kv_lora_rank}") - if qk_rope_head_dim != 64: - raise ValueError(f"Expected qk_rope_head_dim == 64, got {qk_rope_head_dim}") - - B_q, Q_len, H, D_q = query.shape - D_ckv = kv_cache.shape[3] - # if H != 128: - # raise ValueError(f"Expected 128 heads for query, got {H}") - # todo(Yingyi): should we check num_heads == 128? Is this deepseek only? - if D_q != D_ckv or D_q != 576: - raise ValueError( - f"Expected head dim 576 for query and kv_cache, got {D_q} and {D_ckv}" - ) - - if sparse_mla_top_k > 0: - page_table_shape = page_table.shape - if page_table_shape != (B_q, Q_len, sparse_mla_top_k): - raise ValueError( - f"Expected page_table.shape == (B_q, Q_len, sparse_mla_top_k), got {page_table_shape}" - ) - else: - B_block_table, block_num = page_table.shape - block_size = page_size - if B_q != B_block_table: - raise ValueError( - f"Expected batch size {B_q} for query and block_table, got {B_q} and {B_block_table}" - ) - if block_num % (128 / block_size) != 0: - raise ValueError( - f"Expected block_num % (128 / block_size) == 0, got {block_num=} and {block_size=}" - ) - - -@flashinfer_api -def trtllm_batch_decode_with_kv_cache_mla( - query: torch.Tensor, - kv_cache: torch.Tensor, - workspace_buffer: torch.Tensor, - qk_nope_head_dim: int, - kv_lora_rank: int, - qk_rope_head_dim: int, - block_tables: torch.Tensor, - seq_lens: torch.Tensor, - max_seq_len: int, - sparse_mla_top_k: int = 0, - out: Optional[torch.Tensor] = None, - bmm1_scale: Union[float, torch.Tensor] = 1.0, - bmm2_scale: Union[float, torch.Tensor] = 1.0, - sinks: Optional[List[torch.Tensor]] = None, - enable_pdl: bool = None, - backend: str = "auto", -) -> torch.Tensor: - """ - Parameters: - query: [batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length. - kv_cache: [num_pages, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache - workspace_buffer: [num_semaphores, 4], used for multi_block mode. Must be initialized to 0 for its first use. - qk_nope_head_dim: qk_nope_head_dim, must be 128 - kv_lora_rank: kv_lora_rank, must be 512 - qk_rope_head_dim: qk_rope_head_dim, must be 64 - sparse_mla_top_k: sparse MLA top k, must be 0 for non-sparse MLA. - block_tables: page_table of kv cache, [batch_size, num_pages] - seq_lens: query_len - max_seq_len: max sequence length for kv_cache - out: output tensor, if not provided, will be allocated internally - bmm1_scale: fused scale for mla bmm1 input. - when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32. - bmm2_scale: fused scale for mla bmm2 input. - when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32. - sinks: additional value per head in the denominator of the softmax. - backend : str = "auto" - The implementation backend, could be ``auto``/``xqa`` or ``trtllm-gen``. Defaults to ``auto``. - When set to ``auto``, the backend will be chosen based on the device architecture and kernel availability. - For sm_100 and sm_103 (blackwell architecture), ``auto`` will choose ``trtllm-gen`` backend. - For sm_120 (blackwell architecture), ``auto`` will choose ``xqa`` backend. - - Note: - In MLA, the actual BMM1 and BMM2 scales applied would be fused as: - bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5) - bmm2_scale = v_scale * o_scale - or, - bmm1_scale = torch.Tensor([q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5)) - bmm2_scale = torch.Tensor([v_scale * o_scale]) - - The two scale factors should be static constant for cuda graph capture. - Either (bmm1_scale, bmm2_scale) or (bmm1_scale_log2_tensor, bmm2_scale_tensor) should be provided. - - For static constant scale factors, the scale factors should be provided as float. - - (bmm1_scale, bmm2_scale) - For on-device fused scale tensors, which could dynamically change, the scale factors should be provided as torch.Tensor. - - (bmm1_scale_log2_tensor, bmm2_scale_tensor) - - Currently, only fp8 tensor core operation supports this mode. - When both are provided, the dynamic scale factor tensors will be used. - """ - if backend == "auto": - backend = ( - "trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa" - ) - if isinstance(bmm1_scale, torch.Tensor): - assert bmm1_scale.dtype == torch.float32 - bmm1_scale = bmm1_scale * log2e - if isinstance(bmm2_scale, torch.Tensor): - assert bmm2_scale.dtype == torch.float32 - if backend == "xqa": - if ( - get_compute_capability(query.device)[0] != 12 - or query.dtype != torch.float8_e4m3fn - or kv_cache.dtype != torch.float8_e4m3fn - ): - raise ValueError( - f"XQA MLA only supports fp8 operation on SM120 GPUs, got {query.dtype} and {kv_cache.dtype}" - ) - if sinks is not None: - raise ValueError("XQA MLA does not support sinks") - if query.size(1) != 1: - raise ValueError( - f"XQA MLA only supports q_len_per_request == 1, got {query.size(1)}" - ) - return xqa_batch_decode_with_kv_cache_mla( - query, - kv_cache, - workspace_buffer, - qk_nope_head_dim, - kv_lora_rank, - qk_rope_head_dim, - block_tables, - seq_lens, - max_seq_len, - out, - bmm1_scale, - bmm2_scale, - sinks, - enable_pdl, - ) - elif backend == "trtllm-gen": - enable_pdl = ( - device_support_pdl(query.device) if enable_pdl is None else enable_pdl - ) - run_func = get_trtllm_gen_fmha_module().trtllm_paged_attention_decode - sm_count = get_device_sm_count(query.device) - - block_size = kv_cache.size(-2) - if ( - block_size != 32 and block_size != 64 - ): # todo(Yingyi): add support for more block sizes? - raise ValueError(f"Supported block_size are 32 and 64, got {block_size}") - - _check_trtllm_gen_mla_shape( - query, - kv_cache, - qk_nope_head_dim, - kv_lora_rank, - qk_rope_head_dim, - sparse_mla_top_k, - block_tables, - block_size, - ) - - if out is None: - out_shape = query.shape[:-1] + (kv_lora_rank,) - out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device) - else: - batch_size, _, num_q_heads, _ = query.shape - check_shape_dtype_device( - out, - [batch_size, num_q_heads, kv_lora_rank], - torch.bfloat16, - query.device, - "out", - ) - - run_func( - out, - None, # fp4 output not supported in wrapper api yet. - query, - kv_cache, - kv_cache, - workspace_buffer, - block_tables, - seq_lens, - max_seq_len, - bmm1_scale, - bmm2_scale, - -1, # o_sf_scale - -1, # o_sf_vec_size - 0, # o_sf_start_index - -1, # window_left - sparse_mla_top_k, - sm_count, - enable_pdl, - workspace_buffer.numel() * workspace_buffer.element_size(), - sinks, - ) - - return out - else: - raise ValueError(f"Backend {backend} not supported") - - -@flashinfer_api -def xqa_batch_decode_with_kv_cache_mla( - query: torch.Tensor, - kv_cache: torch.Tensor, - workspace_buffer: torch.Tensor, - qk_nope_head_dim: int, - kv_lora_rank: int, - qk_rope_head_dim: int, - block_tables: torch.Tensor, - seq_lens: torch.Tensor, - max_seq_len: int, - out: Optional[torch.Tensor] = None, - bmm1_scale: Union[float, torch.Tensor] = 1.0, - bmm2_scale: Union[float, torch.Tensor] = 1.0, - sinks: Optional[List[torch.Tensor]] = None, - enable_pdl: bool = None, -) -> torch.Tensor: - """ - Parameters: - query: [batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length. - kv_cache: [num_pages, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache - workspace_buffer: torch.Tensor. Must be initialized to 0 for its first use. - qk_nope_head_dim: qk_nope_head_dim, must be 128 - kv_lora_rank: kv_lora_rank, must be 512 - qk_rope_head_dim: qk_rope_head_dim, must be 64 - block_tables: page_table of kv cache, [batch_size, num_pages] - seq_lens: query_len - max_seq_len: max sequence length for kv_cache - out: output tensor, if not provided, will be allocated internally - bmm1_scale: fused scale for mla bmm1 input. Can be a float or a torch.Tensor. - bmm2_scale: fused scale for mla bmm2 input. Can be a float or a torch.Tensor. - sinks: additional value per head in the denominator of the softmax. - - Note: - In MLA, the actual BMM1 and BMM2 scales applied would be fused as: - bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5) - bmm2_scale = v_scale * o_scale - - The two scale factors should be static constant for cuda graph capture. - Either (bmm1_scale, bmm2_scale) or (bmm1_scale_log2_tensor, bmm2_scale_tensor) should be provided. - - For static constant scale factors, the scale factors should be provided as float. - - (bmm1_scale, bmm2_scale) - For on-device fused scale tensors, which could dynamically change, the scale factors should be provided as torch.Tensor. - - (bmm1_scale_log2_tensor, bmm2_scale_tensor) - - Currently, only fp8 tensor core operation supports this mode. - When both are provided, the dynamic scale factor tensors will be used. - """ - enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl - sm_count = get_device_sm_count(query.device) - - block_size = kv_cache.size(-2) - q_len_per_request = query.size(1) - if q_len_per_request != 1: - raise ValueError( - f"XQA MLA only supports q_len_per_request == 1, got {q_len_per_request}" - ) - if query.dtype != torch.float8_e4m3fn or kv_cache.dtype != torch.float8_e4m3fn: - raise ValueError( - f"XQA MLA only supports fp8 tensor core operation, got {query.dtype} and {kv_cache.dtype}" - ) - if sinks is not None: - raise ValueError("XQA MLA does not support sinks") - - _check_trtllm_gen_mla_shape( - query, - kv_cache, - qk_nope_head_dim, - kv_lora_rank, - qk_rope_head_dim, - 0, # sparse_mla_top_k - block_tables, - block_size, - ) - - if out is None: - out_shape = query.shape[:-1] + (kv_lora_rank,) - out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device) - else: - batch_size, _, num_q_heads, _ = query.shape - check_shape_dtype_device( - out, - [batch_size, num_q_heads, kv_lora_rank], - torch.bfloat16, - query.device, - "out", - ) - - workspace_u8 = workspace_buffer.view(torch.uint8) - semaphore = workspace_u8[: 8 * 1024 * 1024] # reserve 8MB for semaphore - scratch = workspace_u8[8 * 1024 * 1024 :] - # This can not be replaced by kv_cache.transpose(1, 2) because the stride is not the same - kv_cache_new = kv_cache.squeeze(1).unsqueeze(2) - seq_lens_new = seq_lens.unsqueeze(1) - - xqa_mla( - query, - kv_cache_new, - kv_cache_new, - block_tables, - seq_lens_new, - out, - scratch, - semaphore, - block_size, - q_scale=bmm1_scale, - kv_scale=bmm2_scale, - sm_count=sm_count, - enable_pdl=enable_pdl, - ) - - return out - - def fast_decode_plan( self, indptr: torch.Tensor, diff --git a/flashinfer/mla.py b/flashinfer/mla.py index 22cf029a2e..3c4560aba9 100644 --- a/flashinfer/mla.py +++ b/flashinfer/mla.py @@ -15,14 +15,23 @@ """ import functools -from typing import Literal, Optional, Tuple, Union, overload +from typing import List, Literal, Optional, Tuple, Union, overload import torch from .api_logging import flashinfer_api -from .jit import gen_batch_mla_module +from .jit import gen_batch_mla_module, gen_trtllm_gen_fmha_module, setup_cubin_loader from .jit.mla import gen_mla_module -from .utils import MaskMode, check_shape_dtype_device, determine_mla_backend +from .utils import ( + MaskMode, + check_shape_dtype_device, + determine_mla_backend, + device_support_pdl, + get_compute_capability, + get_device_sm_count, + log2e, +) +from .xqa import xqa_mla def _check_cutlass_shape(q_nope_pe, ckv_kpe_cache, kv_len, page_table): @@ -54,6 +63,64 @@ def _check_cutlass_shape(q_nope_pe, ckv_kpe_cache, kv_len, page_table): ) +def _check_trtllm_gen_mla_shape( + query, + kv_cache, + qk_nope_head_dim, + kv_lora_rank, + qk_rope_head_dim, + sparse_mla_top_k, + page_table, + page_size, +): + if query.ndim != 4: + raise ValueError(f"Expected query.ndim == 4, got {query.ndim}") + if kv_cache.ndim != 4: + raise ValueError(f"Expected kv_cache.ndim == 4, got {kv_cache.ndim}") + if qk_nope_head_dim != 128: + raise ValueError(f"Expected qk_nope_head_dim == 128, got {qk_nope_head_dim}") + if kv_lora_rank != 512: + raise ValueError(f"Expected kv_lora_rank == 512, got {kv_lora_rank}") + if qk_rope_head_dim != 64: + raise ValueError(f"Expected qk_rope_head_dim == 64, got {qk_rope_head_dim}") + + B_q, Q_len, H, D_q = query.shape + D_ckv = kv_cache.shape[3] + # if H != 128: + # raise ValueError(f"Expected 128 heads for query, got {H}") + # todo(Yingyi): should we check num_heads == 128? Is this deepseek only? + if D_q != D_ckv or D_q != 576: + raise ValueError( + f"Expected head dim 576 for query and kv_cache, got {D_q} and {D_ckv}" + ) + + if sparse_mla_top_k > 0: + page_table_shape = page_table.shape + if page_table_shape != (B_q, Q_len, sparse_mla_top_k): + raise ValueError( + f"Expected page_table.shape == (B_q, Q_len, sparse_mla_top_k), got {page_table_shape}" + ) + else: + B_block_table, block_num = page_table.shape + block_size = page_size + if B_q != B_block_table: + raise ValueError( + f"Expected batch size {B_q} for query and block_table, got {B_q} and {B_block_table}" + ) + if block_num % (128 / block_size) != 0: + raise ValueError( + f"Expected block_num % (128 / block_size) == 0, got {block_num=} and {block_size=}" + ) + + +@functools.cache +def get_trtllm_gen_fmha_module(): + mod = gen_trtllm_gen_fmha_module() + op = mod.build_and_load() + setup_cubin_loader(mod.get_library_path()) + return op + + @functools.cache def get_mla_module(): return gen_mla_module().build_and_load() @@ -453,3 +520,285 @@ def run( ) return (out, lse) if return_lse else out + + +@flashinfer_api +def trtllm_batch_decode_with_kv_cache_mla( + query: torch.Tensor, + kv_cache: torch.Tensor, + workspace_buffer: torch.Tensor, + qk_nope_head_dim: int, + kv_lora_rank: int, + qk_rope_head_dim: int, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + max_seq_len: int, + sparse_mla_top_k: int = 0, + out: Optional[torch.Tensor] = None, + bmm1_scale: Union[float, torch.Tensor] = 1.0, + bmm2_scale: Union[float, torch.Tensor] = 1.0, + sinks: Optional[List[torch.Tensor]] = None, + enable_pdl: bool = None, + backend: str = "auto", +) -> torch.Tensor: + """ + Parameters + ---------- + query: [batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length. + kv_cache: [num_pages, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache + workspace_buffer: [num_semaphores, 4], used for multi_block mode. Must be initialized to 0 for its first use. + qk_nope_head_dim: qk_nope_head_dim, must be 128 + kv_lora_rank: kv_lora_rank, must be 512 + qk_rope_head_dim: qk_rope_head_dim, must be 64 + sparse_mla_top_k: sparse MLA top k, must be 0 for non-sparse MLA. + block_tables: page_table of kv cache, [batch_size, num_pages] + seq_lens: query_len + max_seq_len: max sequence length for kv_cache + out: output tensor, if not provided, will be allocated internally + bmm1_scale: fused scale for mla bmm1 input. + when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32. + bmm2_scale: fused scale for mla bmm2 input. + when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32. + sinks: additional value per head in the denominator of the softmax. + backend : str = "auto" + The implementation backend, could be ``auto``/``xqa`` or ``trtllm-gen``. Defaults to ``auto``. + When set to ``auto``, the backend will be chosen based on the device architecture and kernel availability. + For sm_100 and sm_103 (blackwell architecture), ``auto`` will choose ``trtllm-gen`` backend. + For sm_120 (blackwell architecture), ``auto`` will choose ``xqa`` backend. + + Note + ---- + In MLA, the actual BMM1 and BMM2 scales applied would be fused as: + bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5) + bmm2_scale = v_scale * o_scale + or, + bmm1_scale = torch.Tensor([q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5)) + bmm2_scale = torch.Tensor([v_scale * o_scale]) + + The two scale factors should be static constant for cuda graph capture. + Either (bmm1_scale, bmm2_scale) or (bmm1_scale_log2_tensor, bmm2_scale_tensor) should be provided. + + For static constant scale factors, the scale factors should be provided as float. + - (bmm1_scale, bmm2_scale) + For on-device fused scale tensors, which could dynamically change, the scale factors should be provided as torch.Tensor. + - (bmm1_scale_log2_tensor, bmm2_scale_tensor) + - Currently, only fp8 tensor core operation supports this mode. + When both are provided, the dynamic scale factor tensors will be used. + """ + if backend == "auto": + backend = ( + "trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa" + ) + if isinstance(bmm1_scale, torch.Tensor): + assert bmm1_scale.dtype == torch.float32 + bmm1_scale = bmm1_scale * log2e + if isinstance(bmm2_scale, torch.Tensor): + assert bmm2_scale.dtype == torch.float32 + if backend == "xqa": + if ( + get_compute_capability(query.device)[0] != 12 + or query.dtype != torch.float8_e4m3fn + or kv_cache.dtype != torch.float8_e4m3fn + ): + raise ValueError( + f"XQA MLA only supports fp8 operation on SM120 GPUs, got {query.dtype} and {kv_cache.dtype}" + ) + if sinks is not None: + raise ValueError("XQA MLA does not support sinks") + if query.size(1) != 1: + raise ValueError( + f"XQA MLA only supports q_len_per_request == 1, got {query.size(1)}" + ) + return xqa_batch_decode_with_kv_cache_mla( + query, + kv_cache, + workspace_buffer, + qk_nope_head_dim, + kv_lora_rank, + qk_rope_head_dim, + block_tables, + seq_lens, + max_seq_len, + out, + bmm1_scale, + bmm2_scale, + sinks, + enable_pdl, + ) + elif backend == "trtllm-gen": + enable_pdl = ( + device_support_pdl(query.device) if enable_pdl is None else enable_pdl + ) + run_func = get_trtllm_gen_fmha_module().trtllm_paged_attention_decode + sm_count = get_device_sm_count(query.device) + + block_size = kv_cache.size(-2) + if ( + block_size != 32 and block_size != 64 + ): # todo(Yingyi): add support for more block sizes? + raise ValueError(f"Supported block_size are 32 and 64, got {block_size}") + + _check_trtllm_gen_mla_shape( + query, + kv_cache, + qk_nope_head_dim, + kv_lora_rank, + qk_rope_head_dim, + sparse_mla_top_k, + block_tables, + block_size, + ) + + if out is None: + out_shape = query.shape[:-1] + (kv_lora_rank,) + out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device) + else: + batch_size, _, num_q_heads, _ = query.shape + check_shape_dtype_device( + out, + [batch_size, num_q_heads, kv_lora_rank], + torch.bfloat16, + query.device, + "out", + ) + + run_func( + out, + None, # fp4 output not supported in wrapper api yet. + query, + kv_cache, + kv_cache, + workspace_buffer, + block_tables, + seq_lens, + max_seq_len, + bmm1_scale, + bmm2_scale, + -1, # o_sf_scale + -1, # o_sf_vec_size + 0, # o_sf_start_index + -1, # window_left + sparse_mla_top_k, + sm_count, + enable_pdl, + workspace_buffer.numel() * workspace_buffer.element_size(), + sinks, + ) + + return out + else: + raise ValueError(f"Backend {backend} not supported") + + +@flashinfer_api +def xqa_batch_decode_with_kv_cache_mla( + query: torch.Tensor, + kv_cache: torch.Tensor, + workspace_buffer: torch.Tensor, + qk_nope_head_dim: int, + kv_lora_rank: int, + qk_rope_head_dim: int, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + max_seq_len: int, + out: Optional[torch.Tensor] = None, + bmm1_scale: Union[float, torch.Tensor] = 1.0, + bmm2_scale: Union[float, torch.Tensor] = 1.0, + sinks: Optional[List[torch.Tensor]] = None, + enable_pdl: bool = None, +) -> torch.Tensor: + """ + Parameters: + query: [batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length. + kv_cache: [num_pages, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache + workspace_buffer: torch.Tensor. Must be initialized to 0 for its first use. + qk_nope_head_dim: qk_nope_head_dim, must be 128 + kv_lora_rank: kv_lora_rank, must be 512 + qk_rope_head_dim: qk_rope_head_dim, must be 64 + block_tables: page_table of kv cache, [batch_size, num_pages] + seq_lens: query_len + max_seq_len: max sequence length for kv_cache + out: output tensor, if not provided, will be allocated internally + bmm1_scale: fused scale for mla bmm1 input. Can be a float or a torch.Tensor. + bmm2_scale: fused scale for mla bmm2 input. Can be a float or a torch.Tensor. + sinks: additional value per head in the denominator of the softmax. + + Note: + In MLA, the actual BMM1 and BMM2 scales applied would be fused as: + bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5) + bmm2_scale = v_scale * o_scale + + The two scale factors should be static constant for cuda graph capture. + Either (bmm1_scale, bmm2_scale) or (bmm1_scale_log2_tensor, bmm2_scale_tensor) should be provided. + + For static constant scale factors, the scale factors should be provided as float. + - (bmm1_scale, bmm2_scale) + For on-device fused scale tensors, which could dynamically change, the scale factors should be provided as torch.Tensor. + - (bmm1_scale_log2_tensor, bmm2_scale_tensor) + - Currently, only fp8 tensor core operation supports this mode. + When both are provided, the dynamic scale factor tensors will be used. + """ + enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl + sm_count = get_device_sm_count(query.device) + + block_size = kv_cache.size(-2) + q_len_per_request = query.size(1) + if q_len_per_request != 1: + raise ValueError( + f"XQA MLA only supports q_len_per_request == 1, got {q_len_per_request}" + ) + if query.dtype != torch.float8_e4m3fn or kv_cache.dtype != torch.float8_e4m3fn: + raise ValueError( + f"XQA MLA only supports fp8 tensor core operation, got {query.dtype} and {kv_cache.dtype}" + ) + if sinks is not None: + raise ValueError("XQA MLA does not support sinks") + + _check_trtllm_gen_mla_shape( + query, + kv_cache, + qk_nope_head_dim, + kv_lora_rank, + qk_rope_head_dim, + 0, # sparse_mla_top_k + block_tables, + block_size, + ) + + if out is None: + out_shape = query.shape[:-1] + (kv_lora_rank,) + out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device) + else: + batch_size, _, num_q_heads, _ = query.shape + check_shape_dtype_device( + out, + [batch_size, num_q_heads, kv_lora_rank], + torch.bfloat16, + query.device, + "out", + ) + + workspace_u8 = workspace_buffer.view(torch.uint8) + semaphore = workspace_u8[: 8 * 1024 * 1024] # reserve 8MB for semaphore + scratch = workspace_u8[8 * 1024 * 1024 :] + # This can not be replaced by kv_cache.transpose(1, 2) because the stride is not the same + kv_cache_new = kv_cache.squeeze(1).unsqueeze(2) + seq_lens_new = seq_lens.unsqueeze(1) + + xqa_mla( + query, + kv_cache_new, + kv_cache_new, + block_tables, + seq_lens_new, + out, + scratch, + semaphore, + block_size, + q_scale=bmm1_scale, + kv_scale=bmm2_scale, + sm_count=sm_count, + enable_pdl=enable_pdl, + ) + + return out