tests: add bias testing to nvfp4 moe#2585
Conversation
remove sparse
Summary of ChangesHello @jimmyzho, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the NVFP4 Mixture of Experts (MoE) implementation by adding comprehensive bias support. The changes involve modifying the argument structures to accommodate bias terms for both GEMM operations within the MoE layer and updating the dequantization and reference computation logic to correctly apply these biases. The primary impact is the introduction of new test cases that verify the accuracy of the NVFP4 MoE kernel when biases are present, ensuring the robustness and correctness of the MoE functionality under these conditions. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
No actionable comments were generated in the recent review. 🎉 📝 WalkthroughWalkthroughAdds runtime-configurable GEMM bias support for FP4 MoE by introducing Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Tip Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request adds test coverage for bias support in the nvfp4 MoE implementation. The changes are well-contained within the test file and correctly add gemm1_bias and gemm2_bias to the reference implementations and new tests. The new tests for gemm1_bias, gemm2_bias, and both biases together are comprehensive. I have one suggestion to improve the maintainability of the new test code by refactoring duplicated logic.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
tests/moe/test_trtllm_gen_fused_moe.py (2)
3195-3207: Settorch.random.manual_seed(0)before creating bias tensors for full reproducibility.In all three test functions,
torch.random.manual_seed(0)is called after the bias tensors are created viatorch.randn(...). This means bias values depend on whatever RNG state was left by the prior parametrized test case. While the reference-vs-kernel comparison is still valid (both see the same bias), this makes individual test failures harder to reproduce in isolation.Proposed fix (example for `test_nvfp4_moe_gemm2_bias`; apply analogously to the other two)
num_experts, top_k = 8, 2 device = "cuda" + torch.random.manual_seed(0) # gemm2_bias shape: [num_experts, hidden_size], dtype float32 gemm2_bias = torch.randn( (num_experts, hidden_size), device=device, dtype=torch.float32 ) - torch.random.manual_seed(0) kernel_output, ref_output = _run_fp4_moe_with_bias(Also applies to: 3234-3246, 3271-3287
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/moe/test_trtllm_gen_fused_moe.py` around lines 3195 - 3207, The bias tensors (e.g., gemm2_bias) are created before seeding the RNG, so their values vary with prior test RNG state; move the torch.random.manual_seed(0) call to immediately before creating each bias tensor (i.e., call torch.random.manual_seed(0) before the torch.randn(...) that produces gemm2_bias in test_nvfp4_moe_gemm2_bias and the two analogous test functions) so that the bias is deterministically reproducible while leaving the subsequent calls to _run_fp4_moe_with_bias unchanged.
3138-3138: Inconsistentweight_processingdict key:"shuffle"instead of"use_shuffled_weight".All other call sites (e.g.,
run_moe_testat Line 2545,FP8BlockScaleMoe.prepare_static_weights_for_kernelat Line 946) use"use_shuffled_weight"as the key.FP4Moe.prepare_static_weights_for_kernelhappens to ignore theweight_processingparameter entirely, so this doesn't cause a runtime failure today, but it would silently break if FP4 ever starts using that dict.Proposed fix
- {"shuffle": True, "layout": WeightLayout.MajorK}, + {"use_shuffled_weight": True, "layout": WeightLayout.MajorK},🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/moe/test_trtllm_gen_fused_moe.py` at line 3138, The weight_processing dict in the test uses the wrong key "shuffle" — change that entry to "use_shuffled_weight": True so it matches other call sites (see run_moe_test usage and FP8BlockScaleMoe.prepare_static_weights_for_kernel / FP4Moe.prepare_static_weights_for_kernel expectations); update the dict in the test_trtllm_gen_fused_moe test case where {"shuffle": True, "layout": WeightLayout.MajorK} appears to {"use_shuffled_weight": True, "layout": WeightLayout.MajorK} to ensure consistent behavior if FP4 starts honoring the parameter.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/moe/test_trtllm_gen_fused_moe.py`:
- Around line 3172-3174: The test passes inconsistent types for
routing_method_type and activation_type: the target function signature expects
ints (routing_method_type: int = 0, activation_type: int =
ActivationType.Swiglu.value) but some call sites supply enum objects (e.g.,
self.config["activation_type"]) while others use .value; update all callers to
pass the enum's integer value (use .value) consistently—e.g., change usages of
self.config["activation_type"] or other enum instances to
self.config["activation_type"].value (and likewise for routing_method_type) so
every call to the function uses integer values matching the signature.
---
Nitpick comments:
In `@tests/moe/test_trtllm_gen_fused_moe.py`:
- Around line 3195-3207: The bias tensors (e.g., gemm2_bias) are created before
seeding the RNG, so their values vary with prior test RNG state; move the
torch.random.manual_seed(0) call to immediately before creating each bias tensor
(i.e., call torch.random.manual_seed(0) before the torch.randn(...) that
produces gemm2_bias in test_nvfp4_moe_gemm2_bias and the two analogous test
functions) so that the bias is deterministically reproducible while leaving the
subsequent calls to _run_fp4_moe_with_bias unchanged.
- Line 3138: The weight_processing dict in the test uses the wrong key "shuffle"
— change that entry to "use_shuffled_weight": True so it matches other call
sites (see run_moe_test usage and
FP8BlockScaleMoe.prepare_static_weights_for_kernel /
FP4Moe.prepare_static_weights_for_kernel expectations); update the dict in the
test_trtllm_gen_fused_moe test case where {"shuffle": True, "layout":
WeightLayout.MajorK} appears to {"use_shuffled_weight": True, "layout":
WeightLayout.MajorK} to ensure consistent behavior if FP4 starts honoring the
parameter.
aleozlx
left a comment
There was a problem hiding this comment.
looks good to me
bot comments are reasonable to address, pls take a look
There was a problem hiding this comment.
🧹 Nitpick comments (2)
tests/moe/test_trtllm_gen_fused_moe.py (2)
571-572: Consider using.get()for backward-compatible kwargs access.
kwargs["gemm1_bias"]/kwargs["gemm2_bias"]will raiseKeyErrorif any future caller ofcall_moeomits these. Using.get("gemm1_bias", None)is consistent with howenable_autotuneis already handled in this same method.♻️ Proposed fix
- gemm1_bias = kwargs["gemm1_bias"] - gemm2_bias = kwargs["gemm2_bias"] + gemm1_bias = kwargs.get("gemm1_bias", None) + gemm2_bias = kwargs.get("gemm2_bias", None)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/moe/test_trtllm_gen_fused_moe.py` around lines 571 - 572, In call_moe, replace direct dict indexing for gemm1_bias and gemm2_bias with safe retrieval using kwargs.get("gemm1_bias", None) and kwargs.get("gemm2_bias", None) (same pattern used for enable_autotune) so missing callers won't raise KeyError; update the references to gemm1_bias and gemm2_bias in that function accordingly.
2186-2187: Bias propagation added to non-FP4 reference paths without corresponding production support.
run_moe_reference_dsfp8,run_moe_reference_bf16,run_moe_reference_per_tensor_scale_fp8, andrun_moe_reference_mxint4now forwardgemm1_bias/gemm2_biasintorun_moe_dequant, but their production counterparts (trtllm_fp8_block_scale_moe,trtllm_bf16_moe, etc.) do not accept or apply biases. Any future test that passes non-Nonebiases with these quant modes will silently mismatch between reference and production outputs. Consider adding an assertion in those reference functions that biases areNoneif production doesn't support them, e.g.:assert args.gemm1_bias is None and args.gemm2_bias is None, \ "GEMM bias not supported for FP8/BF16/MxInt4 production kernels"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/moe/test_trtllm_gen_fused_moe.py` around lines 2186 - 2187, The reference functions run_moe_reference_dsfp8, run_moe_reference_bf16, run_moe_reference_per_tensor_scale_fp8, and run_moe_reference_mxint4 are forwarding gemm1_bias/gemm2_bias into run_moe_dequant while their production counterparts (trtllm_fp8_block_scale_moe, trtllm_bf16_moe, etc.) do not support biases; add an assertion at the start of each of those reference functions (before calling run_moe_dequant) that args.gemm1_bias is None and args.gemm2_bias is None with a clear message like "GEMM bias not supported for FP8/BF16/MxInt4 production kernels" so tests fail-fast when non-None biases are passed.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/moe/test_trtllm_gen_fused_moe.py`:
- Around line 571-572: In call_moe, replace direct dict indexing for gemm1_bias
and gemm2_bias with safe retrieval using kwargs.get("gemm1_bias", None) and
kwargs.get("gemm2_bias", None) (same pattern used for enable_autotune) so
missing callers won't raise KeyError; update the references to gemm1_bias and
gemm2_bias in that function accordingly.
- Around line 2186-2187: The reference functions run_moe_reference_dsfp8,
run_moe_reference_bf16, run_moe_reference_per_tensor_scale_fp8, and
run_moe_reference_mxint4 are forwarding gemm1_bias/gemm2_bias into
run_moe_dequant while their production counterparts (trtllm_fp8_block_scale_moe,
trtllm_bf16_moe, etc.) do not support biases; add an assertion at the start of
each of those reference functions (before calling run_moe_dequant) that
args.gemm1_bias is None and args.gemm2_bias is None with a clear message like
"GEMM bias not supported for FP8/BF16/MxInt4 production kernels" so tests
fail-fast when non-None biases are passed.
|
@aleozlx Just refactored the test and now it is directly calling |
|
/bot run |
|
[FAILED] Pipeline #44471261: 14/20 passed |
📌 Description
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Tests
Refactor