-
Notifications
You must be signed in to change notification settings - Fork 760
feat: add trtllm_fp8_block_scale_routed_moe API #2382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
416fcbf
b2fcd9b
5712ad0
5b1bb34
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,177 @@ | ||
| """ | ||
| Copyright (c) 2025 by FlashInfer team. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. | ||
| """ | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from flashinfer import ( | ||
| RoutingMethodType, | ||
| ) | ||
| from flashinfer.fused_moe import ( | ||
| trtllm_fp8_block_scale_moe, | ||
| trtllm_fp8_block_scale_routed_moe, | ||
| ) | ||
| from flashinfer.utils import device_support_pdl | ||
|
|
||
| from .test_trtllm_gen_fused_moe import ( | ||
| routing_reference_renormalize, | ||
| routing_reference_renormalize_naive, | ||
| routing_reference_topk, | ||
| ) | ||
|
|
||
| from flashinfer.utils import get_compute_capability | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("num_tokens", [1, 8, 1024]) | ||
| @pytest.mark.parametrize("hidden_size", [1024, 2048, 3072, 4096]) | ||
| @pytest.mark.parametrize("intermediate_size", [1024, 2048, 3072, 4096]) | ||
| @pytest.mark.parametrize("num_experts", [128, 256]) | ||
| @pytest.mark.parametrize("top_k", [4, 8]) | ||
| @pytest.mark.parametrize( | ||
| "routing_method_type", | ||
| [ | ||
| RoutingMethodType.Renormalize, | ||
| RoutingMethodType.RenormalizeNaive, | ||
| RoutingMethodType.TopK, | ||
| ], | ||
| ) | ||
| def test_trtllm_fp8_routed_fused_moe( | ||
| num_tokens: int, | ||
| hidden_size: int, | ||
| intermediate_size: int, | ||
| top_k: int, | ||
| num_experts: int, | ||
| routing_method_type: RoutingMethodType, | ||
| ): | ||
| compute_capability = get_compute_capability(torch.device(device="cuda")) | ||
| if compute_capability[0] not in [10]: | ||
| pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.") | ||
| torch.manual_seed(42) | ||
| device = torch.device("cuda:0") | ||
| enable_pdl = device_support_pdl(device) | ||
| routing_logits = torch.rand(num_tokens, num_experts, device=device).to( | ||
| torch.bfloat16 | ||
| ) | ||
|
|
||
| # Create FP8 hidden states and scales | ||
| hidden_states = torch.randn(num_tokens, hidden_size, device=device).to( | ||
| torch.float8_e4m3fn | ||
| ) | ||
| # Block scale: [hidden_size//128, num_tokens] | ||
| hidden_states_scale = torch.rand( | ||
| hidden_size // 128, num_tokens, device=device, dtype=torch.float32 | ||
| ) | ||
|
|
||
| # Create FP8 weights and scales | ||
| gemm1_weights = torch.randn( | ||
| num_experts, intermediate_size * 2, hidden_size, device=device | ||
| ).to(torch.float8_e4m3fn) | ||
| gemm1_weights_scale = torch.rand( | ||
| num_experts, | ||
| intermediate_size * 2 // 128, | ||
| hidden_size // 128, | ||
| device=device, | ||
| dtype=torch.float32, | ||
| ) | ||
|
|
||
| gemm2_weights = torch.randn( | ||
| num_experts, hidden_size, intermediate_size, device=device | ||
| ).to(torch.float8_e4m3fn) | ||
| gemm2_weights_scale = torch.rand( | ||
| num_experts, | ||
| hidden_size // 128, | ||
| intermediate_size // 128, | ||
| device=device, | ||
| dtype=torch.float32, | ||
| ) | ||
|
|
||
| # Run the non-routed version as reference | ||
| reference_output = trtllm_fp8_block_scale_moe( | ||
| routing_logits, | ||
| None, # routing_bias | ||
| hidden_states, | ||
| hidden_states_scale, | ||
| gemm1_weights, | ||
| gemm1_weights_scale, | ||
| gemm2_weights, | ||
| gemm2_weights_scale, | ||
| num_experts, | ||
| top_k, | ||
| None, # n_group | ||
| None, # topk_group | ||
| intermediate_size, | ||
| 0, # local_expert_offset | ||
| num_experts, | ||
| None, # routed_scaling_factor | ||
| routing_method_type.value, | ||
| False, # use_shuffled_weight | ||
| 0, # weight_layout | ||
| enable_pdl, | ||
| ).to(torch.float) | ||
|
|
||
| # Compute routing for routed version | ||
| if routing_method_type == RoutingMethodType.Renormalize: | ||
| permute_info, expert_weights = routing_reference_renormalize( | ||
| routing_logits, top_k, num_experts, 8 | ||
| ) | ||
| elif routing_method_type == RoutingMethodType.RenormalizeNaive: | ||
| permute_info, expert_weights = routing_reference_renormalize_naive( | ||
| routing_logits, top_k, num_experts, 8 | ||
| ) | ||
| elif routing_method_type == RoutingMethodType.TopK: | ||
| permute_info, expert_weights = routing_reference_topk( | ||
| routing_logits, top_k, num_experts, 8 | ||
| ) | ||
| topk_ids = permute_info["topKIndices"].to(torch.int32) | ||
| expert_weights = expert_weights.view(num_tokens, num_experts)[ | ||
| torch.arange(num_tokens).unsqueeze(1), topk_ids | ||
| ].to(torch.bfloat16) | ||
|
|
||
| # Pack topk_ids and expert_weights into a single tensor | ||
| packed_tensor = (topk_ids.to(torch.int32) << 16) | expert_weights.to( | ||
| torch.bfloat16 | ||
| ).view(torch.int16) | ||
|
|
||
|
Comment on lines
+143
to
+147
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: # First, find the flashinfer/fused_moe/core.py file and search for routed API docstring
fd "core.py" --path "*fused_moe*" | head -20Repository: flashinfer-ai/flashinfer Length of output: 283 🏁 Script executed: # Also locate the test file to understand the context better
fd "test_trtllm_fp8_routed_fused_moe.py" | head -10Repository: flashinfer-ai/flashinfer Length of output: 113 🏁 Script executed: # Find core.py files
fd "core\.py" | grep -i moeRepository: flashinfer-ai/flashinfer Length of output: 96 🏁 Script executed: # View the test file around lines 143-147 to get context
head -160 tests/moe/test_trtllm_fp8_routed_fused_moe.py | tail -30Repository: flashinfer-ai/flashinfer Length of output: 1119 🏁 Script executed: # Check the flashinfer/fused_moe/core.py file for routed API docstring
cat flashinfer/fused_moe/core.py | head -150Repository: flashinfer-ai/flashinfer Length of output: 4391 🏁 Script executed: # Search for routed_moe function and its documentation
rg "routed_moe|trtllm_fp8_block_scale_routed_moe" flashinfer/fused_moe/core.py -A 20 | head -80Repository: flashinfer-ai/flashinfer Length of output: 1557 🏁 Script executed: # Search for the docstring of trtllm_fp8_block_scale_routed_moe function
rg -A 50 "def trtllm_fp8_block_scale_routed_moe" flashinfer/fused_moe/core.py | head -100Repository: flashinfer-ai/flashinfer Length of output: 2741 🏁 Script executed: # Also search for any documentation about bit layout or packing format
rg "upper.*bits|lower.*bits|score.*index|packed.*tensor" flashinfer/fused_moe/core.py -B 2 -A 2Repository: flashinfer-ai/flashinfer Length of output: 50 🏁 Script executed: # Search for kernel implementation or C++ code related to topk_ids packing
fd -e h -e cu | xargs rg "topk_ids|packed.*index|score.*bits" -l | head -20Repository: flashinfer-ai/flashinfer Length of output: 148 🏁 Script executed: # Search for any other test files that use trtllm_fp8_block_scale_routed_moe to see how they pack
rg "trtllm_fp8_block_scale_routed_moe" tests/ -lRepository: flashinfer-ai/flashinfer Length of output: 113 🏁 Script executed: # Check the kernel launcher to see what it expects
cat csrc/trtllm_fused_moe_kernel_launcher.cu | head -200Repository: flashinfer-ai/flashinfer Length of output: 8174 🏁 Script executed: # Search for bit shift or packing operations in the kernel
rg ">> 16|& 0xFFFF|topk" csrc/trtllm_fused_moe_kernel_launcher.cu -B 2 -A 2Repository: flashinfer-ai/flashinfer Length of output: 6658 🏁 Script executed: # Search for the actual unpacking or usage of topk_ids in the kernel headers
fd -e h | xargs rg "topk_ids|extract.*index|extract.*score" -B 3 -A 3 | head -100Repository: flashinfer-ai/flashinfer Length of output: 50 🏁 Script executed: # Look at the RoutingKernel.h to see how topk is structured
cat flashinfer/include/flashinfer/trtllm/fused_moe/RoutingKernel.h 2>/dev/null | head -100Repository: flashinfer-ai/flashinfer Length of output: 50 🏁 Script executed: # Find runner.h and look for topk_ids or packing documentation
find . -name "runner.h" | xargs cat 2>/dev/null | head -200Repository: flashinfer-ai/flashinfer Length of output: 8394 🏁 Script executed: # Search more broadly for bit manipulation patterns in fused_moe headers
find . -path "*/trtllm/fused_moe/*.h" -type f | head -10Repository: flashinfer-ai/flashinfer Length of output: 324 🏁 Script executed: # Check DevKernel.h for topk_ids packing/unpacking
cat ./include/flashinfer/trtllm/fused_moe/DevKernel.h | head -300Repository: flashinfer-ai/flashinfer Length of output: 20530 🏁 Script executed: # Search for explicit bit operations on expert indices or weights
rg ">> 16|<< 16|0xFFFF" ./include/flashinfer/trtllm/fused_moe/ -B 3 -A 3Repository: flashinfer-ai/flashinfer Length of output: 2309 🏁 Script executed: # Get more context around the packing/unpacking in the kernel
cat ./include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh | grep -A 15 "unpack"Repository: flashinfer-ai/flashinfer Length of output: 2310 🏁 Script executed: # Also check if there are any other references to the packing format
rg "moveBits|compactTmp" ./include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh -B 5 -A 5 | head -80Repository: flashinfer-ai/flashinfer Length of output: 1537 Bit layout is reversed: test packs indices (upper) and weights (lower), but kernel expects weights (upper) and indices (lower). The kernel implementation in However, the test at lines 143-147 reverses this: 🤖 Prompt for AI Agents |
||
| # Run the routed version | ||
| output = trtllm_fp8_block_scale_routed_moe( | ||
| packed_tensor, | ||
| None, # routing_bias | ||
| hidden_states, | ||
| hidden_states_scale, | ||
| gemm1_weights, | ||
| gemm1_weights_scale, | ||
| gemm2_weights, | ||
| gemm2_weights_scale, | ||
| num_experts, | ||
| top_k, | ||
| None, # n_group | ||
| None, # topk_group | ||
| intermediate_size, | ||
| 0, # local_expert_offset | ||
| num_experts, | ||
| None, # routed_scaling_factor | ||
| routing_method_type.value, | ||
| False, # use_shuffled_weight | ||
| 0, # weight_layout | ||
| enable_pdl, | ||
| ).to(torch.float) | ||
|
|
||
| # Compare outputs | ||
| mask = torch.isclose(output, reference_output, rtol=1e-3, atol=1e-3) | ||
|
|
||
| # mismatch percentage | ||
| mismatch_pct = (~mask).float().mean().item() * 100 | ||
| assert mismatch_pct < 6, f"Mismatch percentage is {mismatch_pct:.2f}" | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Parameter grid is likely to OOM on common SM100 cards.
The combination
num_experts=256,hidden_size=4096,intermediate_size=4096allocates ~12+ GB of FP8 weights alone (Line 79–99), which will exceed memory on many GPUs. Please reduce the grid or add a pre‑allocation skip based on estimated bytes.🧮 Example guard to avoid OOM
@@ def test_trtllm_fp8_routed_fused_moe(...): device = torch.device("cuda:0") enable_pdl = device_support_pdl(device) + bytes_per_elem = torch.tensor([], dtype=torch.float8_e4m3fn, device=device).element_size() + est_bytes = num_experts * ( + 2 * intermediate_size * hidden_size + hidden_size * intermediate_size + ) * bytes_per_elem + if est_bytes > 0.5 * torch.cuda.get_device_properties(device).total_memory: + pytest.skip("Skipping large configuration to avoid OOM.")🤖 Prompt for AI Agents