Skip to content

chore/feat: A2A + MoE benchmark; add routed counterpart for trtllm_gen_fp8_fused_moe#2379

Merged
bkryu merged 5 commits intoflashinfer-ai:mainfrom
rosenrodt:chore-combined-a2a-moe-bench
Jan 29, 2026
Merged

chore/feat: A2A + MoE benchmark; add routed counterpart for trtllm_gen_fp8_fused_moe#2379
bkryu merged 5 commits intoflashinfer-ai:mainfrom
rosenrodt:chore-combined-a2a-moe-bench

Conversation

@rosenrodt
Copy link
Contributor

@rosenrodt rosenrodt commented Jan 20, 2026

📌 Description

  • Added --real_math flag for moe_comm.py bench script to run A2A + MoE: a2a dispatch -> trtllm moe (nvfp4 or fp8_block_scale) -> a2a combine
  • Added support of trtllm_gen_fp8_routed_fused_moe specifically for A2A + MoE benchmark

Example:

$ mpirun -np 2 python benchmarks/flashinfer_benchmark.py \
  --routine moe_a2a_dispatch_combine \
  --num_tokens 1024 \
  --hidden_size 7168 \
  --intermediate_size 2048 \
  --num_experts 256 \
  --top_k 8 --quant_dtype nvfp4 \
  --real_math \
  --per_phase_timing 

[INFO] args = Namespace(routine='moe_a2a_dispatch_combine', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=1, output_path=None, num_iters=30, dry_run_iters=5, case_tag=None, generate_repro_command=False, repro_command='', num_tokens=1024, hidden_size=7168, num_experts=256, top_k=8, input_dtype='bfloat16', quant_dtype='nvfp4', real_math=True, intermediate_size=2048, max_num_tokens=1024, validate=False, nvtx=False, per_phase_timing=True, scale_dtype=torch.float8_e4m3fn)
[INFO] args = Namespace(routine='moe_a2a_dispatch_combine', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=False, allow_output_mismatch=False, random_seed=42, verbose=1, output_path=None, num_iters=30, dry_run_iters=5, case_tag=None, generate_repro_command=False, repro_command='', num_tokens=1024, hidden_size=7168, num_experts=256, top_k=8, input_dtype='bfloat16', quant_dtype='nvfp4', real_math=True, intermediate_size=2048, max_num_tokens=1024, validate=False, nvtx=False, per_phase_timing=True, scale_dtype=torch.float8_e4m3fn)
[INFO] Running test_moe_a2a_dispatch_combine
[INFO] ep_size=2, rank=0
[INFO] Inter-rank traffic: dispatch=7.899 MiB, combine=27.863 MiB
[PERF] a2a_dispatch   :: median time 1.757 ms; std 1.778 ms; achieved tflops nan TFLOPs/sec; achieved tb_per_sec 0.005 TB/sec
[PERF] moe_kernel     :: median time 0.734 ms; std 1.262 ms; achieved tflops 1966.389 TFLOPs/sec; achieved tb_per_sec 8.693 TB/sec
[PERF] a2a_combine    :: median time 3.440 ms; std 1.705 ms; achieved tflops nan TFLOPs/sec; achieved tb_per_sec 0.008 TB/sec
[PERF] a2a_total      :: median time 5.182 ms; std 2.106 ms; achieved tflops nan TFLOPs/sec; achieved tb_per_sec 0.007 TB/sec
[INFO] The reported achieved tflops/tb_per_sec is the aggregate FLOPS/bandwidth of all participating ranks based on timing results of rank 0. Could observe rank-to-rank variations.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Selected unit tests (see below) passed locally after the change to CUDA/C source files

python -m pytest -v tests/moe/test_trtllm_gen_routed_fused_moe.py
python -m pytest -v tests/moe/test_trtllm_gen_fused_moe.py

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Pre-packed routing support for FP8 block-scale Mixture-of-Experts via a new public API.
    • Real MoE kernel execution path with FP8/NVFP4 and block-scale quantization for benchmarking.
    • Comprehensive MoE utilities for weight generation, FP4/FP8 quantization, routing, and performance metrics.
  • Documentation

    • Expanded MoE benchmarking docs with A2A dispatch/combine scenarios and examples.
  • Tests

    • Added FP8 routed MoE test coverage.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 20, 2026

📝 Walkthrough

Walkthrough

This PR introduces comprehensive MoE A2A (All-to-All) benchmarking infrastructure with FP4/FP8 quantization support. It adds a new utilities module (moe_utils.py) providing quantization, routing, and metric calculation functions, refactors existing benchmarks to use these utilities, extends the MoE A2A communication benchmark with real math support and pre-packed routing, and updates kernel launchers and Python APIs to enable pre-packed routing paths.

Changes

Cohort / File(s) Summary
Documentation
benchmarks/README.md, benchmarks/samples/sample_testlist.txt
Added MoE A2A dispatch/combine benchmark documentation with flags, launch examples, and test configurations for various quantization modes (none, FP8, NVFP4, FP8 block-scale).
MoE Utilities (New)
benchmarks/routines/moe_utils.py
New 766-line module providing: FP4/FP8 quantization/dequantization, weight generation, routing computation, top-k packing (Triton), MoE performance metrics, layout processing, and CLI argument utilities.
Benchmark Refactoring
benchmarks/routines/moe.py
Migrated from direct imports of WeightLayout and convert_to_block_layout to using moe_utils equivalents; replaced legacy helpers (e.g., quant_fp4_simplequantize_fp4), standardized weight/quantization flows, and integrated performance metric calculations.
MoE A2A Benchmarking
benchmarks/routines/moe_comm.py
Added real MoE kernel dispatch with _init_moe_weights(), extended _create_moe_inputs() for block-scale quantization, added --real_math, --intermediate_size, and --quant_dtype arguments, integrated MoE kernel timing and bandwidth metrics, and extended validation for FP8 block-scale paths.
Kernel Launcher Updates
csrc/trtllm_fused_moe_kernel_launcher.cu
Extended Fp8BlockScaleLauncher to accept optional pre-computed expert_indices and expert_weights, added routing fallback logic, updated entry point signatures (trtllm_fp8_block_scale_moe, trtllm_fp4_block_scale_moe) to propagate pre-packed routing data.
Python API Additions
flashinfer/fused_moe/__init__.py, flashinfer/fused_moe/core.py
Added new public function trtllm_fp8_block_scale_routed_moe() for pre-packed routing; updated trtllm_fp8_block_scale_moe_op() to support optional routing_logits, topk_ids, and expert_weights with dual-mode routing logic.
Testing
tests/moe/test_trtllm_gen_routed_fused_moe.py
Added test_trtllm_gen_fp8_routed_fused_moe() to test FP8 routed MoE paths with full input generation, routing, and reference validation.

Sequence Diagram(s)

sequenceDiagram
    participant Benchmark as MoE A2A<br/>Benchmark
    participant Utils as moe_utils<br/>Utilities
    participant Pack as Triton Packing<br/>(pack_topk_ids)
    participant Kernel as Kernel Launcher<br/>(FP8BlockScale)
    participant Backend as TRT-LLM<br/>Backend

    Benchmark->>Utils: generate_moe_weights()
    Utils-->>Benchmark: gemm1/gemm2 weights

    Benchmark->>Utils: quantize_fp4/quantize_fp8()
    Utils-->>Benchmark: quantized data + scales

    Benchmark->>Utils: compute_routing()
    Utils-->>Benchmark: routing_weights, selected_experts

    Benchmark->>Pack: pack_topk_ids_triton()
    Pack-->>Benchmark: packed expert_indices, expert_weights

    Benchmark->>Kernel: dispatch with packed routing<br/>(expert_indices, expert_weights)
    Kernel->>Kernel: check_routing()<br/>(use precomputed data)

    Kernel->>Backend: run FP8 block-scale MoE<br/>(hidden_states, scales, weights)
    Backend-->>Kernel: output tensors

    Kernel-->>Benchmark: results

    Benchmark->>Utils: calculate_moe_tflops()<br/>calculate_moe_kernel_bandwidth()
    Utils-->>Benchmark: performance metrics
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • Anerudhan
  • djmmoss
  • yzh119
  • cyx-6
  • bkryu
  • kahyunnam

Poem

🐰 Hops through quantized pathways bright,
FP4 and FP8 in flight,
Pre-packed routing speeds the way,
A2A benchmarks dance and sway,
MoE experts now can measure fast! 🚀

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 68.42% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly describes the main changes: adding A2A + MoE benchmarking capability and a routed MoE function variant.
Description check ✅ Passed The description provides context about the --real_math flag, usage example with mpirun command, example output, and notes on testing completion.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @rosenrodt, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the benchmarking infrastructure for Mixture-of-Experts (MoE) models by integrating All-to-All communication with actual MoE kernel execution. It expands the supported quantization formats to include FP8 block-scale and introduces a new API for routed MoE, which streamlines the process of evaluating MoE layers when routing decisions are pre-determined. These changes provide more accurate and flexible tools for performance analysis and optimization of MoE architectures.

Highlights

  • New A2A + MoE Benchmark: Introduced a new benchmark routine, moe_a2a_dispatch_combine, to evaluate the performance of All-to-All communication combined with actual Mixture-of-Experts (MoE) kernel computation. This benchmark supports various quantization types, including FP8 and NVFP4, and can now execute real MoE math.
  • FP8 Block-Scale Quantization Support: Added comprehensive support for FP8 block-scale quantization within the new A2A + MoE benchmark, allowing for more realistic performance analysis of MoE layers using this quantization scheme.
  • Routed Fused MoE API for FP8: Implemented trtllm_fp8_block_scale_routed_moe, a new API that enables MoE operations with pre-computed routing decisions (packed expert IDs and weights). This is particularly beneficial for scenarios like CUDA Graph capture and distributed MoE, where routing is handled externally.
  • MoE Utilities Refactoring: Refactored and consolidated common MoE-related utility functions, such as quantization/dequantization, weight generation, and performance metric calculations, into a new dedicated module benchmarks/routines/moe_utils.py for improved code organization and reusability.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for A2A (All-to-All) communication and MoE (Mixture of Experts) benchmarking, specifically adding a routed counterpart for trtllm_gen_fp8_fused_moe. The changes involve refactoring common MoE utilities into a new moe_utils.py file, updating benchmark scripts to use these utilities, and extending the C++ kernel launcher to support pre-computed routing for FP8 block-scale MoE. The documentation in benchmarks/README.md has been updated to reflect the new MoE communication flags and examples. A new test case for trtllm_gen_fp8_routed_fused_moe has also been added to ensure correctness.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (6)
flashinfer/fused_moe/core.py (2)

1572-1635: Autotuner crashes when routing_logits is None.
MoERunner.get_valid_tactics reads routing_logits.shape[0]; routed calls pass None, so this will throw before any kernel launch. Use a meta placeholder for tuning inputs (as in the FP4 path).

🛠️ Proposed fix
-        inputs = [
-            output,
-            routing_logits,
-            topk_ids,
-            expert_weights,
-            hidden_states,
-            hidden_states_scale,
-        ]
+        routing_logits_for_tuning = (
+            routing_logits
+            if routing_logits is not None
+            else torch.empty(
+                num_tokens, num_experts, dtype=routing_dtype, device="meta"
+            )
+        )
+        inputs = [
+            output,
+            routing_logits_for_tuning,
+            topk_ids,
+            expert_weights,
+            hidden_states,
+            hidden_states_scale,
+        ]

1689-1714: Silence Ruff ARG001 in the fake op signature.
Unused parameters can be prefixed with _ to keep lint clean.

♻️ Proposed cleanup
-def _fake_trtllm_fp8_block_scale_moe(
-    routing_logits: Optional[torch.Tensor],
-    topk_ids: Optional[torch.Tensor],
-    expert_weights: Optional[torch.Tensor],
+def _fake_trtllm_fp8_block_scale_moe(
+    _routing_logits: Optional[torch.Tensor],
+    _topk_ids: Optional[torch.Tensor],
+    _expert_weights: Optional[torch.Tensor],
@@
-    tune_max_num_tokens: int = 8192,
+    _tune_max_num_tokens: int = 8192,
 ):
benchmarks/routines/moe_comm.py (1)

593-636: Exact traffic calc misses fp8_block_scale.
_calculate_exact_comm_traffic falls through to the unquantized path, so dispatch/combine bandwidth is overstated for fp8 block‑scale payloads. Add a branch to include 1‑byte activations plus float32 block scales (per 128 elements).

🛠️ Proposed fix
-    quant_dtype: None, "fp8", or "nvfp4"
+    quant_dtype: None, "fp8", "nvfp4", or "fp8_block_scale"
@@
     if quant_dtype == "nvfp4":
         # NVFP4: 0.5 bytes per element + block scales
         hidden_bytes_per_token = hidden_size // 2
         scale_bytes_per_token = (hidden_size // 16) * 1  # float8_e4m3fn
     elif quant_dtype == "fp8":
         # FP8: 1 byte per element
         hidden_bytes_per_token = hidden_size * 1
         scale_bytes_per_token = 0
+    elif quant_dtype == "fp8_block_scale":
+        # FP8 block scale: 1 byte per element + float32 scales per 128 elements
+        hidden_bytes_per_token = hidden_size * 1
+        scale_bytes_per_token = (hidden_size // 128) * 4  # float32
     else:
         # No quantization
         element_size = torch.tensor([], dtype=input_dtype).element_size()
         hidden_bytes_per_token = hidden_size * element_size
         scale_bytes_per_token = 0
flashinfer/fused_moe/__init__.py (1)

17-58: Add trtllm_fp8_block_scale_routed_moe to the top-level flashinfer/__init__.py exports.

The function is already exported from flashinfer.fused_moe but missing from the main API. Since trtllm_fp4_block_scale_routed_moe is already exported at the top level, trtllm_fp8_block_scale_routed_moe should be added to maintain API consistency and ensure users can access it via flashinfer.trtllm_fp8_block_scale_routed_moe.

csrc/trtllm_fused_moe_kernel_launcher.cu (2)

758-843: Validate precomputed expert_weights before wiring into the routing workspace.
When expert_weights is provided, missing shape/dtype checks can lead to OOB writes or incorrect routing weights. Add the same shape/dtype validation used for internally allocated buffers (expect [num_tokens, top_k] and bfloat16).

🔧 Suggested guard additions
   void check_routing() const override {
     // Check ndim==2 and size>0 because empty placeholder tensors may have non-null data_ptr
     if (expert_indices.ndim() == 2 && expert_indices.size(0) > 0) {
       // Pre-computed routing: expert_indices is a packed tensor
       // Format: (expert_id << 16) | (weight_bf16.view(int16))
       TVM_FFI_ICHECK_EQ(expert_indices.ndim(), 2) << "expert_indices must be 2D.";
       TVM_FFI_ICHECK_EQ(expert_indices.size(0), hidden_states.size(0))
           << "expert_indices and hidden_states must have same number of tokens.";
       TVM_FFI_ICHECK_EQ(expert_indices.size(1), args->top_k)
           << "expert_indices dim1 must match top_k.";
       TVM_FFI_ICHECK_EQ(expert_indices.dtype(), dl_int32) << "expert_indices must be int32.";
     }
+
+    if (expert_weights.ndim() == 2 && expert_weights.size(0) > 0) {
+      TVM_FFI_ICHECK_EQ(expert_weights.size(0), hidden_states.size(0))
+          << "expert_weights and hidden_states must have same number of tokens.";
+      TVM_FFI_ICHECK_EQ(expert_weights.size(1), args->top_k)
+          << "expert_weights dim1 must match top_k.";
+      TVM_FFI_ICHECK_EQ(expert_weights.dtype(), dl_bfloat16)
+          << "expert_weights must be bfloat16.";
+    }

     FusedMoeLauncher::check_routing_common();

1741-1856: Add shape/dtype validation for FP4 expert_indices/expert_weights buffers.
These buffers are now caller-supplied and passed into the routing kernel. Without validation, a malformed tensor can cause OOB writes in the routing phase.

🔧 Suggested guard additions
   if (routing_logits.has_value()) {
     TVM_FFI_ICHECK(routing_logits.value().dtype() == dl_float32 ||
                    routing_logits.value().dtype() == dl_bfloat16)
         << "routing_logits must be float or bfloat16.";
     TVM_FFI_ICHECK_EQ(routing_logits.value().ndim(), 2) << "routing_logits must be 2D.";
     TVM_FFI_ICHECK_EQ(routing_logits.value().size(1), num_experts)
         << "routing_logits has incorrect shape.";
     if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
       TVM_FFI_ICHECK_EQ(routing_logits.value().dtype(), dl_float32)
           << "routing_logits must be float.";
     }
   }
+
+  TVM_FFI_ICHECK_EQ(expert_indices.ndim(), 2) << "expert_indices must be 2D.";
+  TVM_FFI_ICHECK_EQ(expert_indices.size(0), num_tokens)
+      << "expert_indices dim0 must match num_tokens.";
+  TVM_FFI_ICHECK_EQ(expert_indices.size(1), top_k)
+      << "expert_indices dim1 must match top_k.";
+  TVM_FFI_ICHECK_EQ(expert_indices.dtype(), dl_int32)
+      << "expert_indices must be int32.";
+
+  TVM_FFI_ICHECK_EQ(expert_weights.ndim(), 2) << "expert_weights must be 2D.";
+  TVM_FFI_ICHECK_EQ(expert_weights.size(0), num_tokens)
+      << "expert_weights dim0 must match num_tokens.";
+  TVM_FFI_ICHECK_EQ(expert_weights.size(1), top_k)
+      << "expert_weights dim1 must match top_k.";
+  TVM_FFI_ICHECK_EQ(expert_weights.dtype(), dl_bfloat16)
+      << "expert_weights must be bfloat16.";
🤖 Fix all issues with AI agents
In `@benchmarks/routines/moe_comm.py`:
- Around line 254-259: The comment and code disagree: instead of asserting
intermediate_size must be provided when args.real_math is true, implement the
documented default by detecting if args.intermediate_size is None and setting it
to args.hidden_size * 4 before the real_math/quant checks; then remove the
assert that forces a provided value (or adjust it to allow the default). Refer
to args.intermediate_size, args.real_math, and args.hidden_size in the
moe_comm.py block to locate and update the logic accordingly.

In `@benchmarks/routines/moe_utils.py`:
- Around line 453-491: The function calculate_moe_tflops currently accepts
num_experts but never uses it; either remove num_experts from the function
signature or mark it explicitly unused (e.g., rename to _num_experts or add a
comment/type ignore) to satisfy linters. Update all call sites if you remove the
parameter; if you choose to mark it unused, change the parameter name in
calculate_moe_tflops to _num_experts and add a short comment like "# unused -
kept for API compatibility" so the symbol is clearly identified and lint
warnings are suppressed.
- Around line 717-770: The function quantize_and_pack_nvfp4 silently replaces
block_scales_reshaped with all-ones when the reshape size mismatches, masking
bugs; change this to surface the error instead: validate the shape after calling
quantize_fp4 (check block_scales.numel() or block_scales_reshaped.numel() vs
expected_scale_elems) and either raise a descriptive ValueError (including
tensor.shape, expected_scale_elems, and block_scales.numel()) or at minimum log
a clear warning via the existing logging mechanism before failing, rather than
silently returning ones; keep the fallback only for a deliberate debug mode if
needed.
- Around line 248-343: The quantize/dequantize pair disagree on block_scales
layout: quantize_fp8_block_scale currently transposes block_scales to
[num_blocks, num_tokens] (block_scales_transposed) but
dequantize_fp8_block_scale expects [num_tokens, num_blocks]; fix by returning
block_scales in [num_tokens, num_blocks] from quantize_fp8_block_scale (i.e.,
remove the transpose and return block_scales.contiguous() instead of
block_scales_transposed) so the shapes match the docstring and
dequantize_fp8_block_scale; keep variable names block_scales and update any
callers if they relied on the transposed layout.
- Around line 655-689: process_fp8_weight_layout currently returns early when
use_shuffled_weight is False, skipping BlockMajorK conversion; change the logic
so you always obtain a uint8 view first (e.g., uint8_tensor =
tensor.view(torch.uint8)), only call shuffle_matrix_a when use_shuffled_weight
is True, then if weight_layout == WeightLayout.BlockMajorK call
convert_to_block_layout(uint8_tensor, block_k) (keep epilogue_tile_m for shuffle
and block_k=128 as before), and finally return the
result.view(torch.float8_e4m3fn); this keeps shuffle_matrix_a and
convert_to_block_layout usage (and their parameters) intact while ensuring
BlockMajorK conversion runs even without shuffling.

In `@benchmarks/routines/moe.py`:
- Around line 504-509: The tuple unpacking of quantize_fp4_batched assigns
gemm1_scales_global and gemm2_scales_global but those globals are unused and
trigger Ruff RUF059; change the unpacking to either drop or prefix these
variables with an underscore (e.g., _gemm1_scales_global, _gemm2_scales_global)
in the calls to quantize_fp4_batched so the linter no longer flags them (look
for the lines using quantize_fp4_batched that assign gemm1_weights_fp4_bytes,
gemm1_scales_fp4_bytes, gemm1_scales_global and the corresponding gemm2_*
assignment).

In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 1655-1662: The guard for precomputed routing is too lax (uses
expert_indices.data_ptr() != nullptr) and can allow empty tensors that later
crash in prepare_routing(); change the predicate so use_precomputed_routing
mirrors the launcher's check by requiring expert_indices to be 2-D and non-empty
— e.g. set use_precomputed_routing = (expert_indices.ndimension() == 2 &&
expert_indices.size(0) > 0) — and keep the TVM_FFI_ICHECK(...) that requires
either routing_logits or this stricter precomputed-routing condition.
🧹 Nitpick comments (2)
benchmarks/README.md (1)

32-33: Well-documented MOE Communication section.

The new moe_a2a_dispatch_combine routine is clearly described with key requirements (mpirun, quantization options, real math support) highlighted.

📝 Optional: Address markdown linting preference

The static analysis tool flags 4-space indentation for nested list items. While the current 4-space indentation is consistent throughout the file, you could optionally adjust to 2 spaces if you want to satisfy the linter:

 - MOE Communication:
-    - `moe_a2a_dispatch_combine` - MoE All-to-All dispatch + combine benchmark for multi-GPU expert-parallel inference. Requires `mpirun` for multi-GPU execution. Supports optional quantization (FP8, NVFP4, FP8 block-scale) and real MoE kernel computation.
+  - `moe_a2a_dispatch_combine` - MoE All-to-All dispatch + combine benchmark for multi-GPU expert-parallel inference. Requires `mpirun` for multi-GPU execution. Supports optional quantization (FP8, NVFP4, FP8 block-scale) and real MoE kernel computation.

Note: This would require updating all nested list items in the file for consistency.

benchmarks/routines/moe_comm.py (1)

467-468: Ruff TRY003: long message in inline ValueError.
If TRY003 is enforced, consider a # noqa: TRY003 or a small custom exception type.

🧹 Example suppression
-        raise ValueError(f"Unsupported quant_dtype for real computation: {quant_dtype}")
+        raise ValueError(
+            f"Unsupported quant_dtype for real computation: {quant_dtype}"
+        )  # noqa: TRY003

Comment on lines 504 to 509
gemm1_weights_fp4_bytes, gemm1_scales_fp4_bytes, gemm1_scales_global = (
quant_fp4_batches_simple(gemm1_weights, num_experts, use_ue8m0, True)
quantize_fp4_batched(gemm1_weights, num_experts, use_ue8m0, True)
)
gemm2_weights_fp4_bytes, gemm2_scales_fp4_bytes, gemm2_scales_global = (
quant_fp4_batches_simple(gemm2_weights, num_experts, use_ue8m0, True)
quantize_fp4_batched(gemm2_weights, num_experts, use_ue8m0, True)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Unused quant scale globals trigger Ruff RUF059.
Prefix with _ (or drop) to avoid lint failures.

♻️ Proposed cleanup
-    gemm1_weights_fp4_bytes, gemm1_scales_fp4_bytes, gemm1_scales_global = (
+    gemm1_weights_fp4_bytes, gemm1_scales_fp4_bytes, _gemm1_scales_global = (
         quantize_fp4_batched(gemm1_weights, num_experts, use_ue8m0, True)
     )
@@
-    gemm2_weights_fp4_bytes, gemm2_scales_fp4_bytes, gemm2_scales_global = (
+    gemm2_weights_fp4_bytes, gemm2_scales_fp4_bytes, _gemm2_scales_global = (
         quantize_fp4_batched(gemm2_weights, num_experts, use_ue8m0, True)
     )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
gemm1_weights_fp4_bytes, gemm1_scales_fp4_bytes, gemm1_scales_global = (
quant_fp4_batches_simple(gemm1_weights, num_experts, use_ue8m0, True)
quantize_fp4_batched(gemm1_weights, num_experts, use_ue8m0, True)
)
gemm2_weights_fp4_bytes, gemm2_scales_fp4_bytes, gemm2_scales_global = (
quant_fp4_batches_simple(gemm2_weights, num_experts, use_ue8m0, True)
quantize_fp4_batched(gemm2_weights, num_experts, use_ue8m0, True)
)
gemm1_weights_fp4_bytes, gemm1_scales_fp4_bytes, _gemm1_scales_global = (
quantize_fp4_batched(gemm1_weights, num_experts, use_ue8m0, True)
)
gemm2_weights_fp4_bytes, gemm2_scales_fp4_bytes, _gemm2_scales_global = (
quantize_fp4_batched(gemm2_weights, num_experts, use_ue8m0, True)
)
🧰 Tools
🪛 Ruff (0.14.13)

504-504: Unpacked variable gemm1_scales_global is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


507-507: Unpacked variable gemm2_scales_global is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🤖 Prompt for AI Agents
In `@benchmarks/routines/moe.py` around lines 504 - 509, The tuple unpacking of
quantize_fp4_batched assigns gemm1_scales_global and gemm2_scales_global but
those globals are unused and trigger Ruff RUF059; change the unpacking to either
drop or prefix these variables with an underscore (e.g., _gemm1_scales_global,
_gemm2_scales_global) in the calls to quantize_fp4_batched so the linter no
longer flags them (look for the lines using quantize_fp4_batched that assign
gemm1_weights_fp4_bytes, gemm1_scales_fp4_bytes, gemm1_scales_global and the
corresponding gemm2_* assignment).

@yzh119 yzh119 force-pushed the chore-combined-a2a-moe-bench branch from 491d262 to d15d55f Compare January 20, 2026 11:02
…py benchmark

fp8_block_scale; real moe math wip

add fp8 block scale routed moe; should work

resolve oom

wip: add fused topk id/weight padding

remove extra d2d copy

extract utility funcs

add moe tflops/bw calculation

add guard

add example testlist

tidy

wip: minor refactor

extract common utils

rename

tidy; all should pass
@rosenrodt rosenrodt force-pushed the chore-combined-a2a-moe-bench branch from 86adf34 to 5d66242 Compare January 21, 2026 07:37
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
benchmarks/routines/moe.py (1)

1299-1313: Use fp8_block_scale for bandwidth accounting in block‑scale mode.

Line 1308–1309 use "fp8" formats, which omits block‑scale overhead. Use "fp8_block_scale" to reflect scales.

🛠️ Suggested fix
-        input_format="fp8",
-        weight_format="fp8",
+        input_format="fp8_block_scale",
+        weight_format="fp8_block_scale",
flashinfer/fused_moe/core.py (2)

1547-1635: Avoid passing None into AutoTuner inputs when routing_logits is absent.

When routing_logits is None, inputs currently includes a None entry, which can break AutoTuner shape handling. Provide a meta tensor placeholder (as done in FP4) to keep tuner inputs tensor‑only.

🛠️ Suggested fix
-        inputs = [
+        routing_logits_for_tuning = (
+            routing_logits
+            if routing_logits is not None
+            else torch.empty(
+                num_tokens, num_experts, dtype=routing_dtype, device="meta"
+            )
+        )
+        inputs = [
             output,
-            routing_logits,
+            routing_logits_for_tuning,
             topk_ids,
             expert_weights,
             hidden_states,
             hidden_states_scale,
         ]

1690-1714: Silence Ruff ARG001 for unused fake‑op parameters.

The fake op signature now includes unused args; Ruff will flag these. Prefix unused parameters with _ or add # noqa: ARG001 on the definition.

🤖 Fix all issues with AI agents
In `@benchmarks/README.md`:
- Around line 32-33: The nested list under the "MOE Communication:" bullet has
incorrect indentation; adjust the nested bullet for `moe_a2a_dispatch_combine`
to use a 2-space indent (align the nested dash two spaces under the parent list
marker) so the nested item conforms to MD007; ensure the code span and
description remain on the same line after the nested dash.

In `@benchmarks/routines/moe.py`:
- Around line 389-393: The code calls compute_routing(routing_logits.float(),
top_k) and ignores routing_bias; update the non‑DeepSeek path to add
routing_bias into routing_logits before selection so selected_experts and
bandwidth metrics reflect the bias. Specifically, where compute_routing is used
(function/variable names: compute_routing, routing_logits, routing_bias,
selected_experts, top_k), add routing_bias (with proper dtype/device
broadcasting and a None check) to routing_logits (e.g., routing_logits +
routing_bias) and pass the biased logits to compute_routing, then return
selected_experts.

In `@tests/moe/test_trtllm_gen_routed_fused_moe.py`:
- Around line 279-282: The tests generate routing_logits as torch.bfloat16 which
is incompatible with the FP8 block-scale routing kernel; change the tensor
creation for routing_logits to use torch.float32 so the FP8 block-scale routing
receives float32 inputs. Locate the variable routing_logits in the test (the
torch.rand(...) call that currently casts to torch.bfloat16) and remove or
replace the .to(torch.bfloat16) with .to(torch.float32) (or simply create with
dtype=torch.float32) so downstream code like the FP8 block-scale routing path
receives float32 logits.
♻️ Duplicate comments (2)
benchmarks/routines/moe.py (1)

504-509: Prefix unused FP4 global scales to avoid Ruff RUF059.

The globals are unpacked but unused; Ruff will flag this.

♻️ Suggested cleanup
-    gemm1_weights_fp4_bytes, gemm1_scales_fp4_bytes, gemm1_scales_global = (
+    gemm1_weights_fp4_bytes, gemm1_scales_fp4_bytes, _gemm1_scales_global = (
         quantize_fp4_batched(gemm1_weights, num_experts, use_ue8m0, True)
     )
@@
-    gemm2_weights_fp4_bytes, gemm2_scales_fp4_bytes, gemm2_scales_global = (
+    gemm2_weights_fp4_bytes, gemm2_scales_fp4_bytes, _gemm2_scales_global = (
         quantize_fp4_batched(gemm2_weights, num_experts, use_ue8m0, True)
     )
benchmarks/routines/moe_comm.py (1)

253-257: Comment/behavior mismatch for intermediate_size.

The comment at line 254 says "Default intermediate_size to hidden_size * 4 if not specified" (based on the past review), but the code asserts that intermediate_size must be provided when real_math=True. Either implement the default or update the comment to reflect the actual behavior.

🧹 Nitpick comments (1)
benchmarks/routines/moe_comm.py (1)

1317-1318: Potentially redundant FP8 conversion.

When quant_dtype == "fp8_block_scale", the hidden states are already quantized to float8_e4m3fn in _create_moe_inputs (line 539). The conversion at line 1318 appears redundant. While it's a no-op for already-FP8 data, consider removing it or adding a comment clarifying intent.

Suggested change
-                    # Convert hidden states to FP8
-                    hidden_fp8 = hidden_flat.to(torch.float8_e4m3fn)
+                    # hidden_flat is already FP8 from quantize_fp8_block_scale
+                    hidden_fp8 = hidden_flat

Comment on lines +32 to +33
- MOE Communication:
- `moe_a2a_dispatch_combine` - MoE All-to-All dispatch + combine benchmark for multi-GPU expert-parallel inference. Requires `mpirun` for multi-GPU execution. Supports optional quantization (FP8, NVFP4, FP8 block-scale) and real MoE kernel computation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix nested list indentation to satisfy MD007.

Line 33 is flagged by markdownlint for list indent. Align the nested bullet with the configured 2‑space indent.

🛠️ Suggested fix
-    - `moe_a2a_dispatch_combine` - MoE All-to-All dispatch + combine benchmark for multi-GPU expert-parallel inference. Requires `mpirun` for multi-GPU execution. Supports optional quantization (FP8, NVFP4, FP8 block-scale) and real MoE kernel computation.
+  - `moe_a2a_dispatch_combine` - MoE All-to-All dispatch + combine benchmark for multi-GPU expert-parallel inference. Requires `mpirun` for multi-GPU execution. Supports optional quantization (FP8, NVFP4, FP8 block-scale) and real MoE kernel computation.
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
- MOE Communication:
- `moe_a2a_dispatch_combine` - MoE All-to-All dispatch + combine benchmark for multi-GPU expert-parallel inference. Requires `mpirun` for multi-GPU execution. Supports optional quantization (FP8, NVFP4, FP8 block-scale) and real MoE kernel computation.
- MOE Communication:
- `moe_a2a_dispatch_combine` - MoE All-to-All dispatch + combine benchmark for multi-GPU expert-parallel inference. Requires `mpirun` for multi-GPU execution. Supports optional quantization (FP8, NVFP4, FP8 block-scale) and real MoE kernel computation.
🧰 Tools
🪛 markdownlint-cli2 (0.18.1)

33-33: Unordered list indentation
Expected: 2; Actual: 4

(MD007, ul-indent)

🤖 Prompt for AI Agents
In `@benchmarks/README.md` around lines 32 - 33, The nested list under the "MOE
Communication:" bullet has incorrect indentation; adjust the nested bullet for
`moe_a2a_dispatch_combine` to use a 2-space indent (align the nested dash two
spaces under the parent list marker) so the nested item conforms to MD007;
ensure the code span and description remain on the same line after the nested
dash.

Comment on lines 389 to 393
# For other routing methods, use simple top-k as approximation
# This is accurate for Default, Renormalize, RenormalizeNaive, TopK
# and approximate for Llama4
_, selected_experts = _compute_routing(routing_logits.float(), top_k)
_, selected_experts = compute_routing(routing_logits.float(), top_k)
return selected_experts
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Apply routing_bias for accurate expert selection in non‑DeepSeek paths.

Line 392 uses compute_routing on raw logits, which ignores routing_bias when it’s provided. That skews selected_experts and bandwidth metrics. Consider adding the bias before routing.

🛠️ Suggested fix
-        _, selected_experts = compute_routing(routing_logits.float(), top_k)
+        logits = routing_logits.float()
+        if routing_bias is not None:
+            logits = logits + routing_bias.float()
+        _, selected_experts = compute_routing(logits, top_k)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# For other routing methods, use simple top-k as approximation
# This is accurate for Default, Renormalize, RenormalizeNaive, TopK
# and approximate for Llama4
_, selected_experts = _compute_routing(routing_logits.float(), top_k)
_, selected_experts = compute_routing(routing_logits.float(), top_k)
return selected_experts
# For other routing methods, use simple top-k as approximation
# This is accurate for Default, Renormalize, RenormalizeNaive, TopK
# and approximate for Llama4
logits = routing_logits.float()
if routing_bias is not None:
logits = logits + routing_bias.float()
_, selected_experts = compute_routing(logits, top_k)
return selected_experts
🤖 Prompt for AI Agents
In `@benchmarks/routines/moe.py` around lines 389 - 393, The code calls
compute_routing(routing_logits.float(), top_k) and ignores routing_bias; update
the non‑DeepSeek path to add routing_bias into routing_logits before selection
so selected_experts and bandwidth metrics reflect the bias. Specifically, where
compute_routing is used (function/variable names: compute_routing,
routing_logits, routing_bias, selected_experts, top_k), add routing_bias (with
proper dtype/device broadcasting and a None check) to routing_logits (e.g.,
routing_logits + routing_bias) and pass the biased logits to compute_routing,
then return selected_experts.

Comment on lines +279 to +282
# Generate random routing logits for reference
routing_logits = torch.rand(num_tokens, num_experts, device=device).to(
torch.bfloat16
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Use float32 routing_logits for FP8 block‑scale routing.

Line 279–282 uses bfloat16 logits, but FP8 block‑scale routing expects float32 (per kernel constraints noted elsewhere in this PR). Using bfloat16 risks incorrect routing or runtime errors.

🛠️ Suggested fix
-    routing_logits = torch.rand(num_tokens, num_experts, device=device).to(
-        torch.bfloat16
-    )
+    routing_logits = torch.rand(num_tokens, num_experts, device=device).to(
+        torch.float32
+    )
🤖 Prompt for AI Agents
In `@tests/moe/test_trtllm_gen_routed_fused_moe.py` around lines 279 - 282, The
tests generate routing_logits as torch.bfloat16 which is incompatible with the
FP8 block-scale routing kernel; change the tensor creation for routing_logits to
use torch.float32 so the FP8 block-scale routing receives float32 inputs. Locate
the variable routing_logits in the test (the torch.rand(...) call that currently
casts to torch.bfloat16) and remove or replace the .to(torch.bfloat16) with
.to(torch.float32) (or simply create with dtype=torch.float32) so downstream
code like the FP8 block-scale routing path receives float32 logits.

@rosenrodt
Copy link
Contributor Author

Selected unit tests (see below) passed locally after the change to CUDA/C source files

python -m pytest -v tests/moe/test_trtllm_gen_routed_fused_moe.py
python -m pytest -v tests/moe/test_trtllm_gen_fused_moe.py

@aleozlx
Copy link
Collaborator

aleozlx commented Jan 23, 2026

/bot run

Copy link
Collaborator

@aleozlx aleozlx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm thank you!

@flashinfer-bot
Copy link
Collaborator

GitLab MR !260 has been created, and the CI pipeline #42365843 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #42365843: 9/20 passed

@rosenrodt
Copy link
Contributor Author

@aleozlx I see most of the CI errors are unrelated:

failed: 404 Client Error: Not Found for url: https://edge.urm.nvidia.com/artifactory/sw-kernelinferencelibrary-public-generic-unit-test/037e528e719ec3456a7d7d654f26b805e44c63b1/gemm-8704aa4-f91dc9e/

And:

NotImplementedError: "mul_cuda" not implemented for 'Float8_e4m3fn'
[2/2] Replaying 20260123_140516_501_pid951116_mm_fp8_call0001...
Replaying mm_fp8 from /tmp/pytest-of-yowu/pytest-1/test_mm_fp8_replay0/test_dumps/20260123_140516_501_pid951116_mm_fp8_call0001
  Args: 3, Kwargs: []
Executing flashinfer.gemm.gemm_base.mm_fp8...
Dumped inputs to: /tmp/pytest-of-yowu/pytest-1/test_mm_fp8_replay0/test_dumps/20260123_140516_709_pid951116_mm_fp8_call0002 (size: 20.03 MB, total: 4/1000 dumps)
Inputs dumped to: /tmp/pytest-of-yowu/pytest-1/test_mm_fp8_replay0/test_dumps/20260123_140516_709_pid951116_mm_fp8_call0002
================================================================================
[2026-01-23 14:05:16] FlashInfer API Call: mm_fp8
================================================================================

@aleozlx
Copy link
Collaborator

aleozlx commented Jan 26, 2026

agree. this is good to go
cc @yzh119

@bkryu bkryu enabled auto-merge (squash) January 27, 2026 17:28
@aleozlx aleozlx requested a review from nv-yunzheq as a code owner January 27, 2026 18:35
@aleozlx
Copy link
Collaborator

aleozlx commented Jan 27, 2026

there was a pipeline blockage, i did a git merge main to restart the "checks"

@aleozlx
Copy link
Collaborator

aleozlx commented Jan 27, 2026

everything passed now
asked @jimmyzho to rubber-stamp it in for the time being

Copy link
Contributor

@jimmyzho jimmyzho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approved on behalf of @aleozlx

@aleozlx
Copy link
Collaborator

aleozlx commented Jan 27, 2026

@yzh119 there is some sort of bug on github.
i have asked @bkryu and @jimmyzho to help approve it in. but as soon as they approve, i observe them dropping from the required approver list, rendering their approval ineffective.
as the result the PR is still blocked for the same reason 2 days ago..

@aleozlx
Copy link
Collaborator

aleozlx commented Jan 27, 2026

could you help merge this pls
@yzh119

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants