From 487acf6bdfc0b06777e7ef3ac69f62e1809147a8 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Wed, 5 Feb 2025 14:05:40 -0600 Subject: [PATCH] The code assumes WARP_SIZE to be equal to 32, which is not the case on ROCm Signed-off-by: Gregory Shtrasberg --- csrc/moe/moe_align_sum_kernels.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index ff74a42d7e81..01dac4044650 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -207,8 +207,8 @@ __global__ void sgl_moe_align_block_size_kernel( __shared__ int32_t shared_counts[32][8]; __shared__ int32_t local_offsets[256]; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; + const int warp_id = threadIdx.x / 32; + const int lane_id = threadIdx.x % 32; const int experts_per_warp = 8; const int my_expert_start = warp_id * experts_per_warp;