@@ -24,12 +24,24 @@ limitations under the License.
2424
2525#define WARP_SIZE 32
2626
27+ template <typename scalar_t >
28+ __global__ void moe_token_sort_kernel (scalar_t * __restrict__ topk_ids, int32_t * sorted_token_ids,
29+ int32_t * cumsum_buffer, size_t numel) {
30+ const size_t tid = blockIdx .x * blockDim .x + threadIdx .x ;
31+ const size_t stride = blockDim .x * gridDim .x ;
32+
33+ for (size_t i = tid; i < numel; i += stride) {
34+ int32_t expert_id = topk_ids[i];
35+ int32_t rank_post_pad = atomicAdd (&cumsum_buffer[expert_id], 1 );
36+ sorted_token_ids[rank_post_pad] = i;
37+ }
38+ }
39+
2740template <typename scalar_t >
2841__global__ void moe_align_block_size_kernel (scalar_t * __restrict__ topk_ids, int32_t * sorted_token_ids,
2942 int32_t * expert_ids, int32_t * total_tokens_post_pad, int32_t num_experts,
3043 int32_t block_size, size_t numel, int32_t * cumsum) {
3144 __shared__ int32_t shared_counts[WARP_SIZE][8 ];
32- __shared__ int32_t local_offsets[256 ];
3345
3446 const int warp_id = threadIdx .x / WARP_SIZE;
3547 const int experts_per_warp = 8 ;
@@ -72,20 +84,6 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int
7284 for (int i = cumsum[threadIdx .x ]; i < cumsum[threadIdx .x + 1 ]; i += block_size) {
7385 expert_ids[i / block_size] = threadIdx .x ;
7486 }
75- local_offsets[threadIdx .x ] = cumsum[threadIdx .x ];
76- }
77-
78- __syncthreads ();
79-
80- // Note: For the moe_align_kernel, the primary bottleneck lies in the atomic add and non-coalesced memory writes here.
81- // If these operations can be performed using multiple blocks, similar to the Triton version, the performance of this
82- // kernel can achieve state-of-the-art performance across all token cases. However, once multiple blocks are used,
83- // illegal memory access occurs. Even replacing these lines of code with the stage 4 kernel from the Triton version
84- // results in the same issue, and a correct solution has not yet been found.
85- for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
86- int32_t expert_id = topk_ids[i];
87- int32_t rank_post_pad = atomicAdd (&local_offsets[expert_id], 1 );
88- sorted_token_ids[rank_post_pad] = i;
8987 }
9088}
9189
@@ -100,5 +98,15 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
10098 align_kernel<<<1 , 1024 , 0 , stream>>> (topk_ids.data_ptr <scalar_t >(), sorted_token_ids.data_ptr <int32_t >(),
10199 experts_ids.data_ptr <int32_t >(), num_tokens_post_pad.data_ptr <int32_t >(),
102100 num_experts, block_size, topk_ids.numel (), cumsum_buffer.data_ptr <int32_t >());
101+
102+ const int block_threads = 256 ;
103+ const int num_blocks = (topk_ids.numel () + block_threads - 1 ) / block_threads;
104+ const int max_blocks = 65535 ;
105+ const int actual_blocks = std::min (num_blocks, max_blocks);
106+
107+ auto sort_kernel = moe_token_sort_kernel<scalar_t >;
108+ sort_kernel<<<actual_blocks, block_threads, 0 , stream>>> (topk_ids.data_ptr <scalar_t >(),
109+ sorted_token_ids.data_ptr <int32_t >(),
110+ cumsum_buffer.data_ptr <int32_t >(), topk_ids.numel ());
103111 });
104112}
0 commit comments