feat: add trtllm_fp8_block_scale_routed_moe API#2382
Conversation
Add FP8 routed MOE API that skips routing computation and accepts pre-computed top-k expert indices and weights, matching the pattern established by trtllm_fp4_block_scale_routed_moe. Changes: - Modified Fp8BlockScaleLauncher to accept Optional routing_logits - Added expert_indices and expert_weights to FP8 launcher - Implemented trtllm_fp8_block_scale_routed_moe Python API - Added comprehensive unit tests - Updated documentation index Fixes #2381 Co-authored-by: Zihao Ye <yzh119@users.noreply.github.com>
📝 WalkthroughWalkthroughAdds a routed FP8 block-scale MoE API and threads pre-computed top-k indices and expert weights through the FP8 launcher/kernel; Python wrappers, docs, and tests for the routed variant are added, and routing inputs are treated as optional with guarded validation. Changes
Sequence Diagram(s)mermaid (Note: colored rectangles not required for this simple flow.) Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
csrc/trtllm_fused_moe_kernel_launcher.cu (1)
717-817: Routed FP8 path still dereferencesrouting_logitsand ignores provided routing buffers.
routing_logitsis now Optional, butFp8BlockScaleLauncher::prepare_routing()still callsrouting_logits.value()(Line 808) and the baserun()always executes the routing kernel. When the routed API passesNone, this will throw and thetopk_ids/expert_weightsinputs are never used. Please branch onrouting_logits.has_value()to skip routing and wireworkspace.routing_expert_indexes/workspace.expert_weightsto the provided tensors (similar to the FP4 routed path).🐛 Suggested fix sketch
@@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { void prepare_routing() override { FusedMoeLauncher::prepare_routing_common(); @@ - args->mUseDeepSeekFp8 = true; - args->routing_logits = static_cast<float*>(routing_logits.value().data_ptr()); - expert_weights = - alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device()); - workspace.expert_weights = expert_weights.data_ptr(); + args->mUseDeepSeekFp8 = true; + if (routing_logits.has_value()) { + args->routing_logits = static_cast<float*>(routing_logits.value().data_ptr()); + expert_weights = + alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device()); + workspace.expert_weights = expert_weights.data_ptr(); + } else { + args->routing_logits = nullptr; + workspace.routing_expert_indexes = static_cast<int*>(expert_indices.data_ptr()); + workspace.expert_weights = const_cast<void*>(expert_weights.data_ptr()); + } } @@ - routing_runner.run(...); + if (routing_logits.has_value()) { + routing_runner.run(...); + }Also applies to: 1574-1656
🤖 Fix all issues with AI agents
In `@tests/moe/test_trtllm_fp8_routed_fused_moe.py`:
- Around line 38-99: The test test_trtllm_fp8_routed_fused_moe currently
parametrizes large combinations (num_experts, hidden_size, intermediate_size)
that can OOM; add a pre-allocation guard that estimates required bytes for
gemm1_weights, gemm2_weights and their scales (use num_experts,
intermediate_size, hidden_size, element sizes: 1 byte for FP8 tensors and 4
bytes for float32 scales) and compare against the GPU total memory
(torch.cuda.get_device_properties(device).total_memory); if estimated_bytes
exceeds a safe fraction (e.g. 60–80%) of total_memory then call pytest.skip with
a clear message, otherwise proceed to allocate gemm1_weights and gemm2_weights
as before. Ensure the check is placed at the start of
test_trtllm_fp8_routed_fused_moe before creating gemm1_weights/gemm2_weights and
references the existing variables num_experts, hidden_size, intermediate_size,
gemm1_weights_scale, gemm2_weights_scale.
- Around line 143-147: The test currently packs topk_ids into the upper 16 bits
and expert_weights into the lower 16 bits (packed_tensor), but the kernel and
docstring expect the opposite (weights in upper 16 bits, indices in lower 16
bits, with index encoded as 65535 - idx). Change the packing in
tests/moe/test_trtllm_fp8_routed_fused_moe.py so packed_tensor places
expert_weights (converted to bfloat16 and interpreted as 16-bit int) into the
most significant 16 bits (shift left 16) and places the transformed index (65535
- topk_ids, as a 16-bit value) in the least significant 16 bits; reference the
variables packed_tensor, topk_ids, expert_weights and align with
RoutingKernelTopK.cuh's (value << 16) | (65535 - idx) layout.
🧹 Nitpick comments (1)
flashinfer/fused_moe/core.py (1)
2330-2422: Confirm whether this new API should be cached.Per the flashinfer/**/*.py guideline, Python API functions should use
@functools.cacheto avoid recompilation, buttrtllm_fp8_block_scale_routed_moe(Line 2330) isn’t cached. If caching isn’t appropriate for tensor inputs, please confirm and document the exception; otherwise add the decorator or a cached wrapper. As per coding guidelines.
| @pytest.mark.parametrize("num_tokens", [1, 8, 1024]) | ||
| @pytest.mark.parametrize("hidden_size", [1024, 2048, 3072, 4096]) | ||
| @pytest.mark.parametrize("intermediate_size", [1024, 2048, 3072, 4096]) | ||
| @pytest.mark.parametrize("num_experts", [128, 256]) | ||
| @pytest.mark.parametrize("top_k", [4, 8]) | ||
| @pytest.mark.parametrize( | ||
| "routing_method_type", | ||
| [ | ||
| RoutingMethodType.Renormalize, | ||
| RoutingMethodType.RenormalizeNaive, | ||
| RoutingMethodType.TopK, | ||
| ], | ||
| ) | ||
| def test_trtllm_fp8_routed_fused_moe( | ||
| num_tokens: int, | ||
| hidden_size: int, | ||
| intermediate_size: int, | ||
| top_k: int, | ||
| num_experts: int, | ||
| routing_method_type: RoutingMethodType, | ||
| ): | ||
| compute_capability = get_compute_capability(torch.device(device="cuda")) | ||
| if compute_capability[0] not in [10]: | ||
| pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.") | ||
| torch.manual_seed(42) | ||
| device = torch.device("cuda:0") | ||
| enable_pdl = device_support_pdl(device) | ||
| routing_logits = torch.rand(num_tokens, num_experts, device=device).to( | ||
| torch.bfloat16 | ||
| ) | ||
|
|
||
| # Create FP8 hidden states and scales | ||
| hidden_states = torch.randn(num_tokens, hidden_size, device=device).to( | ||
| torch.float8_e4m3fn | ||
| ) | ||
| # Block scale: [hidden_size//128, num_tokens] | ||
| hidden_states_scale = torch.rand( | ||
| hidden_size // 128, num_tokens, device=device, dtype=torch.float32 | ||
| ) | ||
|
|
||
| # Create FP8 weights and scales | ||
| gemm1_weights = torch.randn( | ||
| num_experts, intermediate_size * 2, hidden_size, device=device | ||
| ).to(torch.float8_e4m3fn) | ||
| gemm1_weights_scale = torch.rand( | ||
| num_experts, | ||
| intermediate_size * 2 // 128, | ||
| hidden_size // 128, | ||
| device=device, | ||
| dtype=torch.float32, | ||
| ) | ||
|
|
||
| gemm2_weights = torch.randn( | ||
| num_experts, hidden_size, intermediate_size, device=device | ||
| ).to(torch.float8_e4m3fn) | ||
| gemm2_weights_scale = torch.rand( | ||
| num_experts, | ||
| hidden_size // 128, | ||
| intermediate_size // 128, | ||
| device=device, | ||
| dtype=torch.float32, | ||
| ) |
There was a problem hiding this comment.
Parameter grid is likely to OOM on common SM100 cards.
The combination num_experts=256, hidden_size=4096, intermediate_size=4096 allocates ~12+ GB of FP8 weights alone (Line 79–99), which will exceed memory on many GPUs. Please reduce the grid or add a pre‑allocation skip based on estimated bytes.
🧮 Example guard to avoid OOM
@@ def test_trtllm_fp8_routed_fused_moe(...):
device = torch.device("cuda:0")
enable_pdl = device_support_pdl(device)
+ bytes_per_elem = torch.tensor([], dtype=torch.float8_e4m3fn, device=device).element_size()
+ est_bytes = num_experts * (
+ 2 * intermediate_size * hidden_size + hidden_size * intermediate_size
+ ) * bytes_per_elem
+ if est_bytes > 0.5 * torch.cuda.get_device_properties(device).total_memory:
+ pytest.skip("Skipping large configuration to avoid OOM.")🤖 Prompt for AI Agents
In `@tests/moe/test_trtllm_fp8_routed_fused_moe.py` around lines 38 - 99, The test
test_trtllm_fp8_routed_fused_moe currently parametrizes large combinations
(num_experts, hidden_size, intermediate_size) that can OOM; add a pre-allocation
guard that estimates required bytes for gemm1_weights, gemm2_weights and their
scales (use num_experts, intermediate_size, hidden_size, element sizes: 1 byte
for FP8 tensors and 4 bytes for float32 scales) and compare against the GPU
total memory (torch.cuda.get_device_properties(device).total_memory); if
estimated_bytes exceeds a safe fraction (e.g. 60–80%) of total_memory then call
pytest.skip with a clear message, otherwise proceed to allocate gemm1_weights
and gemm2_weights as before. Ensure the check is placed at the start of
test_trtllm_fp8_routed_fused_moe before creating gemm1_weights/gemm2_weights and
references the existing variables num_experts, hidden_size, intermediate_size,
gemm1_weights_scale, gemm2_weights_scale.
| # Pack topk_ids and expert_weights into a single tensor | ||
| packed_tensor = (topk_ids.to(torch.int32) << 16) | expert_weights.to( | ||
| torch.bfloat16 | ||
| ).view(torch.int16) | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, find the flashinfer/fused_moe/core.py file and search for routed API docstring
fd "core.py" --path "*fused_moe*" | head -20Repository: flashinfer-ai/flashinfer
Length of output: 283
🏁 Script executed:
# Also locate the test file to understand the context better
fd "test_trtllm_fp8_routed_fused_moe.py" | head -10Repository: flashinfer-ai/flashinfer
Length of output: 113
🏁 Script executed:
# Find core.py files
fd "core\.py" | grep -i moeRepository: flashinfer-ai/flashinfer
Length of output: 96
🏁 Script executed:
# View the test file around lines 143-147 to get context
head -160 tests/moe/test_trtllm_fp8_routed_fused_moe.py | tail -30Repository: flashinfer-ai/flashinfer
Length of output: 1119
🏁 Script executed:
# Check the flashinfer/fused_moe/core.py file for routed API docstring
cat flashinfer/fused_moe/core.py | head -150Repository: flashinfer-ai/flashinfer
Length of output: 4391
🏁 Script executed:
# Search for routed_moe function and its documentation
rg "routed_moe|trtllm_fp8_block_scale_routed_moe" flashinfer/fused_moe/core.py -A 20 | head -80Repository: flashinfer-ai/flashinfer
Length of output: 1557
🏁 Script executed:
# Search for the docstring of trtllm_fp8_block_scale_routed_moe function
rg -A 50 "def trtllm_fp8_block_scale_routed_moe" flashinfer/fused_moe/core.py | head -100Repository: flashinfer-ai/flashinfer
Length of output: 2741
🏁 Script executed:
# Also search for any documentation about bit layout or packing format
rg "upper.*bits|lower.*bits|score.*index|packed.*tensor" flashinfer/fused_moe/core.py -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Search for kernel implementation or C++ code related to topk_ids packing
fd -e h -e cu | xargs rg "topk_ids|packed.*index|score.*bits" -l | head -20Repository: flashinfer-ai/flashinfer
Length of output: 148
🏁 Script executed:
# Search for any other test files that use trtllm_fp8_block_scale_routed_moe to see how they pack
rg "trtllm_fp8_block_scale_routed_moe" tests/ -lRepository: flashinfer-ai/flashinfer
Length of output: 113
🏁 Script executed:
# Check the kernel launcher to see what it expects
cat csrc/trtllm_fused_moe_kernel_launcher.cu | head -200Repository: flashinfer-ai/flashinfer
Length of output: 8174
🏁 Script executed:
# Search for bit shift or packing operations in the kernel
rg ">> 16|& 0xFFFF|topk" csrc/trtllm_fused_moe_kernel_launcher.cu -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 6658
🏁 Script executed:
# Search for the actual unpacking or usage of topk_ids in the kernel headers
fd -e h | xargs rg "topk_ids|extract.*index|extract.*score" -B 3 -A 3 | head -100Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look at the RoutingKernel.h to see how topk is structured
cat flashinfer/include/flashinfer/trtllm/fused_moe/RoutingKernel.h 2>/dev/null | head -100Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Find runner.h and look for topk_ids or packing documentation
find . -name "runner.h" | xargs cat 2>/dev/null | head -200Repository: flashinfer-ai/flashinfer
Length of output: 8394
🏁 Script executed:
# Search more broadly for bit manipulation patterns in fused_moe headers
find . -path "*/trtllm/fused_moe/*.h" -type f | head -10Repository: flashinfer-ai/flashinfer
Length of output: 324
🏁 Script executed:
# Check DevKernel.h for topk_ids packing/unpacking
cat ./include/flashinfer/trtllm/fused_moe/DevKernel.h | head -300Repository: flashinfer-ai/flashinfer
Length of output: 20530
🏁 Script executed:
# Search for explicit bit operations on expert indices or weights
rg ">> 16|<< 16|0xFFFF" ./include/flashinfer/trtllm/fused_moe/ -B 3 -A 3Repository: flashinfer-ai/flashinfer
Length of output: 2309
🏁 Script executed:
# Get more context around the packing/unpacking in the kernel
cat ./include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh | grep -A 15 "unpack"Repository: flashinfer-ai/flashinfer
Length of output: 2310
🏁 Script executed:
# Also check if there are any other references to the packing format
rg "moveBits|compactTmp" ./include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh -B 5 -A 5 | head -80Repository: flashinfer-ai/flashinfer
Length of output: 1537
Bit layout is reversed: test packs indices (upper) and weights (lower), but kernel expects weights (upper) and indices (lower).
The kernel implementation in RoutingKernelTopK.cuh packs as (value << 16) | (65535 - idx), placing weights in the upper 16 bits and indices in the lower 16 bits. The API docstring correctly states: "most significant 16 bits represent the score (bfloat16) and the least significant 16 bits represent the index."
However, the test at lines 143-147 reverses this: (topk_ids.to(torch.int32) << 16) | expert_weights... places indices in the upper bits and weights in the lower bits. This will cause incorrect expert routing. Swap the bit positions or add a helper function to ensure consistent packing across the codebase.
🤖 Prompt for AI Agents
In `@tests/moe/test_trtllm_fp8_routed_fused_moe.py` around lines 143 - 147, The
test currently packs topk_ids into the upper 16 bits and expert_weights into the
lower 16 bits (packed_tensor), but the kernel and docstring expect the opposite
(weights in upper 16 bits, indices in lower 16 bits, with index encoded as 65535
- idx). Change the packing in tests/moe/test_trtllm_fp8_routed_fused_moe.py so
packed_tensor places expert_weights (converted to bfloat16 and interpreted as
16-bit int) into the most significant 16 bits (shift left 16) and places the
transformed index (65535 - topk_ids, as a 16-bit value) in the least significant
16 bits; reference the variables packed_tensor, topk_ids, expert_weights and
align with RoutingKernelTopK.cuh's (value << 16) | (65535 - idx) layout.
Code Review SummaryThis PR adds the trtllm_fp8_block_scale_routed_moe API to match the pattern established by trtllm_fp4_block_scale_routed_moe. Overall, the implementation looks solid and follows FlashInfer conventions well. I've identified a few issues that should be addressed: Critical Issues1. Inconsistent packing format in documentation (flashinfer/fused_moe/core.py:2361-2364)The docstring describes a packed format where the most significant 16 bits represent the score (bfloat16) and the least significant 16 bits represent the index. However, the test code (tests/moe/test_trtllm_fp8_routed_fused_moe.py:144-146) implements the OPPOSITE packing: indices are shifted to the most significant 16 bits and weights go in the least significant bits. Recommendation: Update the docstring to accurately describe the packing format as implemented, or verify with the FP4 implementation to ensure consistency. 2. Unused expert_weights tensor allocation (flashinfer/fused_moe/core.py:2393-2396)The routed variant allocates an empty expert_weights tensor that is never populated. This tensor is passed to the C++ backend but appears to be a placeholder since weights are packed in topk_ids. Questions:
Looking at csrc/trtllm_fused_moe_kernel_launcher.cu:814-816, the launcher allocates its own expert_weights tensor and sets workspace.expert_weights, so this parameter might be overwritten anyway. Code Quality Issues3. Missing parameter documentationThe tune_max_num_tokens parameter is present in the signature but not described in the docstring Args section. It should be properly documented in Args for consistency with other APIs. 4. Inconsistent naming: topk_ids vs expert_indicesThe Python API uses topk_ids, but the C++ launcher stores it as expert_indices (csrc/trtllm_fused_moe_kernel_launcher.cu:913, 721). While this works, consistent naming would improve code readability. Positive Observations
SuggestionsPerformance considerationThe test uses a 6% mismatch threshold (tests/moe/test_trtllm_fp8_routed_fused_moe.py:177), which seems quite high. Is this expected for FP8 quantization? Consider adding a comment explaining why this tolerance is necessary. Code organizationThe non-routed trtllm_fp8_block_scale_moe now creates dummy tensors (lines 2294-2301) that weren't there before. This suggests the C++ API signature changed. Consider documenting this architectural decision. RecommendationsMust fix before merge:
Should fix: Nice to have: Overall, this is a well-structured PR that follows FlashInfer conventions. The main concerns are documentation accuracy and clarifying the role of placeholder tensors. Great work matching the established FP4 pattern! |
|
@claude regarding the packing format, we are using little-endian. |
|
Claude finished @yzh119's task in 4m 12s —— View job Clarifying the packing format and little-endianThank you for the clarification! I understand now. After analyzing the code, I see that the packing format in the test is consistent with how little-endian systems work: Test packing (lines 144-146): packed_tensor = (topk_ids.to(torch.int32) << 16) | expert_weights.to(torch.bfloat16).view(torch.int16)This creates a 32-bit int with indices in the upper 16 bits and weights in the lower 16 bits. On little-endian systems, when this int32 is stored in memory and later read by the kernel, the bit positions remain logically consistent regardless of the byte ordering in RAM. Understanding confirmed: The bit-level packing/unpacking operations ( CodeRabbit's concern about the packing format appears to be based on a misunderstanding of how the data flows through the routed path. The current implementation should work correctly for little-endian systems (which includes all modern x86/ARM GPUs). Is there anything specific about the packing format or the PR implementation that you'd like me to verify or improve? |
|
Nice job! |
|
re @xiaoqi35 unfortunately i believe this backend only supports blackwell |
Add FP8 routed MOE API that skips routing computation and accepts pre-computed top-k expert indices and weights, matching the pattern established by trtllm_fp4_block_scale_routed_moe.
Changes:
Fixes #2381
Generated with Claude Code
Summary by CodeRabbit
New Features
Improvements
Documentation
Public API
Tests
✏️ Tip: You can customize this high-level summary in your review settings.