Skip to content

refactor: simplify fp4 rmsnorm#2421

Merged
yzh119 merged 1 commit intoflashinfer-ai:mainfrom
yzh119:refactor-fp4-norm
Jan 27, 2026
Merged

refactor: simplify fp4 rmsnorm#2421
yzh119 merged 1 commit intoflashinfer-ai:mainfrom
yzh119:refactor-fp4-norm

Conversation

@yzh119
Copy link
Collaborator

@yzh119 yzh119 commented Jan 27, 2026

📌 Description

Remove repetition patterns in cute-dsl based fp4 rmsnorm code.

More specifically:

  • Use cute.make_rmem_tensor to create register array instead of explicit creating one register for each of them, and using for loop with cutlass.range_constexpr for elementwise operations.
  • Put common utilitity functions in fp4_common.py

Benchmarks showing there is not performance degradation.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

cc @bkryu

Summary by CodeRabbit

Refactor

  • Reorganized internal quantization utilities into a shared module to improve code maintainability and reduce duplication. All public APIs remain unchanged and fully compatible.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 27, 2026

📝 Walkthrough

Walkthrough

Consolidates FP4 quantization utilities and CuTe-DSL intrinsics into a new shared fp4_common.py module, then refactors add_rmsnorm_fp4quant.py and rmsnorm_fp4quant.py to import and reuse these utilities instead of maintaining duplicated inline definitions, reducing code duplication while preserving public APIs.

Changes

Cohort / File(s) Summary
New FP4 utilities module
flashinfer/cute_dsl/fp4_common.py
Introduces ~40+ new public functions: architecture utilities (get_sm_version), PTX intrinsics (set_block_rank, store_shared_remote, elem_pointer), global memory ops (ld_global_v4_u32, st_global_u64), math intrinsics (rcp_approx_ftz, fmin_f32, fmax_f32), half2/bfloat2 SIMD ops, FP8/E4M3/UE8M0 conversions, reduction utilities (warp, block, cluster, row-level), and SF-block processing helpers for quantization workflows.
Kernel refactoring to use fp4_common
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py,
flashinfer/cute_dsl/rmsnorm_fp4quant.py
Both files replace local in-file PTX/DSA definitions and helper implementations with imports from .fp4_common, delegating memory loads, SIMD operations, reductions, and quantization packing to shared utilities. Public APIs (AddRMSNormFP4QuantKernel/RMSNormFP4QuantKernel, add_rmsnorm_fp4quant/rmsnorm_fp4quant, get_sm_version) remain unchanged; internal control flow is restructured around shared helpers.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested reviewers

  • kaixih
  • aleozlx
  • bkryu
  • jimmyzho

Poem

🐰 Whiskers twitching with glee,
We hops through code, extracting with care,
Common threads bundled, utilities shared,
No logic lost, just cleaner to see—
One helper to rule them all, fair and free! 🎉

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main refactoring objective: simplifying the FP4 RMSNorm implementation by extracting common utilities and reducing code repetition.
Description check ✅ Passed The description is mostly complete, covering the changes made (moving utilities to fp4_common.py, using cute.make_rmem_tensor), includes performance benchmark notes, and follows the template structure. Minor: Related Issues section not filled but is non-critical.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly improves the maintainability and readability of the FP4 RMSNorm implementation by consolidating shared code into a new utility file and abstracting repetitive patterns within the main kernel logic. The refactoring aims to simplify the codebase without affecting the existing performance characteristics.

Highlights

  • Code Centralization: Common utility functions, PTX intrinsics, and reduction logic for FP4 quantization kernels have been moved from add_rmsnorm_fp4quant.py and rmsnorm_fp4quant.py into a new shared module, fp4_common.py.
  • Reduced Repetition: The kernel functions in both add_rmsnorm_fp4quant.py and rmsnorm_fp4quant.py have been refactored to use new helper functions (e.g., load_8_half2, half2_mul_8, quantize_and_pack_16) instead of explicit, repetitive code for element-wise operations and register array creation.
  • Performance Preservation: Benchmarks confirm that these refactoring changes do not introduce any performance degradation.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request significantly refactors the FP4 RMSNorm implementation by extracting common utility functions and repetitive code patterns into a new fp4_common.py module. The changes effectively simplify the add_rmsnorm_fp4quant.py and rmsnorm_fp4quant.py files, making them more readable and maintainable. The introduction of helper functions like load_8_half2, compute_y_and_max_abs_f32, and quantize_and_pack_16 successfully reduces code duplication and improves modularity, aligning perfectly with the stated objective of simplifying the code. The removal of explicit register assignments and manual element-wise operations in favor of these helper functions is a great improvement. Benchmarks confirm no performance degradation, which is crucial for such low-level optimizations.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/cute_dsl/add_rmsnorm_fp4quant.py (1)

1134-1141: The docstring promises in-place modification of residual, but this is not guaranteed for 3D inputs.

When input.dim() == 3 and the residual tensor is not contiguous in row-major layout, residual.view(B * S, H).contiguous() creates a copy. The kernel modifies this copy, not the original tensor. Since the function returns only (y_fp4, block_scale) and not the residual, the caller has no way to access the modified value.

To fix:

  1. Update the docstring to clarify that in-place modification only works for 2D inputs or pre-contiguous 3D inputs
  2. Or, reshape without calling .contiguous() (e.g., residual.reshape(B * S, H) when possible), then handle contiguity at the kernel call site
  3. Or, for 3D inputs, copy the result back: residual.copy_(residual_2d.view(B, S, H)) after kernel execution

@yzh119
Copy link
Collaborator Author

yzh119 commented Jan 27, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !268 has been created, and the CI pipeline #42635293 is currently running. I'll report back once the pipeline job completes.

Copy link
Collaborator

@bkryu bkryu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unit tests are all passing for relevant Blackwell GPUs on SM100/103/120

Thanks @yzh119, this was a much needed cleanup of the initial version of the kernel

@yzh119 yzh119 merged commit 67fc0a1 into flashinfer-ai:main Jan 27, 2026
24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants