Skip to content

Commit 77d1210

Browse files
authored
fix moe_align_block_size (#2615)
1 parent 70dc2fb commit 77d1210

File tree

4 files changed

+24
-18
lines changed

4 files changed

+24
-18
lines changed

sgl-kernel/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "sgl-kernel"
7-
version = "0.0.2.post9"
7+
version = "0.0.2.post10"
88
description = "Kernel Library for SGLang"
99
readme = "README.md"
1010
requires-python = ">=3.8"

sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -118,31 +118,19 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int
118118
}
119119

120120
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size,
121-
torch::Tensor sorted_token_ids, torch::Tensor experts_ids,
122-
torch::Tensor num_tokens_post_pad) {
121+
torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad,
122+
torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer) {
123123
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
124124
DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
125125
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
126126
// tensors
127127
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
128128

129-
const int32_t mem_tokens_cnts = ((num_experts + 1) * num_experts) * sizeof(int32_t);
130-
const int32_t mem_cumsum = (num_experts + 1) * sizeof(int32_t);
131-
132-
// allocate global memory
133-
int32_t* tokens_cnts;
134-
int32_t* cumsum;
135-
cudaMalloc(&tokens_cnts, mem_tokens_cnts);
136-
cudaMalloc(&cumsum, mem_cumsum);
137-
138-
// set dynamic shared mem
139129
auto kernel = moe_align_block_size_kernel<scalar_t>;
140130
kernel<<<1, num_thread, 0, stream>>>(topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
141131
experts_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(),
142-
num_experts, block_size, topk_ids.numel(), tokens_cnts, cumsum);
143-
144-
cudaFree(tokens_cnts);
145-
cudaFree(cumsum);
132+
num_experts, block_size, topk_ids.numel(),
133+
token_cnts_buffer.data_ptr<int32_t>(), cumsum_buffer.data_ptr<int32_t>());
146134
});
147135
}
148136

sgl-kernel/src/sgl-kernel/ops/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ def moe_align_block_size(
88
sorted_token_ids,
99
experts_ids,
1010
num_tokens_post_pad,
11+
token_cnts_buffer,
12+
cumsum_buffer,
1113
):
1214
_moe_align_block_size(
1315
topk_ids,
@@ -16,4 +18,6 @@ def moe_align_block_size(
1618
sorted_token_ids,
1719
experts_ids,
1820
num_tokens_post_pad,
21+
token_cnts_buffer,
22+
cumsum_buffer,
1923
)

sgl-kernel/tests/test_moe_align.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,22 @@ def test_moe_align_block_size():
1818
)
1919
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
2020

21+
token_cnts_buffer = torch.empty(
22+
(num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device
23+
)
24+
cumsum_buffer = torch.empty(
25+
num_experts + 1, dtype=torch.int32, device=topk_ids.device
26+
)
27+
2128
moe_align_block_size(
22-
topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
29+
topk_ids,
30+
num_experts,
31+
block_size,
32+
sorted_ids,
33+
expert_ids,
34+
num_tokens_post_pad,
35+
token_cnts_buffer,
36+
cumsum_buffer,
2337
)
2438

2539

0 commit comments

Comments
 (0)