@@ -36,7 +36,7 @@ __global__ void count_and_sort_expert_tokens_kernel(
3636 const size_t stride = blockDim .x * gridDim .x ;
3737
3838 for (size_t i = tid; i < numel; i += stride) {
39- int32_t expert_id = topk_ids[i];
39+ int32_t expert_id = topk_ids[i] + 1 ;
4040 int32_t rank_post_pad = atomicAdd (&cumsum_buffer[expert_id], 1 );
4141 sorted_token_ids[rank_post_pad] = i;
4242 }
@@ -82,7 +82,7 @@ __global__ void moe_align_block_size_kernel(
8282 __syncthreads ();
8383
8484 for (size_t i = tid; i < numel; i += stride) {
85- int expert_id = topk_ids[i];
85+ int expert_id = topk_ids[i] + 1 ;
8686 atomicAdd (&shared_counts[expert_id], 1 );
8787 }
8888
@@ -215,7 +215,7 @@ __global__ void moe_align_block_size_kernel(
215215 right = mid;
216216 }
217217 }
218- expert_ids[i] = left - 1 ;
218+ expert_ids[i] = left - 2 ;
219219 }
220220
221221 if (pad_sorted_token_ids) {
@@ -251,7 +251,7 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
251251 }
252252
253253 for (size_t i = tid; i < numel; i += stride) {
254- ++tokens_cnts[(threadIdx .x + 1 ) * num_experts + topk_ids[i]];
254+ ++tokens_cnts[(threadIdx .x + 1 ) * num_experts + topk_ids[i] + 1 ];
255255 }
256256
257257 __syncthreads ();
@@ -277,7 +277,7 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
277277
278278 if (threadIdx .x < num_experts) {
279279 for (int i = cumsum[threadIdx .x ]; i < cumsum[threadIdx .x + 1 ]; i += block_size) {
280- expert_ids[i / block_size] = threadIdx .x ;
280+ expert_ids[i / block_size] = threadIdx .x - 1 ;
281281 }
282282 }
283283
@@ -294,7 +294,7 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
294294 __syncthreads ();
295295
296296 for (size_t i = tid; i < numel; i += stride) {
297- int32_t expert_id = topk_ids[i];
297+ int32_t expert_id = topk_ids[i] + 1 ;
298298 int32_t rank_post_pad = tokens_cnts[threadIdx .x * num_experts + expert_id] + cumsum[expert_id];
299299 sorted_token_ids[rank_post_pad] = i;
300300 ++tokens_cnts[threadIdx .x * num_experts + expert_id];
@@ -308,7 +308,6 @@ void moe_align_block_size(
308308 torch::Tensor sorted_token_ids,
309309 torch::Tensor experts_ids,
310310 torch::Tensor num_tokens_post_pad,
311- torch::Tensor token_cnts_buffer,
312311 torch::Tensor cumsum_buffer,
313312 bool pad_sorted_token_ids) {
314313 const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
0 commit comments