Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions sgl-kernel/benchmark/bench_moe_ep_post_reorder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import torch
import triton
from sgl_kernel import ep_moe_post_reorder

from sglang.srt.layers.moe.ep_moe.kernels import post_reorder_triton_kernel

batch_sizes = [64, 128, 256, 512, 640, 768, 1024, 2048, 4096]
configs = [(bs,) for bs in batch_sizes]


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["cuda", "triton"],
line_names=["CUDA Kernel", "Triton Kernel"],
styles=[("green", "-"), ("orange", "-")],
ylabel="us",
plot_name="ep-moe-post-reorder-performance",
args={},
)
)
def benchmark(batch_size, provider):
dtype = torch.bfloat16
device = torch.device("cuda")
hidden_size, topk, start_expert_id, end_expert_id, block_size = 4096, 8, 0, 255, 512

def alloc_tensors():
down_output = torch.randn(
batch_size * topk, hidden_size, dtype=dtype, device=device
)
output = torch.zeros(batch_size, hidden_size, dtype=dtype, device=device)
src2dst = torch.randint(
0, batch_size * topk, (batch_size, topk), dtype=torch.int32, device=device
)
topk_ids = torch.randint(
start_expert_id,
end_expert_id + 1,
(batch_size, topk),
dtype=torch.int32,
device=device,
)
topk_weights = torch.rand(batch_size, topk, dtype=dtype, device=device)
return down_output, output, src2dst, topk_ids, topk_weights

quantiles = [0.5, 0.2, 0.8]

if provider == "cuda":

def run_cuda():
d_out, out, s2d, tk_ids, tk_weights = alloc_tensors()
ep_moe_post_reorder(
d_out,
out,
s2d,
tk_ids,
tk_weights,
start_expert_id,
end_expert_id,
topk,
)

ms, min_ms, max_ms = triton.testing.do_bench(run_cuda, quantiles=quantiles)

elif provider == "triton":

def run_triton():
d_out, out, s2d, tk_ids, tk_weights = alloc_tensors()
post_reorder_triton_kernel[(batch_size,)](
d_out.view(-1),
out.view(-1),
s2d.view(-1),
tk_ids.view(-1),
tk_weights.view(-1),
start_expert_id,
end_expert_id,
topk,
hidden_size,
block_size,
)

ms, min_ms, max_ms = triton.testing.do_bench(run_triton, quantiles=quantiles)

else:
raise ValueError(f"Unknown provider: {provider}")

return 1000 * ms, 1000 * max_ms, 1000 * min_ms


if __name__ == "__main__":
benchmark.run(print_data=True)
8 changes: 6 additions & 2 deletions sgl-kernel/csrc/common_extension.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,13 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"(Tensor[])");
m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate);
m.def(
"ep_moe_pre_reorder(Tensor input_ptr, Tensor gateup_input_ptr, Tensor src2dst_ptr, Tensor topk_ids_ptr, Tensor "
"a1_scales_ptr, int start_expert_id, int end_expert_id, int topk, bool use_per_token_if_dynamic) -> ()");
"ep_moe_pre_reorder(Tensor input, Tensor gateup_input, Tensor src2dst, Tensor topk_ids, Tensor "
"a1_scales, int start_expert_id, int end_expert_id, int topk, bool use_per_token_if_dynamic) -> ()");
m.impl("ep_moe_pre_reorder", torch::kCUDA, &ep_moe_pre_reorder);
m.def(
"ep_moe_post_reorder(Tensor down_output, Tensor output, Tensor src2dst, Tensor topk_ids, Tensor "
"topk_weights, int start_expert_id, int end_expert_id, int topk) -> ()");
m.impl("ep_moe_post_reorder", torch::kCUDA, &ep_moe_post_reorder);
m.def(
"fp8_blockwise_scaled_grouped_mm(Tensor output, Tensor a_ptrs, Tensor b_ptrs, Tensor out_ptrs, Tensor "
"a_scales_ptrs, Tensor b_scales_ptrs, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, Tensor "
Expand Down
85 changes: 83 additions & 2 deletions sgl-kernel/csrc/moe/ep_moe_reorder_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,57 @@ __global__ void ep_pre_reorder_cuda_kernel(
}
}

template <typename scalar_t>
__global__ void ep_post_reorder_cuda_kernel(
const scalar_t* __restrict__ down_output_ptr,
scalar_t* __restrict__ output_ptr,
const int* __restrict__ src2dst_ptr,
const int* __restrict__ topk_ids_ptr,
const scalar_t* __restrict__ topk_weights_ptr,
int start_expert_id,
int end_expert_id,
int topk,
int hidden_size) {
const int token_idx = blockIdx.x;
const int tid = threadIdx.x;

const int* token_src2dst = src2dst_ptr + token_idx * topk;
const int* token_topk_ids = topk_ids_ptr + token_idx * topk;
const scalar_t* token_topk_weights = topk_weights_ptr + token_idx * topk;

scalar_t* dst_ptr = output_ptr + static_cast<int64_t>(token_idx) * hidden_size;

constexpr uint32_t vec_size = 16 / sizeof(scalar_t);
using vec_t = flashinfer::vec_t<scalar_t, vec_size>;

const int vec_iters = hidden_size / vec_size;
for (int idx = tid; idx < vec_iters; idx += blockDim.x) {
float acc[vec_size] = {0};

for (int k = 0; k < topk; ++k) {
const int expert_id = token_topk_ids[k];
if (expert_id < start_expert_id || expert_id > end_expert_id) continue;
const int src_row = token_src2dst[k];
const scalar_t* src_ptr = down_output_ptr + static_cast<int64_t>(src_row) * hidden_size;
const float weight = static_cast<float>(token_topk_weights[k]);

vec_t src_vec;
src_vec.cast_load(src_ptr + idx * vec_size);

#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
acc[i] += static_cast<float>(src_vec[i]) * weight;
}
}
vec_t out_vec;
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i)
out_vec[i] = static_cast<scalar_t>(acc[i]);

out_vec.cast_store(dst_ptr + idx * vec_size);
}
}

void ep_moe_pre_reorder(
torch::Tensor input,
torch::Tensor gateup_input,
Expand All @@ -77,8 +128,8 @@ void ep_moe_pre_reorder(
int64_t end_expert_id,
int64_t topk,
bool use_per_token_if_dynamic) {
int total_blocks = input.size(0);
int block_size = 512;
const int total_blocks = input.size(0);
const int block_size = 512;
dim3 grid(total_blocks);
dim3 block(block_size);
int hidden_size = input.size(1);
Expand All @@ -98,3 +149,33 @@ void ep_moe_pre_reorder(
return true;
});
}

void ep_moe_post_reorder(
torch::Tensor down_output,
torch::Tensor output,
torch::Tensor src2dst,
torch::Tensor topk_ids,
torch::Tensor topk_weights,
int64_t start_expert_id,
int64_t end_expert_id,
int64_t topk) {
const int total_tokens = output.size(0);
const int block_size = 512;
dim3 grid(total_tokens);
dim3 block(block_size);
const int hidden_size = output.size(1);

DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(down_output.scalar_type(), scalar_t, [&] {
ep_post_reorder_cuda_kernel<scalar_t><<<grid, block>>>(
static_cast<scalar_t*>(down_output.data_ptr()),
static_cast<scalar_t*>(output.data_ptr()),
src2dst.data_ptr<int>(),
topk_ids.data_ptr<int>(),
static_cast<scalar_t*>(topk_weights.data_ptr()),
static_cast<int>(start_expert_id),
static_cast<int>(end_expert_id),
static_cast<int>(topk),
hidden_size);
return true;
});
}
10 changes: 10 additions & 0 deletions sgl-kernel/include/sgl_kernel_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,16 @@ void ep_moe_pre_reorder(
int64_t topk,
bool use_per_token_if_dynamic);

void ep_moe_post_reorder(
torch::Tensor down_output,
torch::Tensor output,
torch::Tensor src2dst,
torch::Tensor topk_ids,
torch::Tensor topk_weights,
int64_t start_expert_id,
int64_t end_expert_id,
int64_t topk);

void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor);

void cutlass_fp4_group_mm(
Expand Down
1 change: 1 addition & 0 deletions sgl-kernel/python/sgl_kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda
from sgl_kernel.moe import (
cutlass_fp4_group_mm,
ep_moe_post_reorder,
ep_moe_pre_reorder,
fp8_blockwise_scaled_grouped_mm,
moe_align_block_size,
Expand Down
22 changes: 22 additions & 0 deletions sgl-kernel/python/sgl_kernel/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,28 @@ def ep_moe_pre_reorder(
)


def ep_moe_post_reorder(
down_output,
output,
src2dst,
topk_ids,
topk_weights,
start_expert_id,
end_expert_id,
topk,
):
return torch.ops.sgl_kernel.ep_moe_post_reorder.default(
down_output,
output,
src2dst,
topk_ids,
topk_weights,
start_expert_id,
end_expert_id,
topk,
)


def fp8_blockwise_scaled_grouped_mm(
output,
a_ptrs,
Expand Down
Loading