chore: Add TRTLLM MoE A2A benchmark#2354
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. 📝 WalkthroughWalkthroughAdds a new MoE all-to-all communication benchmark routine Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant Runner as Benchmark Runner
participant MPI
participant Device as CUDA Device
participant Quant as Quantization
participant Moe as MoeAlltoAll
participant Net as Network
User->>Runner: start moe_comm with args
Runner->>MPI: init communicator
MPI-->>Runner: rank, world_size
Runner->>Device: set CUDA device / init
Runner->>Runner: build inputs (hidden_states, token_experts)
alt validation enabled
Runner->>Quant: quantize inputs (FP8 / NVFP4)
Quant-->>Runner: quantized tensors
Runner->>Moe: dispatch
Moe->>Net: all-to-all exchange
Net-->>Moe: exchanged shards
Moe-->>Runner: combine output
Runner->>Quant: dequantize
Runner->>Runner: validate vs fake_moe
end
loop benchmarking iterations
Runner->>Quant: quantize
Runner->>Moe: dispatch + combine
Moe->>Net: all-to-all
Device-->>Runner: record CUDA events / timings
end
Runner->>Runner: compute median/std/bandwidth
Runner-->>User: print results
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @rosenrodt, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a comprehensive benchmark for Mixture-of-Experts (MoE) All-to-All communication operations within the FlashInfer framework. The new benchmark allows for detailed performance analysis of the dispatch and combine phases, supporting various quantization schemes like FP8 and NVFP4. It also includes a robust validation step to verify the functional correctness of the communication patterns, ensuring reliable performance measurements for distributed MoE models. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a new benchmark for TRTLLM MoE All-to-All (A2A) communication, which is a significant addition. The new benchmark routine is well-structured, covering various configurations including quantization (FP8, NVFP4) and providing robust validation logic. The code is comprehensive and includes detailed setup for MPI, performance timing, and correctness checks. My review identifies a couple of areas for improvement: one to correct a function's return type hint for better code clarity and static analysis, and another to fix a potential UnboundLocalError in the results reporting logic to prevent crashes under certain configurations.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@benchmarks/routines/flashinfer_benchmark_utils.py`:
- Around line 72-81: The output schema list for "moe_comm" uses "quant_format"
but the producer code in moe_comm.py writes the field as cur_res["quant_dtype"],
causing the CSV column to be empty; update the schema entry in the "moe_comm"
list (replace "quant_format" with "quant_dtype") so the schema name matches the
producer, and verify any CSV/header generation uses the same symbol
(quant_dtype) to ensure values from cur_res["quant_dtype"] are written.
🧹 Nitpick comments (6)
tests/comm/test_trtllm_moe_alltoall.py (1)
442-446: Minor style difference with duplicatedfake_moeinmoe_comm.py.This condition uses a slightly different style than the copy in
benchmarks/routines/moe_comm.py(lines 651-655), which uses chained comparison (ep_rank * num_experts_per_rank <= expert_id < ...). Both are functionally equivalent, but consider aligning them for consistency since the moe_comm.py version explicitly notes it was copied from this file.benchmarks/routines/moe_comm.py (5)
309-318: Potential issue with tensor comparison on line 312.When
scaleis passed as a tensor (e.g.,torch.tensor([0.0])), the comparisonscale != 0.0returns a tensor, not a boolean. This could cause unexpected behavior. However, in the current usage,scaleis either None (computed fresh with clamped amax >= 1e-6) or passed from MPI sync where it's also non-zero, so this is unlikely to trigger in practice.Suggested defensive fix
- inv_scale = 1.0 / scale if scale != 0.0 else 0.0 + inv_scale = 1.0 / scale if scale.item() != 0.0 else 0.0
607-669: Code duplication with test file - consider future refactor.The
fake_moefunction is duplicated fromtests/comm/test_trtllm_moe_alltoall.py. While this is acceptable for now (and acknowledged with a comment), consider extracting to a shared utility module if this pattern continues to grow.
676-676: Unused parameter:token_final_scales.The
token_final_scalesparameter is passed but never used in the validation logic. Consider removing it or prefixing with underscore if kept for API consistency.Proposed fix
- token_final_scales: torch.Tensor, + _token_final_scales: torch.Tensor, # Unused, kept for API consistencyOr remove from function signature and call sites if not needed for future use.
891-891: Unused variable:local_rank.The
local_rankreturn value is not used after unpacking. Prefix with underscore to indicate intentional discard.Proposed fix
- comm, rank, world_size, local_rank = _setup_mpi_and_device() + comm, rank, world_size, _local_rank = _setup_mpi_and_device()
949-956: Unused variable:scale_factor.The
scale_factoris unpacked but not used directly since it's already included ininput_payloads. Prefix with underscore.Proposed fix
( hidden_states, hidden_states_original, token_selected_experts, token_final_scales, - scale_factor, + _scale_factor, global_scale, input_payloads, ) = _create_moe_inputs(
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
benchmarks/flashinfer_benchmark.pybenchmarks/routines/flashinfer_benchmark_utils.pybenchmarks/routines/moe_comm.pytests/comm/test_trtllm_moe_alltoall.py
🧰 Additional context used
📓 Path-based instructions (1)
tests/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
tests/**/*.py: Test implementations should useflashinfer.utilsfunctions (get_compute_capability,is_sm90a_supported,is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing withmpirunon multi-GPU systems, use the pattern:mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes -tests/conftest.pyprovides auto-skipping for OOM tests as a safety net but should not be relied upon
Files:
tests/comm/test_trtllm_moe_alltoall.py
🧠 Learnings (1)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/__init__.py : Export new operations in `flashinfer/__init__.py` to make them available as public API
Applied to files:
benchmarks/routines/flashinfer_benchmark_utils.py
🧬 Code graph analysis (2)
benchmarks/flashinfer_benchmark.py (1)
benchmarks/routines/moe_comm.py (2)
parse_moe_comm_args(172-263)run_moe_comm_test(155-169)
benchmarks/routines/moe_comm.py (5)
flashinfer/comm/trtllm_moe_alltoall.py (4)
MoeAlltoAll(393-732)dispatch(596-653)get_combine_payload_tensor_in_workspace(699-732)combine(656-696)flashinfer/comm/mapping.py (2)
Mapping(21-475)local_rank(391-392)flashinfer/testing/utils.py (1)
bench_gpu_time(1508-1655)benchmarks/routines/flashinfer_benchmark_utils.py (1)
dtype_str_to_torch_dtype(159-173)flashinfer/fused_moe/utils.py (1)
_(157-163)
🪛 Ruff (0.14.11)
benchmarks/routines/moe_comm.py
169-169: Avoid specifying long messages outside the exception class
(TRY003)
676-676: Unused function argument: token_final_scales
(ARG001)
891-891: Unpacked variable local_rank is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
902-902: Do not catch blind exception: Exception
(BLE001)
954-954: Unpacked variable scale_factor is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🔇 Additional comments (16)
benchmarks/routines/flashinfer_benchmark_utils.py (1)
123-125: LGTM - Newmoe_commAPI group registered correctly.The
moe_a2a_dispatch_combineroutine is properly registered under the newmoe_commcategory. Note that this routine doesn't appear inroutine_cc_to_supported_backends, which is acceptable since the MoE A2A benchmark uses MPI-based multi-GPU setup and doesn't rely on the same backend filtering mechanism as other routines.benchmarks/flashinfer_benchmark.py (4)
12-12: LGTM - Import follows existing pattern.The import for the new
moe_commmodule follows the same pattern as other routine imports (attention, gemm, moe).
30-31: LGTM - Routing correctly integrated.The routing for
moe_commroutines is correctly added following the existing dispatch pattern.
69-72: LGTM - Routine choices extended correctly.The
--routineargument choices now includemoe_commroutines, maintaining consistency with how other routine types are registered.
163-164: LGTM - Argument parsing branch added correctly.The argument parsing for
moe_commroutines follows the established pattern.tests/comm/test_trtllm_moe_alltoall.py (1)
398-423: LGTM - Type annotations and docstring improve clarity.The added type hints and comprehensive docstring improve maintainability and document the deterministic scaling logic used for validation.
benchmarks/routines/moe_comm.py (10)
16-66: LGTM - Comprehensive documentation with launch examples.The module docstring provides clear usage examples for various configurations (basic, FP8, NVFP4, validation, per-phase timing, NVTX profiling), which will be helpful for users.
93-126: LGTM - Well-designed context manager for deferred timing.The
cuda_event_timercorrectly implements deferred timing by collecting events without synchronization, allowing efficient batched timing measurement.
155-169: LGTM - Routing function follows established pattern.The
run_moe_comm_testfunction correctly routes to the appropriate test implementation and includes a TODO for future expansion.
172-263: LGTM - Argument parsing is well-structured.The argument parser correctly handles all required and optional parameters with sensible defaults.
266-291: LGTM - MPI and device setup follows standard patterns.The
_setup_mpi_and_devicecorrectly usesMPI.COMM_TYPE_SHAREDfor node-local rank detection, which is the proper way to handle multi-node multi-GPU setups.
427-534: LGTM - Input creation handles quantization modes correctly.The function properly handles three modes (no quantization, FP8, NVFP4) and correctly synchronizes the global scale across MPI ranks. The return type hint now correctly specifies all 7 return values.
537-604: LGTM - Bandwidth calculation correctly accounts for quantization and phases.The bandwidth calculation properly handles different quantization formats and the asymmetric dispatch/combine data sizes (quantized dispatch, activation-dtype combine).
900-905: Broad exception catch is acceptable here.The broad
except Exceptioncatch for MNNVL initialization is acceptable since MNNVL can fail for various platform-specific reasons (missing hardware support, insufficient permissions, etc.). The error is logged and handled gracefully.
1100-1103: Past review issue resolved - per-phase metrics properly initialized.The per-phase timing metrics are now correctly initialized to
np.nanbefore the conditional block, preventing theUnboundLocalErrorthat was flagged in a previous review.
1158-1180: Note: Per-phase metrics stored but not in CSV output schema.The result dictionary includes per-phase metrics (
dispatch_time,combine_time, etc.) which are useful for debugging, but these fields are not in theoutput_column_dict["moe_comm"]schema, so they won't appear in the CSV output. This is acceptable if console output is sufficient; otherwise, consider extending the schema.
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@benchmarks/routines/moe_comm.py`:
- Around line 285-291: The function _calculate_fp4_global_scale returns a Python
float when tensor_amax == 0.0 which violates the annotated torch.Tensor return
type; change the implementation so both branches return a torch.Tensor (e.g.,
create a torch.tensor(0.0, dtype=torch.float32, device=tensor.device) for the
zero case and ensure the non-zero branch constructs a torch.tensor from the
computed value with matching dtype/device), and keep using
tensor.abs().max().to(torch.float32) (tensor_amax) to determine the branch.
♻️ Duplicate comments (1)
benchmarks/routines/moe_comm.py (1)
1218-1221: Previous concern addressed.The per-phase metrics are now properly initialized with
np.nandefaults before the conditional block, preventingUnboundLocalErrorwhenenable_per_phase_timingisFalse.
🧹 Nitpick comments (2)
benchmarks/routines/moe_comm.py (2)
773-790: Consider prefixing unused parameter with underscore.
token_final_scalesis passed for API consistency but not used in the validation logic. Prefixing with_would silence the linter and signal intentional non-use.def _validate_moe_a2a( moe_a2a: MoeAlltoAll, hidden_states: torch.Tensor, hidden_states_original: torch.Tensor, token_selected_experts: torch.Tensor, - token_final_scales: torch.Tensor, + _token_final_scales: torch.Tensor, input_payloads: List[torch.Tensor],
991-996: Unused variables: consider prefixing with underscore.
local_rankat line 993 andscale_factorat line 1056 are unpacked but never used. The device is already set inside_setup_mpi_and_device(). Prefixing with_signals intentional non-use.- comm, rank, world_size, local_rank = _setup_mpi_and_device() + comm, rank, world_size, _local_rank = _setup_mpi_and_device()- scale_factor, + _scale_factor,
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
benchmarks/routines/moe_comm.py
🧰 Additional context used
🪛 Ruff (0.14.11)
benchmarks/routines/moe_comm.py
169-169: Avoid specifying long messages outside the exception class
(TRY003)
778-778: Unused function argument: token_final_scales
(ARG001)
993-993: Unpacked variable local_rank is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1004-1004: Do not catch blind exception: Exception
(BLE001)
1056-1056: Unpacked variable scale_factor is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (8)
benchmarks/routines/moe_comm.py (8)
93-136: LGTM!The context managers are well-designed with proper try/finally handling and conditional execution support. The deferred timing pattern (recording events without immediate sync) is efficient for batched benchmarking.
427-534: LGTM!The input creation logic is comprehensive. The MPI-synced global scale approach (using
allreducewithMPI.MAX) correctly ensures all ranks use a consistent scale factor that accommodates the maximum range across all ranks. The return type annotation now correctly reflects all 7 return values.
537-624: LGTM!The traffic calculation correctly accounts for the MoE A2A communication pattern: tokens are sent once per unique destination rank even when selecting multiple experts on that rank. The per-token iteration is appropriate for accurate benchmarking setup.
627-706: LGTM!The bandwidth calculation properly supports both exact traffic (from actual routing analysis) and estimated traffic (uniform distribution fallback). The formula correctly converts milliseconds to seconds and bytes to terabytes.
710-770: LGTM!The deterministic fake MoE transformation is well-designed for validation. Using
torch.float32for the summation (line 767) reduces numerical errors during the round-trip verification.
1162-1173: LGTM!The benchmark correctly uses
bench_gpu_timewith appropriate settings: CUDA graphs are disabled when per-phase timing is enabled (since events are inserted mid-execution), and cold L2 cache simulation provides realistic performance numbers.
1195-1216: LGTM!Using the maximum time across ranks for each iteration correctly reflects the synchronous nature of collective communication—the operation completes only when the slowest rank finishes.
1002-1007: The broadExceptioncatch is appropriate and follows the established pattern in the codebase. TheMnnvlMemory.initialize()method can raise multiple exception types (torch exceptions, pynvml exceptions) beyond justRuntimeError, and narrowing the catch would miss valid failure scenarios. Keep as-is.Likely an incorrect or invalid review comment.
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
|
/bot run |
109e6ba to
c33661e
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@benchmarks/routines/moe_comm.py`:
- Around line 1011-1049: Validate input before constructing Mapping and
MoeAlltoAll: check that hidden_size % 16 == 0 and if not log an error (use
rank==0) and return res; check that top_k <= num_experts and return with error
if violated; move the all_num_tokens = comm.allgather(num_tokens) and
runtime_max_tokens_per_rank = max(all_num_tokens) up before creating MoeAlltoAll
and validate that max_num_tokens >= runtime_max_tokens_per_rank (log error and
return if not). Reference the variables/objects hidden_size, top_k, num_experts,
all_num_tokens, runtime_max_tokens_per_rank, max_num_tokens, Mapping, and
MoeAlltoAll when making these checks.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
benchmarks/flashinfer_benchmark.pybenchmarks/routines/flashinfer_benchmark_utils.pybenchmarks/routines/moe_comm.pytests/comm/test_trtllm_moe_alltoall.py
🧰 Additional context used
📓 Path-based instructions (1)
tests/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
tests/**/*.py: Test implementations should useflashinfer.utilsfunctions (get_compute_capability,is_sm90a_supported,is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing withmpirunon multi-GPU systems, use the pattern:mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes -tests/conftest.pyprovides auto-skipping for OOM tests as a safety net but should not be relied upon
Files:
tests/comm/test_trtllm_moe_alltoall.py
🧠 Learnings (1)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to flashinfer/__init__.py : Export new operations in `flashinfer/__init__.py` to make them available as public API
Applied to files:
benchmarks/flashinfer_benchmark.pybenchmarks/routines/flashinfer_benchmark_utils.py
🧬 Code graph analysis (1)
benchmarks/flashinfer_benchmark.py (1)
benchmarks/routines/moe_comm.py (2)
parse_moe_comm_args(172-263)run_moe_comm_test(155-169)
🪛 Ruff (0.14.11)
benchmarks/routines/moe_comm.py
169-169: Avoid specifying long messages outside the exception class
(TRY003)
779-779: Unused function argument: token_final_scales
(ARG001)
994-994: Unpacked variable local_rank is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1005-1005: Do not catch blind exception: Exception
(BLE001)
1057-1057: Unpacked variable scale_factor is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (9)
benchmarks/routines/flashinfer_benchmark_utils.py (1)
68-146: Clean integration ofmoe_commschema and API group.Nice, straightforward extension of the benchmark schema and API registry.
benchmarks/routines/moe_comm.py (6)
93-283: Solid helper utilities and argument parsing setup.Context managers, ordered printing, and MPI/device setup are clean and readable.
285-426: Quantization helpers are well-scoped and consistent.The FP8/NVFP4 quantize/dequantize flow is clear and self-contained.
428-535: Input construction flow is clear and deterministic.The MPI-synced global scale and payload assembly look good for benchmarking.
538-707: Traffic and bandwidth calculations are well-documented.Nice separation between exact routing-based traffic and uniform estimates.
774-977: Validation path is comprehensive and well-guarded.The round-trip compare and rank-ordered error reporting are robust.
1071-1306: Benchmark timing/aggregation flow looks solid.Max-over-ranks aggregation plus actual-traffic bandwidth reporting is a good choice for synchronous collectives.
benchmarks/flashinfer_benchmark.py (1)
12-176: moe_comm routing and arg parsing integration look correct.Routine dispatch and CLI plumbing are consistent with existing patterns.
tests/comm/test_trtllm_moe_alltoall.py (1)
398-457: Type annotations and docstring are a nice clarity boost.No functional changes; the signature now aligns with the benchmark helper.
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
| num_tokens = args.num_tokens | ||
| hidden_size = args.hidden_size | ||
| num_experts = args.num_experts | ||
| top_k = args.top_k | ||
| max_num_tokens = args.max_num_tokens | ||
| input_dtype = dtype_str_to_torch_dtype(args.input_dtype) | ||
| quant_dtype = args.quant_dtype | ||
|
|
||
| res = [] | ||
|
|
||
| # Validate parameters | ||
| if num_experts % ep_size != 0: | ||
| if rank == 0: | ||
| print( | ||
| f"[ERROR] num_experts ({num_experts}) must be divisible by ep_size ({ep_size})" | ||
| ) | ||
| return res | ||
|
|
||
| # Create mapping | ||
| mapping = Mapping( | ||
| rank=rank, | ||
| tp_size=ep_size, | ||
| moe_ep_size=ep_size, | ||
| world_size=world_size, | ||
| ) | ||
|
|
||
| # Create MoeAlltoAll instance | ||
| moe_a2a = MoeAlltoAll( | ||
| mapping=mapping, | ||
| max_num_tokens=max_num_tokens, | ||
| top_k=top_k, | ||
| num_experts=num_experts, | ||
| hidden_size=hidden_size, | ||
| ) | ||
|
|
||
| # Synchronize all_num_tokens across ranks | ||
| all_num_tokens = comm.allgather(num_tokens) | ||
| runtime_max_tokens_per_rank = max(all_num_tokens) | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
rg -n "fp4_quantize|nvfp4|block_size|sf_vec_size|E2M1" -i --type-list | head -20Repository: flashinfer-ai/flashinfer
Length of output: 457
🏁 Script executed:
# Search for NVFP4/FP4 related code and constraints
rg -n "fp4|nvfp4" -i -A 3 -B 3 --max-count=50Repository: flashinfer-ai/flashinfer
Length of output: 50381
🏁 Script executed:
# Look for block size or shape constraints related to quantization
rg -n "hidden_size.*16|16.*hidden|block.*size" -A 2 -B 2Repository: flashinfer-ai/flashinfer
Length of output: 50381
🏁 Script executed:
# Check the file mentioned in the review
head -50 benchmarks/routines/moe_comm.py
wc -l benchmarks/routines/moe_comm.pyRepository: flashinfer-ai/flashinfer
Length of output: 2229
🏁 Script executed:
# Check the specific lines mentioned in the review
sed -n '1000,1060p' benchmarks/routines/moe_comm.pyRepository: flashinfer-ai/flashinfer
Length of output: 1664
🏁 Script executed:
# Find MoeAlltoAll class implementation and validation
rg -n "class MoeAlltoAll|def __init__" -A 10 | grep -A 10 "MoeAlltoAll"Repository: flashinfer-ai/flashinfer
Length of output: 1567
🏁 Script executed:
# Search for top_k validation in the codebase
rg -n "top_k.*num_experts|top_k.*>" --type py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 2122
🏁 Script executed:
# Look for max_num_tokens validation or assertions
rg -n "max_num_tokens" -B 2 -A 2 --type py | head -40Repository: flashinfer-ai/flashinfer
Length of output: 2783
🏁 Script executed:
# Check if there are runtime errors related to these constraints in tests
rg -n "top_k|max_num_tokens" tests/moe/ --type py | head -30Repository: flashinfer-ai/flashinfer
Length of output: 2538
🏁 Script executed:
# Check MoeAlltoAll implementation for validations
sed -n '508,600p' flashinfer/comm/trtllm_moe_alltoall.pyRepository: flashinfer-ai/flashinfer
Length of output: 3728
🏁 Script executed:
# Check what the test_mnnvl_moe_alltoall.py test does with max_num_tokens
sed -n '285,310p' tests/comm/test_mnnvl_moe_alltoall.pyRepository: flashinfer-ai/flashinfer
Length of output: 795
🏁 Script executed:
# Look at the specific function in moe_comm.py that needs validation
sed -n '990,1060p' benchmarks/routines/moe_comm.pyRepository: flashinfer-ai/flashinfer
Length of output: 1970
Add input validation to prevent NVFP4 shape errors and workspace under-allocation.
NVFP4 requires hidden_size divisible by 16 (block scale alignment). Additionally, top_k > num_experts and max_num_tokens < runtime_max_tokens_per_rank cause runtime errors or workspace under-sizing.
🛠️ Proposed fix
# Validate parameters
if num_experts % ep_size != 0:
if rank == 0:
print(
f"[ERROR] num_experts ({num_experts}) must be divisible by ep_size ({ep_size})"
)
return res
+ if top_k > num_experts:
+ if rank == 0:
+ print(
+ f"[ERROR] top_k ({top_k}) must be <= num_experts ({num_experts})"
+ )
+ return res
+ if quant_dtype == "nvfp4" and hidden_size % 16 != 0:
+ if rank == 0:
+ print(
+ f"[ERROR] hidden_size ({hidden_size}) must be a multiple of 16 for nvfp4"
+ )
+ return res
@@
runtime_max_tokens_per_rank = max(all_num_tokens)
+ if max_num_tokens < runtime_max_tokens_per_rank:
+ if rank == 0:
+ print(
+ f"[ERROR] max_num_tokens ({max_num_tokens}) must be >= runtime_max_tokens_per_rank ({runtime_max_tokens_per_rank})"
+ )
+ return res🤖 Prompt for AI Agents
In `@benchmarks/routines/moe_comm.py` around lines 1011 - 1049, Validate input
before constructing Mapping and MoeAlltoAll: check that hidden_size % 16 == 0
and if not log an error (use rank==0) and return res; check that top_k <=
num_experts and return with error if violated; move the all_num_tokens =
comm.allgather(num_tokens) and runtime_max_tokens_per_rank = max(all_num_tokens)
up before creating MoeAlltoAll and validate that max_num_tokens >=
runtime_max_tokens_per_rank (log error and return if not). Reference the
variables/objects hidden_size, top_k, num_experts, all_num_tokens,
runtime_max_tokens_per_rank, max_num_tokens, Mapping, and MoeAlltoAll when
making these checks.
|
[FAILED] Pipeline #41825007: 15/20 passed |
yzh119
left a comment
There was a problem hiding this comment.
Hopper CI failure doesn't look relevant.
The PR itself look good to me, thanks @rosenrodt !
<!-- .github/pull_request_template.md --> ## 📌 Description **No changes to library code. Will not launch internal CI** #2354 added moe_comm microbenchmarks to the benchmark harness. However, the PR causes failures in non-MPI enabled environments because `flashinfer_benchmark.py` blindly imports routines for all modules. Current PR refactors `flashinfer_benchmark.py` to use lazy imports for routine modules instead of importing all modules at startup. This avoids loading unnecessary dependencies when running specific benchmarks. Users without MPI installed can now run all non-MPI benchmarks without import errors. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Chores** * Optimized benchmark execution for improved efficiency with deferred loading of benchmark routines, reducing unnecessary resource consumption during startup. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
📌 Description
This PR adds benchmark script for the MoE A2A kernels added in #2102
Supports bf16, nvfp4, and per-tensor fp8.
A2A dispatch communicates in quantized format (fp8, nvfp4). A2A combine commincates in orignal format (bfloat16/float16).
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Chores
Tests
✏️ Tip: You can customize this high-level summary in your review settings.