-
Notifications
You must be signed in to change notification settings - Fork 589
refactor: Move mla code from decode.py to mla.py and add to documentation #2163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
/bot run |
WalkthroughMLA decode functions are refactored from Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @bkryu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request focuses on improving the code organization by refactoring MLA-related decoding functionalities into their own module. This change not only enhances modularity but also resolves an issue with documentation generation, ensuring that these important functions are correctly exposed in the API reference. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
Summary of ChangesHello @bkryu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request refactors the project's structure by migrating Multi-Layer Attention (MLA) specific functions from the general Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
Summary of ChangesHello @bkryu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request undertakes a significant refactoring effort by relocating Multi-Layer Attention (MLA) functions into their own dedicated module. This change aims to improve the logical organization of the codebase and is essential for correctly integrating these functions into the project's documentation. The refactoring ensures that the project's structure is more modular and easier to navigate, while also preserving existing functionality through compatibility aliases. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the codebase by moving MLA-related functions (trtllm_batch_decode_with_kv_cache_mla and xqa_batch_decode_with_kv_cache_mla) from decode.py to a new mla.py module. This is a good change for modularity and code organization. The PR also updates the documentation to include these functions, which is a necessary follow-up. The refactoring appears to be done correctly, including maintaining backward compatibility. I have a couple of minor suggestions to improve code style and consistency.
| Parameters: | ||
| query: [batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length. | ||
| kv_cache: [num_pages, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache | ||
| workspace_buffer: torch.Tensor. Must be initialized to 0 for its first use. | ||
| qk_nope_head_dim: qk_nope_head_dim, must be 128 | ||
| kv_lora_rank: kv_lora_rank, must be 512 | ||
| qk_rope_head_dim: qk_rope_head_dim, must be 64 | ||
| block_tables: page_table of kv cache, [batch_size, num_pages] | ||
| seq_lens: query_len | ||
| max_seq_len: max sequence length for kv_cache | ||
| out: output tensor, if not provided, will be allocated internally | ||
| bmm1_scale: fused scale for mla bmm1 input. Can be a float or a torch.Tensor. | ||
| bmm2_scale: fused scale for mla bmm2 input. Can be a float or a torch.Tensor. | ||
| sinks: additional value per head in the denominator of the softmax. | ||
| Note: | ||
| In MLA, the actual BMM1 and BMM2 scales applied would be fused as: | ||
| bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5) | ||
| bmm2_scale = v_scale * o_scale | ||
| The two scale factors should be static constant for cuda graph capture. | ||
| Either (bmm1_scale, bmm2_scale) or (bmm1_scale_log2_tensor, bmm2_scale_tensor) should be provided. | ||
| For static constant scale factors, the scale factors should be provided as float. | ||
| - (bmm1_scale, bmm2_scale) | ||
| For on-device fused scale tensors, which could dynamically change, the scale factors should be provided as torch.Tensor. | ||
| - (bmm1_scale_log2_tensor, bmm2_scale_tensor) | ||
| - Currently, only fp8 tensor core operation supports this mode. | ||
| When both are provided, the dynamic scale factor tensors will be used. | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency with trtllm_batch_decode_with_kv_cache_mla and other functions in this file, please update the docstring for xqa_batch_decode_with_kv_cache_mla to use the NumPy docstring format. This improves readability and maintainability, following the spirit of PEP 257 for docstring conventions.
"""
Parameters
----------
query: [batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length.
kv_cache: [num_pages, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache
workspace_buffer: torch.Tensor. Must be initialized to 0 for its first use.
qk_nope_head_dim: qk_nope_head_dim, must be 128
kv_lora_rank: kv_lora_rank, must be 512
qk_rope_head_dim: qk_rope_head_dim, must be 64
block_tables: page_table of kv cache, [batch_size, num_pages]
seq_lens: query_len
max_seq_len: max sequence length for kv_cache
out: output tensor, if not provided, will be allocated internally
bmm1_scale: fused scale for mla bmm1 input. Can be a float or a torch.Tensor.
bmm2_scale: fused scale for mla bmm2 input. Can be a float or a torch.Tensor.
sinks: additional value per head in the denominator of the softmax.
Note
----
In MLA, the actual BMM1 and BMM2 scales applied would be fused as:
bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5)
bmm2_scale = v_scale * o_scale
The two scale factors should be static constant for cuda graph capture.
Either (bmm1_scale, bmm2_scale) or (bmm1_scale_log2_tensor, bmm2_scale_tensor) should be provided.
For static constant scale factors, the scale factors should be provided as float.
- (bmm1_scale, bmm2_scale)
For on-device fused scale tensors, which could dynamically change, the scale factors should be provided as torch.Tensor.
- (bmm1_scale_log2_tensor, bmm2_scale_tensor)
- Currently, only fp8 tensor core operation supports this mode.
When both are provided, the dynamic scale factor tensors will be used.
"""There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request is a nice refactoring that moves MLA-related code from decode.py to a new mla.py file, improving code organization. The documentation has also been updated accordingly. I have a few minor suggestions to improve code style and documentation consistency.
| from .mla import ( | ||
| trtllm_batch_decode_with_kv_cache_mla as trtllm_batch_decode_with_kv_cache_mla, | ||
| xqa_batch_decode_with_kv_cache_mla as xqa_batch_decode_with_kv_cache_mla, | ||
| ) | ||
| from .xqa import xqa, xqa_mla as xqa_mla |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The as clauses in these imports are redundant. You can simplify them for better readability.
| from .mla import ( | |
| trtllm_batch_decode_with_kv_cache_mla as trtllm_batch_decode_with_kv_cache_mla, | |
| xqa_batch_decode_with_kv_cache_mla as xqa_batch_decode_with_kv_cache_mla, | |
| ) | |
| from .xqa import xqa, xqa_mla as xqa_mla | |
| from .mla import ( | |
| trtllm_batch_decode_with_kv_cache_mla, | |
| xqa_batch_decode_with_kv_cache_mla, | |
| ) | |
| from .xqa import xqa, xqa_mla |
| bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5) | ||
| bmm2_scale = v_scale * o_scale | ||
| or, | ||
| bmm1_scale = torch.Tensor([q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| """ | ||
| Parameters: | ||
| query: [batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length. | ||
| kv_cache: [num_pages, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache | ||
| workspace_buffer: torch.Tensor. Must be initialized to 0 for its first use. | ||
| qk_nope_head_dim: qk_nope_head_dim, must be 128 | ||
| kv_lora_rank: kv_lora_rank, must be 512 | ||
| qk_rope_head_dim: qk_rope_head_dim, must be 64 | ||
| block_tables: page_table of kv cache, [batch_size, num_pages] | ||
| seq_lens: query_len | ||
| max_seq_len: max sequence length for kv_cache | ||
| out: output tensor, if not provided, will be allocated internally | ||
| bmm1_scale: fused scale for mla bmm1 input. Can be a float or a torch.Tensor. | ||
| bmm2_scale: fused scale for mla bmm2 input. Can be a float or a torch.Tensor. | ||
| sinks: additional value per head in the denominator of the softmax. | ||
| Note: | ||
| In MLA, the actual BMM1 and BMM2 scales applied would be fused as: | ||
| bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5) | ||
| bmm2_scale = v_scale * o_scale | ||
| The two scale factors should be static constant for cuda graph capture. | ||
| Either (bmm1_scale, bmm2_scale) or (bmm1_scale_log2_tensor, bmm2_scale_tensor) should be provided. | ||
| For static constant scale factors, the scale factors should be provided as float. | ||
| - (bmm1_scale, bmm2_scale) | ||
| For on-device fused scale tensors, which could dynamically change, the scale factors should be provided as torch.Tensor. | ||
| - (bmm1_scale_log2_tensor, bmm2_scale_tensor) | ||
| - Currently, only fp8 tensor core operation supports this mode. | ||
| When both are provided, the dynamic scale factor tensors will be used. | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring for this function doesn't follow the same numpy-style format as trtllm_batch_decode_with_kv_cache_mla in this file. For consistency, it would be great to update it.
"""
Parameters
----------
query: torch.Tensor
[batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length.
kv_cache: torch.Tensor
[num_pages, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache
workspace_buffer: torch.Tensor
Must be initialized to 0 for its first use.
qk_nope_head_dim: int
qk_nope_head_dim, must be 128
kv_lora_rank: int
kv_lora_rank, must be 512
qk_rope_head_dim: int
qk_rope_head_dim, must be 64
block_tables: torch.Tensor
page_table of kv cache, [batch_size, num_pages]
seq_lens: torch.Tensor
query_len
max_seq_len: int
max sequence length for kv_cache
out: Optional[torch.Tensor]
output tensor, if not provided, will be allocated internally
bmm1_scale: Union[float, torch.Tensor]
fused scale for mla bmm1 input. Can be a float or a torch.Tensor.
bmm2_scale: Union[float, torch.Tensor]
fused scale for mla bmm2 input. Can be a float or a torch.Tensor.
sinks: Optional[List[torch.Tensor]]
additional value per head in the denominator of the softmax.
Note
----
In MLA, the actual BMM1 and BMM2 scales applied would be fused as:
bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5)
bmm2_scale = v_scale * o_scale
The two scale factors should be static constant for cuda graph capture.
Either (bmm1_scale, bmm2_scale) or (bmm1_scale_log2_tensor, bmm2_scale_tensor) should be provided.
For static constant scale factors, the scale factors should be provided as float.
- (bmm1_scale, bmm2_scale)
For on-device fused scale tensors, which could dynamically change, the scale factors should be provided as torch.Tensor.
- (bmm1_scale_log2_tensor, bmm2_scale_tensor)
- Currently, only fp8 tensor core operation supports this mode.
When both are provided, the dynamic scale factor tensors will be used.
"""There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request is a good refactoring that moves MLA-related code into its own mla.py module and updates the documentation accordingly. This improves code organization. I've identified a potential issue where a parameter could be silently ignored, and have also suggested some improvements for documentation and code style consistency.
| bmm1_scale = bmm1_scale * log2e | ||
| if isinstance(bmm2_scale, torch.Tensor): | ||
| assert bmm2_scale.dtype == torch.float32 | ||
| if backend == "xqa": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When backend is 'xqa', the function calls xqa_batch_decode_with_kv_cache_mla, which does not support sparse attention (sparse_mla_top_k > 0). However, there is no check to prevent this. If a user calls this function with backend='xqa' and sparse_mla_top_k > 0, the sparse parameter will be silently ignored. You should add a check to raise an error in this case.
if backend == "xqa":
if sparse_mla_top_k > 0:
raise ValueError("XQA backend does not support sparse MLA attention.")| from .mla import ( | ||
| trtllm_batch_decode_with_kv_cache_mla as trtllm_batch_decode_with_kv_cache_mla, | ||
| xqa_batch_decode_with_kv_cache_mla as xqa_batch_decode_with_kv_cache_mla, | ||
| ) | ||
| from .xqa import xqa, xqa_mla as xqa_mla |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The aliases for the imported functions are redundant. You can simplify these imports for better readability.
| from .mla import ( | |
| trtllm_batch_decode_with_kv_cache_mla as trtllm_batch_decode_with_kv_cache_mla, | |
| xqa_batch_decode_with_kv_cache_mla as xqa_batch_decode_with_kv_cache_mla, | |
| ) | |
| from .xqa import xqa, xqa_mla as xqa_mla | |
| from .mla import ( | |
| trtllm_batch_decode_with_kv_cache_mla, | |
| xqa_batch_decode_with_kv_cache_mla, | |
| ) | |
| from .xqa import xqa, xqa_mla |
| """ | ||
| Parameters | ||
| ---------- | ||
| query: [batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length. | ||
| kv_cache: [num_pages, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache | ||
| workspace_buffer: [num_semaphores, 4], used for multi_block mode. Must be initialized to 0 for its first use. | ||
| qk_nope_head_dim: qk_nope_head_dim, must be 128 | ||
| kv_lora_rank: kv_lora_rank, must be 512 | ||
| qk_rope_head_dim: qk_rope_head_dim, must be 64 | ||
| sparse_mla_top_k: sparse MLA top k, must be 0 for non-sparse MLA. | ||
| block_tables: page_table of kv cache, [batch_size, num_pages] | ||
| seq_lens: query_len | ||
| max_seq_len: max sequence length for kv_cache | ||
| out: output tensor, if not provided, will be allocated internally | ||
| bmm1_scale: fused scale for mla bmm1 input. | ||
| when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32. | ||
| bmm2_scale: fused scale for mla bmm2 input. | ||
| when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32. | ||
| sinks: additional value per head in the denominator of the softmax. | ||
| backend : str = "auto" | ||
| The implementation backend, could be ``auto``/``xqa`` or ``trtllm-gen``. Defaults to ``auto``. | ||
| When set to ``auto``, the backend will be chosen based on the device architecture and kernel availability. | ||
| For sm_100 and sm_103 (blackwell architecture), ``auto`` will choose ``trtllm-gen`` backend. | ||
| For sm_120 (blackwell architecture), ``auto`` will choose ``xqa`` backend. | ||
| Note | ||
| ---- | ||
| In MLA, the actual BMM1 and BMM2 scales applied would be fused as: | ||
| bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5) | ||
| bmm2_scale = v_scale * o_scale | ||
| or, | ||
| bmm1_scale = torch.Tensor([q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5)) | ||
| bmm2_scale = torch.Tensor([v_scale * o_scale]) | ||
| The two scale factors should be static constant for cuda graph capture. | ||
| Either (bmm1_scale, bmm2_scale) or (bmm1_scale_log2_tensor, bmm2_scale_tensor) should be provided. | ||
| For static constant scale factors, the scale factors should be provided as float. | ||
| - (bmm1_scale, bmm2_scale) | ||
| For on-device fused scale tensors, which could dynamically change, the scale factors should be provided as torch.Tensor. | ||
| - (bmm1_scale_log2_tensor, bmm2_scale_tensor) | ||
| - Currently, only fp8 tensor core operation supports this mode. | ||
| When both are provided, the dynamic scale factor tensors will be used. | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring for trtllm_batch_decode_with_kv_cache_mla is missing parameter types, and the enable_pdl parameter is not documented. Adding these would improve clarity and consistency with the function's type hints.
"""
Parameters
----------
query: torch.Tensor
[batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length.
kv_cache: torch.Tensor
[num_pages, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache
workspace_buffer: torch.Tensor
[num_semaphores, 4], used for multi_block mode. Must be initialized to 0 for its first use.
qk_nope_head_dim: int
qk_nope_head_dim, must be 128
kv_lora_rank: int
kv_lora_rank, must be 512
qk_rope_head_dim: int
qk_rope_head_dim, must be 64
sparse_mla_top_k: int
sparse MLA top k, must be 0 for non-sparse MLA.
block_tables: torch.Tensor
page_table of kv cache, [batch_size, num_pages]
seq_lens: torch.Tensor
query_len
max_seq_len: int
max sequence length for kv_cache
out: Optional[torch.Tensor]
output tensor, if not provided, will be allocated internally
bmm1_scale: Union[float, torch.Tensor]
fused scale for mla bmm1 input.
when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32.
bmm2_scale: Union[float, torch.Tensor]
fused scale for mla bmm2 input.
when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32.
sinks: Optional[List[torch.Tensor]]
additional value per head in the denominator of the softmax.
enable_pdl: Optional[bool]
Whether to enable Programmatic Dependent Launch (PDL).
backend : str
The implementation backend, could be ``auto``/``xqa`` or ``trtllm-gen``. Defaults to ``auto``.
When set to ``auto``, the backend will be chosen based on the device architecture and kernel availability.
For sm_100 and sm_103 (blackwell architecture), ``auto`` will choose ``trtllm-gen`` backend.
For sm_120 (blackwell architecture), ``auto`` will choose ``xqa`` backend.
Note
----
In MLA, the actual BMM1 and BMM2 scales applied would be fused as:
bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5)
bmm2_scale = v_scale * o_scale
or,
bmm1_scale = torch.Tensor([q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5))
bmm2_scale = torch.Tensor([v_scale * o_scale])
The two scale factors should be static constant for cuda graph capture.
Either (bmm1_scale, bmm2_scale) or (bmm1_scale_log2_tensor, bmm2_scale_tensor) should be provided.
For static constant scale factors, the scale factors should be provided as float.
- (bmm1_scale, bmm2_scale)
For on-device fused scale tensors, which could dynamically change, the scale factors should be provided as torch.Tensor.
- (bmm1_scale_log2_tensor, bmm2_scale_tensor)
- Currently, only fp8 tensor core operation supports this mode.
When both are provided, the dynamic scale factor tensors will be used.
"""| """ | ||
| Parameters: | ||
| query: [batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length. | ||
| kv_cache: [num_pages, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache | ||
| workspace_buffer: torch.Tensor. Must be initialized to 0 for its first use. | ||
| qk_nope_head_dim: qk_nope_head_dim, must be 128 | ||
| kv_lora_rank: kv_lora_rank, must be 512 | ||
| qk_rope_head_dim: qk_rope_head_dim, must be 64 | ||
| block_tables: page_table of kv cache, [batch_size, num_pages] | ||
| seq_lens: query_len | ||
| max_seq_len: max sequence length for kv_cache | ||
| out: output tensor, if not provided, will be allocated internally | ||
| bmm1_scale: fused scale for mla bmm1 input. Can be a float or a torch.Tensor. | ||
| bmm2_scale: fused scale for mla bmm2 input. Can be a float or a torch.Tensor. | ||
| sinks: additional value per head in the denominator of the softmax. | ||
| Note: | ||
| In MLA, the actual BMM1 and BMM2 scales applied would be fused as: | ||
| bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5) | ||
| bmm2_scale = v_scale * o_scale | ||
| The two scale factors should be static constant for cuda graph capture. | ||
| Either (bmm1_scale, bmm2_scale) or (bmm1_scale_log2_tensor, bmm2_scale_tensor) should be provided. | ||
| For static constant scale factors, the scale factors should be provided as float. | ||
| - (bmm1_scale, bmm2_scale) | ||
| For on-device fused scale tensors, which could dynamically change, the scale factors should be provided as torch.Tensor. | ||
| - (bmm1_scale_log2_tensor, bmm2_scale_tensor) | ||
| - Currently, only fp8 tensor core operation supports this mode. | ||
| When both are provided, the dynamic scale factor tensors will be used. | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring for xqa_batch_decode_with_kv_cache_mla is inconsistent with the numpy docstring format used in trtllm_batch_decode_with_kv_cache_mla. It's also missing parameter types and the enable_pdl parameter. For consistency and clarity, it should be updated.
"""
Parameters
----------
query: torch.Tensor
[batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length.
kv_cache: torch.Tensor
[num_pages, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache
workspace_buffer: torch.Tensor
torch.Tensor. Must be initialized to 0 for its first use.
qk_nope_head_dim: int
qk_nope_head_dim, must be 128
kv_lora_rank: int
kv_lora_rank, must be 512
qk_rope_head_dim: int
qk_rope_head_dim, must be 64
block_tables: torch.Tensor
page_table of kv cache, [batch_size, num_pages]
seq_lens: torch.Tensor
query_len
max_seq_len: int
max sequence length for kv_cache
out: Optional[torch.Tensor]
output tensor, if not provided, will be allocated internally
bmm1_scale: Union[float, torch.Tensor]
fused scale for mla bmm1 input. Can be a float or a torch.Tensor.
bmm2_scale: Union[float, torch.Tensor]
fused scale for mla bmm2 input. Can be a float or a torch.Tensor.
sinks: Optional[List[torch.Tensor]]
additional value per head in the denominator of the softmax.
enable_pdl: Optional[bool]
Whether to enable Programmatic Dependent Launch (PDL).
Note
----
In MLA, the actual BMM1 and BMM2 scales applied would be fused as:
bmm1_scale = q_scale * k_scale * sm_scale / (head_dim_qk ** 0.5)
bmm2_scale = v_scale * o_scale
The two scale factors should be static constant for cuda graph capture.
Either (bmm1_scale, bmm2_scale) or (bmm1_scale_log2_tensor, bmm2_scale_tensor) should be provided.
For static constant scale factors, the scale factors should be provided as float.
- (bmm1_scale, bmm2_scale)
For on-device fused scale tensors, which could dynamically change, the scale factors should be provided as torch.Tensor.
- (bmm1_scale_log2_tensor, bmm2_scale_tensor)
- Currently, only fp8 tensor core operation supports this mode.
When both are provided, the dynamic scale factor tensors will be used.
"""There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (4)
flashinfer/mla.py (4)
87-91: Consider using_for unpacked but unused variableH.The variable
Hat line 87 is unpacked but never used. Per static analysis hint, prefix it with an underscore to indicate it's intentionally unused:- B_q, Q_len, H, D_q = query.shape + B_q, Q_len, _H, D_q = query.shapeThe commented-out num_heads check (lines 89-91) with the TODO suggests this might be DeepSeek-specific. Consider either removing the dead code or documenting the decision.
541-541: Use explicitOptional[bool]instead of implicitNonedefault.Per PEP 484, the type annotation should explicitly indicate optionality:
- enable_pdl: bool = None, + enable_pdl: Optional[bool] = None,
703-703: Unused parametermax_seq_lenshould be documented or removed.The
max_seq_lenparameter is declared but never used in the function body. If this is intentional (e.g., for API consistency withtrtllm_batch_decode_with_kv_cache_mla), consider adding a comment. Otherwise, it should be removed to avoid confusion.
708-708: Use explicitOptional[bool]type annotation.- enable_pdl: bool = None, + enable_pdl: Optional[bool] = None,
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
benchmarks/routines/attention.py(1 hunks)docs/api/attention.rst(2 hunks)flashinfer/decode.py(1 hunks)flashinfer/mla.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
flashinfer/decode.py (2)
flashinfer/mla.py (2)
trtllm_batch_decode_with_kv_cache_mla(526-690)xqa_batch_decode_with_kv_cache_mla(694-804)flashinfer/xqa.py (4)
xqa(65-112)xqa(148-333)xqa_mla(358-391)xqa_mla(420-530)
benchmarks/routines/attention.py (1)
flashinfer/mla.py (1)
trtllm_batch_decode_with_kv_cache_mla(526-690)
🪛 Ruff (0.14.7)
flashinfer/mla.py
77-77: Avoid specifying long messages outside the exception class
(TRY003)
79-79: Avoid specifying long messages outside the exception class
(TRY003)
81-81: Avoid specifying long messages outside the exception class
(TRY003)
83-83: Avoid specifying long messages outside the exception class
(TRY003)
85-85: Avoid specifying long messages outside the exception class
(TRY003)
87-87: Unpacked variable H is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
93-95: Avoid specifying long messages outside the exception class
(TRY003)
100-102: Avoid specifying long messages outside the exception class
(TRY003)
107-109: Avoid specifying long messages outside the exception class
(TRY003)
111-113: Avoid specifying long messages outside the exception class
(TRY003)
541-541: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
603-605: Avoid specifying long messages outside the exception class
(TRY003)
607-607: Avoid specifying long messages outside the exception class
(TRY003)
609-611: Avoid specifying long messages outside the exception class
(TRY003)
639-639: Avoid specifying long messages outside the exception class
(TRY003)
653-653: Consider (*query.shape[:-1], kv_lora_rank) instead of concatenation
Replace with (*query.shape[:-1], kv_lora_rank)
(RUF005)
690-690: Avoid specifying long messages outside the exception class
(TRY003)
703-703: Unused function argument: max_seq_len
(ARG001)
708-708: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
747-749: Avoid specifying long messages outside the exception class
(TRY003)
751-753: Avoid specifying long messages outside the exception class
(TRY003)
755-755: Avoid specifying long messages outside the exception class
(TRY003)
769-769: Consider (*query.shape[:-1], kv_lora_rank) instead of concatenation
Replace with (*query.shape[:-1], kv_lora_rank)
(RUF005)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (3)
docs/api/attention.rst (1)
50-50: LGTM! Documentation entries align with the refactored module structure.The new symbols
xqa_mlaunderflashinfer.xqaandtrtllm_batch_decode_with_kv_cache_mlaunderflashinfer.mlaare correctly documented and match the module locations in the code changes.Also applies to: 102-105
benchmarks/routines/attention.py (1)
1839-1852: LGTM! The module path update correctly reflects the refactored MLA function location.The call to
flashinfer.mla.trtllm_batch_decode_with_kv_cache_mlais consistent with the PR's refactoring of MLA functions into the dedicatedmla.pymodule. The parameter mappings and shape handling (unsqueeze/squeeze) are appropriate.Consider addressing the TODO comment at line 1844 regarding the hardcoded
qk_nope_head_dim=128to improve code clarity for future maintainers.flashinfer/decode.py (1)
25-31: LGTM! Backward-compatible re-exports maintain API stability.The explicit aliasing pattern (
from .mla import trtllm_batch_decode_with_kv_cache_mla as trtllm_batch_decode_with_kv_cache_mla) correctly re-exports the moved functions, ensuring existing code that imports fromflashinfer.decodecontinues to work without modification. The comment at line 26 effectively documents the rationale.
| if out is None: | ||
| out_shape = query.shape[:-1] + (kv_lora_rank,) | ||
| out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device) | ||
| else: | ||
| batch_size, _, num_q_heads, _ = query.shape | ||
| check_shape_dtype_device( | ||
| out, | ||
| [batch_size, num_q_heads, kv_lora_rank], | ||
| torch.bfloat16, | ||
| query.device, | ||
| "out", | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shape mismatch between allocated output tensor and validation for provided output.
When out is None, the allocated shape is 4D: query.shape[:-1] + (kv_lora_rank,) = (batch_size, q_len, num_heads, kv_lora_rank).
However, when out is provided, it's validated against a 3D shape: [batch_size, num_q_heads, kv_lora_rank].
This inconsistency could cause issues when users provide their own output tensor.
if out is None:
- out_shape = query.shape[:-1] + (kv_lora_rank,)
+ out_shape = (*query.shape[:-1], kv_lora_rank)
out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device)
else:
batch_size, _, num_q_heads, _ = query.shape
check_shape_dtype_device(
out,
- [batch_size, num_q_heads, kv_lora_rank],
+ [batch_size, 1, num_q_heads, kv_lora_rank],
torch.bfloat16,
query.device,
"out",
)Alternatively, verify if the 3D validation shape is intentional and the output is squeezed elsewhere, but this should be documented clearly.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if out is None: | |
| out_shape = query.shape[:-1] + (kv_lora_rank,) | |
| out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device) | |
| else: | |
| batch_size, _, num_q_heads, _ = query.shape | |
| check_shape_dtype_device( | |
| out, | |
| [batch_size, num_q_heads, kv_lora_rank], | |
| torch.bfloat16, | |
| query.device, | |
| "out", | |
| ) | |
| if out is None: | |
| out_shape = (*query.shape[:-1], kv_lora_rank) | |
| out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device) | |
| else: | |
| batch_size, q_len, num_q_heads, _ = query.shape | |
| check_shape_dtype_device( | |
| out, | |
| [batch_size, q_len, num_q_heads, kv_lora_rank], | |
| torch.bfloat16, | |
| query.device, | |
| "out", | |
| ) |
🧰 Tools
🪛 Ruff (0.14.7)
653-653: Consider (*query.shape[:-1], kv_lora_rank) instead of concatenation
Replace with (*query.shape[:-1], kv_lora_rank)
(RUF005)
🤖 Prompt for AI Agents
In flashinfer/mla.py around lines 652 to 663, the code allocates out as a 4D
tensor (query.shape[:-1] + (kv_lora_rank,) => [batch_size, q_len, num_q_heads,
kv_lora_rank]) but validates a provided out against a 3D shape [batch_size,
num_q_heads, kv_lora_rank], causing a shape mismatch; update the validation to
expect the same 4D shape (use [batch_size, query.shape[1], num_q_heads,
kv_lora_rank] or equivalent) when out is provided, and keep dtype/device checks
the same, or alternatively change the allocation to 3D if the function truly
expects a 3D output—ensure both allocation and validation use the identical
shape and add a brief comment clarifying the expected out dimensionality.
| if out is None: | ||
| out_shape = query.shape[:-1] + (kv_lora_rank,) | ||
| out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device) | ||
| else: | ||
| batch_size, _, num_q_heads, _ = query.shape | ||
| check_shape_dtype_device( | ||
| out, | ||
| [batch_size, num_q_heads, kv_lora_rank], | ||
| torch.bfloat16, | ||
| query.device, | ||
| "out", | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same shape mismatch issue as in trtllm_batch_decode_with_kv_cache_mla.
The allocated output shape is 4D but the validation for provided output checks against 3D. Apply a consistent fix:
if out is None:
- out_shape = query.shape[:-1] + (kv_lora_rank,)
+ out_shape = (*query.shape[:-1], kv_lora_rank)
out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device)
else:
batch_size, _, num_q_heads, _ = query.shape
check_shape_dtype_device(
out,
- [batch_size, num_q_heads, kv_lora_rank],
+ [batch_size, 1, num_q_heads, kv_lora_rank],
torch.bfloat16,
query.device,
"out",
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if out is None: | |
| out_shape = query.shape[:-1] + (kv_lora_rank,) | |
| out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device) | |
| else: | |
| batch_size, _, num_q_heads, _ = query.shape | |
| check_shape_dtype_device( | |
| out, | |
| [batch_size, num_q_heads, kv_lora_rank], | |
| torch.bfloat16, | |
| query.device, | |
| "out", | |
| ) | |
| if out is None: | |
| out_shape = (*query.shape[:-1], kv_lora_rank) | |
| out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device) | |
| else: | |
| batch_size, _, num_q_heads, _ = query.shape | |
| check_shape_dtype_device( | |
| out, | |
| [batch_size, 1, num_q_heads, kv_lora_rank], | |
| torch.bfloat16, | |
| query.device, | |
| "out", | |
| ) |
🧰 Tools
🪛 Ruff (0.14.7)
769-769: Consider (*query.shape[:-1], kv_lora_rank) instead of concatenation
Replace with (*query.shape[:-1], kv_lora_rank)
(RUF005)
🤖 Prompt for AI Agents
In flashinfer/mla.py around lines 768 to 779, the code allocates out as a 4D
tensor but the provided-output validation expects a 3D tensor, causing a shape
mismatch; fix by allocating out with the same 3D shape the validator expects:
extract batch_size and num_q_heads from query (batch_size, _, num_q_heads, _)
and set out_shape = (batch_size, num_q_heads, kv_lora_rank) when out is None,
preserving dtype and device so the allocated tensor matches the
check_shape_dtype_device call.
|
/bot run |
|
[FAILED] Pipeline #39508973: 4/20 passed |
📌 Description
trtllm_batch_decode_with_kv_cache_mlaandxqa_batch_decode_with_kv_cache_mlacurrently reside indecode.py. This PR moves them tomla.pyand makes them show up in the documentation via adding them toattention.rst.Note that the addition to documentation at the correct place requires this refactor as the docs generator looks at each module for indexing.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Documentation
✏️ Tip: You can customize this high-level summary in your review settings.