From 052499c9e850fdaf74ba4af17cab5e5c62b6b949 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Tue, 21 Jan 2025 11:16:39 +0800 Subject: [PATCH 1/2] fix moe_align_block_size error condition Signed-off-by: Jinzhen Lin --- csrc/moe/moe_align_sum_kernels.cu | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index 715a1b42841f..9b9390385e1e 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -233,15 +233,17 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, (num_experts + 1) * sizeof(int32_t); bool use_global_memory = false; - bool use_i16 = false; // Use uint16_t for shared memory token counts - if (shared_mem_i16 > device_max_shared_mem) { - use_global_memory = true; - } else if (shared_mem_i32 > device_max_shared_mem && + bool use_i16 = false; // Use uint16_t for shared memory token counts + if (shared_mem_i32 < device_max_shared_mem) { + // use_global_memory = false, use_i16 = false + } else if (shared_mem_i16 < device_max_shared_mem && topk_ids.numel() <= 65535) { // when nelements of topk_ids is smaller than 65535 (max value of uint16), // element value of token_cnts would also smaller than 65535, // so we can use uint16 as dtype of token_cnts use_i16 = true; + } else { + use_global_memory = true; } if (use_global_memory) { From 64d112b98d52135cd42a7690b4ef9a47f30eee61 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Tue, 21 Jan 2025 12:31:51 +0800 Subject: [PATCH 2/2] add comment for clarity Signed-off-by: Jinzhen Lin --- csrc/moe/moe_align_sum_kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index 9b9390385e1e..d609ce1697df 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -235,7 +235,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, bool use_global_memory = false; bool use_i16 = false; // Use uint16_t for shared memory token counts if (shared_mem_i32 < device_max_shared_mem) { - // use_global_memory = false, use_i16 = false + // Do nothing in this case. We're all set to use int32_t token counts } else if (shared_mem_i16 < device_max_shared_mem && topk_ids.numel() <= 65535) { // when nelements of topk_ids is smaller than 65535 (max value of uint16),