Skip to content

[TRTLLM-9831][perf] Use TMA.RED to improve effective memory bandwidth#10987

Merged
syuoni merged 2 commits intoNVIDIA:mainfrom
sherry-1001:main
Jan 27, 2026
Merged

[TRTLLM-9831][perf] Use TMA.RED to improve effective memory bandwidth#10987
syuoni merged 2 commits intoNVIDIA:mainfrom
sherry-1001:main

Conversation

@sherry-1001
Copy link
Collaborator

@sherry-1001 sherry-1001 commented Jan 26, 2026

Summary by CodeRabbit

Release Notes

  • New Features

    • Added block reduction optimization path for GEMM kernel operations, enabling improved performance for certain computation scenarios.
  • Refactor

    • Enhanced kernel implementation with block reduction support across multiple data types (BF16, FP32, FP16), including updated memory layout handling and epilogue processing.

✏️ Tip: You can customize this high-level summary in your review settings.

Description

The performance of REDG.128bit is highly sensitive to the data distribution—more precisely, to the address distribution and contention pattern. In contrast, UBLKRED/BLKRED consolidate many fine‑grained REDG operations into regular, coarse‑grained bulk accesses, making performance more stable with much smaller variance.
UBLKRED can gain 50% performance improvement in some cases, so we use UBLKRED defaultly for finalized fusion kernel

Test Coverage

TensorRT-LLM/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

Signed-off-by: zhichen jiang <zhichenj@NVIDIA.com>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 26, 2026

📝 Walkthrough

Walkthrough

A block-reduction (blkred) code path is introduced to a GEMM kernel, controlled by a new use_blkred parameter. Changes include extended kernel signatures, epilogue logic for block-reduce operations, memory layout adjustments, and new block-reduce utility functions (bf16, fp16, fp32 variants).

Changes

Cohort / File(s) Summary
Kernel Custom Ops Integration
tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
Enables use_blkred=True in GEMM kernel invocation within finalize fusion path.
Block-Reduction Kernel Implementation
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py
Introduces optional block-reduction path controlled by use_blkred parameter. Extends kernel __init__, _compute_stages, and _setup_attributes signatures. Adds epilogue shared-memory copy/partition logic (epilog_smem_copy_and_partition), conditional epilogue accumulation routing, block-reduce utility invocations, and updated memory layout handling for block-reduce (c_smem_layout_staged). Modifies SharedStorage to conditionally include shared C buffer.
Block-Reduction Utility Operators
tensorrt_llm/_torch/cute_dsl_kernels/blackwell/utils.py
Introduces three new DSL user-op blocks: blk_reduce_bf16, blk_reduce_fp32, blk_reduce_fp16. Each wraps inline CUDA assembly with CP-reduce intrinsics for vectorized reduction across global and shared memory.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 61.54% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the main change: using TMA.RED (UBLKRED/BLKRED) to improve effective memory bandwidth in the finalized fusion kernel.
Description check ✅ Passed The description includes key sections: the problem (REDG.128bit sensitivity), the solution (UBLKRED/BLKRED), and test coverage. However, it lacks detailed explanation of what was changed and why the specific approach was chosen.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In
`@tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`:
- Around line 538-547: The computed C shared-memory sizing in _compute_stages is
inconsistent with c_smem_layout_staged: instead of using swizzled_pad = 16 //
(out_dtype.width // 8) (which underestimates bytes for FP32), compute the staged
C stride using the same formula as c_smem_layout_staged (i.e., use
cta_tile_shape_mnk[1] + 8 for the C stage stride/width) and derive C bytes from
that stride and out_dtype.width; update _compute_stages and the other affected
locations (the blocks around where c_smem_layout_staged is used and the
occurrences you noted) to use cta_tile_shape_mnk[1] + 8 consistently so A/B
stage counts and SMEM capacity calculations match the layout and avoid
underestimation for FP32.

@kaiyux
Copy link
Member

kaiyux commented Jan 26, 2026

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #33528 [ run ] triggered by Bot. Commit: 18de6bf

@tensorrt-cicd
Copy link
Collaborator

PR_Github #33528 [ run ] completed with state SUCCESS. Commit: 18de6bf
/LLM/main/L0_MergeRequest_PR pipeline #25864 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

@syuoni
Copy link
Collaborator

syuoni commented Jan 26, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #33612 [ run ] triggered by Bot. Commit: 18de6bf

@tensorrt-cicd
Copy link
Collaborator

PR_Github #33612 [ run ] completed with state SUCCESS. Commit: 18de6bf
/LLM/main/L0_MergeRequest_PR pipeline #25930 completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Signed-off-by: zhichen jiang <zhichenj@NVIDIA.com>
@sherry-1001
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #33673 [ run ] triggered by Bot. Commit: 731d020

@tensorrt-cicd
Copy link
Collaborator

PR_Github #33673 [ run ] completed with state SUCCESS. Commit: 731d020
/LLM/main/L0_MergeRequest_PR pipeline #25977 completed with status: 'SUCCESS'

@syuoni syuoni merged commit fae4985 into NVIDIA:main Jan 27, 2026
5 checks passed
yzh119 pushed a commit to flashinfer-ai/flashinfer that referenced this pull request Feb 16, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

The PR is follow up to PR #2398 
To integration [TRTLLM PR
10987](NVIDIA/TensorRT-LLM#10987). Use TMA.RED
to improve effective memory bandwidth

Perf data is (tested on GB200):

Tokens | CuteDSL (main) ms | CuteDSL (TMA.RED) ms | TRTLLM gen ms |
CUTLASS ms | Winner | CuteDSL Speedup (main/TMA.RED)
-- | -- | -- | -- | -- | -- | --
1 | 0.064 | 0.064 | 0.053 | 0.099 | TRTLLM | 1.000x
2 | 0.077 | 0.077 | 0.063 | 0.107 | TRTLLM | 1.000x
4 | 0.096 | 0.096 | 0.085 | 0.131 | TRTLLM | 1.000x
8 | 0.096 | 0.096 | 0.091 | 0.131 | TRTLLM | 1.000x
16 | 0.101 | 0.102 | 0.103 | 0.138 | CuteDSL | 0.990x
32 | 0.114 | 0.114 | 0.142 | 0.152 | CuteDSL | 1.000x
62 | 0.122 | 0.122 | 0.183 | 0.163 | CuteDSL | 1.000x
128 | 0.133 | 0.132 | 0.173 | 0.220 | CuteDSL | 1.008x
256 | 0.142 | 0.138 | 0.220 | 0.251 | CuteDSL | 1.029x
512 | 0.190 | 0.183 | 0.271 | 0.333 | CuteDSL | 1.038x
1024 | 0.286 | 0.278 | 0.576 | 0.482 | CuteDSL | 1.029x
2048 | 0.472 | 0.461 | 0.555 | 0.723 | CuteDSL | 1.024x
4096 | 0.855 | 0.824 | 0.873 | 1.278 | CuteDSL | 1.038x
8192 | 1.764 | 1.713 | 1.653 | 2.383 | TRTLLM | 1.030x



## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [ ] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [ ] I have installed the hooks with `pre-commit install`.
- [ ] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [ ] Tests have been added or updated as needed.
- [ ] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

## Release Notes

* **New Features**
* Introduced block-reduction optimization in MOE finalization kernels
for improved performance on latest hardware.
* Added support for block-wise reduction operations across multiple data
types (BF16, FP32, FP16).

* **Performance**
* Optimized GPU memory utilization by reducing unnecessary cross-device
data transfers during computation.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants