Skip to content

Commit a5f5ab4

Browse files
ch-wanBBufispobock
authored
update sgl-kernel for EP: kernel part (#8514)
Co-authored-by: Xiaoyu Zhang <[email protected]> Co-authored-by: Ke Bao <[email protected]>
1 parent 59aab76 commit a5f5ab4

File tree

7 files changed

+12
-32
lines changed

7 files changed

+12
-32
lines changed

sgl-kernel/benchmark/bench_moe_align_block_size.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,6 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
164164
num_tokens_post_pad_cuda = torch.empty(
165165
(1), dtype=torch.int32, device=topk_ids.device
166166
)
167-
token_cnts_buffer = torch.zeros(
168-
(num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device
169-
)
170167
cumsum_buffer = torch.zeros(
171168
num_experts + 1, dtype=torch.int32, device=topk_ids.device
172169
)
@@ -189,7 +186,6 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
189186
sorted_ids_cuda,
190187
expert_ids_cuda,
191188
num_tokens_post_pad_cuda,
192-
token_cnts_buffer,
193189
cumsum_buffer,
194190
)
195191
moe_align_block_size_triton(
@@ -273,11 +269,6 @@ def sgl_moe_align_block_size_with_empty(
273269
if not pad_sorted_token_ids:
274270
sorted_ids.fill_(topk_ids.numel())
275271

276-
token_cnts_buffer = torch.empty(
277-
(num_experts + 1) * num_experts,
278-
dtype=torch.int32,
279-
device=topk_ids.device,
280-
)
281272
cumsum_buffer = torch.empty(
282273
num_experts + 1, dtype=torch.int32, device=topk_ids.device
283274
)
@@ -289,7 +280,6 @@ def sgl_moe_align_block_size_with_empty(
289280
sorted_ids.clone(),
290281
expert_ids.clone(),
291282
num_tokens_post_pad.clone(),
292-
token_cnts_buffer,
293283
cumsum_buffer,
294284
pad_sorted_token_ids,
295285
)

sgl-kernel/csrc/common_extension.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
165165
*/
166166
m.def(
167167
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
168-
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer, bool "
168+
"experts_ids, Tensor! num_tokens_post_pad, Tensor! cumsum_buffer, bool "
169169
"pad_sorted_token_ids) -> ()");
170170
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
171171

sgl-kernel/csrc/moe/moe_align_kernel.cu

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ __global__ void count_and_sort_expert_tokens_kernel(
3636
const size_t stride = blockDim.x * gridDim.x;
3737

3838
for (size_t i = tid; i < numel; i += stride) {
39-
int32_t expert_id = topk_ids[i];
39+
int32_t expert_id = topk_ids[i] + 1;
4040
int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1);
4141
sorted_token_ids[rank_post_pad] = i;
4242
}
@@ -82,7 +82,7 @@ __global__ void moe_align_block_size_kernel(
8282
__syncthreads();
8383

8484
for (size_t i = tid; i < numel; i += stride) {
85-
int expert_id = topk_ids[i];
85+
int expert_id = topk_ids[i] + 1;
8686
atomicAdd(&shared_counts[expert_id], 1);
8787
}
8888

@@ -215,7 +215,7 @@ __global__ void moe_align_block_size_kernel(
215215
right = mid;
216216
}
217217
}
218-
expert_ids[i] = left - 1;
218+
expert_ids[i] = left - 2;
219219
}
220220

221221
if (pad_sorted_token_ids) {
@@ -251,7 +251,7 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
251251
}
252252

253253
for (size_t i = tid; i < numel; i += stride) {
254-
++tokens_cnts[(threadIdx.x + 1) * num_experts + topk_ids[i]];
254+
++tokens_cnts[(threadIdx.x + 1) * num_experts + topk_ids[i] + 1];
255255
}
256256

257257
__syncthreads();
@@ -277,7 +277,7 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
277277

278278
if (threadIdx.x < num_experts) {
279279
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) {
280-
expert_ids[i / block_size] = threadIdx.x;
280+
expert_ids[i / block_size] = threadIdx.x - 1;
281281
}
282282
}
283283

@@ -294,7 +294,7 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
294294
__syncthreads();
295295

296296
for (size_t i = tid; i < numel; i += stride) {
297-
int32_t expert_id = topk_ids[i];
297+
int32_t expert_id = topk_ids[i] + 1;
298298
int32_t rank_post_pad = tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id];
299299
sorted_token_ids[rank_post_pad] = i;
300300
++tokens_cnts[threadIdx.x * num_experts + expert_id];
@@ -308,7 +308,6 @@ void moe_align_block_size(
308308
torch::Tensor sorted_token_ids,
309309
torch::Tensor experts_ids,
310310
torch::Tensor num_tokens_post_pad,
311-
torch::Tensor token_cnts_buffer,
312311
torch::Tensor cumsum_buffer,
313312
bool pad_sorted_token_ids) {
314313
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

sgl-kernel/csrc/torch_extension_rocm.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
9292
*/
9393
m.def(
9494
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
95-
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer, bool "
95+
"experts_ids, Tensor! num_tokens_post_pad, Tensor! cumsum_buffer, bool "
9696
"pad_sorted_token_ids) -> ()");
9797
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
9898

sgl-kernel/include/sgl_kernel_ops.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,6 @@ void moe_align_block_size(
230230
torch::Tensor sorted_token_ids,
231231
torch::Tensor experts_ids,
232232
torch::Tensor num_tokens_post_pad,
233-
torch::Tensor token_cnts_buffer,
234233
torch::Tensor cumsum_buffer,
235234
bool pad_sorted_token_ids);
236235

sgl-kernel/python/sgl_kernel/moe.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ def moe_align_block_size(
1010
sorted_token_ids,
1111
experts_ids,
1212
num_tokens_post_pad,
13-
token_cnts_buffer,
1413
cumsum_buffer,
1514
pad_sorted_token_ids=False,
1615
):
@@ -21,7 +20,6 @@ def moe_align_block_size(
2120
sorted_token_ids,
2221
experts_ids,
2322
num_tokens_post_pad,
24-
token_cnts_buffer,
2523
cumsum_buffer,
2624
pad_sorted_token_ids,
2725
)

sgl-kernel/tests/test_moe_align.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def test_moe_align_block_size_compare_implementations(
157157
:, :topk
158158
]
159159

160-
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
160+
max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1)
161161

162162
sorted_ids_cuda = torch.empty(
163163
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
@@ -171,13 +171,8 @@ def test_moe_align_block_size_compare_implementations(
171171
num_tokens_post_pad_cuda = torch.empty(
172172
(1), dtype=torch.int32, device=topk_ids.device
173173
)
174-
token_cnts_buffer = torch.empty(
175-
(num_experts + 1) * num_experts,
176-
dtype=torch.int32,
177-
device=topk_ids.device,
178-
)
179174
cumsum_buffer = torch.empty(
180-
num_experts + 1, dtype=torch.int32, device=topk_ids.device
175+
num_experts + 2, dtype=torch.int32, device=topk_ids.device
181176
)
182177

183178
sorted_ids_triton = torch.empty_like(sorted_ids_cuda)
@@ -187,19 +182,18 @@ def test_moe_align_block_size_compare_implementations(
187182

188183
moe_align_block_size(
189184
topk_ids,
190-
num_experts,
185+
num_experts + 1,
191186
block_size,
192187
sorted_ids_cuda,
193188
expert_ids_cuda,
194189
num_tokens_post_pad_cuda,
195-
token_cnts_buffer,
196190
cumsum_buffer,
197191
pad_sorted_token_ids,
198192
)
199193

200194
moe_align_block_size_triton(
201195
topk_ids,
202-
num_experts,
196+
num_experts + 1,
203197
block_size,
204198
sorted_ids_triton,
205199
expert_ids_triton,

0 commit comments

Comments
 (0)