@@ -222,22 +222,28 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
222222 int64_t block_size, torch::Tensor sorted_token_ids,
223223 torch::Tensor experts_ids,
224224 torch::Tensor num_tokens_post_pad,
225- torch::Tensor token_cnts_buffer,
226- torch::Tensor cumsum_buffer) {
225+ bool use_global_memory) {
227226 const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
228227
229228 // If we have very large number of experts, we can no longer use shared
230229 // memory.
231230 // TODO(simon): the right solution should be calculating the exact right
232231 // amount of shared memory and use that. The num_experts >= 256 is just a
233232 // temporary solution to unblock Deepseek V3.
234- if (token_cnts_buffer. numel () > 0 && token_cnts_buffer. numel () > 0 ) {
233+ if (use_global_memory ) {
235234 VLLM_DISPATCH_INTEGRAL_TYPES (
236235 topk_ids.scalar_type (), " moe_align_block_size_global_mem_kernel" , [&] {
237236 // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
238237 // tensors
239238 const int32_t num_thread = max ((int32_t )num_experts, WARP_SIZE);
240239
240+ auto options_int =
241+ torch::TensorOptions ().dtype (torch::kInt ).device (topk_ids.device ());
242+ torch::Tensor token_cnts_buffer =
243+ torch::empty ({(num_experts + 1 ) * num_experts}, options_int);
244+ torch::Tensor cumsum_buffer =
245+ torch::empty ({num_experts + 1 }, options_int);
246+
241247 auto kernel =
242248 vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t >;
243249 kernel<<<1 , num_thread, 0 , stream>>> (
0 commit comments