-
Notifications
You must be signed in to change notification settings - Fork 589
refactor: Move mla code from decode.py to mla.py and add to documentation #2163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
bkryu
wants to merge
1
commit into
flashinfer-ai:main
Choose a base branch
from
bkryu:mla_refactor
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+366
β335
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -22,7 +22,13 @@ | |||||||||||||||||||||
| import torch | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| from .api_logging import flashinfer_api | ||||||||||||||||||||||
| from .xqa import xqa, xqa_mla | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| ## NOTE: MLA functions have been moved to mla.py, but we keep the aliases here for backward compatibility. | ||||||||||||||||||||||
| from .mla import ( | ||||||||||||||||||||||
| trtllm_batch_decode_with_kv_cache_mla as trtllm_batch_decode_with_kv_cache_mla, | ||||||||||||||||||||||
| xqa_batch_decode_with_kv_cache_mla as xqa_batch_decode_with_kv_cache_mla, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| from .xqa import xqa, xqa_mla as xqa_mla | ||||||||||||||||||||||
|
Comment on lines
+27
to
+31
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The aliases for the imported functions are redundant. You can simplify these imports for better readability.
Suggested change
|
||||||||||||||||||||||
| from .cudnn import cudnn_batch_decode_with_kv_cache as cudnn_batch_decode_with_kv_cache | ||||||||||||||||||||||
bkryu marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||
| from .jit import ( | ||||||||||||||||||||||
| gen_batch_decode_mla_module, | ||||||||||||||||||||||
|
|
@@ -2499,336 +2505,6 @@ def xqa_batch_decode_with_kv_cache( | |||||||||||||||||||||
| return out | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def _check_trtllm_gen_mla_shape( | ||||||||||||||||||||||
| query, | ||||||||||||||||||||||
| kv_cache, | ||||||||||||||||||||||
| qk_nope_head_dim, | ||||||||||||||||||||||
| kv_lora_rank, | ||||||||||||||||||||||
| qk_rope_head_dim, | ||||||||||||||||||||||
| sparse_mla_top_k, | ||||||||||||||||||||||
| page_table, | ||||||||||||||||||||||
| page_size, | ||||||||||||||||||||||
| ): | ||||||||||||||||||||||
| if query.ndim != 4: | ||||||||||||||||||||||
| raise ValueError(f"Expected query.ndim == 4, got {query.ndim}") | ||||||||||||||||||||||
| if kv_cache.ndim != 4: | ||||||||||||||||||||||
| raise ValueError(f"Expected kv_cache.ndim == 4, got {kv_cache.ndim}") | ||||||||||||||||||||||
| if qk_nope_head_dim != 128: | ||||||||||||||||||||||
| raise ValueError(f"Expected qk_nope_head_dim == 128, got {qk_nope_head_dim}") | ||||||||||||||||||||||
| if kv_lora_rank != 512: | ||||||||||||||||||||||
| raise ValueError(f"Expected kv_lora_rank == 512, got {kv_lora_rank}") | ||||||||||||||||||||||
| if qk_rope_head_dim != 64: | ||||||||||||||||||||||
| raise ValueError(f"Expected qk_rope_head_dim == 64, got {qk_rope_head_dim}") | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| B_q, Q_len, H, D_q = query.shape | ||||||||||||||||||||||
| D_ckv = kv_cache.shape[3] | ||||||||||||||||||||||
| # if H != 128: | ||||||||||||||||||||||
| # raise ValueError(f"Expected 128 heads for query, got {H}") | ||||||||||||||||||||||
| # todo(Yingyi): should we check num_heads == 128? Is this deepseek only? | ||||||||||||||||||||||
| if D_q != D_ckv or D_q != 576: | ||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||
| f"Expected head dim 576 for query and kv_cache, got {D_q} and {D_ckv}" | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if sparse_mla_top_k > 0: | ||||||||||||||||||||||
| page_table_shape = page_table.shape | ||||||||||||||||||||||
| if page_table_shape != (B_q, Q_len, sparse_mla_top_k): | ||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||
| f"Expected page_table.shape == (B_q, Q_len, sparse_mla_top_k), got {page_table_shape}" | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| else: | ||||||||||||||||||||||
| B_block_table, block_num = page_table.shape | ||||||||||||||||||||||
| block_size = page_size | ||||||||||||||||||||||
| if B_q != B_block_table: | ||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||
| f"Expected batch size {B_q} for query and block_table, got {B_q} and {B_block_table}" | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| if block_num % (128 / block_size) != 0: | ||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||
| f"Expected block_num % (128 / block_size) == 0, got {block_num=} and {block_size=}" | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| @flashinfer_api | ||||||||||||||||||||||
| def trtllm_batch_decode_with_kv_cache_mla( | ||||||||||||||||||||||
| query: torch.Tensor, | ||||||||||||||||||||||
| kv_cache: torch.Tensor, | ||||||||||||||||||||||
| workspace_buffer: torch.Tensor, | ||||||||||||||||||||||
| qk_nope_head_dim: int, | ||||||||||||||||||||||
| kv_lora_rank: int, | ||||||||||||||||||||||
| qk_rope_head_dim: int, | ||||||||||||||||||||||
| block_tables: torch.Tensor, | ||||||||||||||||||||||
| seq_lens: torch.Tensor, | ||||||||||||||||||||||
| max_seq_len: int, | ||||||||||||||||||||||
| sparse_mla_top_k: int = 0, | ||||||||||||||||||||||
| out: Optional[torch.Tensor] = None, | ||||||||||||||||||||||
| bmm1_scale: Union[float, torch.Tensor] = 1.0, | ||||||||||||||||||||||
| bmm2_scale: Union[float, torch.Tensor] = 1.0, | ||||||||||||||||||||||
| sinks: Optional[List[torch.Tensor]] = None, | ||||||||||||||||||||||
| enable_pdl: bool = None, | ||||||||||||||||||||||
| backend: str = "auto", | ||||||||||||||||||||||
| ) -> torch.Tensor: | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| Parameters: | ||||||||||||||||||||||
| query: [batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length. | ||||||||||||||||||||||
| kv_cache: [num_pages, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache | ||||||||||||||||||||||
| workspace_buffer: [num_semaphores, 4], used for multi_block mode. Must be initialized to 0 for its first use. | ||||||||||||||||||||||
| qk_nope_head_dim: qk_nope_head_dim, must be 128 | ||||||||||||||||||||||
| kv_lora_rank: kv_lora_rank, must be 512 | ||||||||||||||||||||||
| qk_rope_head_dim: qk_rope_head_dim, must be 64 | ||||||||||||||||||||||
| sparse_mla_top_k: sparse MLA top k, must be 0 for non-sparse MLA. | ||||||||||||||||||||||
| block_tables: page_table of kv cache, [batch_size, num_pages] | ||||||||||||||||||||||
| seq_lens: query_len | ||||||||||||||||||||||
| max_seq_len: max sequence length for kv_cache | ||||||||||||||||||||||
| out: output tensor, if not provided, will be allocated internally | ||||||||||||||||||||||
| bmm1_scale: fused scale for mla bmm1 input. | ||||||||||||||||||||||
| when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32. | ||||||||||||||||||||||
| bmm2_scale: fused scale for mla bmm2 input. | ||||||||||||||||||||||
| when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32. | ||||||||||||||||||||||
| sinks: additional value per head in the denominator of the softmax. | ||||||||||||||||||||||
| backend : str = "auto" | ||||||||||||||||||||||
| The implementation backend, could be ``auto``/``xqa`` or ``trtllm-gen``. Defaults to ``auto``. | ||||||||||||||||||||||
| When set to ``auto``, the backend will be chosen based on the device architecture and kernel availability. | ||||||||||||||||||||||
| For sm_100 and sm_103 (blackwell architecture), ``auto`` will choose ``trtllm-gen`` backend. | ||||||||||||||||||||||
| For sm_120 (blackwell architecture), ``auto`` will choose ``xqa`` backend. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Note: | ||||||||||||||||||||||
| In MLA, the actual BMM1 and BMM2 scales applied would be fused as: | ||||||||||||||||||||||
| bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5) | ||||||||||||||||||||||
| bmm2_scale = v_scale * o_scale | ||||||||||||||||||||||
| or, | ||||||||||||||||||||||
| bmm1_scale = torch.Tensor([q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5)) | ||||||||||||||||||||||
| bmm2_scale = torch.Tensor([v_scale * o_scale]) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| The two scale factors should be static constant for cuda graph capture. | ||||||||||||||||||||||
| Either (bmm1_scale, bmm2_scale) or (bmm1_scale_log2_tensor, bmm2_scale_tensor) should be provided. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| For static constant scale factors, the scale factors should be provided as float. | ||||||||||||||||||||||
| - (bmm1_scale, bmm2_scale) | ||||||||||||||||||||||
| For on-device fused scale tensors, which could dynamically change, the scale factors should be provided as torch.Tensor. | ||||||||||||||||||||||
| - (bmm1_scale_log2_tensor, bmm2_scale_tensor) | ||||||||||||||||||||||
| - Currently, only fp8 tensor core operation supports this mode. | ||||||||||||||||||||||
| When both are provided, the dynamic scale factor tensors will be used. | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| if backend == "auto": | ||||||||||||||||||||||
| backend = ( | ||||||||||||||||||||||
| "trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa" | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| if isinstance(bmm1_scale, torch.Tensor): | ||||||||||||||||||||||
| assert bmm1_scale.dtype == torch.float32 | ||||||||||||||||||||||
| bmm1_scale = bmm1_scale * log2e | ||||||||||||||||||||||
| if isinstance(bmm2_scale, torch.Tensor): | ||||||||||||||||||||||
| assert bmm2_scale.dtype == torch.float32 | ||||||||||||||||||||||
| if backend == "xqa": | ||||||||||||||||||||||
| if ( | ||||||||||||||||||||||
| get_compute_capability(query.device)[0] != 12 | ||||||||||||||||||||||
| or query.dtype != torch.float8_e4m3fn | ||||||||||||||||||||||
| or kv_cache.dtype != torch.float8_e4m3fn | ||||||||||||||||||||||
| ): | ||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||
| f"XQA MLA only supports fp8 operation on SM120 GPUs, got {query.dtype} and {kv_cache.dtype}" | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| if sinks is not None: | ||||||||||||||||||||||
| raise ValueError("XQA MLA does not support sinks") | ||||||||||||||||||||||
| if query.size(1) != 1: | ||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||
| f"XQA MLA only supports q_len_per_request == 1, got {query.size(1)}" | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| return xqa_batch_decode_with_kv_cache_mla( | ||||||||||||||||||||||
| query, | ||||||||||||||||||||||
| kv_cache, | ||||||||||||||||||||||
| workspace_buffer, | ||||||||||||||||||||||
| qk_nope_head_dim, | ||||||||||||||||||||||
| kv_lora_rank, | ||||||||||||||||||||||
| qk_rope_head_dim, | ||||||||||||||||||||||
| block_tables, | ||||||||||||||||||||||
| seq_lens, | ||||||||||||||||||||||
| max_seq_len, | ||||||||||||||||||||||
| out, | ||||||||||||||||||||||
| bmm1_scale, | ||||||||||||||||||||||
| bmm2_scale, | ||||||||||||||||||||||
| sinks, | ||||||||||||||||||||||
| enable_pdl, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| elif backend == "trtllm-gen": | ||||||||||||||||||||||
| enable_pdl = ( | ||||||||||||||||||||||
| device_support_pdl(query.device) if enable_pdl is None else enable_pdl | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| run_func = get_trtllm_gen_fmha_module().trtllm_paged_attention_decode | ||||||||||||||||||||||
| sm_count = get_device_sm_count(query.device) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| block_size = kv_cache.size(-2) | ||||||||||||||||||||||
| if ( | ||||||||||||||||||||||
| block_size != 32 and block_size != 64 | ||||||||||||||||||||||
| ): # todo(Yingyi): add support for more block sizes? | ||||||||||||||||||||||
| raise ValueError(f"Supported block_size are 32 and 64, got {block_size}") | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| _check_trtllm_gen_mla_shape( | ||||||||||||||||||||||
| query, | ||||||||||||||||||||||
| kv_cache, | ||||||||||||||||||||||
| qk_nope_head_dim, | ||||||||||||||||||||||
| kv_lora_rank, | ||||||||||||||||||||||
| qk_rope_head_dim, | ||||||||||||||||||||||
| sparse_mla_top_k, | ||||||||||||||||||||||
| block_tables, | ||||||||||||||||||||||
| block_size, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if out is None: | ||||||||||||||||||||||
| out_shape = query.shape[:-1] + (kv_lora_rank,) | ||||||||||||||||||||||
| out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device) | ||||||||||||||||||||||
| else: | ||||||||||||||||||||||
| batch_size, _, num_q_heads, _ = query.shape | ||||||||||||||||||||||
| check_shape_dtype_device( | ||||||||||||||||||||||
| out, | ||||||||||||||||||||||
| [batch_size, num_q_heads, kv_lora_rank], | ||||||||||||||||||||||
| torch.bfloat16, | ||||||||||||||||||||||
| query.device, | ||||||||||||||||||||||
| "out", | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| run_func( | ||||||||||||||||||||||
| out, | ||||||||||||||||||||||
| None, # fp4 output not supported in wrapper api yet. | ||||||||||||||||||||||
| query, | ||||||||||||||||||||||
| kv_cache, | ||||||||||||||||||||||
| kv_cache, | ||||||||||||||||||||||
| workspace_buffer, | ||||||||||||||||||||||
| block_tables, | ||||||||||||||||||||||
| seq_lens, | ||||||||||||||||||||||
| max_seq_len, | ||||||||||||||||||||||
| bmm1_scale, | ||||||||||||||||||||||
| bmm2_scale, | ||||||||||||||||||||||
| -1, # o_sf_scale | ||||||||||||||||||||||
| -1, # o_sf_vec_size | ||||||||||||||||||||||
| 0, # o_sf_start_index | ||||||||||||||||||||||
| -1, # window_left | ||||||||||||||||||||||
| sparse_mla_top_k, | ||||||||||||||||||||||
| sm_count, | ||||||||||||||||||||||
| enable_pdl, | ||||||||||||||||||||||
| workspace_buffer.numel() * workspace_buffer.element_size(), | ||||||||||||||||||||||
| sinks, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| return out | ||||||||||||||||||||||
| else: | ||||||||||||||||||||||
| raise ValueError(f"Backend {backend} not supported") | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| @flashinfer_api | ||||||||||||||||||||||
| def xqa_batch_decode_with_kv_cache_mla( | ||||||||||||||||||||||
| query: torch.Tensor, | ||||||||||||||||||||||
| kv_cache: torch.Tensor, | ||||||||||||||||||||||
| workspace_buffer: torch.Tensor, | ||||||||||||||||||||||
| qk_nope_head_dim: int, | ||||||||||||||||||||||
| kv_lora_rank: int, | ||||||||||||||||||||||
| qk_rope_head_dim: int, | ||||||||||||||||||||||
| block_tables: torch.Tensor, | ||||||||||||||||||||||
| seq_lens: torch.Tensor, | ||||||||||||||||||||||
| max_seq_len: int, | ||||||||||||||||||||||
| out: Optional[torch.Tensor] = None, | ||||||||||||||||||||||
| bmm1_scale: Union[float, torch.Tensor] = 1.0, | ||||||||||||||||||||||
| bmm2_scale: Union[float, torch.Tensor] = 1.0, | ||||||||||||||||||||||
| sinks: Optional[List[torch.Tensor]] = None, | ||||||||||||||||||||||
| enable_pdl: bool = None, | ||||||||||||||||||||||
| ) -> torch.Tensor: | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| Parameters: | ||||||||||||||||||||||
| query: [batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length. | ||||||||||||||||||||||
| kv_cache: [num_pages, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache | ||||||||||||||||||||||
| workspace_buffer: torch.Tensor. Must be initialized to 0 for its first use. | ||||||||||||||||||||||
| qk_nope_head_dim: qk_nope_head_dim, must be 128 | ||||||||||||||||||||||
| kv_lora_rank: kv_lora_rank, must be 512 | ||||||||||||||||||||||
| qk_rope_head_dim: qk_rope_head_dim, must be 64 | ||||||||||||||||||||||
| block_tables: page_table of kv cache, [batch_size, num_pages] | ||||||||||||||||||||||
| seq_lens: query_len | ||||||||||||||||||||||
| max_seq_len: max sequence length for kv_cache | ||||||||||||||||||||||
| out: output tensor, if not provided, will be allocated internally | ||||||||||||||||||||||
| bmm1_scale: fused scale for mla bmm1 input. Can be a float or a torch.Tensor. | ||||||||||||||||||||||
| bmm2_scale: fused scale for mla bmm2 input. Can be a float or a torch.Tensor. | ||||||||||||||||||||||
| sinks: additional value per head in the denominator of the softmax. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Note: | ||||||||||||||||||||||
| In MLA, the actual BMM1 and BMM2 scales applied would be fused as: | ||||||||||||||||||||||
| bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5) | ||||||||||||||||||||||
| bmm2_scale = v_scale * o_scale | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| The two scale factors should be static constant for cuda graph capture. | ||||||||||||||||||||||
| Either (bmm1_scale, bmm2_scale) or (bmm1_scale_log2_tensor, bmm2_scale_tensor) should be provided. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| For static constant scale factors, the scale factors should be provided as float. | ||||||||||||||||||||||
| - (bmm1_scale, bmm2_scale) | ||||||||||||||||||||||
| For on-device fused scale tensors, which could dynamically change, the scale factors should be provided as torch.Tensor. | ||||||||||||||||||||||
| - (bmm1_scale_log2_tensor, bmm2_scale_tensor) | ||||||||||||||||||||||
| - Currently, only fp8 tensor core operation supports this mode. | ||||||||||||||||||||||
| When both are provided, the dynamic scale factor tensors will be used. | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl | ||||||||||||||||||||||
| sm_count = get_device_sm_count(query.device) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| block_size = kv_cache.size(-2) | ||||||||||||||||||||||
| q_len_per_request = query.size(1) | ||||||||||||||||||||||
| if q_len_per_request != 1: | ||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||
| f"XQA MLA only supports q_len_per_request == 1, got {q_len_per_request}" | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| if query.dtype != torch.float8_e4m3fn or kv_cache.dtype != torch.float8_e4m3fn: | ||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||
| f"XQA MLA only supports fp8 tensor core operation, got {query.dtype} and {kv_cache.dtype}" | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| if sinks is not None: | ||||||||||||||||||||||
| raise ValueError("XQA MLA does not support sinks") | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| _check_trtllm_gen_mla_shape( | ||||||||||||||||||||||
| query, | ||||||||||||||||||||||
| kv_cache, | ||||||||||||||||||||||
| qk_nope_head_dim, | ||||||||||||||||||||||
| kv_lora_rank, | ||||||||||||||||||||||
| qk_rope_head_dim, | ||||||||||||||||||||||
| 0, # sparse_mla_top_k | ||||||||||||||||||||||
| block_tables, | ||||||||||||||||||||||
| block_size, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if out is None: | ||||||||||||||||||||||
| out_shape = query.shape[:-1] + (kv_lora_rank,) | ||||||||||||||||||||||
| out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device) | ||||||||||||||||||||||
| else: | ||||||||||||||||||||||
| batch_size, _, num_q_heads, _ = query.shape | ||||||||||||||||||||||
| check_shape_dtype_device( | ||||||||||||||||||||||
| out, | ||||||||||||||||||||||
| [batch_size, num_q_heads, kv_lora_rank], | ||||||||||||||||||||||
| torch.bfloat16, | ||||||||||||||||||||||
| query.device, | ||||||||||||||||||||||
| "out", | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| workspace_u8 = workspace_buffer.view(torch.uint8) | ||||||||||||||||||||||
| semaphore = workspace_u8[: 8 * 1024 * 1024] # reserve 8MB for semaphore | ||||||||||||||||||||||
| scratch = workspace_u8[8 * 1024 * 1024 :] | ||||||||||||||||||||||
| # This can not be replaced by kv_cache.transpose(1, 2) because the stride is not the same | ||||||||||||||||||||||
| kv_cache_new = kv_cache.squeeze(1).unsqueeze(2) | ||||||||||||||||||||||
| seq_lens_new = seq_lens.unsqueeze(1) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| xqa_mla( | ||||||||||||||||||||||
| query, | ||||||||||||||||||||||
| kv_cache_new, | ||||||||||||||||||||||
| kv_cache_new, | ||||||||||||||||||||||
| block_tables, | ||||||||||||||||||||||
| seq_lens_new, | ||||||||||||||||||||||
| out, | ||||||||||||||||||||||
| scratch, | ||||||||||||||||||||||
| semaphore, | ||||||||||||||||||||||
| block_size, | ||||||||||||||||||||||
| q_scale=bmm1_scale, | ||||||||||||||||||||||
| kv_scale=bmm2_scale, | ||||||||||||||||||||||
| sm_count=sm_count, | ||||||||||||||||||||||
| enable_pdl=enable_pdl, | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| return out | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def fast_decode_plan( | ||||||||||||||||||||||
| self, | ||||||||||||||||||||||
| indptr: torch.Tensor, | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
asclauses in these imports are redundant. You can simplify them for better readability.