Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 25 additions & 19 deletions csrc/trtllm_fused_moe_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1479,14 +1479,16 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher {
}
};

Tensor trtllm_bf16_moe(TensorView const& routing_logits, Optional<TensorView> const& routing_bias,
TensorView const& hidden_states, TensorView const& gemm1_weights,
TensorView const& gemm2_weights, int64_t num_experts, int64_t top_k,
Optional<int64_t> n_group, Optional<int64_t> topk_group,
int64_t intermediate_size, int64_t local_expert_offset,
int64_t local_num_experts, Optional<double> routed_scaling_factor,
int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout,
bool enable_pdl, Array<int64_t> moe_tactic) {
Array<Tensor> trtllm_bf16_moe(TensorView const& routing_logits,
Optional<TensorView> const& routing_bias,
TensorView const& hidden_states, TensorView const& gemm1_weights,
TensorView const& gemm2_weights, int64_t num_experts, int64_t top_k,
Optional<int64_t> n_group, Optional<int64_t> topk_group,
int64_t intermediate_size, int64_t local_expert_offset,
int64_t local_num_experts, Optional<double> routed_scaling_factor,
int64_t routing_method_type, bool use_shuffled_weight,
int64_t weight_layout, bool do_finalize, bool enable_pdl,
Array<int64_t> moe_tactic) {
// Just some basic type validation first and leave more checks to the launcher
TVM_FFI_ICHECK(routing_logits.dtype() == dl_float32 || routing_logits.dtype() == dl_bfloat16)
<< "BF16 MoE: routing_logits must be bfloat16 or float.";
Expand Down Expand Up @@ -1523,6 +1525,7 @@ Tensor trtllm_bf16_moe(TensorView const& routing_logits, Optional<TensorView> co
args->local_expert_offset = local_expert_offset;
args->local_num_experts = local_num_experts;
args->intermediate_size = intermediate_size;
args->do_finalize = do_finalize;

// Create and initialize launcher for this tile size
auto launcher = std::make_unique<Bf16MoeLauncher>(routing_logits, routing_bias, hidden_states,
Expand All @@ -1546,19 +1549,19 @@ Tensor trtllm_bf16_moe(TensorView const& routing_logits, Optional<TensorView> co
auto& selected_launcher = launchers_map.at(tile_N);

// Run the launcher - it will create its own runner internally
auto result = selected_launcher->run(config, enable_pdl)[0];
auto result = selected_launcher->run(config, enable_pdl);
return result;
}

Tensor trtllm_fp8_per_tensor_scale_moe(
Array<Tensor> trtllm_fp8_per_tensor_scale_moe(
TensorView routing_logits, Optional<TensorView> routing_bias, TensorView hidden_states,
TensorView gemm1_weights, TensorView output1_scales_scalar,
TensorView output1_scales_gate_scalar, TensorView gemm2_weights,
TensorView output2_scales_scalar, TensorView output, int64_t num_experts, int64_t top_k,
Optional<int64_t> n_group, Optional<int64_t> topk_group, int64_t intermediate_size,
int64_t local_expert_offset, int64_t local_num_experts, Optional<double> routed_scaling_factor,
bool use_routing_scales_on_input, int64_t routing_method_type, bool enable_pdl,
Array<int64_t> config_index, int64_t activation_type) {
bool use_routing_scales_on_input, int64_t routing_method_type, bool do_finalize,
bool enable_pdl, Array<int64_t> config_index, int64_t activation_type) {
// Basic type validation
auto dtype = hidden_states.dtype();
auto activation = static_cast<ActivationType>(activation_type);
Expand Down Expand Up @@ -1612,6 +1615,7 @@ Tensor trtllm_fp8_per_tensor_scale_moe(
args->local_num_experts = local_num_experts;
args->intermediate_size = intermediate_size;
args->routed_scaling_factor = routed_scaling_factor.value_or(1.0);
args->do_finalize = do_finalize;

// Create and initialize launcher for this tile size
auto launcher = std::make_unique<Fp8PerTensorLauncher>(
Expand All @@ -1636,20 +1640,20 @@ Tensor trtllm_fp8_per_tensor_scale_moe(
auto& selected_launcher = launchers_map.at(tile_N);

// Run the launcher - it will create its own runner internally
auto result = selected_launcher->run(config, enable_pdl, use_routing_scales_on_input)[0];
auto result = selected_launcher->run(config, enable_pdl, use_routing_scales_on_input);
// Return the result tensor
return result;
}

Tensor trtllm_fp8_block_scale_moe(
Array<Tensor> trtllm_fp8_block_scale_moe(
Optional<TensorView> routing_logits, TensorView expert_indices, TensorView expert_weights,
Optional<TensorView> routing_bias, TensorView hidden_states, TensorView hidden_states_scale,
TensorView gemm1_weights, TensorView gemm1_weights_scale, TensorView gemm2_weights,
TensorView gemm2_weights_scale, TensorView output, int64_t num_experts, int64_t top_k,
Optional<int64_t> n_group, Optional<int64_t> topk_group, int64_t intermediate_size,
int64_t local_expert_offset, int64_t local_num_experts, Optional<double> routed_scaling_factor,
int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout, bool enable_pdl,
Array<int64_t> config_index) {
int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout, bool do_finalize,
bool enable_pdl, Array<int64_t> config_index) {
// Basic type validation
auto dtype = hidden_states.dtype();

Expand Down Expand Up @@ -1709,6 +1713,7 @@ Tensor trtllm_fp8_block_scale_moe(
args->local_num_experts = local_num_experts;
args->intermediate_size = intermediate_size;
args->routed_scaling_factor = routed_scaling_factor.value_or(1.0);
args->do_finalize = do_finalize;

// Create and initialize launcher for this tile size
auto launcher = std::make_unique<Fp8BlockScaleLauncher>(
Expand All @@ -1734,7 +1739,7 @@ Tensor trtllm_fp8_block_scale_moe(

// Run the launcher with DeepSeek FP8 enabled - it will create its own runner internally
auto result = selected_launcher->run(config, enable_pdl, false /* use_routing_scales_on_input */,
true /* use_deep_seek_fp8 */)[0];
true /* use_deep_seek_fp8 */);
// Return the result tensor
return result;
}
Expand Down Expand Up @@ -2010,8 +2015,9 @@ Array<Array<int64_t>> trtllm_get_valid_moe_configs(
}

TVM_FFI_LOG_AND_THROW(NotImplementedError)
<< "Unsupported data type combination for getValidConfigs: " << "dtype_act="
<< static_cast<int>(dtype_act) << ", dtype_weights=" << static_cast<int>(dtype_weights)
<< "Unsupported data type combination for getValidConfigs: "
<< "dtype_act=" << static_cast<int>(dtype_act)
<< ", dtype_weights=" << static_cast<int>(dtype_weights)
<< ", useDeepSeekFp8=" << useDeepSeekFp8;

// Unreachable code - added to suppress compiler warning
Expand Down
Loading
Loading