From 5fde7cb352c403962018880b8c84f5fe511e1762 Mon Sep 17 00:00:00 2001 From: phantomlei Date: Mon, 1 Dec 2025 17:42:18 +0800 Subject: [PATCH] refactor: breakdown fused moe kernel for deepseek all2all setup. --- xllm/core/kernels/mlu/fused_moe.cpp | 262 ++++------------------ xllm/core/kernels/mlu/group_gemm.cpp | 54 +++++ xllm/core/kernels/mlu/mlu_ops_api.h | 58 +++-- xllm/core/kernels/mlu/scaled_quantize.cpp | 4 +- xllm/core/kernels/ops_api.cpp | 102 ++++++--- xllm/core/kernels/ops_api.h | 11 +- xllm/core/kernels/param.h | 218 ++++++++++++------ xllm/core/layers/common/fused_moe.cpp | 253 +++++++++++++++++++-- xllm/core/layers/common/fused_moe.h | 15 ++ xllm/core/layers/common/indexer.h | 2 +- 10 files changed, 629 insertions(+), 350 deletions(-) create mode 100644 xllm/core/kernels/mlu/group_gemm.cpp diff --git a/xllm/core/kernels/mlu/fused_moe.cpp b/xllm/core/kernels/mlu/fused_moe.cpp index 49edeac89..9ab041a8f 100644 --- a/xllm/core/kernels/mlu/fused_moe.cpp +++ b/xllm/core/kernels/mlu/fused_moe.cpp @@ -17,240 +17,76 @@ limitations under the License. #include "mlu_ops_api.h" -namespace { -torch::Tensor create_group_gemm_output( - const torch::Tensor& a, - const torch::Tensor& b, - const torch::Tensor& group_list, - torch::ScalarType dtype = torch::ScalarType::BFloat16) { - torch::TensorOptions target_options = a.options().dtype(dtype); - if (b.dim() != 2) { - return torch::empty({a.size(0), b.size(1)}, target_options); - } - return torch::empty({group_list.size(0), a.size(0), b.size(0)}, - target_options); -} -} // namespace - namespace xllm::kernel::mlu { -torch::Tensor fused_moe( - const torch::Tensor& hidden_states, - const torch::Tensor& gating_output, - const torch::Tensor& w1, - const torch::Tensor& w2, - const std::optional& bias1, - const std::optional& bias2, - const std::optional& residual, - const std::optional& input_smooth, - const std::optional& act_smooth, - const std::optional& w1_scale, - const std::optional& w2_scale, - const std::optional& e_score_correction_bias, + +std::tuple moe_active_topk( + const torch::Tensor& input, int64_t topk, - bool renormalize, - bool gated, - const std::string& act_mode, - const std::string& scoring_func, int64_t num_expert_group, int64_t topk_group, + bool normalize, + const std::optional& mask, + const std::string& normed_by, + const std::string& scoring_func, double route_scale, - int64_t start_expert_id, - bool avg_moe, - const std::optional>& w1_quant_flag, - const std::optional>& w2_quant_flag) { - auto dtype = hidden_states.dtype(); - auto ori_input_shape = hidden_states.sizes(); - - auto hidden_states_2d = hidden_states.reshape({-1, hidden_states.size(-1)}); - int64_t tokens = hidden_states_2d.size(0); - auto gating_output_2d = gating_output.reshape({-1, gating_output.size(-1)}); - - std::optional residual_2d = std::nullopt; - if (residual.has_value()) { - residual_2d = residual.value().reshape({-1, residual.value().size(-1)}); - } - - // check smooth quant variables - bool all_present = input_smooth && act_smooth && w1_scale && w2_scale; - bool all_none = !input_smooth && !act_smooth && !w1_scale && !w2_scale; - CHECK(all_none || all_present) - << "input_smooth, act_smooth, w1_scale and w2_scale must be present or " - "absent at the same time."; - bool is_smoothquant = all_present; - int64_t expert_num = gating_output_2d.size(-1); - int64_t expert_size = w1.size(0); - - // apply softmax_topk or sigmoid_topk - auto reduce_weight = torch::empty( - {gating_output_2d.size(0), topk}, - torch::dtype(torch::kFloat).device(gating_output_2d.device())); - auto expert_id = torch::empty( - {gating_output_2d.size(0), topk}, - torch::dtype(torch::kInt32).device(gating_output_2d.device())); - - tmo::torch_api::moe_active_topk(gating_output_2d, + const std::optional& e_score_correction_bias) { + auto reduce_weight = + torch::empty({input.size(0), topk}, + torch::dtype(torch::kFloat).device(input.device())); + auto expert_id = + torch::empty({input.size(0), topk}, + torch::dtype(torch::kInt32).device(input.device())); + tmo::torch_api::moe_active_topk(input, topk, num_expert_group, topk_group, - renormalize, - /*mask=*/std::nullopt, - /*normed_by=*/"topk_logit", + normalize, + mask, + normed_by, scoring_func, route_scale, e_score_correction_bias, reduce_weight, expert_id); + return std::make_tuple(reduce_weight, expert_id); +} - auto output_vec = tmo::torch_api::moe_gen_idx(expert_id, expert_num); - auto expand_idx = output_vec[0]; - auto combine_idx = output_vec[1]; - auto token_count = output_vec[2]; - auto cusum_token_count = output_vec[3]; - - // prepare the parameters for the first group gemm - auto token_count_slice = - token_count.slice(0, start_expert_id, start_expert_id + expert_size); - auto gather_index_start_position = - cusum_token_count.index({start_expert_id}).unsqueeze(0); - torch::Tensor expand_hidden_states; - torch::Tensor input_scale; - - if (is_smoothquant) { - // w8a8 path: quantize input hidden states directly (fused with - // moe_expand_input) - std::tie(expand_hidden_states, input_scale) = - scaled_quantize(hidden_states_2d, // Use original hidden_states_2d - // instead of expand_hidden_states - input_smooth.value(), - /*zero=*/std::nullopt, - token_count_slice, - expand_idx, - gather_index_start_position, - /*output=*/std::nullopt, - /*output_scale=*/std::nullopt, - /*act_mode=*/"none", - /*active_coef=*/1.0, - /*is_gated=*/false, - /*quant_type=*/torch::kChar); - } else { - // bf16/fp32 path: expand input hidden states - expand_hidden_states = tmo::torch_api::moe_expand_input(hidden_states_2d, - expand_idx, - cusum_token_count, - start_expert_id, - expert_size); - } - - torch::Tensor gemm1_out = create_group_gemm_output( - expand_hidden_states, w1, token_count_slice, dtype.toScalarType()); - - // Unified group_gemm call using input_scale/w1_scale/quant_flag only if - // present - tmo::torch_api::group_gemm( - expand_hidden_states, - w1, - token_count_slice, - gemm1_out, - /*gather_idx=*/std::nullopt, - /*c=*/std::nullopt, - /*alpha=*/std::nullopt, - /*beta=*/std::nullopt, - /*a_scale=*/input_scale.defined() ? std::make_optional(input_scale) - : std::nullopt, - /*b_scale=*/w1_scale.has_value() ? std::make_optional(w1_scale.value()) - : std::nullopt, - /*bias=*/std::nullopt, - /*a_calibration=*/std::nullopt, - /*b_calibration=*/std::nullopt, - /*quant_flag=*/w1_quant_flag.has_value() ? w1_quant_flag : std::nullopt, - /*b_offset=*/std::nullopt, - /*tile_config=*/std::nullopt, - /*max_dim=*/tokens, - /*trans_a=*/false, - /*trans_b=*/true, - /*a_quant_bit=*/is_smoothquant ? 8 : -1); - - // prepare the parameters for the second group gemm - torch::Tensor act_out; - torch::Tensor act_out_scale; - if (is_smoothquant) { - // w8a8 path: reuse quantized_input and input_scale from first group_gemm - act_out = gated ? expand_hidden_states.slice(1, 0, gemm1_out.size(1) / 2) - : expand_hidden_states.slice(1, 0, gemm1_out.size(1)); - act_out_scale = input_scale.slice(0, 0, gemm1_out.size(0)); - - // Quantize gemm1_out directly (fused with active operation) using reused - // tensors - auto [quantized_activation, activation_scale] = - scaled_quantize(gemm1_out, - act_smooth.value(), - /*zero=*/std::nullopt, - /*token_count=*/token_count_slice, - /*gather_index=*/std::nullopt, - /*gather_index_start_position=*/std::nullopt, - act_out, // output - reuse from quantized_input - act_out_scale, // output_scale - reuse from input_scale - /*act_mode=*/act_mode, - /*active_coef=*/1.0, - /*is_gated=*/gated, - /*quant_type=*/torch::kChar); - act_out = quantized_activation; - act_out_scale = activation_scale; - } else { - // bf16/fp32 path: apply activation function first - act_out = gated ? gemm1_out.slice(1, 0, gemm1_out.size(1) / 2) : gemm1_out; - tmo::torch_api::active(gemm1_out, - act_out, - bias1, - cusum_token_count, - act_mode, - gated, - start_expert_id, - expert_size); - } - - torch::Tensor gemm2_out = create_group_gemm_output( - act_out, w2, token_count_slice, dtype.toScalarType()); +std::vector moe_gen_idx(const torch::Tensor& expert_id, + int64_t expert_num) { + return tmo::torch_api::moe_gen_idx(expert_id, expert_num); +} - // Unified group_gemm call, now only checks the existance of - // input_scale/w1_scale for smoothquant - tmo::torch_api::group_gemm( - act_out, - w2, - token_count_slice, - gemm2_out, - /*gather_idx=*/std::nullopt, - /*c=*/std::nullopt, - /*alpha=*/std::nullopt, - /*beta=*/std::nullopt, - act_out_scale.defined() ? std::make_optional(act_out_scale) - : std::nullopt, // a_scale - w2_scale.has_value() ? std::make_optional(w2_scale.value()) - : std::nullopt, // b_scale - /*bias=*/std::nullopt, - /*a_calibration=*/std::nullopt, - /*b_calibration=*/std::nullopt, - w2_quant_flag.has_value() ? w2_quant_flag : std::nullopt, // quant_flag - /*b_offset=*/std::nullopt, - /*tile_config=*/std::nullopt, - /*max_dim=*/tokens, - /*trans_a=*/false, - /*trans_b=*/true, - /*a_quant_bit=*/is_smoothquant ? 8 : -1); +torch::Tensor moe_expand_input( + const torch::Tensor& input, + const torch::Tensor& gather_index, + const std::optional& cusum_token_count, + int64_t start_expert_id, + int64_t expert_size) { + return tmo::torch_api::moe_expand_input( + input, gather_index, cusum_token_count, start_expert_id, expert_size); +} - auto output = torch::empty({reduce_weight.size(0), gemm2_out.size(1)}, - gemm2_out.options()); - tmo::torch_api::moe_combine_result(gemm2_out, +torch::Tensor moe_combine_result( + const torch::Tensor& input, + const torch::Tensor& reduce_weight, + const torch::Tensor& gather_ids, + const std::optional& residual, + const std::optional& cusum_token_count, + const int64_t start_expert_id, + const int64_t expert_size, + const std::optional& bias) { + auto output = + torch::empty({reduce_weight.size(0), input.size(1)}, input.options()); + tmo::torch_api::moe_combine_result(input, output, reduce_weight, - combine_idx, - residual_2d, + gather_ids, + residual, cusum_token_count, start_expert_id, expert_size, - bias2); - - return output.reshape(ori_input_shape); + bias); + return output; } } // namespace xllm::kernel::mlu diff --git a/xllm/core/kernels/mlu/group_gemm.cpp b/xllm/core/kernels/mlu/group_gemm.cpp new file mode 100644 index 000000000..fade635dd --- /dev/null +++ b/xllm/core/kernels/mlu/group_gemm.cpp @@ -0,0 +1,54 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlu_ops_api.h" + +namespace xllm::kernel::mlu { + +torch::Tensor group_gemm(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& token_count, + torch::Tensor& output, + const std::optional& a_scale, + const std::optional& b_scale, + const std::optional>& quant_flag, + const int64_t max_dim, + const bool trans_a, + const bool trans_b, + const int64_t a_quant_bit) { + tmo::torch_api::group_gemm(a, + b, + token_count, + output, + /*gather_idx=*/std::nullopt, + /*c=*/std::nullopt, + /*alpha=*/std::nullopt, + /*beta=*/std::nullopt, + a_scale, + b_scale, + /*bias=*/std::nullopt, + /*a_calibration=*/std::nullopt, + /*b_calibration=*/std::nullopt, + quant_flag, + /*b_offset=*/std::nullopt, + /*tile_config=*/std::nullopt, + max_dim, + trans_a, + trans_b, + a_quant_bit); + return output; +} + +} // namespace xllm::kernel::mlu diff --git a/xllm/core/kernels/mlu/mlu_ops_api.h b/xllm/core/kernels/mlu/mlu_ops_api.h index 442d1ad72..600db9335 100644 --- a/xllm/core/kernels/mlu/mlu_ops_api.h +++ b/xllm/core/kernels/mlu/mlu_ops_api.h @@ -138,31 +138,49 @@ torch::Tensor matmul(const torch::Tensor& a, double alpha, double beta); -torch::Tensor fused_moe( - const torch::Tensor& hidden_states, - const torch::Tensor& gating_output, - const torch::Tensor& w1, - const torch::Tensor& w2, - const std::optional& bias1, - const std::optional& bias2, - const std::optional& residual, - const std::optional& input_smooth, - const std::optional& act_smooth, - const std::optional& w1_scale, - const std::optional& w2_scale, - const std::optional& e_score_correction_bias, +torch::Tensor group_gemm(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& token_count, + torch::Tensor& output, + const std::optional& a_scale, + const std::optional& b_scale, + const std::optional>& quant_flag, + const int64_t max_dim, + const bool trans_a, + const bool trans_b, + const int64_t a_quant_bit); + +std::tuple moe_active_topk( + const torch::Tensor& input, int64_t topk, - bool renormalize, - bool gated, - const std::string& act_mode, - const std::string& scoring_func, int64_t num_expert_group, int64_t topk_group, + bool normalize, + const std::optional& mask, + const std::string& normed_by, + const std::string& scoring_func, double route_scale, + const std::optional& e_score_correction_bias); + +std::vector moe_gen_idx(const torch::Tensor& expert_id, + int64_t expert_num); + +torch::Tensor moe_expand_input( + const torch::Tensor& input, + const torch::Tensor& gather_index, + const std::optional& cusum_token_count, int64_t start_expert_id, - bool avg_moe, - const std::optional>& w1_quant_flag, - const std::optional>& w2_quant_flag); + int64_t expert_size); + +torch::Tensor moe_combine_result( + const torch::Tensor& input, + const torch::Tensor& reduce_weight, + const torch::Tensor& gather_ids, + const std::optional& residual, + const std::optional& cusum_token_count, + const int64_t start_expert_id, + const int64_t expert_size, + const std::optional& bias); std::tuple scaled_quantize( const torch::Tensor& x, diff --git a/xllm/core/kernels/mlu/scaled_quantize.cpp b/xllm/core/kernels/mlu/scaled_quantize.cpp index 1e53b123b..73c18e9d0 100644 --- a/xllm/core/kernels/mlu/scaled_quantize.cpp +++ b/xllm/core/kernels/mlu/scaled_quantize.cpp @@ -62,14 +62,14 @@ std::tuple scaled_quantize( if (output.has_value()) { result_output = output.value(); } else { - result_output = at::empty(output_shape, x.options().dtype(quant_type)); + result_output = torch::empty(output_shape, x.options().dtype(quant_type)); } if (output_scale.has_value()) { result_output_scale = output_scale.value(); } else { result_output_scale = - at::empty(output_scale_shape, x.options().dtype(at::kFloat)); + torch::empty(output_scale_shape, x.options().dtype(at::kFloat)); } // Call underlying MLU kernel diff --git a/xllm/core/kernels/ops_api.cpp b/xllm/core/kernels/ops_api.cpp index 3d694f7e4..5634c5cdc 100644 --- a/xllm/core/kernels/ops_api.cpp +++ b/xllm/core/kernels/ops_api.cpp @@ -207,36 +207,84 @@ torch::Tensor matmul(MatmulParams& params) { #endif } -torch::Tensor fused_moe(FusedMoEParams& params) { +torch::Tensor group_gemm(GroupGemmParams& params) { #if defined(USE_MLU) - return mlu::fused_moe(params.hidden_states, - params.gating_output, - params.w1, - params.w2, - params.bias1, - params.bias2, - params.residual, - params.input_smooth, - params.act_smooth, - params.w1_scale, - params.w2_scale, - params.e_score_correction_bias, - params.topk, - params.renormalize, - params.gated, - params.act_mode, - params.scoring_func, - params.num_expert_group, - params.topk_group, - params.route_scale, - params.start_expert_id, - params.avg_moe, - params.w1_quant_flag, - params.w2_quant_flag); + return mlu::group_gemm(params.a, + params.b, + params.token_count, + params.output, + params.a_scale, + params.b_scale, + params.quant_flag, + params.max_dim, + params.trans_a, + params.trans_b, + params.a_quant_bit); #elif defined(USE_CUDA) - LOG(FATAL) << "fused_moe for cuda not implemented"; + LOG(FATAL) << "group_gemm for cuda not implemented"; #else - LOG(FATAL) << "fused_moe not implemented"; + LOG(FATAL) << "group_gemm not implemented"; +#endif +} + +std::tuple moe_active_topk( + MoeActiveTopkParams& params) { +#if defined(USE_MLU) + return mlu::moe_active_topk(params.input, + params.topk, + params.num_expert_group, + params.topk_group, + params.normalize, + params.mask, + params.normed_by, + params.scoring_func, + params.route_scale, + params.e_score_correction_bias); +#elif defined(USE_CUDA) + LOG(FATAL) << "moe_active_topk for cuda not implemented"; +#else + LOG(FATAL) << "moe_active_topk not implemented"; +#endif +} + +std::vector moe_gen_idx(MoeGenIdxParams& params) { +#if defined(USE_MLU) + return mlu::moe_gen_idx(params.expert_id, params.expert_num); +#elif defined(USE_CUDA) + LOG(FATAL) << "moe_gen_idx for cuda not implemented"; +#else + LOG(FATAL) << "moe_gen_idx not implemented"; +#endif +} + +torch::Tensor moe_expand_input(MoeExpandInputParams& params) { +#if defined(USE_MLU) + return mlu::moe_expand_input(params.input, + params.gather_index, + params.cusum_token_count, + params.start_expert_id, + params.expert_size); +#elif defined(USE_CUDA) + LOG(FATAL) << "moe_expand_input for cuda not implemented"; +#else + LOG(FATAL) << "moe_expand_input not implemented"; +#endif +} + +torch::Tensor moe_combine_result(MoeCombineResultParams& params) { +#if defined(USE_MLU) + return mlu::moe_combine_result(params.input, + params.reduce_weight, + params.gather_ids, + params.residual, + params.cusum_token_count, + params.start_expert_id, + params.expert_size, + params.bias); +#elif defined(USE_CUDA) + LOG(FATAL) << "moe_combine_result for cuda not implemented"; +#else + LOG(FATAL) << "moe_combine_result not implemented"; #endif } diff --git a/xllm/core/kernels/ops_api.h b/xllm/core/kernels/ops_api.h index 4250b70a1..4f17659f6 100644 --- a/xllm/core/kernels/ops_api.h +++ b/xllm/core/kernels/ops_api.h @@ -38,7 +38,16 @@ void fused_layernorm(FusedLayerNormParams& params); torch::Tensor matmul(MatmulParams& params); -torch::Tensor fused_moe(FusedMoEParams& params); +torch::Tensor group_gemm(GroupGemmParams& params); + +std::tuple moe_active_topk( + MoeActiveTopkParams& params); + +std::vector moe_gen_idx(MoeGenIdxParams& params); + +torch::Tensor moe_expand_input(MoeExpandInputParams& params); + +torch::Tensor moe_combine_result(MoeCombineResultParams& params); std::tuple scaled_quantize( ScaledQuantizeParams& params); diff --git a/xllm/core/kernels/param.h b/xllm/core/kernels/param.h index 0eb963fe0..2be0e88aa 100644 --- a/xllm/core/kernels/param.h +++ b/xllm/core/kernels/param.h @@ -371,76 +371,164 @@ struct MatmulParams { double beta = 0.0; }; -// Fused MoE parameters -struct FusedMoEParams { - // Input hidden states tensor. Will be reshaped to 2D [tokens, hidden_size] - // internally. tokens = hidden_states.numel() / hidden_states.size(-1) - torch::Tensor hidden_states; - // Gating output tensor for expert selection. Will be reshaped to 2D [tokens, - // expert_num]. expert_num = gating_output.size(-1) - torch::Tensor gating_output; - // First weight matrix W1. Shape: [expert_size, ...]. expert_size = - // w1.size(0). Used in first group_gemm operation (trans_b=true). - torch::Tensor w1; - // Second weight matrix W2. Shape: [expert_size, ...]. - // Used in second group_gemm operation (trans_b=true). - torch::Tensor w2; - // Optional bias for first activation. - std::optional bias1; - // Optional bias for output combination. - std::optional bias2; - // Optional residual tensor. Will be reshaped to 2D [tokens, hidden_size] - // internally. Added to final output after MoE combine result. - std::optional residual; - // Optional input smooth quantization scale. For smooth quant mode. - // Must be present together with act_smooth, w1_scale, w2_scale. - // Used to quantize hidden_states before first group_gemm. - std::optional input_smooth; - // Optional activation smooth quantization scale. For smooth quant mode. - // Must be present together with input_smooth, w1_scale, w2_scale. - // Used to quantize gemm1_out after activation. - std::optional act_smooth; - // Optional W1 quantization scale. For smooth quant mode. - // Must be present together with input_smooth, act_smooth, w2_scale. - // Used in first group_gemm as b_scale. - std::optional w1_scale; - // Optional W2 quantization scale. For smooth quant mode. - // Must be present together with input_smooth, act_smooth, w1_scale. - // Used in second group_gemm as b_scale. - std::optional w2_scale; - // Optional expert score correction bias. - std::optional e_score_correction_bias; +struct GroupGemmParams { + // Input activation tensor. + // Shape: 2D [M, K] if trans_a==false; [K, M] if trans_a==true. + // Must be contiguous. Dtype: float16, bfloat16, or float32. + // Must have same dtype and device as b, output. + torch::Tensor a; + // Weight tensor. + // If trans_b is true, shape is (num_experts, N, K) or (N, K); + // if trans_b is false, shape is (num_experts, K, N) or (K, N). + // Must be contiguous. Dtype and device must match a, output. + torch::Tensor b; + // Per-expert token count tensor. + // Shape: 1D [num_experts]. Type must be int32. + // Controls number of tokens processed per group/expert. + torch::Tensor token_count; + // Output tensor. + // Shape: [num_experts, N] or [num_experts, N, K]. num_experts = + // token_count.size(0). Must be contiguous. Dtype and device must match a. + torch::Tensor output; + // Optional scale tensor for a (input activation), used in quantized mode. + // Shape depends on quantization granularity. + std::optional a_scale; + // Optional scale tensor for b (weight), used in quantized mode. + // Shape depends on quantization granularity. + std::optional b_scale; + // Optional quantization config flag list. + // Used to control per-expert weight quantization mode. + std::optional> quant_flag; + // Maximum workspace dimension (e.g., maximum tokens per expert allowed). + // Used for configuring inner kernel workspace. + int64_t max_dim; + // Whether to transpose a: + // false: [M, K] (default); true: [K, M]. + bool trans_a; + // Whether to transpose b: + // false: [K, N] (default); true: [N, K]. + bool trans_b; + // Quantization bit-width for input a. + // Set -1 to disable quantization. + int64_t a_quant_bit; +}; + +struct MoeActiveTopkParams { + // Input tensor. + // Shape: [*, num_mask, num_expert] (e.g., [batch, num_mask, num_expert]). + // Dtype: float32, float16, bfloat16. + // Must be contiguous. + torch::Tensor input; // Number of top-k experts to select per token. + // Constraint: 0 < topk <= num_expert. int64_t topk; + // Number of expert groups for group-limited top-k selection. + // If > 1, mask must be None, and num_expert % num_expert_group == 0. + int64_t num_expert_group; + // Maximum selected experts per group. + // Constraint: 0 < topk_group <= num_expert_group. + int64_t topk_group; // Whether to renormalize expert weights after top-k selection. - bool renormalize; - // Whether to use gated activation. If true, activation output shape is - // halved. - bool gated; - // Activation mode string. - // Supported: "none", "gelu", "silu". - std::string act_mode = "none"; - // Scoring function for expert selection. Default: "softmax". + bool normalize; + // Optional mask tensor. + // Shape: [1, ..., 1, num_mask, num_expert] (leading dims must be 1). + // Dtype must match input. + // Must be contiguous. + std::optional mask; + // Normalization logic after top-k selection. + // For softmax: "topk_logit" or "softmax_logit". + // For sigmoid: "topk_logit" or "sigmoid_logit". + std::string normed_by; + // Scoring function for expert selection. // Supported: "softmax", "sigmoid". - std::string scoring_func = "softmax"; - // Number of expert groups. Default: -1. - int64_t num_expert_group = -1; - // Top-k group parameter. Default: 0. - int64_t topk_group = 0; - // Route scaling factor. Default: 1.0. - double route_scale = 1.0; - // Starting expert ID. Used to slice token_count and cusum_token_count. - // Processing range: [start_expert_id, start_expert_id + expert_size). + std::string scoring_func; + // Route scaling factor applied to routing scores. + double route_scale; + // Optional expert score correction bias. + // Shape: [num_expert]. + // Dtype: float32, float16, or bfloat16. + // Must be contiguous. + std::optional e_score_correction_bias; +}; + +struct MoeGenIdxParams { + // The input tensor stores the expert id of each token. + // Shape: [num_tokens, topk]. + // Dtype: int32. + torch::Tensor expert_id; + // Expert number. + // Must be >= 0. + int64_t expert_num; +}; + +struct MoeExpandInputParams { + // Input tensor to be expanded. + // Shape: [token_num, hidden_size]. + // Dtype: int8, float, half, or bfloat16. + torch::Tensor input; + // Index tensor for gather operation. + // Shape: [expand_token_num]. + // Dtype: int32. + torch::Tensor gather_index; + // Optional prefix sum of token count per expert. + // Shape: [num_experts + 1]. + // Dtype: int32. + // If provided, adjusts gather range for each expert. + std::optional cusum_token_count; + // Starting expert id to process. + // Must be >= 0. + int64_t start_expert_id; + // Number of experts to process in this call. + // Must be >= 0. + int64_t expert_size; +}; + +struct MoeCombineResultParams { + // Expert output tensor to be combined. + // Shape: [num_tokens * topk, hidden_size]. + // - Must be contiguous. + // - Dtype: float32, float16, or bfloat16. + // - This is the concatenated output from all experts, not yet reordered back + // to the original sequence order. + torch::Tensor input; + // Router/gating weights tensor. Used for weighted combination of expert + // outputs. Shape: [num_tokens, topk]. + // - Must be contiguous at last dimension. + // - Dtype: float32. + // - Constraint: reduce_weight.numel() == input.size(0). + torch::Tensor reduce_weight; + // Gather index tensor that maps combined output to original token positions. + // Shape: [num_tokens * topk]. + // - Must be contiguous. + // - Dtype: int32. + // - Corresponds to permutation/scatter indices for reordering expert outputs. + torch::Tensor gather_ids; + // Optional residual connection input. + // Shape: [num_tokens, hidden_size]. + // - Must have same shape and dtype as output if provided. + // - Must be contiguous if provided. + // - Default: std::nullopt (no residual). + std::optional residual; + // Optional cumulative token count for expert assignment. + // Shape: [num_experts + 1] or deduced by expert_size. + // - Must be contiguous if provided. + // - Dtype: int32. + // - Used to infer num_expert or assist calculation in some kernels. + std::optional cusum_token_count; + // Starting expert ID + // - Must be >= 0. + // - Used to mark the offset of current experts being processed (for + // sharding). int64_t start_expert_id = 0; - // Enforce every expert get equal number of tokens - // This option has not implemented yet. - bool avg_moe = false; - // Optional quantization flag list for W1. - // Used in first group gemm. - std::optional> w1_quant_flag; - // Optional quantization flag list for W2. - // Used in second group gemm. - std::optional> w2_quant_flag; + // Number of experts processed in this step. + // - If cusum_token_count not given, num_expert is set to this value. + // - If cusum_token_count given, deduced num_expert must satisfy: + // num_expert >= start_expert_id + expert_size + int64_t expert_size = 0; + // Optional bias tensor. + // WARNING: Bias addition is NOT supported in current implementation. + // Always keep as std::nullopt unless bias support is added in the future. + std::optional bias; }; // Per token smooth quantize parameters diff --git a/xllm/core/layers/common/fused_moe.cpp b/xllm/core/layers/common/fused_moe.cpp index fbb51a294..9b1a12077 100644 --- a/xllm/core/layers/common/fused_moe.cpp +++ b/xllm/core/layers/common/fused_moe.cpp @@ -20,6 +20,21 @@ limitations under the License. #include "framework/parallel_state/parallel_state.h" #include "kernels/ops_api.h" +namespace { +torch::Tensor create_group_gemm_output( + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& group_list, + torch::ScalarType dtype = torch::ScalarType::BFloat16) { + torch::TensorOptions target_options = a.options().dtype(dtype); + if (b.dim() != 2) { + return torch::empty({a.size(0), b.size(1)}, target_options); + } + return torch::empty({group_list.size(0), a.size(0), b.size(0)}, + target_options); +} +} // namespace + namespace xllm { namespace layer { @@ -168,6 +183,94 @@ FusedMoEImpl::FusedMoEImpl(int64_t num_experts, } } +torch::Tensor FusedMoEImpl::select_experts( + const torch::Tensor& hidden_states_2d, + const torch::Tensor& router_logits_2d, + SelectedExpertInfo& selected_expert_info) { + // prepare the parameters for select_experts + std::optional e_score_correction_bias = std::nullopt; + if (e_score_correction_bias_.defined()) { + e_score_correction_bias = e_score_correction_bias_; + } + int64_t expert_size = w13_.size(0); + + // Step 1: apply softmax topk or sigmoid topk / routing logic + torch::Tensor reduce_weight; + torch::Tensor expert_id; + { + xllm::kernel::MoeActiveTopkParams moe_active_topk_params; + moe_active_topk_params.input = router_logits_2d; + moe_active_topk_params.topk = topk_; + moe_active_topk_params.num_expert_group = num_expert_group_; + moe_active_topk_params.topk_group = topk_group_; + moe_active_topk_params.normalize = renormalize_; + moe_active_topk_params.normed_by = "topk_logit"; + moe_active_topk_params.scoring_func = scoring_func_; + moe_active_topk_params.route_scale = route_scale_; + moe_active_topk_params.e_score_correction_bias = e_score_correction_bias; + std::tie(reduce_weight, expert_id) = + xllm::kernel::moe_active_topk(moe_active_topk_params); + } + + // Step 2: generate expert ids + torch::Tensor gather_idx; + torch::Tensor combine_idx; + torch::Tensor token_count; + torch::Tensor cusum_token_count; + { + xllm::kernel::MoeGenIdxParams moe_gen_idx_params; + moe_gen_idx_params.expert_id = expert_id; + moe_gen_idx_params.expert_num = router_logits_2d.size(-1); + std::vector output_vec = + xllm::kernel::moe_gen_idx(moe_gen_idx_params); + gather_idx = output_vec[0]; + combine_idx = output_vec[1]; + token_count = output_vec[2]; + cusum_token_count = output_vec[3]; + } + + // Step 3: expand and quantize input if needed + torch::Tensor expand_hidden_states; + torch::Tensor hidden_states_scale; + torch::Tensor token_count_slice = + token_count.slice(0, start_expert_id_, start_expert_id_ + expert_size); + if (is_smoothquant_) { + xllm::kernel::ScaledQuantizeParams scaled_quantize_params; + scaled_quantize_params.x = hidden_states_2d; + scaled_quantize_params.smooth = input_smooth_; + scaled_quantize_params.token_count = token_count_slice; + scaled_quantize_params.gather_index = gather_idx; + scaled_quantize_params.gather_index_start_position = + cusum_token_count.index({start_expert_id_}).unsqueeze(0); + scaled_quantize_params.act_mode = "none"; + scaled_quantize_params.active_coef = 1.0; + scaled_quantize_params.is_gated = false; + scaled_quantize_params.quant_type = torch::kChar; + std::tie(expand_hidden_states, hidden_states_scale) = + xllm::kernel::scaled_quantize(scaled_quantize_params); + } else { + xllm::kernel::MoeExpandInputParams moe_expand_input_params; + moe_expand_input_params.input = hidden_states_2d; + moe_expand_input_params.gather_index = gather_idx; + moe_expand_input_params.cusum_token_count = cusum_token_count; + moe_expand_input_params.start_expert_id = start_expert_id_; + moe_expand_input_params.expert_size = expert_size; + expand_hidden_states = + xllm::kernel::moe_expand_input(moe_expand_input_params); + } + + // collect the selected tensor + selected_expert_info.reduce_weight = reduce_weight; + selected_expert_info.combine_idx = combine_idx; + selected_expert_info.token_count_slice = token_count_slice; + selected_expert_info.cusum_token_count = cusum_token_count; + if (is_smoothquant_) { + selected_expert_info.input_scale = hidden_states_scale; + } + + return expand_hidden_states; +} + torch::Tensor FusedMoEImpl::forward_expert( const torch::Tensor& hidden_states, const torch::Tensor& router_logits, @@ -177,30 +280,138 @@ torch::Tensor FusedMoEImpl::forward_expert( e_score_correction_bias = e_score_correction_bias_; } - xllm::kernel::FusedMoEParams fused_moe_params; - fused_moe_params.hidden_states = hidden_states; - fused_moe_params.gating_output = router_logits; - fused_moe_params.w1 = w13_; - fused_moe_params.w2 = w2_; - fused_moe_params.residual = shared_output; - fused_moe_params.num_expert_group = num_expert_group_; - fused_moe_params.topk_group = topk_group_; - fused_moe_params.route_scale = route_scale_; - fused_moe_params.e_score_correction_bias = e_score_correction_bias; - fused_moe_params.topk = topk_; - fused_moe_params.renormalize = renormalize_; - fused_moe_params.gated = is_gated_; - fused_moe_params.act_mode = hidden_act_; - fused_moe_params.scoring_func = scoring_func_; - fused_moe_params.start_expert_id = start_expert_id_; + // prepare the parameters for MoE computation + torch::IntArrayRef hidden_states_shape = hidden_states.sizes(); + torch::ScalarType hidden_states_dtype = hidden_states.dtype().toScalarType(); + torch::Tensor hidden_states_2d = + hidden_states.reshape({-1, hidden_states.size(-1)}); + torch::Tensor router_logits_2d = + router_logits.reshape({-1, router_logits.size(-1)}); + int64_t group_gemm_max_dim = hidden_states_2d.size(0); + int64_t expert_size = w13_.size(0); + + // Step 1-3: select experts + SelectedExpertInfo selected_expert_info; + torch::Tensor expand_hidden_states = + select_experts(hidden_states_2d, router_logits_2d, selected_expert_info); + + // Step 4: group gemm 1 + torch::Tensor gemm1_out = + create_group_gemm_output(expand_hidden_states, + w13_, + selected_expert_info.token_count_slice, + hidden_states_dtype); + // ensure the lifespan of these parameters via brace + { + xllm::kernel::GroupGemmParams group_gemm_params; + group_gemm_params.a = expand_hidden_states; + group_gemm_params.b = w13_; + group_gemm_params.token_count = selected_expert_info.token_count_slice; + if (is_smoothquant_) { + group_gemm_params.a_scale = selected_expert_info.input_scale; + group_gemm_params.b_scale = w13_scale_; + } + group_gemm_params.max_dim = group_gemm_max_dim; + group_gemm_params.trans_a = false; + group_gemm_params.trans_b = true; + group_gemm_params.a_quant_bit = is_smoothquant_ ? 8 : -1; + group_gemm_params.output = gemm1_out; + gemm1_out = xllm::kernel::group_gemm(group_gemm_params); + } + + // Step 5: activation or scaled quantization(fused with activation) + torch::Tensor act_out; + torch::Tensor act_out_scale; if (is_smoothquant_) { - fused_moe_params.w1_scale = w13_scale_; - fused_moe_params.w2_scale = w2_scale_; - fused_moe_params.input_smooth = input_smooth_; - fused_moe_params.act_smooth = act_smooth_; + int64_t slice_dim = gemm1_out.size(1); + if (is_gated_) slice_dim /= 2; + // slice operation is a view, does not take up extra memory, but points to + // the same memory + act_out = expand_hidden_states.slice(1, 0, slice_dim); + act_out_scale = + selected_expert_info.input_scale.value().slice(0, 0, gemm1_out.size(0)); + // call scaled quantization kernel (also fused with activation) + xllm::kernel::ScaledQuantizeParams scaled_quantize_params; + scaled_quantize_params.x = gemm1_out; + scaled_quantize_params.smooth = act_smooth_; + scaled_quantize_params.token_count = selected_expert_info.token_count_slice; + scaled_quantize_params.output = act_out; + scaled_quantize_params.output_scale = act_out_scale; + scaled_quantize_params.act_mode = hidden_act_; + scaled_quantize_params.active_coef = 1.0; + scaled_quantize_params.is_gated = is_gated_; + scaled_quantize_params.quant_type = torch::kChar; + std::tie(act_out, act_out_scale) = + xllm::kernel::scaled_quantize(scaled_quantize_params); + } else { + act_out = + is_gated_ ? gemm1_out.slice(1, 0, gemm1_out.size(1) / 2) : gemm1_out; + // call activation kernel + xllm::kernel::ActivationParams activation_params; + activation_params.input = gemm1_out; + activation_params.output = act_out; + activation_params.cusum_token_count = + selected_expert_info.cusum_token_count; + activation_params.act_mode = hidden_act_; + activation_params.is_gated = is_gated_; + activation_params.start_expert_id = start_expert_id_; + activation_params.expert_size = expert_size; + xllm::kernel::active(activation_params); + } + + // Step 6: group gemm 2 + torch::Tensor gemm2_out = + create_group_gemm_output(act_out, + w2_, + selected_expert_info.token_count_slice, + hidden_states_dtype); + // ensure the lifespan of these parameters via brace + { + xllm::kernel::GroupGemmParams group_gemm_params; + group_gemm_params.a = act_out; + group_gemm_params.b = w2_; + group_gemm_params.token_count = selected_expert_info.token_count_slice; + if (is_smoothquant_) { + group_gemm_params.a_scale = act_out_scale; + group_gemm_params.b_scale = w2_scale_; + } + group_gemm_params.max_dim = group_gemm_max_dim; + group_gemm_params.trans_a = false; + group_gemm_params.trans_b = true; + group_gemm_params.a_quant_bit = is_smoothquant_ ? 8 : -1; + group_gemm_params.output = gemm2_out; + gemm2_out = xllm::kernel::group_gemm(group_gemm_params); + } + // After group gemm is finished, expand_hidden_states and input_scale are no + // longer needed. We must explicitly release the memory. + expand_hidden_states = torch::Tensor(); + selected_expert_info.input_scale = std::nullopt; + + // Step 7: combine the intermediate results and get the final hidden states + torch::Tensor final_hidden_states; + // ensure the lifespan of these parameters via brace + { + xllm::kernel::MoeCombineResultParams moe_combine_result_params; + moe_combine_result_params.input = gemm2_out; + moe_combine_result_params.reduce_weight = + selected_expert_info.reduce_weight; + moe_combine_result_params.gather_ids = selected_expert_info.combine_idx; + moe_combine_result_params.cusum_token_count = + selected_expert_info.cusum_token_count; + moe_combine_result_params.start_expert_id = start_expert_id_; + moe_combine_result_params.expert_size = expert_size; + moe_combine_result_params.bias = std::nullopt; + // make sure residual fits the requirements of moe_combine_result + if (shared_output.has_value()) { + moe_combine_result_params.residual = + shared_output.value().reshape({-1, shared_output.value().size(-1)}); + } + final_hidden_states = + xllm::kernel::moe_combine_result(moe_combine_result_params); } - auto final_hidden_states = xllm::kernel::fused_moe(fused_moe_params); + // reshape the final hidden states to the original shape + final_hidden_states = final_hidden_states.reshape(hidden_states_shape); if (tp_pg_->world_size() > 1) { final_hidden_states = parallel_state::reduce(final_hidden_states, tp_pg_); diff --git a/xllm/core/layers/common/fused_moe.h b/xllm/core/layers/common/fused_moe.h index f50a6124c..25f217f03 100644 --- a/xllm/core/layers/common/fused_moe.h +++ b/xllm/core/layers/common/fused_moe.h @@ -59,6 +59,21 @@ class FusedMoEImpl : public torch::nn::Module { const ModelInputParams& input_params); void load_state_dict(const StateDict& state_dict); + private: + // struct to store the selected expert info + struct SelectedExpertInfo { + torch::Tensor reduce_weight; + torch::Tensor combine_idx; + torch::Tensor token_count_slice; + torch::Tensor cusum_token_count; + std::optional input_scale; + }; + + // initial steps for MoE computation, select the experts for each token + torch::Tensor select_experts(const torch::Tensor& hidden_states_2d, + const torch::Tensor& router_logits_2d, + SelectedExpertInfo& selected_expert_info); + private: int64_t topk_; int64_t num_expert_group_; diff --git a/xllm/core/layers/common/indexer.h b/xllm/core/layers/common/indexer.h index c45788c00..992fc9749 100644 --- a/xllm/core/layers/common/indexer.h +++ b/xllm/core/layers/common/indexer.h @@ -25,7 +25,7 @@ limitations under the License. #include "../mlu/attention.h" #elif defined(USE_CUDA) #include "../cuda/attention.h" -#endif #include "framework/kv_cache/kv_cache.h" +#endif #include "framework/model/model_input_params.h" #include "framework/parallel_state/parallel_args.h" #include "framework/quant_args.h"