Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion sgl-kernel/csrc/common_extension.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);

m.def(
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk) -> "
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int "
"n_share_experts_fusion, float routed_scaling_factor) -> "
"(Tensor[])");
m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate);

Expand Down
83 changes: 70 additions & 13 deletions sgl-kernel/csrc/moe/moe_fused_gate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ __device__ void moe_fused_gate_impl(
int64_t num_rows,
int64_t topk_group,
int64_t topk,
int64_t n_share_experts_fusion,
double routed_scaling_factor,
Params params) {
int tidx = threadIdx.x;
int64_t thread_row =
Expand Down Expand Up @@ -210,21 +212,38 @@ __device__ void moe_fused_gate_impl(
indices_ptr[idx] = static_cast<int32_t>(expert);
}

// accumulate sum
if (thread_group_idx == 0) {
// accumulate sum for first k-1 elements only
if (thread_group_idx == 0 && k_idx < topk - 1) {
output_sum += output_ptr[idx];
}
}

__syncthreads();
}

if (thread_group_idx == 0) {
int64_t last_idx = topk * thread_row + (topk - 1);

if (n_share_experts_fusion > 0) {
// Use round-robin to select expert
int64_t expert_offset = thread_row % n_share_experts_fusion;
indices_ptr[last_idx] = static_cast<int32_t>(params.NUM_EXPERTS + expert_offset);

// Set the weight to the sum of the first k-1 weights divided by routed_scaling_factor
output_ptr[last_idx] = output_sum / routed_scaling_factor;
} else {
// If not using shared experts, add the last weight to output_sum
output_sum += output_ptr[last_idx];
}
}
__syncthreads();

////////////////////// Rescale Output //////////////////////
if (thread_group_idx == 0) {
#pragma unroll
for (int ii = 0; ii < topk; ++ii) {
int64_t const idx = topk * thread_row + ii;
output_ptr[idx] = static_cast<float>(static_cast<T>(output_ptr[idx]) / static_cast<T>(output_sum));
output_ptr[idx] = output_ptr[idx] / output_sum;
}
}
}
Expand Down Expand Up @@ -257,9 +276,21 @@ __global__ void moe_fused_gate_kernel(
int32_t* indices_ptr,
int64_t num_rows,
int64_t topk_group,
int64_t topk) {
int64_t topk,
int64_t n_share_experts_fusion,
double routed_scaling_factor) {
KernelParams<VPT, NUM_EXPERTS, THREADS_PER_ROW, ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> params;
moe_fused_gate_impl<T>(input, bias, output_ptr, indices_ptr, num_rows, topk_group, topk, params);
moe_fused_gate_impl<T>(
input,
bias,
output_ptr,
indices_ptr,
num_rows,
topk_group,
topk,
n_share_experts_fusion,
routed_scaling_factor,
params);
}

// Macro to compute compile-time constants and launch the kernel.
Expand All @@ -277,7 +308,9 @@ __global__ void moe_fused_gate_kernel(
indices.data_ptr<int32_t>(), \
num_rows, \
topk_group, \
topk); \
topk, \
n_share_experts_fusion, \
routed_scaling_factor); \
dispatched = true; \
} while (0)

Expand All @@ -303,7 +336,9 @@ __global__ void moe_fused_gate_kernel_dynamic(
int64_t num_experts,
int64_t num_expert_group,
int64_t topk_group,
int64_t topk) {
int64_t topk,
int64_t n_share_experts_fusion,
double routed_scaling_factor) {
KernelParamsDynamic params;
params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256
params.VPT = num_experts / num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32
Expand All @@ -312,14 +347,30 @@ __global__ void moe_fused_gate_kernel_dynamic(
params.ROWS_PER_WARP = std::max<int64_t>(1, WARP_SIZE / num_expert_group); // WARP_SIZE is fixed as 32
params.ROWS_PER_CTA = params.WARPS_PER_CTA * params.ROWS_PER_WARP;

moe_fused_gate_impl<T>(input, bias, output_ptr, indices_ptr, num_rows, topk_group, topk, params);
moe_fused_gate_impl<T>(
input,
bias,
output_ptr,
indices_ptr,
num_rows,
topk_group,
topk,
n_share_experts_fusion,
routed_scaling_factor,
params);
}

//------------------------------------------------------------------------------
// Host Launcher Function
//------------------------------------------------------------------------------
std::vector<at::Tensor>
moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, int64_t topk_group, int64_t topk) {
std::vector<at::Tensor> moe_fused_gate(
at::Tensor& input,
at::Tensor& bias,
int64_t num_expert_group,
int64_t topk_group,
int64_t topk,
int64_t n_share_experts_fusion,
double routed_scaling_factor) {
int64_t num_rows = input.size(0);
int32_t num_experts = input.size(1);
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
Expand Down Expand Up @@ -416,7 +467,9 @@ moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, in
num_experts,
num_expert_group,
topk_group,
topk);
topk,
n_share_experts_fusion,
routed_scaling_factor);
} else if (input.scalar_type() == at::kHalf) {
moe_fused_gate_kernel_dynamic<float16_t><<<num_blocks, block_dim, 0, stream>>>(
input.data_ptr(),
Expand All @@ -427,7 +480,9 @@ moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, in
num_experts,
num_expert_group,
topk_group,
topk);
topk,
n_share_experts_fusion,
routed_scaling_factor);
} else if (input.scalar_type() == at::kFloat) {
moe_fused_gate_kernel_dynamic<float32_t><<<num_blocks, block_dim, 0, stream>>>(
input.data_ptr(),
Expand All @@ -438,7 +493,9 @@ moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, in
num_experts,
num_expert_group,
topk_group,
topk);
topk,
n_share_experts_fusion,
routed_scaling_factor);
} else {
TORCH_CHECK(false, "Unsupported data type for moe_fused_gate");
}
Expand Down
10 changes: 8 additions & 2 deletions sgl-kernel/include/sgl_kernel_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,14 @@ void topk_softmax(
torch::Tensor& token_expert_indices,
torch::Tensor& gating_output);

std::vector<at::Tensor>
moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, int64_t topk_group, int64_t topk);
std::vector<at::Tensor> moe_fused_gate(
at::Tensor& input,
at::Tensor& bias,
int64_t num_expert_group,
int64_t topk_group,
int64_t topk,
int64_t n_share_experts_fusion,
double routed_scaling_factor);

/*
* From csrc/speculative
Expand Down
20 changes: 18 additions & 2 deletions sgl-kernel/python/sgl_kernel/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,29 @@ def topk_softmax(
)


def moe_fused_gate(input_tensor, bias, num_expert_group, topk_group, topk):
def moe_fused_gate(
input_tensor,
bias,
num_expert_group,
topk_group,
topk,
n_share_experts_fusion=0,
routed_scaling_factor=0,
):
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
# it split group of expert into num_expert_group, and use top2 expert weight sum in each group
# as the group weight to select exerpt groups and then select topk experts within the selected groups
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limitted for now.
# for non-supported case, we suggestion to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
# n_share_experts_fusion: if > 0, the last expert will be replaced with a round-robin shared expert
# routed_scaling_factor: if > 0, the last expert will be scaled by this factor
return torch.ops.sgl_kernel.moe_fused_gate.default(
input_tensor, bias, num_expert_group, topk_group, topk
input_tensor,
bias,
num_expert_group,
topk_group,
topk,
n_share_experts_fusion,
routed_scaling_factor,
)
36 changes: 31 additions & 5 deletions sgl-kernel/tests/test_moe_fused_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,24 @@
(512, 16, 8, 16),
],
)
def test_moe_fused_gate_combined(seq_length, dtype, params):
@pytest.mark.parametrize("n_share_experts_fusion", [0, 8])
def test_moe_fused_gate_combined(seq_length, dtype, params, n_share_experts_fusion):
num_experts, num_expert_group, topk_group, topk = params

torch.manual_seed(seq_length)
tensor = torch.rand((seq_length, num_experts)).to(dtype).cuda()
scores = tensor.clone()
bias = torch.rand(num_experts).to(dtype).cuda()
topk = topk + min(1, n_share_experts_fusion)

output, indices = moe_fused_gate(
tensor,
bias,
num_expert_group=num_expert_group,
topk_group=topk_group,
topk=topk,
n_share_experts_fusion=n_share_experts_fusion,
routed_scaling_factor=2.5,
)
ref_output, ref_indices = biased_grouped_topk(
scores,
Expand All @@ -43,8 +47,30 @@ def test_moe_fused_gate_combined(seq_length, dtype, params):
num_expert_group=num_expert_group,
topk_group=topk_group,
compiled=False,
n_share_experts_fusion=n_share_experts_fusion,
)

# When n_share_experts_fusion > 0, ignore the comparison of the last topk dimension
if n_share_experts_fusion > 0:
original_indices = indices.clone()
original_ref_indices = ref_indices.clone()

indices = indices[:, :-1]
ref_indices = ref_indices[:, :-1]

valid_min = num_experts
valid_max = num_experts + n_share_experts_fusion
shared_indices = original_indices[:, -1]
shared_ref_indices = original_ref_indices[:, -1]
if shared_indices is not None:
assert torch.all(
(shared_indices >= valid_min) & (shared_indices < valid_max)
), f"Shared expert indices out of range: found values outside [{valid_min}, {valid_max})"
if shared_ref_indices is not None:
assert torch.all(
(shared_ref_indices >= valid_min) & (shared_ref_indices < valid_max)
), f"Shared expert reference indices out of range: found values outside [{valid_min}, {valid_max})"

idx_check = torch.allclose(
ref_indices.sort()[0].to(torch.int32),
indices.sort()[0].to(torch.int32),
Expand All @@ -54,17 +80,17 @@ def test_moe_fused_gate_combined(seq_length, dtype, params):
output_check = torch.allclose(
ref_output.sort()[0].to(torch.float32),
output.sort()[0].to(torch.float32),
rtol=1e-04,
atol=1e-05,
rtol=1e-02,
atol=1e-03,
)

assert idx_check, (
f"Indices mismatch at seq_length {seq_length}, dtype {dtype}, "
f"params {params}"
f"params {params}, n_share_experts_fusion {n_share_experts_fusion}"
)
assert output_check, (
f"Output mismatch at seq_length {seq_length}, dtype {dtype}, "
f"params {params}"
f"params {params}, n_share_experts_fusion {n_share_experts_fusion}"
)


Expand Down
Loading