Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 31 additions & 20 deletions csrc/trtllm_fused_moe_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -714,16 +714,20 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
public:
static constexpr std::array<int32_t, 5> mSupportedTileNums = {8, 16, 32, 64, 128};

Fp8BlockScaleLauncher(TensorView const& routing_logits, Optional<TensorView> const& routing_bias,
TensorView const& hidden_states, TensorView const& hidden_states_scale,
TensorView const& gemm1_weights, TensorView const& gemm1_weights_scale,
TensorView const& gemm2_weights, TensorView const& gemm2_weights_scale)
: FusedMoeLauncher(Optional<TensorView>(routing_logits), routing_bias, hidden_states,
gemm1_weights, Optional<TensorView>(), Optional<TensorView>(),
gemm2_weights, Optional<TensorView>()),
Fp8BlockScaleLauncher(Optional<TensorView> const& routing_logits,
Optional<TensorView> const& routing_bias, TensorView const& hidden_states,
TensorView const& hidden_states_scale, TensorView const& gemm1_weights,
TensorView const& gemm1_weights_scale, TensorView const& gemm2_weights,
TensorView const& gemm2_weights_scale, TensorView const& expert_indices,
TensorView const& expert_weights)
: FusedMoeLauncher(routing_logits, routing_bias, hidden_states, gemm1_weights,
Optional<TensorView>(), Optional<TensorView>(), gemm2_weights,
Optional<TensorView>()),
hidden_states_scale(hidden_states_scale),
gemm1_weights_scale(gemm1_weights_scale),
gemm2_weights_scale(gemm2_weights_scale) {}
gemm2_weights_scale(gemm2_weights_scale),
expert_indices(expert_indices),
expert_weights(expert_weights) {}

void init(std::unique_ptr<tensorrt_llm::kernels::trtllmgen_moe::MoE::MoERunnerArgs>&& args,
int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight,
Expand Down Expand Up @@ -906,6 +910,8 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
TensorView hidden_states_scale;
TensorView gemm1_weights_scale;
TensorView gemm2_weights_scale;
TensorView expert_indices;
TensorView expert_weights;
Tensor gemm1_output_scale;
Tensor activation_output_scale;

Expand Down Expand Up @@ -1565,19 +1571,24 @@ Tensor trtllm_fp8_per_tensor_scale_moe(
}

Tensor trtllm_fp8_block_scale_moe(
TensorView routing_logits, Optional<TensorView> routing_bias, TensorView hidden_states,
TensorView hidden_states_scale, TensorView gemm1_weights, TensorView gemm1_weights_scale,
TensorView gemm2_weights, TensorView gemm2_weights_scale, TensorView output,
int64_t num_experts, int64_t top_k, Optional<int64_t> n_group, Optional<int64_t> topk_group,
int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts,
Optional<double> routed_scaling_factor, int64_t routing_method_type, bool use_shuffled_weight,
int64_t weight_layout, bool enable_pdl, Array<int64_t> config_index) {
Optional<TensorView> routing_logits, TensorView topk_ids, TensorView expert_weights,
Optional<TensorView> routing_bias, TensorView hidden_states, TensorView hidden_states_scale,
TensorView gemm1_weights, TensorView gemm1_weights_scale, TensorView gemm2_weights,
TensorView gemm2_weights_scale, TensorView output, int64_t num_experts, int64_t top_k,
Optional<int64_t> n_group, Optional<int64_t> topk_group, int64_t intermediate_size,
int64_t local_expert_offset, int64_t local_num_experts, Optional<double> routed_scaling_factor,
int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout, bool enable_pdl,
Array<int64_t> config_index) {
// Basic type validation
auto dtype = hidden_states.dtype();
if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_float32) << "routing_logits must be float.";
} else {
TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_bfloat16) << "routing_logits must be bfloat16.";
if (routing_logits.has_value()) {
if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
TVM_FFI_ICHECK_EQ(routing_logits.value().dtype(), dl_float32)
<< "routing_logits must be float.";
} else {
TVM_FFI_ICHECK_EQ(routing_logits.value().dtype(), dl_bfloat16)
<< "routing_logits must be bfloat16.";
}
}
TVM_FFI_ICHECK(dtype == dl_float16 || dtype == dl_bfloat16 || dtype == dl_float8_e4m3fn)
<< "FP8 block scale MoE: hidden_states must be fp16, bf16, or fp8.";
Expand Down Expand Up @@ -1621,7 +1632,7 @@ Tensor trtllm_fp8_block_scale_moe(
// Create and initialize launcher for this tile size
auto launcher = std::make_unique<Fp8BlockScaleLauncher>(
routing_logits, routing_bias, hidden_states, hidden_states_scale, gemm1_weights,
gemm1_weights_scale, gemm2_weights, gemm2_weights_scale);
gemm1_weights_scale, gemm2_weights, gemm2_weights_scale, topk_ids, expert_weights);
launcher->init(std::move(args), curr_tile_N, routing_method_type, use_shuffled_weight,
weight_layout);

Expand Down
2 changes: 2 additions & 0 deletions docs/api/fused_moe.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,7 @@ TensorRT-LLM Fused MoE
:toctree: ../generated

trtllm_fp4_block_scale_moe
trtllm_fp4_block_scale_routed_moe
trtllm_fp8_block_scale_moe
trtllm_fp8_block_scale_routed_moe
trtllm_fp8_per_tensor_scale_moe
2 changes: 2 additions & 0 deletions flashinfer/fused_moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
trtllm_fp4_block_scale_moe,
trtllm_fp4_block_scale_routed_moe,
trtllm_fp8_block_scale_moe,
trtllm_fp8_block_scale_routed_moe,
trtllm_fp8_per_tensor_scale_moe,
trtllm_bf16_moe,
trtllm_mxint4_block_scale_moe,
Expand All @@ -54,6 +55,7 @@
"trtllm_fp4_block_scale_moe",
"trtllm_fp4_block_scale_routed_moe",
"trtllm_fp8_block_scale_moe",
"trtllm_fp8_block_scale_routed_moe",
"trtllm_fp8_per_tensor_scale_moe",
"trtllm_mxint4_block_scale_moe",
"fused_topk_deepseek",
Expand Down
105 changes: 105 additions & 0 deletions flashinfer/fused_moe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2291,8 +2291,113 @@ def trtllm_fp8_block_scale_moe(
output = torch.empty(
hidden_states.shape, dtype=torch.bfloat16, device=hidden_states.device
)
# Create dummy topk_ids and expert_weights for non-routed variant
num_tokens = hidden_states.shape[0]
topk_ids = torch.empty(
num_tokens, top_k, dtype=torch.int32, device=hidden_states.device
)
expert_weights = torch.empty(
num_tokens, top_k, dtype=routing_logits.dtype, device=hidden_states.device
)
return get_trtllm_moe_sm100_module().trtllm_fp8_block_scale_moe(
routing_logits,
topk_ids,
expert_weights,
routing_bias,
hidden_states,
hidden_states_scale,
gemm1_weights,
gemm1_weights_scale,
gemm2_weights,
gemm2_weights_scale,
output,
num_experts,
top_k,
n_group,
topk_group,
intermediate_size,
local_expert_offset,
local_num_experts,
routed_scaling_factor,
routing_method_type,
use_shuffled_weight,
weight_layout,
enable_pdl,
tune_max_num_tokens,
)


@flashinfer_api
def trtllm_fp8_block_scale_routed_moe(
topk_ids: torch.Tensor,
routing_bias: Optional[torch.Tensor],
hidden_states: torch.Tensor,
hidden_states_scale: torch.Tensor,
gemm1_weights: torch.Tensor,
gemm1_weights_scale: torch.Tensor,
gemm2_weights: torch.Tensor,
gemm2_weights_scale: torch.Tensor,
num_experts: int,
top_k: int,
n_group: Optional[int],
topk_group: Optional[int],
intermediate_size: int,
local_expert_offset: int,
local_num_experts: int,
routed_scaling_factor: Optional[float],
routing_method_type: int = 0,
use_shuffled_weight: bool = False,
weight_layout: int = 0,
enable_pdl: Optional[bool] = None,
tune_max_num_tokens: int = 8192,
) -> torch.Tensor:
"""FP8 block scale MoE operation with pre-computed routing (routed variant).

This function skips the routing computation and accepts pre-computed top-k expert
indices and weights. This is useful when routing is computed separately or when
integrating with custom routing strategies.

Args:
topk_ids: [seq_len, top_k] tensor of top-k indices and expert weights.
Dtype must be int32. It must represent a packed value where the most
significant 16 bits represent the score (bfloat16) and the least
significant 16 bits represent the index of the chosen expert (unsigned).
routing_bias: [num_experts] tensor of routing bias
hidden_states: [seq_len, hidden_size] tensor of input hidden states
hidden_states_scale: [hidden_size//128, seq_len] tensor of hidden states block scales
gemm1_weights: [num_experts, 2*intermediate_size, hidden_size] tensor of first layer weights
gemm1_weights_scale: [num_experts, 2*intermediate_size//128, hidden_size//128] tensor of first layer block scales
gemm2_weights: [num_experts, hidden_size, intermediate_size] tensor of second layer weights
gemm2_weights_scale: [num_experts, hidden_size//128, intermediate_size//128] tensor of second layer block scales
num_experts: Total number of experts
top_k: Number of experts to route to per token
n_group: Number of expert groups
topk_group: Number of groups to consider for top-k routing
intermediate_size: Size of intermediate layer
local_expert_offset: Offset of local experts in global expert space
local_num_experts: Number of experts handled by this device
routed_scaling_factor: Scaling factor for routing
routing_method_type: Type of routing method to use (default: 0)
use_shuffled_weight: Whether to use shuffled weight layout (default: False)
weight_layout: Weight layout type (default: 0)
enable_pdl: Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90.
tune_max_num_tokens: Maximum number of tokens for tuning. (default: 8192)

Returns:
torch.Tensor: Output tensor of shape [seq_len, hidden_size]
"""
output = torch.empty(
hidden_states.shape, dtype=torch.bfloat16, device=hidden_states.device
)
# Create empty expert_weights tensor (weights are packed in topk_ids)
num_tokens = hidden_states.shape[0]
expert_weights = torch.empty(
num_tokens, top_k, dtype=torch.bfloat16, device=hidden_states.device
)
return get_trtllm_moe_sm100_module().trtllm_fp8_block_scale_moe(
None, # routing_logits (None for routed variant)
topk_ids,
expert_weights, # empty tensor, weights are packed in topk_ids
routing_bias,
hidden_states,
hidden_states_scale,
Expand Down
177 changes: 177 additions & 0 deletions tests/moe/test_trtllm_fp8_routed_fused_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
"""
Copyright (c) 2025 by FlashInfer team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import pytest
import torch

from flashinfer import (
RoutingMethodType,
)
from flashinfer.fused_moe import (
trtllm_fp8_block_scale_moe,
trtllm_fp8_block_scale_routed_moe,
)
from flashinfer.utils import device_support_pdl

from .test_trtllm_gen_fused_moe import (
routing_reference_renormalize,
routing_reference_renormalize_naive,
routing_reference_topk,
)

from flashinfer.utils import get_compute_capability


@pytest.mark.parametrize("num_tokens", [1, 8, 1024])
@pytest.mark.parametrize("hidden_size", [1024, 2048, 3072, 4096])
@pytest.mark.parametrize("intermediate_size", [1024, 2048, 3072, 4096])
@pytest.mark.parametrize("num_experts", [128, 256])
@pytest.mark.parametrize("top_k", [4, 8])
@pytest.mark.parametrize(
"routing_method_type",
[
RoutingMethodType.Renormalize,
RoutingMethodType.RenormalizeNaive,
RoutingMethodType.TopK,
],
)
def test_trtllm_fp8_routed_fused_moe(
num_tokens: int,
hidden_size: int,
intermediate_size: int,
top_k: int,
num_experts: int,
routing_method_type: RoutingMethodType,
):
compute_capability = get_compute_capability(torch.device(device="cuda"))
if compute_capability[0] not in [10]:
pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.")
torch.manual_seed(42)
device = torch.device("cuda:0")
enable_pdl = device_support_pdl(device)
routing_logits = torch.rand(num_tokens, num_experts, device=device).to(
torch.bfloat16
)

# Create FP8 hidden states and scales
hidden_states = torch.randn(num_tokens, hidden_size, device=device).to(
torch.float8_e4m3fn
)
# Block scale: [hidden_size//128, num_tokens]
hidden_states_scale = torch.rand(
hidden_size // 128, num_tokens, device=device, dtype=torch.float32
)

# Create FP8 weights and scales
gemm1_weights = torch.randn(
num_experts, intermediate_size * 2, hidden_size, device=device
).to(torch.float8_e4m3fn)
gemm1_weights_scale = torch.rand(
num_experts,
intermediate_size * 2 // 128,
hidden_size // 128,
device=device,
dtype=torch.float32,
)

gemm2_weights = torch.randn(
num_experts, hidden_size, intermediate_size, device=device
).to(torch.float8_e4m3fn)
gemm2_weights_scale = torch.rand(
num_experts,
hidden_size // 128,
intermediate_size // 128,
device=device,
dtype=torch.float32,
)
Comment on lines +38 to +99
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Parameter grid is likely to OOM on common SM100 cards.

The combination num_experts=256, hidden_size=4096, intermediate_size=4096 allocates ~12+ GB of FP8 weights alone (Line 79–99), which will exceed memory on many GPUs. Please reduce the grid or add a pre‑allocation skip based on estimated bytes.

🧮 Example guard to avoid OOM
@@ def test_trtllm_fp8_routed_fused_moe(...):
     device = torch.device("cuda:0")
     enable_pdl = device_support_pdl(device)
+    bytes_per_elem = torch.tensor([], dtype=torch.float8_e4m3fn, device=device).element_size()
+    est_bytes = num_experts * (
+        2 * intermediate_size * hidden_size + hidden_size * intermediate_size
+    ) * bytes_per_elem
+    if est_bytes > 0.5 * torch.cuda.get_device_properties(device).total_memory:
+        pytest.skip("Skipping large configuration to avoid OOM.")
🤖 Prompt for AI Agents
In `@tests/moe/test_trtllm_fp8_routed_fused_moe.py` around lines 38 - 99, The test
test_trtllm_fp8_routed_fused_moe currently parametrizes large combinations
(num_experts, hidden_size, intermediate_size) that can OOM; add a pre-allocation
guard that estimates required bytes for gemm1_weights, gemm2_weights and their
scales (use num_experts, intermediate_size, hidden_size, element sizes: 1 byte
for FP8 tensors and 4 bytes for float32 scales) and compare against the GPU
total memory (torch.cuda.get_device_properties(device).total_memory); if
estimated_bytes exceeds a safe fraction (e.g. 60–80%) of total_memory then call
pytest.skip with a clear message, otherwise proceed to allocate gemm1_weights
and gemm2_weights as before. Ensure the check is placed at the start of
test_trtllm_fp8_routed_fused_moe before creating gemm1_weights/gemm2_weights and
references the existing variables num_experts, hidden_size, intermediate_size,
gemm1_weights_scale, gemm2_weights_scale.


# Run the non-routed version as reference
reference_output = trtllm_fp8_block_scale_moe(
routing_logits,
None, # routing_bias
hidden_states,
hidden_states_scale,
gemm1_weights,
gemm1_weights_scale,
gemm2_weights,
gemm2_weights_scale,
num_experts,
top_k,
None, # n_group
None, # topk_group
intermediate_size,
0, # local_expert_offset
num_experts,
None, # routed_scaling_factor
routing_method_type.value,
False, # use_shuffled_weight
0, # weight_layout
enable_pdl,
).to(torch.float)

# Compute routing for routed version
if routing_method_type == RoutingMethodType.Renormalize:
permute_info, expert_weights = routing_reference_renormalize(
routing_logits, top_k, num_experts, 8
)
elif routing_method_type == RoutingMethodType.RenormalizeNaive:
permute_info, expert_weights = routing_reference_renormalize_naive(
routing_logits, top_k, num_experts, 8
)
elif routing_method_type == RoutingMethodType.TopK:
permute_info, expert_weights = routing_reference_topk(
routing_logits, top_k, num_experts, 8
)
topk_ids = permute_info["topKIndices"].to(torch.int32)
expert_weights = expert_weights.view(num_tokens, num_experts)[
torch.arange(num_tokens).unsqueeze(1), topk_ids
].to(torch.bfloat16)

# Pack topk_ids and expert_weights into a single tensor
packed_tensor = (topk_ids.to(torch.int32) << 16) | expert_weights.to(
torch.bfloat16
).view(torch.int16)

Comment on lines +143 to +147
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# First, find the flashinfer/fused_moe/core.py file and search for routed API docstring
fd "core.py" --path "*fused_moe*" | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 283


🏁 Script executed:

# Also locate the test file to understand the context better
fd "test_trtllm_fp8_routed_fused_moe.py" | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 113


🏁 Script executed:

# Find core.py files
fd "core\.py" | grep -i moe

Repository: flashinfer-ai/flashinfer

Length of output: 96


🏁 Script executed:

# View the test file around lines 143-147 to get context
head -160 tests/moe/test_trtllm_fp8_routed_fused_moe.py | tail -30

Repository: flashinfer-ai/flashinfer

Length of output: 1119


🏁 Script executed:

# Check the flashinfer/fused_moe/core.py file for routed API docstring
cat flashinfer/fused_moe/core.py | head -150

Repository: flashinfer-ai/flashinfer

Length of output: 4391


🏁 Script executed:

# Search for routed_moe function and its documentation
rg "routed_moe|trtllm_fp8_block_scale_routed_moe" flashinfer/fused_moe/core.py -A 20 | head -80

Repository: flashinfer-ai/flashinfer

Length of output: 1557


🏁 Script executed:

# Search for the docstring of trtllm_fp8_block_scale_routed_moe function
rg -A 50 "def trtllm_fp8_block_scale_routed_moe" flashinfer/fused_moe/core.py | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 2741


🏁 Script executed:

# Also search for any documentation about bit layout or packing format
rg "upper.*bits|lower.*bits|score.*index|packed.*tensor" flashinfer/fused_moe/core.py -B 2 -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Search for kernel implementation or C++ code related to topk_ids packing
fd -e h -e cu | xargs rg "topk_ids|packed.*index|score.*bits" -l | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 148


🏁 Script executed:

# Search for any other test files that use trtllm_fp8_block_scale_routed_moe to see how they pack
rg "trtllm_fp8_block_scale_routed_moe" tests/ -l

Repository: flashinfer-ai/flashinfer

Length of output: 113


🏁 Script executed:

# Check the kernel launcher to see what it expects
cat csrc/trtllm_fused_moe_kernel_launcher.cu | head -200

Repository: flashinfer-ai/flashinfer

Length of output: 8174


🏁 Script executed:

# Search for bit shift or packing operations in the kernel
rg ">> 16|& 0xFFFF|topk" csrc/trtllm_fused_moe_kernel_launcher.cu -B 2 -A 2

Repository: flashinfer-ai/flashinfer

Length of output: 6658


🏁 Script executed:

# Search for the actual unpacking or usage of topk_ids in the kernel headers
fd -e h | xargs rg "topk_ids|extract.*index|extract.*score" -B 3 -A 3 | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Look at the RoutingKernel.h to see how topk is structured
cat flashinfer/include/flashinfer/trtllm/fused_moe/RoutingKernel.h 2>/dev/null | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Find runner.h and look for topk_ids or packing documentation
find . -name "runner.h" | xargs cat 2>/dev/null | head -200

Repository: flashinfer-ai/flashinfer

Length of output: 8394


🏁 Script executed:

# Search more broadly for bit manipulation patterns in fused_moe headers
find . -path "*/trtllm/fused_moe/*.h" -type f | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 324


🏁 Script executed:

# Check DevKernel.h for topk_ids packing/unpacking
cat ./include/flashinfer/trtllm/fused_moe/DevKernel.h | head -300

Repository: flashinfer-ai/flashinfer

Length of output: 20530


🏁 Script executed:

# Search for explicit bit operations on expert indices or weights
rg ">> 16|<< 16|0xFFFF" ./include/flashinfer/trtllm/fused_moe/ -B 3 -A 3

Repository: flashinfer-ai/flashinfer

Length of output: 2309


🏁 Script executed:

# Get more context around the packing/unpacking in the kernel
cat ./include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh | grep -A 15 "unpack"

Repository: flashinfer-ai/flashinfer

Length of output: 2310


🏁 Script executed:

# Also check if there are any other references to the packing format
rg "moveBits|compactTmp" ./include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh -B 5 -A 5 | head -80

Repository: flashinfer-ai/flashinfer

Length of output: 1537


Bit layout is reversed: test packs indices (upper) and weights (lower), but kernel expects weights (upper) and indices (lower).

The kernel implementation in RoutingKernelTopK.cuh packs as (value << 16) | (65535 - idx), placing weights in the upper 16 bits and indices in the lower 16 bits. The API docstring correctly states: "most significant 16 bits represent the score (bfloat16) and the least significant 16 bits represent the index."

However, the test at lines 143-147 reverses this: (topk_ids.to(torch.int32) << 16) | expert_weights... places indices in the upper bits and weights in the lower bits. This will cause incorrect expert routing. Swap the bit positions or add a helper function to ensure consistent packing across the codebase.

🤖 Prompt for AI Agents
In `@tests/moe/test_trtllm_fp8_routed_fused_moe.py` around lines 143 - 147, The
test currently packs topk_ids into the upper 16 bits and expert_weights into the
lower 16 bits (packed_tensor), but the kernel and docstring expect the opposite
(weights in upper 16 bits, indices in lower 16 bits, with index encoded as 65535
- idx). Change the packing in tests/moe/test_trtllm_fp8_routed_fused_moe.py so
packed_tensor places expert_weights (converted to bfloat16 and interpreted as
16-bit int) into the most significant 16 bits (shift left 16) and places the
transformed index (65535 - topk_ids, as a 16-bit value) in the least significant
16 bits; reference the variables packed_tensor, topk_ids, expert_weights and
align with RoutingKernelTopK.cuh's (value << 16) | (65535 - idx) layout.

# Run the routed version
output = trtllm_fp8_block_scale_routed_moe(
packed_tensor,
None, # routing_bias
hidden_states,
hidden_states_scale,
gemm1_weights,
gemm1_weights_scale,
gemm2_weights,
gemm2_weights_scale,
num_experts,
top_k,
None, # n_group
None, # topk_group
intermediate_size,
0, # local_expert_offset
num_experts,
None, # routed_scaling_factor
routing_method_type.value,
False, # use_shuffled_weight
0, # weight_layout
enable_pdl,
).to(torch.float)

# Compare outputs
mask = torch.isclose(output, reference_output, rtol=1e-3, atol=1e-3)

# mismatch percentage
mismatch_pct = (~mask).float().mean().item() * 100
assert mismatch_pct < 6, f"Mismatch percentage is {mismatch_pct:.2f}"
Loading