Skip to content

Commit 83f0926

Browse files
committed
init buffer in kernel
1 parent f4cac8d commit 83f0926

File tree

6 files changed

+16
-32
lines changed

6 files changed

+16
-32
lines changed

csrc/moe/moe_align_sum_kernels.cu

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,22 +222,28 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
222222
int64_t block_size, torch::Tensor sorted_token_ids,
223223
torch::Tensor experts_ids,
224224
torch::Tensor num_tokens_post_pad,
225-
torch::Tensor token_cnts_buffer,
226-
torch::Tensor cumsum_buffer) {
225+
bool use_global_memory) {
227226
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
228227

229228
// If we have very large number of experts, we can no longer use shared
230229
// memory.
231230
// TODO(simon): the right solution should be calculating the exact right
232231
// amount of shared memory and use that. The num_experts >= 256 is just a
233232
// temporary solution to unblock Deepseek V3.
234-
if (token_cnts_buffer.numel() > 0 && token_cnts_buffer.numel() > 0) {
233+
if (use_global_memory) {
235234
VLLM_DISPATCH_INTEGRAL_TYPES(
236235
topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
237236
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
238237
// tensors
239238
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
240239

240+
auto options_int =
241+
torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device());
242+
torch::Tensor token_cnts_buffer =
243+
torch::empty({(num_experts + 1) * num_experts}, options_int);
244+
torch::Tensor cumsum_buffer =
245+
torch::empty({num_experts + 1}, options_int);
246+
241247
auto kernel =
242248
vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t>;
243249
kernel<<<1, num_thread, 0, stream>>>(

csrc/moe/moe_ops.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,4 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
1212
int64_t block_size, torch::Tensor sorted_token_ids,
1313
torch::Tensor experts_ids,
1414
torch::Tensor num_tokens_post_pad,
15-
torch::Tensor token_cnts_buffer,
16-
torch::Tensor cumsum_buffer);
15+
bool use_global_memory);

csrc/moe/torch_bindings.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
2020
" int block_size, Tensor! sorted_token_ids,"
2121
" Tensor! experts_ids,"
2222
" Tensor! num_tokens_post_pad,"
23-
" Tensor! token_cnts_buffer,"
24-
" Tensor! cumsum_buffer) -> ()");
23+
" bool use_global_memory) -> ()");
2524
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
2625

2726
#ifndef USE_ROCM

vllm/_custom_ops.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -918,17 +918,11 @@ def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
918918
block_size: int, sorted_token_ids: torch.Tensor,
919919
experts_ids: torch.Tensor,
920920
num_tokens_post_pad: torch.Tensor,
921-
token_cnts_buffer: Optional[torch.Tensor] = None,
922-
cumsum_buffer: Optional[torch.Tensor] = None,
923-
) -> None:
924-
if token_cnts_buffer is None:
925-
token_cnts_buffer = torch.empty((0,), device=topk_ids.device)
926-
if cumsum_buffer is None:
927-
cumsum_buffer = torch.empty((0,), device=topk_ids.device)
921+
use_global_memory: bool = False) -> None:
928922
torch.ops._moe_C.moe_align_block_size(topk_ids, num_experts, block_size,
929923
sorted_token_ids, experts_ids,
930924
num_tokens_post_pad,
931-
token_cnts_buffer, cumsum_buffer)
925+
use_global_memory)
932926

933927

934928
def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,

vllm/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,7 @@ def _verify_cuda_graph(self) -> None:
604604
self.max_model_len)
605605

606606
if (self.hf_config.model_type == 'deepseek_v3'
607+
and self.quantization == "fp8"
607608
and not self.enforce_eager):
608609
logger.warning("CUDA graph is not supported for Deepseek V3 yet, "
609610
"fallback to the eager mode.")

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -256,25 +256,10 @@ def moe_align_block_size(
256256
num_tokens_post_pad = torch.empty((1),
257257
dtype=torch.int32,
258258
device=topk_ids.device)
259-
if num_experts >= 256:
260-
# For DeepSeek-V3
261-
token_cnts_buffer = torch.empty((num_experts + 1) * num_experts,
262-
dtype=torch.int32,
263-
device=topk_ids.device)
264-
cumsum_buffer = torch.empty(num_experts + 1,
265-
dtype=torch.int32,
266-
device=topk_ids.device)
267-
else:
268-
token_cnts_buffer = torch.empty((0, ),
269-
dtype=torch.int32,
270-
device=topk_ids.device)
271-
cumsum_buffer = torch.empty((0, ),
272-
dtype=torch.int32,
273-
device=topk_ids.device)
274-
259+
use_global_memory = num_experts >= 256 # for deepseek-v3
275260
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
276261
expert_ids, num_tokens_post_pad,
277-
token_cnts_buffer, cumsum_buffer)
262+
use_global_memory)
278263
return sorted_ids, expert_ids, num_tokens_post_pad
279264

280265

0 commit comments

Comments
 (0)