Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 49 additions & 38 deletions csrc/trtllm_fused_moe_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ class FusedMoeLauncher {
if (args->do_finalize) {
return {output};
}
return {gemm2_output, expert_weights, expanded_idx_to_permuted_idx};
return {gemm2_output, expanded_idx_to_permuted_idx};
}
};

Expand Down Expand Up @@ -511,9 +511,11 @@ class Bf16MoeLauncher : public FusedMoeLauncher {
workspace.gemm2_output = gemm2_output.data_ptr();
workspace.gemm2_output_scale = nullptr;

output =
alloc_tensor({args->num_tokens, args->hidden_size}, dl_bfloat16, hidden_states.device());
args->output = output.data_ptr();
if (args->output == nullptr) {
output =
alloc_tensor({args->num_tokens, args->hidden_size}, dl_bfloat16, hidden_states.device());
args->output = output.data_ptr();
}
args->output_scale = nullptr;
}

Expand Down Expand Up @@ -685,11 +687,12 @@ class Fp8PerTensorLauncher : public FusedMoeLauncher {
workspace.gemm2_output = gemm2_output.data_ptr();
workspace.gemm2_output_scale = nullptr;

output =
alloc_tensor({args->num_tokens, args->hidden_size}, dl_bfloat16, hidden_states.device());
args->output = output.data_ptr();
if (args->output == nullptr) {
output =
alloc_tensor({args->num_tokens, args->hidden_size}, dl_bfloat16, hidden_states.device());
args->output = output.data_ptr();
}
args->output_scale = nullptr;
args->do_finalize = true; // FP8 per-tensor scale always finalizes

// Set scale pointers
TVM_FFI_ICHECK(output1_scales_scalar.has_value());
Expand Down Expand Up @@ -1012,11 +1015,12 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
workspace.gemm2_output = gemm2_output.data_ptr();
workspace.gemm2_output_scale = nullptr;

output =
alloc_tensor({args->num_tokens, args->hidden_size}, dl_bfloat16, hidden_states.device());
args->output = output.data_ptr();
if (args->output == nullptr) {
output =
alloc_tensor({args->num_tokens, args->hidden_size}, dl_bfloat16, hidden_states.device());
args->output = output.data_ptr();
}
args->output_scale = nullptr;
args->do_finalize = true;

args->hidden_states_scale = static_cast<float*>(hidden_states_scale.data_ptr());
args->gemm1_weights_scale = static_cast<float*>(gemm1_weights_scale.data_ptr());
Expand Down Expand Up @@ -1536,7 +1540,7 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher {

// Match original FP4 behavior for return values
if (args->do_finalize) {
return {};
return {output};
}
return {gemm2_output, expanded_idx_to_permuted_idx};
}
Expand Down Expand Up @@ -1570,14 +1574,16 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher {
}
};

Tensor trtllm_bf16_moe(TensorView const& routing_logits, Optional<TensorView> const& routing_bias,
TensorView const& hidden_states, TensorView const& gemm1_weights,
TensorView const& gemm2_weights, int64_t num_experts, int64_t top_k,
Optional<int64_t> n_group, Optional<int64_t> topk_group,
int64_t intermediate_size, int64_t local_expert_offset,
int64_t local_num_experts, Optional<double> routed_scaling_factor,
int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout,
bool enable_pdl, Array<int64_t> moe_tactic) {
Array<Tensor> trtllm_bf16_moe(TensorView const& routing_logits,
Optional<TensorView> const& routing_bias,
TensorView const& hidden_states, TensorView const& gemm1_weights,
TensorView const& gemm2_weights, TensorView output,
int64_t num_experts, int64_t top_k, Optional<int64_t> n_group,
Optional<int64_t> topk_group, int64_t intermediate_size,
int64_t local_expert_offset, int64_t local_num_experts,
Optional<double> routed_scaling_factor, int64_t routing_method_type,
bool use_shuffled_weight, int64_t weight_layout, bool do_finalize,
bool enable_pdl, Array<int64_t> moe_tactic) {
// Just some basic type validation first and leave more checks to the launcher
TVM_FFI_ICHECK(routing_logits.dtype() == dl_float32 || routing_logits.dtype() == dl_bfloat16)
<< "BF16 MoE: routing_logits must be bfloat16 or float.";
Expand Down Expand Up @@ -1614,6 +1620,9 @@ Tensor trtllm_bf16_moe(TensorView const& routing_logits, Optional<TensorView> co
args->local_expert_offset = local_expert_offset;
args->local_num_experts = local_num_experts;
args->intermediate_size = intermediate_size;
args->do_finalize = do_finalize;
args->output = output.data_ptr();
args->output_scale = nullptr;

// Create and initialize launcher for this tile size
auto launcher = std::make_unique<Bf16MoeLauncher>(routing_logits, routing_bias, hidden_states,
Expand All @@ -1637,19 +1646,18 @@ Tensor trtllm_bf16_moe(TensorView const& routing_logits, Optional<TensorView> co
auto& selected_launcher = launchers_map.at(tile_N);

// Run the launcher - it will create its own runner internally
auto result = selected_launcher->run(config, enable_pdl)[0];
return result;
return selected_launcher->run(config, enable_pdl);
}

Tensor trtllm_fp8_per_tensor_scale_moe(
Array<Tensor> trtllm_fp8_per_tensor_scale_moe(
TensorView routing_logits, Optional<TensorView> routing_bias, TensorView hidden_states,
TensorView gemm1_weights, TensorView output1_scales_scalar,
TensorView output1_scales_gate_scalar, TensorView gemm2_weights,
TensorView output2_scales_scalar, TensorView output, int64_t num_experts, int64_t top_k,
Optional<int64_t> n_group, Optional<int64_t> topk_group, int64_t intermediate_size,
int64_t local_expert_offset, int64_t local_num_experts, Optional<double> routed_scaling_factor,
bool use_routing_scales_on_input, int64_t routing_method_type, bool enable_pdl,
Array<int64_t> config_index, int64_t activation_type) {
bool use_routing_scales_on_input, int64_t routing_method_type, bool do_finalize,
bool enable_pdl, Array<int64_t> config_index, int64_t activation_type) {
// Basic type validation
auto dtype = hidden_states.dtype();
auto activation = static_cast<ActivationType>(activation_type);
Expand Down Expand Up @@ -1703,6 +1711,9 @@ Tensor trtllm_fp8_per_tensor_scale_moe(
args->local_num_experts = local_num_experts;
args->intermediate_size = intermediate_size;
args->routed_scaling_factor = routed_scaling_factor.value_or(1.0);
args->do_finalize = do_finalize;
args->output = output.data_ptr();
args->output_scale = nullptr;

// Create and initialize launcher for this tile size
auto launcher = std::make_unique<Fp8PerTensorLauncher>(
Expand All @@ -1727,20 +1738,18 @@ Tensor trtllm_fp8_per_tensor_scale_moe(
auto& selected_launcher = launchers_map.at(tile_N);

// Run the launcher - it will create its own runner internally
auto result = selected_launcher->run(config, enable_pdl, use_routing_scales_on_input)[0];
// Return the result tensor
return result;
return selected_launcher->run(config, enable_pdl, use_routing_scales_on_input);
}

Tensor trtllm_fp8_block_scale_moe(
Array<Tensor> trtllm_fp8_block_scale_moe(
Optional<TensorView> routing_logits, TensorView expert_indices, TensorView expert_weights,
Optional<TensorView> routing_bias, TensorView hidden_states, TensorView hidden_states_scale,
TensorView gemm1_weights, TensorView gemm1_weights_scale, TensorView gemm2_weights,
TensorView gemm2_weights_scale, TensorView output, int64_t num_experts, int64_t top_k,
Optional<int64_t> n_group, Optional<int64_t> topk_group, int64_t intermediate_size,
int64_t local_expert_offset, int64_t local_num_experts, Optional<double> routed_scaling_factor,
int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout, bool enable_pdl,
Array<int64_t> config_index, Fp8QuantizationType quantization_type) {
int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout, bool do_finalize,
bool enable_pdl, Array<int64_t> config_index, Fp8QuantizationType quantization_type) {
// Basic type validation
auto dtype = hidden_states.dtype();

Expand Down Expand Up @@ -1822,6 +1831,9 @@ Tensor trtllm_fp8_block_scale_moe(
args->local_num_experts = local_num_experts;
args->intermediate_size = intermediate_size;
args->routed_scaling_factor = routed_scaling_factor.value_or(1.0);
args->do_finalize = do_finalize;
args->output = output.data_ptr();
args->output_scale = nullptr;

// Create and initialize launcher for this tile size
auto launcher = std::make_unique<Fp8BlockScaleLauncher>(
Expand All @@ -1847,11 +1859,9 @@ Tensor trtllm_fp8_block_scale_moe(
auto& selected_launcher = launchers_map.at(tile_N);

// Run the launcher with DeepSeek FP8 enabled - it will create its own runner internally
auto result = selected_launcher->run(
return selected_launcher->run(
config, enable_pdl, false /* use_routing_scales_on_input */,
quantization_type == Fp8QuantizationType::DeepSeekFp8 /* use_deep_seek_fp8 */)[0];
// Return the result tensor
return result;
quantization_type == Fp8QuantizationType::DeepSeekFp8 /* use_deep_seek_fp8 */);
}

Array<Tensor> trtllm_fp4_block_scale_moe(
Expand Down Expand Up @@ -2004,7 +2014,8 @@ Array<Tensor> trtllm_mxint4_block_scale_moe(
TensorView gemm2_weights, TensorView gemm2_weights_scale, int64_t num_experts, int64_t top_k,
Optional<int64_t> n_group, Optional<int64_t> topk_group, int64_t intermediate_size,
int64_t local_expert_offset, int64_t local_num_experts, Optional<double> routed_scaling_factor,
int64_t routing_method_type, bool enable_pdl, TensorView output, Array<int64_t> config_index) {
int64_t routing_method_type, bool do_finalize, bool enable_pdl, TensorView output,
Array<int64_t> config_index) {
// Determine data types based on input format
int const num_tokens = hidden_states.size(0);
int hidden_size = hidden_states.size(1);
Expand Down Expand Up @@ -2055,7 +2066,7 @@ Array<Tensor> trtllm_mxint4_block_scale_moe(
args->local_num_experts = local_num_experts;
args->intermediate_size = intermediate_size;
args->routed_scaling_factor = routed_scaling_factor.value_or(1.0);
args->do_finalize = true;
args->do_finalize = do_finalize;
args->output = output.data_ptr();
args->output_scale = nullptr;

Expand Down
Loading
Loading