[API change] Allow using torch.Tensor for scales for trtllm-gen attention#2084
[API change] Allow using torch.Tensor for scales for trtllm-gen attention#2084
Conversation
Signed-off-by: Siyuan Fu <[email protected]>
WalkthroughAdds tensor-or-scalar support for attention scaling across Python APIs and C++ FMHA launchers using tvm::ffi::Variant; accepts device-resident scale tensors (passed as float pointers) or host scalars, applies on-device log2e scaling for tensor inputs, and updates FMHA cubin artifact path and checksum constants. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Py as Python API
participant Bind as FFI Binder (C++)
participant Launcher as trtllm_*_launcher
participant Runner as FMHA Runner
Py->>Bind: call API with bmm1_scale, bmm2_scale (float or Tensor)
alt Tensor inputs
note right of Bind `#d6f5d6`: assert dtype float32\ncompute tensor * log2e on device
Bind->>Launcher: Variant(tensor) + provide bmm1_scale_log2_ptr & bmm2_scale_ptr
else Scalar inputs
note right of Bind `#f0f0f0`: keep/convert as double
Bind->>Launcher: Variant(double) + pass nullptr for scale pointers
end
Launcher->>Runner: set runner_params.scaleSoftmaxLog2Ptr & outputScalePtr
Runner->>Runner: execute FMHA using pointer scales if present else scalar scales
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Tip 📝 Customizable high-level summaries are now available in beta!You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.
Example instruction:
Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Signed-off-by: Siyuan Fu <[email protected]>
Signed-off-by: Siyuan Fu <[email protected]>
Signed-off-by: Siyuan Fu <[email protected]>
|
/bot run |
|
@IwakuraRein is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww |
|
/bot run |
Signed-off-by: Siyuan Fu <[email protected]>
|
[CANCELING] Pipeline #38436074: canceled |
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/decode.py (1)
1883-1901: In-place scale multiply causes driftSame issue here:
bmm1_scale *= log2eupdates the caller’s tensor. If the caller caches that buffer (common in decode loops), it compounds every step. Please switch to a non-in-place multiply or clone first. (docs.pytorch.org)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
csrc/trtllm_fmha_kernel_launcher.cu(13 hunks)flashinfer/artifacts.py(2 hunks)flashinfer/decode.py(13 hunks)flashinfer/prefill.py(9 hunks)include/flashinfer/trtllm/fmha/kernelParams.h(0 hunks)
💤 Files with no reviewable changes (1)
- include/flashinfer/trtllm/fmha/kernelParams.h
| auto maybe_bmm2_scale_value = bmm2_scale.as<double>(); | ||
| auto maybe_bmm1_scale_log2_tensor = bmm1_scale.as<ffi::Tensor>(); | ||
| auto maybe_bmm2_scale_tensor = bmm2_scale.as<ffi::Tensor>(); | ||
| TVM_FFI_CHECK(maybe_bmm1_scale_value.has_value() || maybe_bmm1_scale_log2_tensor.has_value(), | ||
| "bmm1_scale must be either a double or a tensor"); | ||
| TVM_FFI_CHECK(maybe_bmm2_scale_value.has_value() || maybe_bmm2_scale_tensor.has_value(), | ||
| "bmm2_scale must be either a double or a tensor"); | ||
| double bmm1_scale_value = | ||
| maybe_bmm1_scale_value.has_value() ? maybe_bmm1_scale_value.value() : 1.0; | ||
| double bmm2_scale_value = | ||
| maybe_bmm2_scale_value.has_value() ? maybe_bmm2_scale_value.value() : 1.0; | ||
| float* bmm1_scale_log2_ptr = | ||
| maybe_bmm1_scale_log2_tensor.has_value() | ||
| ? static_cast<float*>(maybe_bmm1_scale_log2_tensor.value().data_ptr()) | ||
| : nullptr; | ||
| float* bmm2_scale_ptr = maybe_bmm2_scale_tensor.has_value() | ||
| ? static_cast<float*>(maybe_bmm2_scale_tensor.value().data_ptr()) | ||
| : nullptr; |
There was a problem hiding this comment.
Guard tensor-based scales with dtype checks
When bmm*_scale comes in as a tensor, we immediately reinterpret the storage as float*. Callers can legally hand us torch.Float16/torch.BFloat16 today, so this reinterpret cast will read garbage and corrupt the softmax/output scales. Please gate the tensor branch with a dtype == dl_float32 check (and emit a clear error otherwise) before taking the pointer, and apply the same fix in the context and ragged code paths.
@@
- float* bmm1_scale_log2_ptr =
- maybe_bmm1_scale_log2_tensor.has_value()
- ? static_cast<float*>(maybe_bmm1_scale_log2_tensor.value().data_ptr())
- : nullptr;
- float* bmm2_scale_ptr = maybe_bmm2_scale_tensor.has_value()
- ? static_cast<float*>(maybe_bmm2_scale_tensor.value().data_ptr())
- : nullptr;
+ float* bmm1_scale_log2_ptr = nullptr;
+ if (maybe_bmm1_scale_log2_tensor.has_value()) {
+ TVM_FFI_ICHECK_EQ(maybe_bmm1_scale_log2_tensor.value().dtype(), dl_float32)
+ << "bmm1_scale tensor must be float32";
+ bmm1_scale_log2_ptr =
+ static_cast<float*>(maybe_bmm1_scale_log2_tensor.value().data_ptr());
+ }
+ float* bmm2_scale_ptr = nullptr;
+ if (maybe_bmm2_scale_tensor.has_value()) {
+ TVM_FFI_ICHECK_EQ(maybe_bmm2_scale_tensor.value().dtype(), dl_float32)
+ << "bmm2_scale tensor must be float32";
+ bmm2_scale_ptr =
+ static_cast<float*>(maybe_bmm2_scale_tensor.value().data_ptr());
+ }Please mirror this guard in trtllm_paged_attention_context and trtllm_ragged_attention.
Also applies to: 338-356, 503-521
🤖 Prompt for AI Agents
csrc/trtllm_fmha_kernel_launcher.cu lines 260-277: when bmm1_scale or bmm2_scale
is a tensor the code currently reinterpret_casts data_ptr() to float* without
checking dtype which will misread half/bfloat tensors; modify the tensor branch
to first check the tensor dtype is float32 (dl_float32) and TVM_FFI_CHECK/throw
a clear error if not, then take the data_ptr() as float*; apply the identical
dtype-guard and error message to the similar blocks at lines 338-356 and 503-521
and also mirror these dtype guards in the corresponding
trtllm_paged_attention_context and trtllm_ragged_attention code paths.
| if isinstance(bmm1_scale, torch.Tensor): | ||
| assert bmm1_scale.dtype == torch.float32 | ||
| bmm1_scale *= log2e | ||
| if isinstance(bmm2_scale, torch.Tensor): | ||
| assert bmm2_scale.dtype == torch.float32 | ||
|
|
There was a problem hiding this comment.
Don’t mutate caller tensors when applying log2e
Same issue here: bmm1_scale *= log2e alters the input tensor in place, so repeated invocations accumulate the scaling and yield incorrect kernels. Use an out-of-place multiply before launching the kernel.
- if isinstance(bmm1_scale, torch.Tensor):
- assert bmm1_scale.dtype == torch.float32
- bmm1_scale *= log2e
+ if isinstance(bmm1_scale, torch.Tensor):
+ assert bmm1_scale.dtype == torch.float32
+ bmm1_scale = bmm1_scale * log2e📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if isinstance(bmm1_scale, torch.Tensor): | |
| assert bmm1_scale.dtype == torch.float32 | |
| bmm1_scale *= log2e | |
| if isinstance(bmm2_scale, torch.Tensor): | |
| assert bmm2_scale.dtype == torch.float32 | |
| if isinstance(bmm1_scale, torch.Tensor): | |
| assert bmm1_scale.dtype == torch.float32 | |
| bmm1_scale = bmm1_scale * log2e | |
| if isinstance(bmm2_scale, torch.Tensor): | |
| assert bmm2_scale.dtype == torch.float32 | |
🤖 Prompt for AI Agents
In flashinfer/decode.py around lines 2296 to 2301, the code currently does an
in-place scale (bmm1_scale *= log2e) which mutates the caller's tensor; change
this to an out-of-place multiplication and reassign the result to bmm1_scale
(for example use torch.mul or the * operator) so a new tensor is produced on the
same dtype/device and the original caller tensor is not modified; ensure the
result remains float32 and mirrored onto the correct device; also review nearby
bmm2_scale handling and apply the same out-of-place pattern if it will be scaled
later.
|
/bot run |
Signed-off-by: Siyuan Fu <[email protected]>
|
[CANCELING] Pipeline #38436713: canceled |
|
/bot run |
|
@IwakuraRein is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww |
|
/bot run |
Signed-off-by: Siyuan Fu <[email protected]>
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/attention/test_trtllm_gen_attention.py (1)
490-499: Scale conversion logic is correct.The conditional conversion between tensor and scalar forms based on
device_scalecorrectly handles all cases:
- Tensor → scalar when
device_scale=False- Scalar → tensor when
device_scale=True- Preserves existing type when already in the desired form
This ensures comprehensive test coverage across both FP8 (where scales may be tensors) and FP16/BF16 (where scales are scalars) data types.
Consider extracting this conversion logic to a helper function to reduce duplication with the identical logic at lines 798-807:
def convert_scale_for_test(scale, device_scale, device): """Convert scale between tensor and scalar based on test parameter.""" if isinstance(scale, torch.Tensor) and not device_scale: return scale.item() elif not isinstance(scale, torch.Tensor) and device_scale: return torch.tensor(scale, device=device, dtype=torch.float32) return scaleThen use:
bmm1_scale = convert_scale_for_test(q_scale * k_scale * sm_scale, device_scale, GPU_DEVICE) bmm2_scale = convert_scale_for_test(v_scale / o_scale, device_scale, GPU_DEVICE)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tests/attention/test_trtllm_gen_attention.py(13 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/attention/test_trtllm_gen_attention.py (1)
flashinfer/decode.py (1)
trtllm_batch_decode_with_kv_cache(2061-2335)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (3)
tests/attention/test_trtllm_gen_attention.py (3)
374-374: Excellent test coverage expansion.Adding
device_scaleas a test parameter ensures both device-resident tensor scales and host scalar scales are validated across all test configurations. This comprehensive approach aligns well with the PR's goal of supportingUnion[float, torch.Tensor]for scale parameters.
1128-1128: Forward-looking test design.Setting
device_scale=Truein this expected-to-fail test ensures that when the head_dim=256 issue is resolved, the test will immediately validate tensor scale support for that configuration. This is good defensive test engineering.
969-969: LGTM: Consistent parameter threading.The
device_scaleparameter is correctly added to the test matrix and properly propagated through the function call hierarchy. The default value ofFalsein_test_trtllm_batch_decode(line 650) provides a sensible default for focused test cases.Also applies to: 650-650, 986-986, 1005-1005
|
/bot run |
|
[FAILED] Pipeline #38539973: 6/17 passed |
Signed-off-by: Siyuan Fu <[email protected]>
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (1)
tests/attention/test_trtllm_gen_attention.py (1)
830-839: Same conversion logic as prefill path - see comment on lines 457-466.This conversion logic is identical to the prefill path. Consider using the helper function suggested in the previous comment to eliminate this duplication.
🧹 Nitpick comments (2)
tests/attention/test_trtllm_gen_attention.py (2)
457-466: Consider extracting the scale conversion logic into a helper function.The conversion logic between tensor and scalar forms for
bmm1_scaleandbmm2_scaleis duplicated in both the prefill path (lines 457-466) and decode path (lines 830-839). Extracting this into a helper function would reduce code duplication and improve maintainability.Example helper function:
def convert_scale_form(scale, device_scale: bool, device): """Convert scale between tensor and scalar forms based on device_scale flag.""" if isinstance(scale, torch.Tensor) and not device_scale: return scale.item() elif not isinstance(scale, torch.Tensor) and device_scale: return torch.tensor(scale, device=device, dtype=torch.float32) return scaleThen use it as:
bmm1_scale = convert_scale_form(q_scale * k_scale * sm_scale, device_scale, GPU_DEVICE) bmm2_scale = convert_scale_form(v_scale / o_scale, device_scale, GPU_DEVICE)
611-612: Consider explicitly parametrizing device_scale for more comprehensive test coverage.The test passes
kv_dtype == "fp8"as thedevice_scaleargument, which means:
- fp8 tests always use device-side scales
- non-fp8 tests always use host-side scales
While this aligns with the natural scale types (fp8 quantization produces tensor scales), it limits test coverage. The specialized tests (
test_trtllm_batch_decode_bs1,test_trtllm_batch_decode_head_dim_256,test_trtllm_batch_decode_long_sequence_length) parametrizedevice_scaleexplicitly with[True, False], providing better coverage of both code paths across all dtypes.For consistency and completeness, consider whether this test should also parametrize
device_scaleexplicitly, or document why the current approach is intentional.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tests/attention/test_trtllm_gen_attention.py(19 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/attention/test_trtllm_gen_attention.py (1)
flashinfer/decode.py (1)
trtllm_batch_decode_with_kv_cache(2061-2335)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (2)
tests/attention/test_trtllm_gen_attention.py (2)
1035-1035: Note: xqa backend converts device scales to host scalars internally.When
backend="xqa"andkv_dtype="fp8", this test creates device-side tensor scales (device_scale=True), but the xqa backend immediately converts them back to host scalars (see flashinfer/decode.py lines 2158-2162). While this doesn't cause incorrect behavior, it does result in unnecessary tensor creation and conversion.This is expected behavior since device-side scale support for xqa was removed in PR #2033 (as noted in the PR description's TODO). The test still validates that the conversion works correctly.
1057-1057: Good test coverage with explicit device_scale parametrization.These specialized tests explicitly parametrize
device_scalewith[True, False], providing comprehensive coverage of both device-side and host-side scale paths across different scenarios (bs1, head_dim=256, and long sequences). This is more thorough than the general test approach and ensures both code paths work correctly for all dtype combinations.Since these tests use the trtllm-gen backend exclusively (which supports device scales), they avoid the unnecessary tensor-to-scalar conversions that would occur with the xqa backend.
Also applies to: 1126-1126, 1188-1188
|
/bot run |
|
[SUCCESS] Pipeline #38646833: 10/18 passed |
…tion (flashinfer-ai#2084) <!-- .github/pull_request_template.md --> - change `bmm1_scale` and `bmm2_scale` to `Union[float, torch.Tensor]`. notice that when using tensor, it must be applied by log2e - **remove the `bmm1_scale_log2_tensor` and `bmm2_scale_tensor` in the `xqa_batch_decode_with_kv_cache_mla`** - update trtllm-gen FMHA kernels TODO: do the same refactor for xqa kernels. The support for the device side scales was removed in flashinfer-ai#2033 <!-- Link any related issues here --> Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> * **New Features** * Attention scale parameters now accept either floats or 1-element tensors across prefill, decode and runtime; tensor scales are validated and applied on-device and pointer-backed scale paths are supported. * **Chores** * Updated FMHA artifact path and checksum constants; added a public utility import and removed an obsolete inline comment. * **Tests** * Updated tests to exercise device/tensor-or-scalar scale flows, removed legacy per-tensor call-site args, and added device-scale parametrization for several test variants. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Siyuan Fu <[email protected]>
📌 Description
bmm1_scaleandbmm2_scaletoUnion[float, torch.Tensor]. notice that when using tensor, it must be applied by log2ebmm1_scale_log2_tensorandbmm2_scale_tensorin thexqa_batch_decode_with_kv_cache_mlaTODO: do the same refactor for xqa kernels. The support for the device side scales was removed in #2033
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Chores
Tests