Skip to content

chore: Add TRTLLM MoE A2A benchmark#2354

Merged
yzh119 merged 5 commits intoflashinfer-ai:mainfrom
rosenrodt:chore-moe-a2a-bench
Jan 18, 2026
Merged

chore: Add TRTLLM MoE A2A benchmark#2354
yzh119 merged 5 commits intoflashinfer-ai:mainfrom
rosenrodt:chore-moe-a2a-bench

Conversation

@rosenrodt
Copy link
Contributor

@rosenrodt rosenrodt commented Jan 14, 2026

📌 Description

This PR adds benchmark script for the MoE A2A kernels added in #2102

Supports bf16, nvfp4, and per-tensor fp8.

A2A dispatch communicates in quantized format (fp8, nvfp4). A2A combine commincates in orignal format (bfloat16/float16).

# Example usage
mpirun -np 2 python benchmarks/flashinfer_benchmark.py --routine moe_a2a_dispatch_combine --num_tokens 1024 --hidden_size 7168 --num_experts 256 --top_k 8 -v --quant_dtype nvfp4

# Check with NVTX + nsys
nsys profile -f true -o moe_a2a --trace=cuda,nvtx --trace-fork-before-exec=true --cuda-event-trace=true --cuda-graph-trace=node mpirun -np 2 python benchmarks/flashinfer_benchmark.py --routine moe_a2a_dispatch_combine --num_tokens 1024 --hidden_size 7168 --num_experts 256 --top_k 8 -v --nvtx --quant_dtype nvfp4

# DS-R1 DEP8 example output (number is for demo purpose only)
$ mpirun -np 8 python benchmarks/flashinfer_benchmark.py --routine moe_a2a_dispatch_combine --num_tokens 8192 --hidden_size 7168 --num_experts 256 --top_k 8 --validate -vv --quant_dtype nvfp4
[INFO] Actual inter-rank traffic: dispatch=1170.583 MiB, combine=4149.633 MiB
[PERF] a2a_total      :: median time 1.073 ms; std 0.001 ms; achieved tflops nan TFLOPs/sec; achieved tb_per_sec 5.198 TB/sec
[INFO] The reported achieved tb_per_sec is the aggregate bandwidth of all participating ranks.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added MoE all-to-all communication benchmark with FP8 and NVFP4 options, per‑phase timing, optional NVTX profiling, and bandwidth metrics.
    • Added new output schema for MoE communication and included it in reported benchmark columns; new benchmark API group added.
  • Chores

    • Integrated MoE communication routine into benchmark configuration, argument parsing, and dispatch.
  • Tests

    • Enhanced test helper with type annotations and a descriptive docstring.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 14, 2026

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

📝 Walkthrough

Walkthrough

Adds a new MoE all-to-all communication benchmark routine moe_comm with a large new benchmarking module (MPI/CUDA setup, FP8/NVFP4 quantization, dispatch/combine flows, timing, validation), integrates schema and CLI parsing, and updates a test's fake_moe signature with type annotations.

Changes

Cohort / File(s) Summary
Benchmark CLI & routing
benchmarks/flashinfer_benchmark.py
Added moe_comm to top-level --routine choices, wired dispatch branch to call run_moe_comm_test, and added routine-specific arg parsing.
Benchmark utilities / schemas
benchmarks/routines/flashinfer_benchmark_utils.py
Added output_column_dict["moe_comm"], included it in full_output_columns, and registered benchmark_apis["moe_comm"] with endpoint moe_a2a_dispatch_combine.
MoE communication benchmark module
benchmarks/routines/moe_comm.py
New module implementing MPI/device setup, FP8/NVFP4 quantize/dequantize, input generation, fake-MoE validation, dispatch/combine benchmarking, per-phase timing, NVTX hooks; exposes run_moe_comm_test, parse_moe_comm_args, and test_moe_a2a_dispatch_combine. Attention: MPI/device mapping, quantization math, synchronization, and bandwidth calculations.
Tests — fake MoE signature
tests/comm/test_trtllm_moe_alltoall.py
Added full Python type annotations and docstring to fake_moe() signature; implementation unchanged.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant Runner as Benchmark Runner
    participant MPI
    participant Device as CUDA Device
    participant Quant as Quantization
    participant Moe as MoeAlltoAll
    participant Net as Network

    User->>Runner: start moe_comm with args
    Runner->>MPI: init communicator
    MPI-->>Runner: rank, world_size
    Runner->>Device: set CUDA device / init
    Runner->>Runner: build inputs (hidden_states, token_experts)
    alt validation enabled
        Runner->>Quant: quantize inputs (FP8 / NVFP4)
        Quant-->>Runner: quantized tensors
        Runner->>Moe: dispatch
        Moe->>Net: all-to-all exchange
        Net-->>Moe: exchanged shards
        Moe-->>Runner: combine output
        Runner->>Quant: dequantize
        Runner->>Runner: validate vs fake_moe
    end
    loop benchmarking iterations
        Runner->>Quant: quantize
        Runner->>Moe: dispatch + combine
        Moe->>Net: all-to-all
        Device-->>Runner: record CUDA events / timings
    end
    Runner->>Runner: compute median/std/bandwidth
    Runner-->>User: print results
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly related PRs

Suggested reviewers

  • cyx-6
  • jiahanc
  • bkryu
  • yzh119
  • kahyunnam
  • aleozlx

Poem

🐰 Hopping through tensors, I quantify and roam,

Dispatching carrots from rank to rank, back home.
FP8 whispers while NVFP4 hums a tune,
MPI drums softly under the moon.
Benchmark carrots—bandwidth blooms soon!

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title 'chore: Add TRTLLM MoE A2A benchmark' clearly and directly describes the main change - adding a new benchmark for MoE A2A communication kernels.
Description check ✅ Passed The description includes a clear summary of changes, example usage commands, output examples, and notes that pre-commit checks were completed, though test-related checkbox remains unchecked.
Docstring Coverage ✅ Passed Docstring coverage is 82.61% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @rosenrodt, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a comprehensive benchmark for Mixture-of-Experts (MoE) All-to-All communication operations within the FlashInfer framework. The new benchmark allows for detailed performance analysis of the dispatch and combine phases, supporting various quantization schemes like FP8 and NVFP4. It also includes a robust validation step to verify the functional correctness of the communication patterns, ensuring reliable performance measurements for distributed MoE models.

Highlights

  • New MoE All-to-All Communication Benchmark: Added a new benchmark for Mixture-of-Experts (MoE) All-to-All communication, specifically focusing on the dispatch_combine routine to measure end-to-end performance.
  • Quantization Support: Introduced support for benchmarking MoE communication with FP8 and NVFP4 quantization, allowing performance evaluation under different precision settings and data formats.
  • Validation Mechanism: Implemented a robust validation mechanism within the benchmark to ensure the correctness of the MoE All-to-All operations, including handling of quantized data and comparing results against a deterministic fake MoE transformation.
  • Integration into Benchmark Framework: Integrated the new MoE communication routines and argument parsing into the main flashinfer_benchmark.py script and utility functions, making it accessible through the existing benchmarking infrastructure.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new benchmark for TRTLLM MoE All-to-All (A2A) communication, which is a significant addition. The new benchmark routine is well-structured, covering various configurations including quantization (FP8, NVFP4) and providing robust validation logic. The code is comprehensive and includes detailed setup for MPI, performance timing, and correctness checks. My review identifies a couple of areas for improvement: one to correct a function's return type hint for better code clarity and static analysis, and another to fix a potential UnboundLocalError in the results reporting logic to prevent crashes under certain configurations.

@rosenrodt rosenrodt marked this pull request as ready for review January 15, 2026 02:23
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@benchmarks/routines/flashinfer_benchmark_utils.py`:
- Around line 72-81: The output schema list for "moe_comm" uses "quant_format"
but the producer code in moe_comm.py writes the field as cur_res["quant_dtype"],
causing the CSV column to be empty; update the schema entry in the "moe_comm"
list (replace "quant_format" with "quant_dtype") so the schema name matches the
producer, and verify any CSV/header generation uses the same symbol
(quant_dtype) to ensure values from cur_res["quant_dtype"] are written.
🧹 Nitpick comments (6)
tests/comm/test_trtllm_moe_alltoall.py (1)

442-446: Minor style difference with duplicated fake_moe in moe_comm.py.

This condition uses a slightly different style than the copy in benchmarks/routines/moe_comm.py (lines 651-655), which uses chained comparison (ep_rank * num_experts_per_rank <= expert_id < ...). Both are functionally equivalent, but consider aligning them for consistency since the moe_comm.py version explicitly notes it was copied from this file.

benchmarks/routines/moe_comm.py (5)

309-318: Potential issue with tensor comparison on line 312.

When scale is passed as a tensor (e.g., torch.tensor([0.0])), the comparison scale != 0.0 returns a tensor, not a boolean. This could cause unexpected behavior. However, in the current usage, scale is either None (computed fresh with clamped amax >= 1e-6) or passed from MPI sync where it's also non-zero, so this is unlikely to trigger in practice.

Suggested defensive fix
-    inv_scale = 1.0 / scale if scale != 0.0 else 0.0
+    inv_scale = 1.0 / scale if scale.item() != 0.0 else 0.0

607-669: Code duplication with test file - consider future refactor.

The fake_moe function is duplicated from tests/comm/test_trtllm_moe_alltoall.py. While this is acceptable for now (and acknowledged with a comment), consider extracting to a shared utility module if this pattern continues to grow.


676-676: Unused parameter: token_final_scales.

The token_final_scales parameter is passed but never used in the validation logic. Consider removing it or prefixing with underscore if kept for API consistency.

Proposed fix
-    token_final_scales: torch.Tensor,
+    _token_final_scales: torch.Tensor,  # Unused, kept for API consistency

Or remove from function signature and call sites if not needed for future use.


891-891: Unused variable: local_rank.

The local_rank return value is not used after unpacking. Prefix with underscore to indicate intentional discard.

Proposed fix
-    comm, rank, world_size, local_rank = _setup_mpi_and_device()
+    comm, rank, world_size, _local_rank = _setup_mpi_and_device()

949-956: Unused variable: scale_factor.

The scale_factor is unpacked but not used directly since it's already included in input_payloads. Prefix with underscore.

Proposed fix
     (
         hidden_states,
         hidden_states_original,
         token_selected_experts,
         token_final_scales,
-        scale_factor,
+        _scale_factor,
         global_scale,
         input_payloads,
     ) = _create_moe_inputs(
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f0277fd and b43a786.

📒 Files selected for processing (4)
  • benchmarks/flashinfer_benchmark.py
  • benchmarks/routines/flashinfer_benchmark_utils.py
  • benchmarks/routines/moe_comm.py
  • tests/comm/test_trtllm_moe_alltoall.py
🧰 Additional context used
📓 Path-based instructions (1)
tests/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

tests/**/*.py: Test implementations should use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing with mpirun on multi-GPU systems, use the pattern: mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes - tests/conftest.py provides auto-skipping for OOM tests as a safety net but should not be relied upon

Files:

  • tests/comm/test_trtllm_moe_alltoall.py
🧠 Learnings (1)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/__init__.py : Export new operations in `flashinfer/__init__.py` to make them available as public API

Applied to files:

  • benchmarks/routines/flashinfer_benchmark_utils.py
🧬 Code graph analysis (2)
benchmarks/flashinfer_benchmark.py (1)
benchmarks/routines/moe_comm.py (2)
  • parse_moe_comm_args (172-263)
  • run_moe_comm_test (155-169)
benchmarks/routines/moe_comm.py (5)
flashinfer/comm/trtllm_moe_alltoall.py (4)
  • MoeAlltoAll (393-732)
  • dispatch (596-653)
  • get_combine_payload_tensor_in_workspace (699-732)
  • combine (656-696)
flashinfer/comm/mapping.py (2)
  • Mapping (21-475)
  • local_rank (391-392)
flashinfer/testing/utils.py (1)
  • bench_gpu_time (1508-1655)
benchmarks/routines/flashinfer_benchmark_utils.py (1)
  • dtype_str_to_torch_dtype (159-173)
flashinfer/fused_moe/utils.py (1)
  • _ (157-163)
🪛 Ruff (0.14.11)
benchmarks/routines/moe_comm.py

169-169: Avoid specifying long messages outside the exception class

(TRY003)


676-676: Unused function argument: token_final_scales

(ARG001)


891-891: Unpacked variable local_rank is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


902-902: Do not catch blind exception: Exception

(BLE001)


954-954: Unpacked variable scale_factor is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🔇 Additional comments (16)
benchmarks/routines/flashinfer_benchmark_utils.py (1)

123-125: LGTM - New moe_comm API group registered correctly.

The moe_a2a_dispatch_combine routine is properly registered under the new moe_comm category. Note that this routine doesn't appear in routine_cc_to_supported_backends, which is acceptable since the MoE A2A benchmark uses MPI-based multi-GPU setup and doesn't rely on the same backend filtering mechanism as other routines.

benchmarks/flashinfer_benchmark.py (4)

12-12: LGTM - Import follows existing pattern.

The import for the new moe_comm module follows the same pattern as other routine imports (attention, gemm, moe).


30-31: LGTM - Routing correctly integrated.

The routing for moe_comm routines is correctly added following the existing dispatch pattern.


69-72: LGTM - Routine choices extended correctly.

The --routine argument choices now include moe_comm routines, maintaining consistency with how other routine types are registered.


163-164: LGTM - Argument parsing branch added correctly.

The argument parsing for moe_comm routines follows the established pattern.

tests/comm/test_trtllm_moe_alltoall.py (1)

398-423: LGTM - Type annotations and docstring improve clarity.

The added type hints and comprehensive docstring improve maintainability and document the deterministic scaling logic used for validation.

benchmarks/routines/moe_comm.py (10)

16-66: LGTM - Comprehensive documentation with launch examples.

The module docstring provides clear usage examples for various configurations (basic, FP8, NVFP4, validation, per-phase timing, NVTX profiling), which will be helpful for users.


93-126: LGTM - Well-designed context manager for deferred timing.

The cuda_event_timer correctly implements deferred timing by collecting events without synchronization, allowing efficient batched timing measurement.


155-169: LGTM - Routing function follows established pattern.

The run_moe_comm_test function correctly routes to the appropriate test implementation and includes a TODO for future expansion.


172-263: LGTM - Argument parsing is well-structured.

The argument parser correctly handles all required and optional parameters with sensible defaults.


266-291: LGTM - MPI and device setup follows standard patterns.

The _setup_mpi_and_device correctly uses MPI.COMM_TYPE_SHARED for node-local rank detection, which is the proper way to handle multi-node multi-GPU setups.


427-534: LGTM - Input creation handles quantization modes correctly.

The function properly handles three modes (no quantization, FP8, NVFP4) and correctly synchronizes the global scale across MPI ranks. The return type hint now correctly specifies all 7 return values.


537-604: LGTM - Bandwidth calculation correctly accounts for quantization and phases.

The bandwidth calculation properly handles different quantization formats and the asymmetric dispatch/combine data sizes (quantized dispatch, activation-dtype combine).


900-905: Broad exception catch is acceptable here.

The broad except Exception catch for MNNVL initialization is acceptable since MNNVL can fail for various platform-specific reasons (missing hardware support, insufficient permissions, etc.). The error is logged and handled gracefully.


1100-1103: Past review issue resolved - per-phase metrics properly initialized.

The per-phase timing metrics are now correctly initialized to np.nan before the conditional block, preventing the UnboundLocalError that was flagged in a previous review.


1158-1180: Note: Per-phase metrics stored but not in CSV output schema.

The result dictionary includes per-phase metrics (dispatch_time, combine_time, etc.) which are useful for debugging, but these fields are not in the output_column_dict["moe_comm"] schema, so they won't appear in the CSV output. This is acceptable if console output is sufficient; otherwise, consider extending the schema.

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@benchmarks/routines/moe_comm.py`:
- Around line 285-291: The function _calculate_fp4_global_scale returns a Python
float when tensor_amax == 0.0 which violates the annotated torch.Tensor return
type; change the implementation so both branches return a torch.Tensor (e.g.,
create a torch.tensor(0.0, dtype=torch.float32, device=tensor.device) for the
zero case and ensure the non-zero branch constructs a torch.tensor from the
computed value with matching dtype/device), and keep using
tensor.abs().max().to(torch.float32) (tensor_amax) to determine the branch.
♻️ Duplicate comments (1)
benchmarks/routines/moe_comm.py (1)

1218-1221: Previous concern addressed.

The per-phase metrics are now properly initialized with np.nan defaults before the conditional block, preventing UnboundLocalError when enable_per_phase_timing is False.

🧹 Nitpick comments (2)
benchmarks/routines/moe_comm.py (2)

773-790: Consider prefixing unused parameter with underscore.

token_final_scales is passed for API consistency but not used in the validation logic. Prefixing with _ would silence the linter and signal intentional non-use.

 def _validate_moe_a2a(
     moe_a2a: MoeAlltoAll,
     hidden_states: torch.Tensor,
     hidden_states_original: torch.Tensor,
     token_selected_experts: torch.Tensor,
-    token_final_scales: torch.Tensor,
+    _token_final_scales: torch.Tensor,
     input_payloads: List[torch.Tensor],

991-996: Unused variables: consider prefixing with underscore.

local_rank at line 993 and scale_factor at line 1056 are unpacked but never used. The device is already set inside _setup_mpi_and_device(). Prefixing with _ signals intentional non-use.

-    comm, rank, world_size, local_rank = _setup_mpi_and_device()
+    comm, rank, world_size, _local_rank = _setup_mpi_and_device()
-        scale_factor,
+        _scale_factor,
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 90517cb and 109e6ba.

📒 Files selected for processing (1)
  • benchmarks/routines/moe_comm.py
🧰 Additional context used
🪛 Ruff (0.14.11)
benchmarks/routines/moe_comm.py

169-169: Avoid specifying long messages outside the exception class

(TRY003)


778-778: Unused function argument: token_final_scales

(ARG001)


993-993: Unpacked variable local_rank is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


1004-1004: Do not catch blind exception: Exception

(BLE001)


1056-1056: Unpacked variable scale_factor is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (8)
benchmarks/routines/moe_comm.py (8)

93-136: LGTM!

The context managers are well-designed with proper try/finally handling and conditional execution support. The deferred timing pattern (recording events without immediate sync) is efficient for batched benchmarking.


427-534: LGTM!

The input creation logic is comprehensive. The MPI-synced global scale approach (using allreduce with MPI.MAX) correctly ensures all ranks use a consistent scale factor that accommodates the maximum range across all ranks. The return type annotation now correctly reflects all 7 return values.


537-624: LGTM!

The traffic calculation correctly accounts for the MoE A2A communication pattern: tokens are sent once per unique destination rank even when selecting multiple experts on that rank. The per-token iteration is appropriate for accurate benchmarking setup.


627-706: LGTM!

The bandwidth calculation properly supports both exact traffic (from actual routing analysis) and estimated traffic (uniform distribution fallback). The formula correctly converts milliseconds to seconds and bytes to terabytes.


710-770: LGTM!

The deterministic fake MoE transformation is well-designed for validation. Using torch.float32 for the summation (line 767) reduces numerical errors during the round-trip verification.


1162-1173: LGTM!

The benchmark correctly uses bench_gpu_time with appropriate settings: CUDA graphs are disabled when per-phase timing is enabled (since events are inserted mid-execution), and cold L2 cache simulation provides realistic performance numbers.


1195-1216: LGTM!

Using the maximum time across ranks for each iteration correctly reflects the synchronous nature of collective communication—the operation completes only when the slowest rank finishes.


1002-1007: The broad Exception catch is appropriate and follows the established pattern in the codebase. The MnnvlMemory.initialize() method can raise multiple exception types (torch exceptions, pynvml exceptions) beyond just RuntimeError, and narrowing the catch would miss valid failure scenarios. Keep as-is.

Likely an incorrect or invalid review comment.

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.

Copy link
Collaborator

@aleozlx aleozlx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@aleozlx
Copy link
Collaborator

aleozlx commented Jan 15, 2026

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !243 has been created, and the CI pipeline #41825007 is currently running. I'll report back once the pipeline job completes.

@rosenrodt rosenrodt force-pushed the chore-moe-a2a-bench branch from 109e6ba to c33661e Compare January 16, 2026 13:40
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@benchmarks/routines/moe_comm.py`:
- Around line 1011-1049: Validate input before constructing Mapping and
MoeAlltoAll: check that hidden_size % 16 == 0 and if not log an error (use
rank==0) and return res; check that top_k <= num_experts and return with error
if violated; move the all_num_tokens = comm.allgather(num_tokens) and
runtime_max_tokens_per_rank = max(all_num_tokens) up before creating MoeAlltoAll
and validate that max_num_tokens >= runtime_max_tokens_per_rank (log error and
return if not). Reference the variables/objects hidden_size, top_k, num_experts,
all_num_tokens, runtime_max_tokens_per_rank, max_num_tokens, Mapping, and
MoeAlltoAll when making these checks.
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 109e6ba and c33661e.

📒 Files selected for processing (4)
  • benchmarks/flashinfer_benchmark.py
  • benchmarks/routines/flashinfer_benchmark_utils.py
  • benchmarks/routines/moe_comm.py
  • tests/comm/test_trtllm_moe_alltoall.py
🧰 Additional context used
📓 Path-based instructions (1)
tests/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

tests/**/*.py: Test implementations should use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing with mpirun on multi-GPU systems, use the pattern: mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes - tests/conftest.py provides auto-skipping for OOM tests as a safety net but should not be relied upon

Files:

  • tests/comm/test_trtllm_moe_alltoall.py
🧠 Learnings (1)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/__init__.py : Export new operations in `flashinfer/__init__.py` to make them available as public API

Applied to files:

  • benchmarks/flashinfer_benchmark.py
  • benchmarks/routines/flashinfer_benchmark_utils.py
🧬 Code graph analysis (1)
benchmarks/flashinfer_benchmark.py (1)
benchmarks/routines/moe_comm.py (2)
  • parse_moe_comm_args (172-263)
  • run_moe_comm_test (155-169)
🪛 Ruff (0.14.11)
benchmarks/routines/moe_comm.py

169-169: Avoid specifying long messages outside the exception class

(TRY003)


779-779: Unused function argument: token_final_scales

(ARG001)


994-994: Unpacked variable local_rank is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


1005-1005: Do not catch blind exception: Exception

(BLE001)


1057-1057: Unpacked variable scale_factor is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (9)
benchmarks/routines/flashinfer_benchmark_utils.py (1)

68-146: Clean integration of moe_comm schema and API group.

Nice, straightforward extension of the benchmark schema and API registry.

benchmarks/routines/moe_comm.py (6)

93-283: Solid helper utilities and argument parsing setup.

Context managers, ordered printing, and MPI/device setup are clean and readable.


285-426: Quantization helpers are well-scoped and consistent.

The FP8/NVFP4 quantize/dequantize flow is clear and self-contained.


428-535: Input construction flow is clear and deterministic.

The MPI-synced global scale and payload assembly look good for benchmarking.


538-707: Traffic and bandwidth calculations are well-documented.

Nice separation between exact routing-based traffic and uniform estimates.


774-977: Validation path is comprehensive and well-guarded.

The round-trip compare and rank-ordered error reporting are robust.


1071-1306: Benchmark timing/aggregation flow looks solid.

Max-over-ranks aggregation plus actual-traffic bandwidth reporting is a good choice for synchronous collectives.

benchmarks/flashinfer_benchmark.py (1)

12-176: moe_comm routing and arg parsing integration look correct.

Routine dispatch and CLI plumbing are consistent with existing patterns.

tests/comm/test_trtllm_moe_alltoall.py (1)

398-457: Type annotations and docstring are a nice clarity boost.

No functional changes; the signature now aligns with the benchmark helper.

✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.

Comment on lines +1011 to +1049
num_tokens = args.num_tokens
hidden_size = args.hidden_size
num_experts = args.num_experts
top_k = args.top_k
max_num_tokens = args.max_num_tokens
input_dtype = dtype_str_to_torch_dtype(args.input_dtype)
quant_dtype = args.quant_dtype

res = []

# Validate parameters
if num_experts % ep_size != 0:
if rank == 0:
print(
f"[ERROR] num_experts ({num_experts}) must be divisible by ep_size ({ep_size})"
)
return res

# Create mapping
mapping = Mapping(
rank=rank,
tp_size=ep_size,
moe_ep_size=ep_size,
world_size=world_size,
)

# Create MoeAlltoAll instance
moe_a2a = MoeAlltoAll(
mapping=mapping,
max_num_tokens=max_num_tokens,
top_k=top_k,
num_experts=num_experts,
hidden_size=hidden_size,
)

# Synchronize all_num_tokens across ranks
all_num_tokens = comm.allgather(num_tokens)
runtime_max_tokens_per_rank = max(all_num_tokens)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

rg -n "fp4_quantize|nvfp4|block_size|sf_vec_size|E2M1" -i --type-list | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 457


🏁 Script executed:

# Search for NVFP4/FP4 related code and constraints
rg -n "fp4|nvfp4" -i -A 3 -B 3 --max-count=50

Repository: flashinfer-ai/flashinfer

Length of output: 50381


🏁 Script executed:

# Look for block size or shape constraints related to quantization
rg -n "hidden_size.*16|16.*hidden|block.*size" -A 2 -B 2

Repository: flashinfer-ai/flashinfer

Length of output: 50381


🏁 Script executed:

# Check the file mentioned in the review
head -50 benchmarks/routines/moe_comm.py
wc -l benchmarks/routines/moe_comm.py

Repository: flashinfer-ai/flashinfer

Length of output: 2229


🏁 Script executed:

# Check the specific lines mentioned in the review
sed -n '1000,1060p' benchmarks/routines/moe_comm.py

Repository: flashinfer-ai/flashinfer

Length of output: 1664


🏁 Script executed:

# Find MoeAlltoAll class implementation and validation
rg -n "class MoeAlltoAll|def __init__" -A 10 | grep -A 10 "MoeAlltoAll"

Repository: flashinfer-ai/flashinfer

Length of output: 1567


🏁 Script executed:

# Search for top_k validation in the codebase
rg -n "top_k.*num_experts|top_k.*>" --type py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 2122


🏁 Script executed:

# Look for max_num_tokens validation or assertions
rg -n "max_num_tokens" -B 2 -A 2 --type py | head -40

Repository: flashinfer-ai/flashinfer

Length of output: 2783


🏁 Script executed:

# Check if there are runtime errors related to these constraints in tests
rg -n "top_k|max_num_tokens" tests/moe/ --type py | head -30

Repository: flashinfer-ai/flashinfer

Length of output: 2538


🏁 Script executed:

# Check MoeAlltoAll implementation for validations
sed -n '508,600p' flashinfer/comm/trtllm_moe_alltoall.py

Repository: flashinfer-ai/flashinfer

Length of output: 3728


🏁 Script executed:

# Check what the test_mnnvl_moe_alltoall.py test does with max_num_tokens
sed -n '285,310p' tests/comm/test_mnnvl_moe_alltoall.py

Repository: flashinfer-ai/flashinfer

Length of output: 795


🏁 Script executed:

# Look at the specific function in moe_comm.py that needs validation
sed -n '990,1060p' benchmarks/routines/moe_comm.py

Repository: flashinfer-ai/flashinfer

Length of output: 1970


Add input validation to prevent NVFP4 shape errors and workspace under-allocation.

NVFP4 requires hidden_size divisible by 16 (block scale alignment). Additionally, top_k > num_experts and max_num_tokens < runtime_max_tokens_per_rank cause runtime errors or workspace under-sizing.

🛠️ Proposed fix
     # Validate parameters
     if num_experts % ep_size != 0:
         if rank == 0:
             print(
                 f"[ERROR] num_experts ({num_experts}) must be divisible by ep_size ({ep_size})"
             )
         return res
+    if top_k > num_experts:
+        if rank == 0:
+            print(
+                f"[ERROR] top_k ({top_k}) must be <= num_experts ({num_experts})"
+            )
+        return res
+    if quant_dtype == "nvfp4" and hidden_size % 16 != 0:
+        if rank == 0:
+            print(
+                f"[ERROR] hidden_size ({hidden_size}) must be a multiple of 16 for nvfp4"
+            )
+        return res
@@
     runtime_max_tokens_per_rank = max(all_num_tokens)
+    if max_num_tokens < runtime_max_tokens_per_rank:
+        if rank == 0:
+            print(
+                f"[ERROR] max_num_tokens ({max_num_tokens}) must be >= runtime_max_tokens_per_rank ({runtime_max_tokens_per_rank})"
+            )
+        return res
🤖 Prompt for AI Agents
In `@benchmarks/routines/moe_comm.py` around lines 1011 - 1049, Validate input
before constructing Mapping and MoeAlltoAll: check that hidden_size % 16 == 0
and if not log an error (use rank==0) and return res; check that top_k <=
num_experts and return with error if violated; move the all_num_tokens =
comm.allgather(num_tokens) and runtime_max_tokens_per_rank = max(all_num_tokens)
up before creating MoeAlltoAll and validate that max_num_tokens >=
runtime_max_tokens_per_rank (log error and return if not). Reference the
variables/objects hidden_size, top_k, num_experts, all_num_tokens,
runtime_max_tokens_per_rank, max_num_tokens, Mapping, and MoeAlltoAll when
making these checks.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #41825007: 15/20 passed

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopper CI failure doesn't look relevant.
The PR itself look good to me, thanks @rosenrodt !

@yzh119 yzh119 merged commit 7108377 into flashinfer-ai:main Jan 18, 2026
8 checks passed
yzh119 pushed a commit that referenced this pull request Jan 21, 2026
<!-- .github/pull_request_template.md -->

## 📌 Description

**No changes to library code. Will not launch internal CI**

#2354 added moe_comm microbenchmarks to the benchmark harness. However,
the PR causes failures in non-MPI enabled environments because
`flashinfer_benchmark.py` blindly imports routines for all modules.

Current PR refactors `flashinfer_benchmark.py` to use lazy imports for
routine modules instead of importing all modules at startup. This avoids
loading unnecessary dependencies when running specific benchmarks. Users
without MPI installed can now run all non-MPI benchmarks without import
errors.


<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

* **Chores**
* Optimized benchmark execution for improved efficiency with deferred
loading of benchmark routines, reducing unnecessary resource consumption
during startup.

<sub>✏️ Tip: You can customize this high-level summary in your review
settings.</sub>

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants