chore/feat: A2A + MoE benchmark; add routed counterpart for trtllm_gen_fp8_fused_moe#2379
Conversation
📝 WalkthroughWalkthroughThis PR introduces comprehensive MoE A2A (All-to-All) benchmarking infrastructure with FP4/FP8 quantization support. It adds a new utilities module ( Changes
Sequence Diagram(s)sequenceDiagram
participant Benchmark as MoE A2A<br/>Benchmark
participant Utils as moe_utils<br/>Utilities
participant Pack as Triton Packing<br/>(pack_topk_ids)
participant Kernel as Kernel Launcher<br/>(FP8BlockScale)
participant Backend as TRT-LLM<br/>Backend
Benchmark->>Utils: generate_moe_weights()
Utils-->>Benchmark: gemm1/gemm2 weights
Benchmark->>Utils: quantize_fp4/quantize_fp8()
Utils-->>Benchmark: quantized data + scales
Benchmark->>Utils: compute_routing()
Utils-->>Benchmark: routing_weights, selected_experts
Benchmark->>Pack: pack_topk_ids_triton()
Pack-->>Benchmark: packed expert_indices, expert_weights
Benchmark->>Kernel: dispatch with packed routing<br/>(expert_indices, expert_weights)
Kernel->>Kernel: check_routing()<br/>(use precomputed data)
Kernel->>Backend: run FP8 block-scale MoE<br/>(hidden_states, scales, weights)
Backend-->>Kernel: output tensors
Kernel-->>Benchmark: results
Benchmark->>Utils: calculate_moe_tflops()<br/>calculate_moe_kernel_bandwidth()
Utils-->>Benchmark: performance metrics
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @rosenrodt, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the benchmarking infrastructure for Mixture-of-Experts (MoE) models by integrating All-to-All communication with actual MoE kernel execution. It expands the supported quantization formats to include FP8 block-scale and introduces a new API for routed MoE, which streamlines the process of evaluating MoE layers when routing decisions are pre-determined. These changes provide more accurate and flexible tools for performance analysis and optimization of MoE architectures. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for A2A (All-to-All) communication and MoE (Mixture of Experts) benchmarking, specifically adding a routed counterpart for trtllm_gen_fp8_fused_moe. The changes involve refactoring common MoE utilities into a new moe_utils.py file, updating benchmark scripts to use these utilities, and extending the C++ kernel launcher to support pre-computed routing for FP8 block-scale MoE. The documentation in benchmarks/README.md has been updated to reflect the new MoE communication flags and examples. A new test case for trtllm_gen_fp8_routed_fused_moe has also been added to ensure correctness.
There was a problem hiding this comment.
Actionable comments posted: 7
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (6)
flashinfer/fused_moe/core.py (2)
1572-1635: Autotuner crashes whenrouting_logitsisNone.
MoERunner.get_valid_tacticsreadsrouting_logits.shape[0]; routed calls passNone, so this will throw before any kernel launch. Use a meta placeholder for tuning inputs (as in the FP4 path).🛠️ Proposed fix
- inputs = [ - output, - routing_logits, - topk_ids, - expert_weights, - hidden_states, - hidden_states_scale, - ] + routing_logits_for_tuning = ( + routing_logits + if routing_logits is not None + else torch.empty( + num_tokens, num_experts, dtype=routing_dtype, device="meta" + ) + ) + inputs = [ + output, + routing_logits_for_tuning, + topk_ids, + expert_weights, + hidden_states, + hidden_states_scale, + ]
1689-1714: Silence Ruff ARG001 in the fake op signature.
Unused parameters can be prefixed with_to keep lint clean.♻️ Proposed cleanup
-def _fake_trtllm_fp8_block_scale_moe( - routing_logits: Optional[torch.Tensor], - topk_ids: Optional[torch.Tensor], - expert_weights: Optional[torch.Tensor], +def _fake_trtllm_fp8_block_scale_moe( + _routing_logits: Optional[torch.Tensor], + _topk_ids: Optional[torch.Tensor], + _expert_weights: Optional[torch.Tensor], @@ - tune_max_num_tokens: int = 8192, + _tune_max_num_tokens: int = 8192, ):benchmarks/routines/moe_comm.py (1)
593-636: Exact traffic calc missesfp8_block_scale.
_calculate_exact_comm_trafficfalls through to the unquantized path, so dispatch/combine bandwidth is overstated for fp8 block‑scale payloads. Add a branch to include 1‑byte activations plus float32 block scales (per 128 elements).🛠️ Proposed fix
- quant_dtype: None, "fp8", or "nvfp4" + quant_dtype: None, "fp8", "nvfp4", or "fp8_block_scale" @@ if quant_dtype == "nvfp4": # NVFP4: 0.5 bytes per element + block scales hidden_bytes_per_token = hidden_size // 2 scale_bytes_per_token = (hidden_size // 16) * 1 # float8_e4m3fn elif quant_dtype == "fp8": # FP8: 1 byte per element hidden_bytes_per_token = hidden_size * 1 scale_bytes_per_token = 0 + elif quant_dtype == "fp8_block_scale": + # FP8 block scale: 1 byte per element + float32 scales per 128 elements + hidden_bytes_per_token = hidden_size * 1 + scale_bytes_per_token = (hidden_size // 128) * 4 # float32 else: # No quantization element_size = torch.tensor([], dtype=input_dtype).element_size() hidden_bytes_per_token = hidden_size * element_size scale_bytes_per_token = 0flashinfer/fused_moe/__init__.py (1)
17-58: Addtrtllm_fp8_block_scale_routed_moeto the top-levelflashinfer/__init__.pyexports.The function is already exported from
flashinfer.fused_moebut missing from the main API. Sincetrtllm_fp4_block_scale_routed_moeis already exported at the top level,trtllm_fp8_block_scale_routed_moeshould be added to maintain API consistency and ensure users can access it viaflashinfer.trtllm_fp8_block_scale_routed_moe.csrc/trtllm_fused_moe_kernel_launcher.cu (2)
758-843: Validate precomputed expert_weights before wiring into the routing workspace.
When expert_weights is provided, missing shape/dtype checks can lead to OOB writes or incorrect routing weights. Add the same shape/dtype validation used for internally allocated buffers (expect[num_tokens, top_k]andbfloat16).🔧 Suggested guard additions
void check_routing() const override { // Check ndim==2 and size>0 because empty placeholder tensors may have non-null data_ptr if (expert_indices.ndim() == 2 && expert_indices.size(0) > 0) { // Pre-computed routing: expert_indices is a packed tensor // Format: (expert_id << 16) | (weight_bf16.view(int16)) TVM_FFI_ICHECK_EQ(expert_indices.ndim(), 2) << "expert_indices must be 2D."; TVM_FFI_ICHECK_EQ(expert_indices.size(0), hidden_states.size(0)) << "expert_indices and hidden_states must have same number of tokens."; TVM_FFI_ICHECK_EQ(expert_indices.size(1), args->top_k) << "expert_indices dim1 must match top_k."; TVM_FFI_ICHECK_EQ(expert_indices.dtype(), dl_int32) << "expert_indices must be int32."; } + + if (expert_weights.ndim() == 2 && expert_weights.size(0) > 0) { + TVM_FFI_ICHECK_EQ(expert_weights.size(0), hidden_states.size(0)) + << "expert_weights and hidden_states must have same number of tokens."; + TVM_FFI_ICHECK_EQ(expert_weights.size(1), args->top_k) + << "expert_weights dim1 must match top_k."; + TVM_FFI_ICHECK_EQ(expert_weights.dtype(), dl_bfloat16) + << "expert_weights must be bfloat16."; + } FusedMoeLauncher::check_routing_common();
1741-1856: Add shape/dtype validation for FP4 expert_indices/expert_weights buffers.
These buffers are now caller-supplied and passed into the routing kernel. Without validation, a malformed tensor can cause OOB writes in the routing phase.🔧 Suggested guard additions
if (routing_logits.has_value()) { TVM_FFI_ICHECK(routing_logits.value().dtype() == dl_float32 || routing_logits.value().dtype() == dl_bfloat16) << "routing_logits must be float or bfloat16."; TVM_FFI_ICHECK_EQ(routing_logits.value().ndim(), 2) << "routing_logits must be 2D."; TVM_FFI_ICHECK_EQ(routing_logits.value().size(1), num_experts) << "routing_logits has incorrect shape."; if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) { TVM_FFI_ICHECK_EQ(routing_logits.value().dtype(), dl_float32) << "routing_logits must be float."; } } + + TVM_FFI_ICHECK_EQ(expert_indices.ndim(), 2) << "expert_indices must be 2D."; + TVM_FFI_ICHECK_EQ(expert_indices.size(0), num_tokens) + << "expert_indices dim0 must match num_tokens."; + TVM_FFI_ICHECK_EQ(expert_indices.size(1), top_k) + << "expert_indices dim1 must match top_k."; + TVM_FFI_ICHECK_EQ(expert_indices.dtype(), dl_int32) + << "expert_indices must be int32."; + + TVM_FFI_ICHECK_EQ(expert_weights.ndim(), 2) << "expert_weights must be 2D."; + TVM_FFI_ICHECK_EQ(expert_weights.size(0), num_tokens) + << "expert_weights dim0 must match num_tokens."; + TVM_FFI_ICHECK_EQ(expert_weights.size(1), top_k) + << "expert_weights dim1 must match top_k."; + TVM_FFI_ICHECK_EQ(expert_weights.dtype(), dl_bfloat16) + << "expert_weights must be bfloat16.";
🤖 Fix all issues with AI agents
In `@benchmarks/routines/moe_comm.py`:
- Around line 254-259: The comment and code disagree: instead of asserting
intermediate_size must be provided when args.real_math is true, implement the
documented default by detecting if args.intermediate_size is None and setting it
to args.hidden_size * 4 before the real_math/quant checks; then remove the
assert that forces a provided value (or adjust it to allow the default). Refer
to args.intermediate_size, args.real_math, and args.hidden_size in the
moe_comm.py block to locate and update the logic accordingly.
In `@benchmarks/routines/moe_utils.py`:
- Around line 453-491: The function calculate_moe_tflops currently accepts
num_experts but never uses it; either remove num_experts from the function
signature or mark it explicitly unused (e.g., rename to _num_experts or add a
comment/type ignore) to satisfy linters. Update all call sites if you remove the
parameter; if you choose to mark it unused, change the parameter name in
calculate_moe_tflops to _num_experts and add a short comment like "# unused -
kept for API compatibility" so the symbol is clearly identified and lint
warnings are suppressed.
- Around line 717-770: The function quantize_and_pack_nvfp4 silently replaces
block_scales_reshaped with all-ones when the reshape size mismatches, masking
bugs; change this to surface the error instead: validate the shape after calling
quantize_fp4 (check block_scales.numel() or block_scales_reshaped.numel() vs
expected_scale_elems) and either raise a descriptive ValueError (including
tensor.shape, expected_scale_elems, and block_scales.numel()) or at minimum log
a clear warning via the existing logging mechanism before failing, rather than
silently returning ones; keep the fallback only for a deliberate debug mode if
needed.
- Around line 248-343: The quantize/dequantize pair disagree on block_scales
layout: quantize_fp8_block_scale currently transposes block_scales to
[num_blocks, num_tokens] (block_scales_transposed) but
dequantize_fp8_block_scale expects [num_tokens, num_blocks]; fix by returning
block_scales in [num_tokens, num_blocks] from quantize_fp8_block_scale (i.e.,
remove the transpose and return block_scales.contiguous() instead of
block_scales_transposed) so the shapes match the docstring and
dequantize_fp8_block_scale; keep variable names block_scales and update any
callers if they relied on the transposed layout.
- Around line 655-689: process_fp8_weight_layout currently returns early when
use_shuffled_weight is False, skipping BlockMajorK conversion; change the logic
so you always obtain a uint8 view first (e.g., uint8_tensor =
tensor.view(torch.uint8)), only call shuffle_matrix_a when use_shuffled_weight
is True, then if weight_layout == WeightLayout.BlockMajorK call
convert_to_block_layout(uint8_tensor, block_k) (keep epilogue_tile_m for shuffle
and block_k=128 as before), and finally return the
result.view(torch.float8_e4m3fn); this keeps shuffle_matrix_a and
convert_to_block_layout usage (and their parameters) intact while ensuring
BlockMajorK conversion runs even without shuffling.
In `@benchmarks/routines/moe.py`:
- Around line 504-509: The tuple unpacking of quantize_fp4_batched assigns
gemm1_scales_global and gemm2_scales_global but those globals are unused and
trigger Ruff RUF059; change the unpacking to either drop or prefix these
variables with an underscore (e.g., _gemm1_scales_global, _gemm2_scales_global)
in the calls to quantize_fp4_batched so the linter no longer flags them (look
for the lines using quantize_fp4_batched that assign gemm1_weights_fp4_bytes,
gemm1_scales_fp4_bytes, gemm1_scales_global and the corresponding gemm2_*
assignment).
In `@csrc/trtllm_fused_moe_kernel_launcher.cu`:
- Around line 1655-1662: The guard for precomputed routing is too lax (uses
expert_indices.data_ptr() != nullptr) and can allow empty tensors that later
crash in prepare_routing(); change the predicate so use_precomputed_routing
mirrors the launcher's check by requiring expert_indices to be 2-D and non-empty
— e.g. set use_precomputed_routing = (expert_indices.ndimension() == 2 &&
expert_indices.size(0) > 0) — and keep the TVM_FFI_ICHECK(...) that requires
either routing_logits or this stricter precomputed-routing condition.
🧹 Nitpick comments (2)
benchmarks/README.md (1)
32-33: Well-documented MOE Communication section.The new
moe_a2a_dispatch_combineroutine is clearly described with key requirements (mpirun, quantization options, real math support) highlighted.📝 Optional: Address markdown linting preference
The static analysis tool flags 4-space indentation for nested list items. While the current 4-space indentation is consistent throughout the file, you could optionally adjust to 2 spaces if you want to satisfy the linter:
- MOE Communication: - - `moe_a2a_dispatch_combine` - MoE All-to-All dispatch + combine benchmark for multi-GPU expert-parallel inference. Requires `mpirun` for multi-GPU execution. Supports optional quantization (FP8, NVFP4, FP8 block-scale) and real MoE kernel computation. + - `moe_a2a_dispatch_combine` - MoE All-to-All dispatch + combine benchmark for multi-GPU expert-parallel inference. Requires `mpirun` for multi-GPU execution. Supports optional quantization (FP8, NVFP4, FP8 block-scale) and real MoE kernel computation.Note: This would require updating all nested list items in the file for consistency.
benchmarks/routines/moe_comm.py (1)
467-468: Ruff TRY003: long message in inlineValueError.
If TRY003 is enforced, consider a# noqa: TRY003or a small custom exception type.🧹 Example suppression
- raise ValueError(f"Unsupported quant_dtype for real computation: {quant_dtype}") + raise ValueError( + f"Unsupported quant_dtype for real computation: {quant_dtype}" + ) # noqa: TRY003
| gemm1_weights_fp4_bytes, gemm1_scales_fp4_bytes, gemm1_scales_global = ( | ||
| quant_fp4_batches_simple(gemm1_weights, num_experts, use_ue8m0, True) | ||
| quantize_fp4_batched(gemm1_weights, num_experts, use_ue8m0, True) | ||
| ) | ||
| gemm2_weights_fp4_bytes, gemm2_scales_fp4_bytes, gemm2_scales_global = ( | ||
| quant_fp4_batches_simple(gemm2_weights, num_experts, use_ue8m0, True) | ||
| quantize_fp4_batched(gemm2_weights, num_experts, use_ue8m0, True) | ||
| ) |
There was a problem hiding this comment.
Unused quant scale globals trigger Ruff RUF059.
Prefix with _ (or drop) to avoid lint failures.
♻️ Proposed cleanup
- gemm1_weights_fp4_bytes, gemm1_scales_fp4_bytes, gemm1_scales_global = (
+ gemm1_weights_fp4_bytes, gemm1_scales_fp4_bytes, _gemm1_scales_global = (
quantize_fp4_batched(gemm1_weights, num_experts, use_ue8m0, True)
)
@@
- gemm2_weights_fp4_bytes, gemm2_scales_fp4_bytes, gemm2_scales_global = (
+ gemm2_weights_fp4_bytes, gemm2_scales_fp4_bytes, _gemm2_scales_global = (
quantize_fp4_batched(gemm2_weights, num_experts, use_ue8m0, True)
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| gemm1_weights_fp4_bytes, gemm1_scales_fp4_bytes, gemm1_scales_global = ( | |
| quant_fp4_batches_simple(gemm1_weights, num_experts, use_ue8m0, True) | |
| quantize_fp4_batched(gemm1_weights, num_experts, use_ue8m0, True) | |
| ) | |
| gemm2_weights_fp4_bytes, gemm2_scales_fp4_bytes, gemm2_scales_global = ( | |
| quant_fp4_batches_simple(gemm2_weights, num_experts, use_ue8m0, True) | |
| quantize_fp4_batched(gemm2_weights, num_experts, use_ue8m0, True) | |
| ) | |
| gemm1_weights_fp4_bytes, gemm1_scales_fp4_bytes, _gemm1_scales_global = ( | |
| quantize_fp4_batched(gemm1_weights, num_experts, use_ue8m0, True) | |
| ) | |
| gemm2_weights_fp4_bytes, gemm2_scales_fp4_bytes, _gemm2_scales_global = ( | |
| quantize_fp4_batched(gemm2_weights, num_experts, use_ue8m0, True) | |
| ) |
🧰 Tools
🪛 Ruff (0.14.13)
504-504: Unpacked variable gemm1_scales_global is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
507-507: Unpacked variable gemm2_scales_global is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents
In `@benchmarks/routines/moe.py` around lines 504 - 509, The tuple unpacking of
quantize_fp4_batched assigns gemm1_scales_global and gemm2_scales_global but
those globals are unused and trigger Ruff RUF059; change the unpacking to either
drop or prefix these variables with an underscore (e.g., _gemm1_scales_global,
_gemm2_scales_global) in the calls to quantize_fp4_batched so the linter no
longer flags them (look for the lines using quantize_fp4_batched that assign
gemm1_weights_fp4_bytes, gemm1_scales_fp4_bytes, gemm1_scales_global and the
corresponding gemm2_* assignment).
491d262 to
d15d55f
Compare
…py benchmark fp8_block_scale; real moe math wip add fp8 block scale routed moe; should work resolve oom wip: add fused topk id/weight padding remove extra d2d copy extract utility funcs add moe tflops/bw calculation add guard add example testlist tidy wip: minor refactor extract common utils rename tidy; all should pass
86adf34 to
5d66242
Compare
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
benchmarks/routines/moe.py (1)
1299-1313: Use fp8_block_scale for bandwidth accounting in block‑scale mode.Line 1308–1309 use
"fp8"formats, which omits block‑scale overhead. Use"fp8_block_scale"to reflect scales.🛠️ Suggested fix
- input_format="fp8", - weight_format="fp8", + input_format="fp8_block_scale", + weight_format="fp8_block_scale",flashinfer/fused_moe/core.py (2)
1547-1635: Avoid passingNoneinto AutoTuner inputs when routing_logits is absent.When
routing_logitsisNone,inputscurrently includes aNoneentry, which can break AutoTuner shape handling. Provide a meta tensor placeholder (as done in FP4) to keep tuner inputs tensor‑only.🛠️ Suggested fix
- inputs = [ + routing_logits_for_tuning = ( + routing_logits + if routing_logits is not None + else torch.empty( + num_tokens, num_experts, dtype=routing_dtype, device="meta" + ) + ) + inputs = [ output, - routing_logits, + routing_logits_for_tuning, topk_ids, expert_weights, hidden_states, hidden_states_scale, ]
1690-1714: Silence Ruff ARG001 for unused fake‑op parameters.The fake op signature now includes unused args; Ruff will flag these. Prefix unused parameters with
_or add# noqa: ARG001on the definition.
🤖 Fix all issues with AI agents
In `@benchmarks/README.md`:
- Around line 32-33: The nested list under the "MOE Communication:" bullet has
incorrect indentation; adjust the nested bullet for `moe_a2a_dispatch_combine`
to use a 2-space indent (align the nested dash two spaces under the parent list
marker) so the nested item conforms to MD007; ensure the code span and
description remain on the same line after the nested dash.
In `@benchmarks/routines/moe.py`:
- Around line 389-393: The code calls compute_routing(routing_logits.float(),
top_k) and ignores routing_bias; update the non‑DeepSeek path to add
routing_bias into routing_logits before selection so selected_experts and
bandwidth metrics reflect the bias. Specifically, where compute_routing is used
(function/variable names: compute_routing, routing_logits, routing_bias,
selected_experts, top_k), add routing_bias (with proper dtype/device
broadcasting and a None check) to routing_logits (e.g., routing_logits +
routing_bias) and pass the biased logits to compute_routing, then return
selected_experts.
In `@tests/moe/test_trtllm_gen_routed_fused_moe.py`:
- Around line 279-282: The tests generate routing_logits as torch.bfloat16 which
is incompatible with the FP8 block-scale routing kernel; change the tensor
creation for routing_logits to use torch.float32 so the FP8 block-scale routing
receives float32 inputs. Locate the variable routing_logits in the test (the
torch.rand(...) call that currently casts to torch.bfloat16) and remove or
replace the .to(torch.bfloat16) with .to(torch.float32) (or simply create with
dtype=torch.float32) so downstream code like the FP8 block-scale routing path
receives float32 logits.
♻️ Duplicate comments (2)
benchmarks/routines/moe.py (1)
504-509: Prefix unused FP4 global scales to avoid Ruff RUF059.The globals are unpacked but unused; Ruff will flag this.
♻️ Suggested cleanup
- gemm1_weights_fp4_bytes, gemm1_scales_fp4_bytes, gemm1_scales_global = ( + gemm1_weights_fp4_bytes, gemm1_scales_fp4_bytes, _gemm1_scales_global = ( quantize_fp4_batched(gemm1_weights, num_experts, use_ue8m0, True) ) @@ - gemm2_weights_fp4_bytes, gemm2_scales_fp4_bytes, gemm2_scales_global = ( + gemm2_weights_fp4_bytes, gemm2_scales_fp4_bytes, _gemm2_scales_global = ( quantize_fp4_batched(gemm2_weights, num_experts, use_ue8m0, True) )benchmarks/routines/moe_comm.py (1)
253-257: Comment/behavior mismatch forintermediate_size.The comment at line 254 says "Default intermediate_size to hidden_size * 4 if not specified" (based on the past review), but the code asserts that
intermediate_sizemust be provided whenreal_math=True. Either implement the default or update the comment to reflect the actual behavior.
🧹 Nitpick comments (1)
benchmarks/routines/moe_comm.py (1)
1317-1318: Potentially redundant FP8 conversion.When
quant_dtype == "fp8_block_scale", the hidden states are already quantized tofloat8_e4m3fnin_create_moe_inputs(line 539). The conversion at line 1318 appears redundant. While it's a no-op for already-FP8 data, consider removing it or adding a comment clarifying intent.Suggested change
- # Convert hidden states to FP8 - hidden_fp8 = hidden_flat.to(torch.float8_e4m3fn) + # hidden_flat is already FP8 from quantize_fp8_block_scale + hidden_fp8 = hidden_flat
| - MOE Communication: | ||
| - `moe_a2a_dispatch_combine` - MoE All-to-All dispatch + combine benchmark for multi-GPU expert-parallel inference. Requires `mpirun` for multi-GPU execution. Supports optional quantization (FP8, NVFP4, FP8 block-scale) and real MoE kernel computation. |
There was a problem hiding this comment.
Fix nested list indentation to satisfy MD007.
Line 33 is flagged by markdownlint for list indent. Align the nested bullet with the configured 2‑space indent.
🛠️ Suggested fix
- - `moe_a2a_dispatch_combine` - MoE All-to-All dispatch + combine benchmark for multi-GPU expert-parallel inference. Requires `mpirun` for multi-GPU execution. Supports optional quantization (FP8, NVFP4, FP8 block-scale) and real MoE kernel computation.
+ - `moe_a2a_dispatch_combine` - MoE All-to-All dispatch + combine benchmark for multi-GPU expert-parallel inference. Requires `mpirun` for multi-GPU execution. Supports optional quantization (FP8, NVFP4, FP8 block-scale) and real MoE kernel computation.📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| - MOE Communication: | |
| - `moe_a2a_dispatch_combine` - MoE All-to-All dispatch + combine benchmark for multi-GPU expert-parallel inference. Requires `mpirun` for multi-GPU execution. Supports optional quantization (FP8, NVFP4, FP8 block-scale) and real MoE kernel computation. | |
| - MOE Communication: | |
| - `moe_a2a_dispatch_combine` - MoE All-to-All dispatch + combine benchmark for multi-GPU expert-parallel inference. Requires `mpirun` for multi-GPU execution. Supports optional quantization (FP8, NVFP4, FP8 block-scale) and real MoE kernel computation. |
🧰 Tools
🪛 markdownlint-cli2 (0.18.1)
33-33: Unordered list indentation
Expected: 2; Actual: 4
(MD007, ul-indent)
🤖 Prompt for AI Agents
In `@benchmarks/README.md` around lines 32 - 33, The nested list under the "MOE
Communication:" bullet has incorrect indentation; adjust the nested bullet for
`moe_a2a_dispatch_combine` to use a 2-space indent (align the nested dash two
spaces under the parent list marker) so the nested item conforms to MD007;
ensure the code span and description remain on the same line after the nested
dash.
| # For other routing methods, use simple top-k as approximation | ||
| # This is accurate for Default, Renormalize, RenormalizeNaive, TopK | ||
| # and approximate for Llama4 | ||
| _, selected_experts = _compute_routing(routing_logits.float(), top_k) | ||
| _, selected_experts = compute_routing(routing_logits.float(), top_k) | ||
| return selected_experts |
There was a problem hiding this comment.
Apply routing_bias for accurate expert selection in non‑DeepSeek paths.
Line 392 uses compute_routing on raw logits, which ignores routing_bias when it’s provided. That skews selected_experts and bandwidth metrics. Consider adding the bias before routing.
🛠️ Suggested fix
- _, selected_experts = compute_routing(routing_logits.float(), top_k)
+ logits = routing_logits.float()
+ if routing_bias is not None:
+ logits = logits + routing_bias.float()
+ _, selected_experts = compute_routing(logits, top_k)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # For other routing methods, use simple top-k as approximation | |
| # This is accurate for Default, Renormalize, RenormalizeNaive, TopK | |
| # and approximate for Llama4 | |
| _, selected_experts = _compute_routing(routing_logits.float(), top_k) | |
| _, selected_experts = compute_routing(routing_logits.float(), top_k) | |
| return selected_experts | |
| # For other routing methods, use simple top-k as approximation | |
| # This is accurate for Default, Renormalize, RenormalizeNaive, TopK | |
| # and approximate for Llama4 | |
| logits = routing_logits.float() | |
| if routing_bias is not None: | |
| logits = logits + routing_bias.float() | |
| _, selected_experts = compute_routing(logits, top_k) | |
| return selected_experts |
🤖 Prompt for AI Agents
In `@benchmarks/routines/moe.py` around lines 389 - 393, The code calls
compute_routing(routing_logits.float(), top_k) and ignores routing_bias; update
the non‑DeepSeek path to add routing_bias into routing_logits before selection
so selected_experts and bandwidth metrics reflect the bias. Specifically, where
compute_routing is used (function/variable names: compute_routing,
routing_logits, routing_bias, selected_experts, top_k), add routing_bias (with
proper dtype/device broadcasting and a None check) to routing_logits (e.g.,
routing_logits + routing_bias) and pass the biased logits to compute_routing,
then return selected_experts.
| # Generate random routing logits for reference | ||
| routing_logits = torch.rand(num_tokens, num_experts, device=device).to( | ||
| torch.bfloat16 | ||
| ) |
There was a problem hiding this comment.
Use float32 routing_logits for FP8 block‑scale routing.
Line 279–282 uses bfloat16 logits, but FP8 block‑scale routing expects float32 (per kernel constraints noted elsewhere in this PR). Using bfloat16 risks incorrect routing or runtime errors.
🛠️ Suggested fix
- routing_logits = torch.rand(num_tokens, num_experts, device=device).to(
- torch.bfloat16
- )
+ routing_logits = torch.rand(num_tokens, num_experts, device=device).to(
+ torch.float32
+ )🤖 Prompt for AI Agents
In `@tests/moe/test_trtllm_gen_routed_fused_moe.py` around lines 279 - 282, The
tests generate routing_logits as torch.bfloat16 which is incompatible with the
FP8 block-scale routing kernel; change the tensor creation for routing_logits to
use torch.float32 so the FP8 block-scale routing receives float32 inputs. Locate
the variable routing_logits in the test (the torch.rand(...) call that currently
casts to torch.bfloat16) and remove or replace the .to(torch.bfloat16) with
.to(torch.float32) (or simply create with dtype=torch.float32) so downstream
code like the FP8 block-scale routing path receives float32 logits.
|
Selected unit tests (see below) passed locally after the change to CUDA/C source files |
|
/bot run |
|
[FAILED] Pipeline #42365843: 9/20 passed |
|
@aleozlx I see most of the CI errors are unrelated: And: |
|
agree. this is good to go |
|
there was a pipeline blockage, i did a |
|
everything passed now |
|
could you help merge this pls |
📌 Description
--real_mathflag formoe_comm.pybench script to run A2A + MoE:a2a dispatch -> trtllm moe (nvfp4 or fp8_block_scale) -> a2a combinetrtllm_gen_fp8_routed_fused_moespecifically for A2A + MoE benchmarkExample:
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Selected unit tests (see below) passed locally after the change to CUDA/C source files
Reviewer Notes
Summary by CodeRabbit
New Features
Documentation
Tests
✏️ Tip: You can customize this high-level summary in your review settings.