Skip to content

Commit ce60993

Browse files
committed
fix bf16 api
Signed-off-by: jiahanc <[email protected]>
1 parent cf7e8e3 commit ce60993

File tree

2 files changed

+16
-14
lines changed

2 files changed

+16
-14
lines changed

csrc/trtllm_fused_moe_kernel_launcher.cu

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,10 +1241,11 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher {
12411241
Tensor trtllm_bf16_moe(TensorView const& routing_logits, Optional<TensorView> const& routing_bias,
12421242
TensorView const& hidden_states, TensorView const& gemm1_weights,
12431243
TensorView const& gemm2_weights, int64_t num_experts, int64_t top_k,
1244-
int64_t n_group, int64_t topk_group, int64_t intermediate_size,
1245-
int64_t local_expert_offset, int64_t local_num_experts,
1246-
int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout,
1247-
bool enable_pdl, Array<int64_t> moe_tactic) {
1244+
Optional<int64_t> n_group, Optional<int64_t> topk_group,
1245+
int64_t intermediate_size, int64_t local_expert_offset,
1246+
int64_t local_num_experts, int64_t routing_method_type,
1247+
bool use_shuffled_weight, int64_t weight_layout, bool enable_pdl,
1248+
Array<int64_t> moe_tactic) {
12481249
// Just some basic type validation first and leave more checks to the launcher
12491250
TVM_FFI_ICHECK(routing_logits.dtype() == dl_float32 || routing_logits.dtype() == dl_bfloat16)
12501251
<< "BF16 MoE: routing_logits must be bfloat16 or float.";
@@ -1275,8 +1276,9 @@ Tensor trtllm_bf16_moe(TensorView const& routing_logits, Optional<TensorView> co
12751276
args->hidden_size = hidden_size;
12761277
args->hidden_size_output = args->hidden_size;
12771278
args->top_k = top_k;
1278-
args->n_group = n_group;
1279-
args->topk_group = topk_group;
1279+
args->n_group = n_group.value_or(0);
1280+
args->topk_group = topk_group.value_or(0);
1281+
;
12801282
args->local_expert_offset = local_expert_offset;
12811283
args->local_num_experts = local_num_experts;
12821284
args->intermediate_size = intermediate_size;

flashinfer/fused_moe/core.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,8 +1168,8 @@ def trtllm_bf16_moe_op(
11681168
gemm2_weights: torch.Tensor,
11691169
num_experts: int,
11701170
top_k: int,
1171-
n_group: int,
1172-
topk_group: int,
1171+
n_group: Optional[int],
1172+
topk_group: Optional[int],
11731173
intermediate_size: int,
11741174
local_expert_offset: int,
11751175
local_num_experts: int,
@@ -1268,8 +1268,8 @@ def _fake_trtllm_bf16_moe(
12681268
gemm2_weights: torch.Tensor,
12691269
num_experts: int,
12701270
top_k: int,
1271-
n_group: int,
1272-
topk_group: int,
1271+
n_group: Optional[int],
1272+
topk_group: Optional[int],
12731273
intermediate_size: int,
12741274
local_expert_offset: int,
12751275
local_num_experts: int,
@@ -1808,8 +1808,8 @@ def trtllm_bf16_moe(
18081808
gemm2_weights: torch.Tensor,
18091809
num_experts: int,
18101810
top_k: int,
1811-
n_group: int,
1812-
topk_group: int,
1811+
n_group: Optional[int],
1812+
topk_group: Optional[int],
18131813
intermediate_size: int,
18141814
local_expert_offset: int,
18151815
local_num_experts: int,
@@ -1867,8 +1867,8 @@ def trtllm_bf16_moe(
18671867
gemm2_weights,
18681868
num_experts,
18691869
top_k,
1870-
n_group or 0, # may receive None from test configs, convert to 0
1871-
topk_group or 0,
1870+
n_group,
1871+
topk_group,
18721872
intermediate_size,
18731873
local_expert_offset,
18741874
local_num_experts,

0 commit comments

Comments
 (0)