Skip to content

Commit 55f95f0

Browse files
committed
fix autotune and apply fp8 auto tune to test
Signed-off-by: jiahanc <[email protected]>
1 parent 0cd4848 commit 55f95f0

File tree

2 files changed

+74
-54
lines changed

2 files changed

+74
-54
lines changed

flashinfer/fused_moe/core.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -926,6 +926,15 @@ def __init__(
926926
self.gated_act_type = GatedActType(gated_act_type)
927927
self.use_shuffled_weight = use_shuffled_weight
928928
self.weight_layout = WeightLayout(weight_layout)
929+
if (
930+
not self.use_shuffled_weight
931+
or self.weight_layout != WeightLayout.MajorK
932+
):
933+
assert (
934+
self.use_deepseek_fp8 and self.dtype_weights == DtypeTrtllmGen.E4m3
935+
), (
936+
"use_shuffled_weight is False or weight_layout is not MajorK is only supported for FP8 block scale"
937+
)
929938

930939
def get_valid_tactics(
931940
self,
@@ -1022,7 +1031,7 @@ def forward(
10221031
dtype=torch.float,
10231032
device=hidden_states.device,
10241033
)
1025-
return moe_op.trtllm_fp8_block_scale_moe(
1034+
moe_op.trtllm_fp8_block_scale_moe(
10261035
routing_logits,
10271036
kwargs["routing_bias"],
10281037
hidden_states,
@@ -1031,6 +1040,7 @@ def forward(
10311040
kwargs["gemm1_weights_scale"],
10321041
kwargs["gemm2_weights"],
10331042
kwargs["gemm2_weights_scale"],
1043+
output,
10341044
kwargs["num_experts"],
10351045
self.top_k,
10361046
kwargs["n_group"],
@@ -1047,7 +1057,7 @@ def forward(
10471057
)
10481058
else:
10491059
# FP8 per tensor scale
1050-
return moe_op.trtllm_fp8_per_tensor_scale_moe(
1060+
moe_op.trtllm_fp8_per_tensor_scale_moe(
10511061
routing_logits,
10521062
kwargs["routing_bias"],
10531063
hidden_states,
@@ -1056,6 +1066,7 @@ def forward(
10561066
kwargs["output1_scales_gate_scalar"],
10571067
kwargs["gemm2_weights"],
10581068
kwargs["output2_scales_scalar"],
1069+
output,
10591070
kwargs["num_experts"],
10601071
self.top_k,
10611072
kwargs["n_group"],
@@ -1188,6 +1199,8 @@ def trtllm_fp8_per_tensor_scale_moe_op(
11881199
use_deepseek_fp8=False, # per_tensor mode
11891200
hidden_size=hidden_size,
11901201
intermediate_size=intermediate_size,
1202+
weight_layout=WeightLayout.MajorK,
1203+
use_shuffled_weight=True,
11911204
)
11921205

11931206
inputs = [output, routing_logits, topk_ids, expert_weights, hidden_states]
@@ -1203,6 +1216,7 @@ def trtllm_fp8_per_tensor_scale_moe_op(
12031216
output1_scales_gate_scalar=output1_scales_gate_scalar,
12041217
gemm2_weights=gemm2_weights,
12051218
output2_scales_scalar=output2_scales_scalar,
1219+
num_experts=num_experts,
12061220
n_group=n_group,
12071221
topk_group=topk_group,
12081222
local_expert_offset=local_expert_offset,
@@ -1325,6 +1339,8 @@ def trtllm_fp8_block_scale_moe_op(
13251339
use_deepseek_fp8=True, # block_scale mode
13261340
hidden_size=hidden_size,
13271341
intermediate_size=intermediate_size,
1342+
weight_layout=weight_layout,
1343+
use_shuffled_weight=use_shuffled_weight,
13281344
)
13291345

13301346
inputs = [
@@ -1346,6 +1362,7 @@ def trtllm_fp8_block_scale_moe_op(
13461362
gemm1_weights_scale=gemm1_weights_scale,
13471363
gemm2_weights=gemm2_weights,
13481364
gemm2_weights_scale=gemm2_weights_scale,
1365+
num_experts=num_experts,
13491366
n_group=n_group,
13501367
topk_group=topk_group,
13511368
local_expert_offset=local_expert_offset,
@@ -1498,9 +1515,8 @@ def trtllm_fp4_block_scale_moe_op(
14981515
hidden_size=hidden_size,
14991516
intermediate_size=intermediate_size,
15001517
gated_act_type=gated_act_type,
1501-
# NOTE(siyuan): do not fix the tile_tokens_dim to let tunnable runner decide the tile_tokens_dim itself.
1502-
# however, when the user chooses a different heuristic for tile_tokens_dim, the autotuner will fail to find the correct cached tactics.
1503-
# tile_tokens_dim=tile_tokens_dim,
1518+
weight_layout=WeightLayout.MajorK,
1519+
use_shuffled_weight=True,
15041520
)
15051521
tunning_config = (
15061522
MoERunner.tuning_config_no_hidden_states_scales

tests/moe/test_trtllm_gen_fused_moe.py

Lines changed: 53 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -769,29 +769,31 @@ def call_moe(
769769
"NaN detected in hidden_states_fp8"
770770
)
771771

772-
output = trtllm_fp8_block_scale_moe(
773-
expert_logits,
774-
routing_bias,
775-
hidden_states_fp8,
776-
hidden_states_scale,
777-
static_data["gemm1_weights"],
778-
static_data["gemm1_scales"],
779-
static_data["gemm2_weights"],
780-
static_data["gemm2_scales"],
781-
num_experts,
782-
top_k,
783-
n_groups,
784-
top_k_groups,
785-
intermediate_size,
786-
0,
787-
num_experts,
788-
routed_scaling,
789-
None,
790-
routing_method_type,
791-
use_shuffled_weight=static_data["use_shuffled_weight"],
792-
weight_layout=static_data["weight_layout"],
793-
enable_pdl=enable_pdl,
794-
)
772+
# Use autotuner for optimal kernel selection
773+
with autotune(True):
774+
output = trtllm_fp8_block_scale_moe(
775+
expert_logits,
776+
routing_bias,
777+
hidden_states_fp8,
778+
hidden_states_scale,
779+
static_data["gemm1_weights"],
780+
static_data["gemm1_scales"],
781+
static_data["gemm2_weights"],
782+
static_data["gemm2_scales"],
783+
num_experts,
784+
top_k,
785+
n_groups,
786+
top_k_groups,
787+
intermediate_size,
788+
0,
789+
num_experts,
790+
routed_scaling,
791+
None,
792+
routing_method_type,
793+
use_shuffled_weight=static_data["use_shuffled_weight"],
794+
weight_layout=static_data["weight_layout"],
795+
enable_pdl=enable_pdl,
796+
)
795797

796798
return output.to(torch.float)
797799

@@ -940,32 +942,34 @@ def call_moe(
940942
hidden_states_orig, hidden_states_scale_global
941943
)
942944

943-
output = trtllm_fp8_per_tensor_scale_moe(
944-
(
945-
expert_logits.to(torch.bfloat16)
946-
if routing_method_type == RoutingMethodType.Llama4
947-
else expert_logits
948-
),
949-
routing_bias,
950-
hidden_states_fp8,
951-
static_data["gemm1_weights"],
952-
static_data["scale_c_fc1"],
953-
static_data["scale_gate_fc1"],
954-
static_data["gemm2_weights"],
955-
static_data["scale_c_fc2"],
956-
num_experts,
957-
top_k,
958-
n_groups,
959-
top_k_groups,
960-
intermediate_size,
961-
0,
962-
num_experts,
963-
routed_scaling,
964-
routing_method_type
965-
== RoutingMethodType.Llama4, # Use_routing_scales_on_input
966-
None,
967-
routing_method_type,
968-
)
945+
# Use autotuner for optimal kernel selection
946+
with autotune(True):
947+
output = trtllm_fp8_per_tensor_scale_moe(
948+
(
949+
expert_logits.to(torch.bfloat16)
950+
if routing_method_type == RoutingMethodType.Llama4
951+
else expert_logits
952+
),
953+
routing_bias,
954+
hidden_states_fp8,
955+
static_data["gemm1_weights"],
956+
static_data["scale_c_fc1"],
957+
static_data["scale_gate_fc1"],
958+
static_data["gemm2_weights"],
959+
static_data["scale_c_fc2"],
960+
num_experts,
961+
top_k,
962+
n_groups,
963+
top_k_groups,
964+
intermediate_size,
965+
0,
966+
num_experts,
967+
routed_scaling,
968+
routing_method_type
969+
== RoutingMethodType.Llama4, # Use_routing_scales_on_input
970+
None,
971+
routing_method_type,
972+
)
969973

970974
return output.to(torch.float)
971975

0 commit comments

Comments
 (0)