From 1ae3d450d1543eddb6f56944b49ee6107fa4beca Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 4 Jun 2025 23:16:31 +0000 Subject: [PATCH 01/17] Add chunking logic to modular triton kernel Signed-off-by: Bill Nell --- tests/kernels/moe/test_moe.py | 23 ++- .../layers/fused_moe/fused_moe.py | 148 +++++++++++------- 2 files changed, 110 insertions(+), 61 deletions(-) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 7238813a299d..8b68975130f5 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -15,7 +15,8 @@ from tests.kernels.utils import opcheck, stack_and_dev, torch_moe from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe import fused_moe -from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_topk, modular_triton_fused_moe) from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( fused_moe as iterative_moe) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( @@ -103,7 +104,27 @@ def test_fused_moe( expert_map=e_map, renormalize=False) + m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=False, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + per_channel_quant=False, + block_shape=None) + + m_triton_output = m_fused_moe(a, + w1, + w2, + score, + topk, + global_num_experts=e, + expert_map=e_map, + renormalize=False) + torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) + torch.testing.assert_close(m_triton_output, + torch_output, + atol=2e-2, + rtol=0) torch.testing.assert_close(iterative_output, torch_output, atol=2e-2, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index ba1498e65319..dcf50ac8a3cc 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1600,12 +1600,18 @@ def apply( if global_num_experts == -1: global_num_experts = E + # We execute the fused_moe kernel in chunks to circumvent this issue: + # https://github.com/vllm-project/vllm/issues/5938 + CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE + M = min(num_tokens, CHUNK_SIZE) + config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8, use_int8_w8a16=self.use_int8_w8a16, use_int4_w4a16=self.use_int4_w4a16, dtype=hidden_states.dtype) - config = try_get_optimal_moe_config( + get_config_func = functools.partial( + try_get_optimal_moe_config, w1.shape, w2.shape, top_k_num, @@ -1614,6 +1620,8 @@ def apply( block_shape=self.block_shape, ) + config = get_config_func(M) + if hidden_states.dtype == torch.bfloat16: compute_type = tl.bfloat16 elif hidden_states.dtype == torch.float16: @@ -1628,67 +1636,87 @@ def apply( # We can reuse the memory between these because by the time we need # cache3, we're done with cache1 - intermediate_cache1 = _resize_cache(workspace13, - (num_tokens, top_k_num, N)) + intermediate_cache1 = _resize_cache(workspace13, (M, top_k_num, N)) intermediate_cache2 = _resize_cache(workspace2, - (num_tokens * top_k_num, N // 2)) - intermediate_cache3 = _resize_cache(workspace13, - (num_tokens, top_k_num, K)) - - sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'], - global_num_experts, expert_map)) - - invoke_fused_moe_kernel(hidden_states, - w1, - intermediate_cache1, - a1q_scale, - w1_scale, - w1_zp, - None, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - False, - top_k_num, - config, - compute_type=compute_type, - use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a8=self.use_int8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, - per_channel_quant=self.per_channel_quant, - block_shape=self.block_shape) - - self.activation(activation, intermediate_cache2, - intermediate_cache1.view(-1, N)) - - a2q_scale: Optional[torch.Tensor] = None - - qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( - intermediate_cache2, a2_scale, self.qtype, self.per_channel_quant, - self.block_shape) + (M * top_k_num, N // 2)) + intermediate_cache3 = _resize_cache(workspace13, (M, top_k_num, K)) + + for chunk in range((num_tokens // CHUNK_SIZE) + 1): + begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, + min((chunk + 1) * CHUNK_SIZE, + num_tokens)) + curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] + tokens_in_chunk, _ = curr_hidden_states.shape + + if tokens_in_chunk == 0: + break - invoke_fused_moe_kernel(qintermediate_cache2, - w2, - intermediate_cache3, - a2q_scale, - w2_scale, - w2_zp, - None, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - False, - 1, - config, - compute_type=compute_type, - use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a8=self.use_int8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, - per_channel_quant=self.per_channel_quant, - block_shape=self.block_shape) + if tokens_in_chunk < CHUNK_SIZE and chunk > 0: + # Adjust the intermediate cache size and config for the last + # chunk. Note that in most cases we only have one chunk + # so the cache size and config are already set correctly and + # do not need to be adjusted. + intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] + intermediate_cache2 = intermediate_cache2[:tokens_in_chunk * + topk_ids.shape[1]] + intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] + config = get_config_func(tokens_in_chunk) + + sorted_token_ids, expert_ids, num_tokens_post_padded = ( + moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], + global_num_experts, expert_map)) + + invoke_fused_moe_kernel(curr_hidden_states, + w1, + intermediate_cache1, + a1q_scale, + w1_scale, + w1_zp, + None, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + top_k_num, + config, + compute_type=compute_type, + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a8=self.use_int8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + per_channel_quant=self.per_channel_quant, + block_shape=self.block_shape) + + self.activation(activation, intermediate_cache2, + intermediate_cache1.view(-1, N)) + + a2q_scale: Optional[torch.Tensor] = None + + qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( + intermediate_cache2, a2_scale, self.qtype, + self.per_channel_quant, self.block_shape) + + invoke_fused_moe_kernel(qintermediate_cache2, + w2, + intermediate_cache3, + a2q_scale, + w2_scale, + w2_zp, + None, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + 1, + config, + compute_type=compute_type, + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a8=self.use_int8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + per_channel_quant=self.per_channel_quant, + block_shape=self.block_shape) return intermediate_cache3 From 193a9cab70e7cc02db52c27a6274d9f4d9e2a80e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 5 Jun 2025 02:51:02 +0000 Subject: [PATCH 02/17] fix Signed-off-by: Bill Nell --- tests/kernels/moe/test_moe.py | 22 ++--- tests/kernels/quantization/test_block_fp8.py | 27 +++++- .../layers/fused_moe/fused_moe.py | 93 +++++++++++++------ 3 files changed, 100 insertions(+), 42 deletions(-) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 8b68975130f5..bed374cf4d56 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -77,6 +77,13 @@ def test_fused_moe( else: e_map = None + m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=False, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + per_channel_quant=False, + block_shape=None) + with set_current_vllm_config(vllm_config): torch_output = torch_moe(a, w1, w2, score, topk, e_map) iterative_output = iterative_moe(a, @@ -104,21 +111,14 @@ def test_fused_moe( expert_map=e_map, renormalize=False) - m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=False, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - per_channel_quant=False, - block_shape=None) - + topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) m_triton_output = m_fused_moe(a, w1, w2, - score, - topk, + topk_weights, + topk_ids, global_num_experts=e, - expert_map=e_map, - renormalize=False) + expert_map=e_map) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) torch.testing.assert_close(m_triton_output, diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index 8c5ee98743d7..eec59573792d 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -13,7 +13,8 @@ from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( _valid_deep_gemm_shape, deep_gemm_moe_fp8) -from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_topk, modular_triton_fused_moe) from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( @@ -45,7 +46,7 @@ K = [256, 3884, 4096, 13824, 16384] # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # and its hidden size is 7168. -M_moe = [1, 2, 7, 83, 128, 2048] +M_moe = [1, 2, 7, 83, 128, 2048, 1024 * 128] M_moe_dg = [128, 192, 1335, 2048] N_moe = [128, 256, 1024, 4608] # [13824] K_moe = [256, 512, 7168] # [13824] @@ -214,6 +215,13 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): score = torch.randn((M, E), dtype=dtype) + m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + per_channel_quant=False, + block_shape=block_size) + # Set the context to avoid lots of warning spam. with set_current_vllm_config(vllm_config): out = fused_moe( @@ -231,6 +239,16 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_size) + topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) + m_out = m_fused_moe(a, + w1, + w2, + topk_weights, + topk_ids, + global_num_experts=E, + w1_scale=w1_s, + w2_scale=w2_s) + #print(f"{out.sum()=}") #print(f"{ref_out.sum()=}") @@ -239,6 +257,11 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): torch.mean(torch.abs(ref_out.to(torch.float32)))) assert rel_diff < 0.03 + rel_diff = (torch.mean( + torch.abs(m_out.to(torch.float32) - ref_out.to(torch.float32))) / + torch.mean(torch.abs(ref_out.to(torch.float32)))) + assert rel_diff < 0.03 + def per_block_cast_to_fp8( x: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index dcf50ac8a3cc..642d770bde48 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1517,6 +1517,16 @@ def fused_moe( block_shape=block_shape) +def _chunk_scales(scales: Optional[torch.Tensor], start: int, + end: int) -> Optional[torch.Tensor]: + if scales is not None: + if scales.numel() == 1: + return scales + else: + return scales[start:end] + return None + + class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( @@ -1553,8 +1563,15 @@ def workspace_shapes( num_experts: int, ) -> tuple[int, int, torch.dtype]: factor = num_experts if a.dim() == 3 else 1 - workspace1 = M * topk * max(N * 2, K) * factor - workspace2 = M * topk * N * factor + + CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE + if M <= CHUNK_SIZE: + workspace1 = M * topk * max(N * 2, K) * factor + else: + workspace1 = (M + CHUNK_SIZE) * topk * max(N * 2, K) * factor + + workspace2 = min(M, CHUNK_SIZE) * topk * N * factor + return (workspace1, workspace2, a.dtype) def apply( @@ -1616,7 +1633,6 @@ def apply( w2.shape, top_k_num, config_dtype, - num_tokens, block_shape=self.block_shape, ) @@ -1636,16 +1652,34 @@ def apply( # We can reuse the memory between these because by the time we need # cache3, we're done with cache1 - intermediate_cache1 = _resize_cache(workspace13, (M, top_k_num, N)) intermediate_cache2 = _resize_cache(workspace2, (M * top_k_num, N // 2)) - intermediate_cache3 = _resize_cache(workspace13, (M, top_k_num, K)) - for chunk in range((num_tokens // CHUNK_SIZE) + 1): + num_chunks = num_tokens // CHUNK_SIZE + + if num_chunks <= 1: + intermediate_cache1 = _resize_cache(workspace13, (M, top_k_num, N)) + intermediate_cache3 = _resize_cache(workspace13, + (num_tokens, top_k_num, K)) + else: + ws_numel = workspace13.numel() + result_numel = num_tokens * top_k_num * K + ws1 = workspace13[-(ws_numel - result_numel):] + ws3 = workspace13[:result_numel] + intermediate_cache1 = _resize_cache(ws1, + (CHUNK_SIZE, top_k_num, N)) + intermediate_cache3 = _resize_cache(ws3, + (num_tokens, top_k_num, K)) + + for chunk in range(num_chunks + 1): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, num_tokens)) curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] + curr_a1q_scale = _chunk_scales(a1q_scale, begin_chunk_idx, + end_chunk_idx) + curr_a2_scale = _chunk_scales(a2_scale, begin_chunk_idx, + end_chunk_idx) curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] tokens_in_chunk, _ = curr_hidden_states.shape @@ -1660,7 +1694,7 @@ def apply( intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] intermediate_cache2 = intermediate_cache2[:tokens_in_chunk * topk_ids.shape[1]] - intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] + #intermdiate_cache_3 = intermediate_cache3[:tokens_in_chunk] config = get_config_func(tokens_in_chunk) sorted_token_ids, expert_ids, num_tokens_post_padded = ( @@ -1670,7 +1704,7 @@ def apply( invoke_fused_moe_kernel(curr_hidden_states, w1, intermediate_cache1, - a1q_scale, + curr_a1q_scale, w1_scale, w1_zp, None, @@ -1694,29 +1728,30 @@ def apply( a2q_scale: Optional[torch.Tensor] = None qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( - intermediate_cache2, a2_scale, self.qtype, + intermediate_cache2, curr_a2_scale, self.qtype, self.per_channel_quant, self.block_shape) - invoke_fused_moe_kernel(qintermediate_cache2, - w2, - intermediate_cache3, - a2q_scale, - w2_scale, - w2_zp, - None, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - False, - 1, - config, - compute_type=compute_type, - use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a8=self.use_int8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, - per_channel_quant=self.per_channel_quant, - block_shape=self.block_shape) + invoke_fused_moe_kernel( + qintermediate_cache2, + w2, + intermediate_cache3[begin_chunk_idx:end_chunk_idx], + a2q_scale, + w2_scale, + w2_zp, + None, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + 1, + config, + compute_type=compute_type, + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a8=self.use_int8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + per_channel_quant=self.per_channel_quant, + block_shape=self.block_shape) return intermediate_cache3 From 24a795ef6172316ad5becf6e5233f5f71fe3ecf3 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 5 Jun 2025 19:15:39 +0000 Subject: [PATCH 03/17] more general chunking Signed-off-by: Bill Nell --- .../layers/fused_moe/batched_deep_gemm_moe.py | 12 +- .../batched_triton_or_deep_gemm_moe.py | 11 +- .../layers/fused_moe/cutlass_moe.py | 8 +- .../layers/fused_moe/deep_gemm_moe.py | 15 +- .../layers/fused_moe/fused_batched_moe.py | 24 ++- .../layers/fused_moe/fused_moe.py | 194 ++++++----------- .../layers/fused_moe/modular_kernel.py | 196 ++++++++++-------- .../layers/fused_moe/triton_deep_gemm_moe.py | 12 +- 8 files changed, 239 insertions(+), 233 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 76d71ca08856..8db924c4ea74 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -36,6 +36,9 @@ def __init__(self, max_num_tokens: int, world_size: int, dp_size: int, assert (len(self.block_shape) == 2 and all( [v == self.DEEPGEMM_BLOCK_SHAPE for v in self.block_shape])) + def supports_chunking(self) -> bool: + return False + def workspace_shapes( self, a: torch.Tensor, @@ -45,14 +48,15 @@ def workspace_shapes( K: int, topk: int, num_experts: int, - ) -> tuple[int, int, torch.dtype]: + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: assert a.dim() == 2 num_dp = self.world_size // self.dp_size max_num_tokens = a.size( 0) if self.max_num_tokens is None else self.max_num_tokens - workspace13 = num_experts * max_num_tokens * num_dp * max(K, N) - workspace2 = num_experts * max_num_tokens * num_dp * (N // 2) - return (workspace13, workspace2, a.dtype) + workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N)) + workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2)) + output = (num_experts, max_num_tokens * num_dp, K) + return (workspace13, workspace2, K, a.dtype) def apply( self, diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index d62d519af8d7..ceb4a02ff98f 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -64,6 +64,15 @@ def __init__(self, block_shape=self.block_shape, # type: ignore[arg-type] ) if (self.allow_deep_gemm and is_fp8_128_block_quantized) else None + assert (self.batched_deep_gemm_experts is not None or + self.batched_triton_experts is not None) + + def supports_chunking(self) -> bool: + bdge = self.batched_deep_gemm_experts + bte = self.batched_triton_experts + return ((bdge is None or bdge.supports_chunking()) and + (bte is None or bte.supports_chunking())) + def workspace_shapes( self, a: torch.Tensor, @@ -73,7 +82,7 @@ def workspace_shapes( K: int, topk: int, num_experts: int, - ) -> tuple[int, int, torch.dtype]: + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm # even if we fall back to triton later, e.g. if expert maps are set. diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 6e7b1a4f2b6c..bb52046da79f 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -213,6 +213,9 @@ def __init__( self.per_act_token = per_act_token self.per_out_ch = per_out_ch + def supports_chunking(self) -> bool: + return True + def workspace_shapes( self, a: torch.Tensor, @@ -222,11 +225,12 @@ def workspace_shapes( K: int, topk: int, num_experts: int, - ) -> tuple[int, int, torch.dtype]: + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: padded_M = aq.shape[1] workspace1 = self.max_experts_per_worker * padded_M * max(N, K) workspace2 = self.max_experts_per_worker * padded_M * (N // 2) - return (workspace1, workspace2, self.out_dtype) + output = (padded_M, topk, K) + return (workspace1, workspace2, output, self.out_dtype) def apply( self, diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 436c632be9c4..084b4fb01e20 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -70,6 +70,10 @@ def __init__(self): super().__init__() self.block_shape = deep_gemm_block_shape() + def supports_chunking(self) -> bool: + # TODO: for now + return False + def workspace_shapes( self, a: torch.Tensor, @@ -79,15 +83,14 @@ def workspace_shapes( K: int, topk: int, num_experts: int, - ) -> tuple[int, int, torch.dtype]: - + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: block_m = self.block_shape[0] M_sum = (M * topk) + num_experts * (block_m - 1) M_sum = round_up(M_sum, block_m) - workspace1 = M_sum * max(N * 2, K) - workspace2 = M_sum * max(N, K) - - return (workspace1, workspace2, a.dtype) + workspace1 = (M_sum, max(N * 2, K)) + workspace2 = (M_sum, max(N, K)) + output = (M_sum, K) + return (workspace1, workspace2, output, a.dtype) def apply( self, diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 68a3485ff1f6..546d441f76a0 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -518,6 +518,9 @@ def __init__( self.world_size = world_size self.dp_size = dp_size + def supports_chunking(self) -> bool: + return False + def workspace_shapes( self, a: torch.Tensor, @@ -527,15 +530,14 @@ def workspace_shapes( K: int, topk: int, num_experts: int, - ) -> tuple[int, int, torch.dtype]: + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: assert a.dim() == 2 num_dp = self.world_size // self.dp_size max_num_tokens = a.size( 0) if self.max_num_tokens is None else self.max_num_tokens - #print(f"WORKSPACE {max_num_tokens} {num_dp}") - workspace13 = num_experts * max_num_tokens * num_dp * K - workspace2 = max_num_tokens * num_dp * N - return (workspace13, workspace2, a.dtype) + workspace13 = (num_experts, max_num_tokens * num_dp, K) + workspace2 = (max_num_tokens * num_dp, N) + return (workspace13, workspace2, workspace13, a.dtype) def apply( self, @@ -630,6 +632,9 @@ def __init__( assert not use_int4_w4a16, "NYI" assert self.block_shape is None, "NYI" + def supports_chunking(self) -> bool: + return False + def workspace_shapes( self, a: torch.Tensor, @@ -639,14 +644,15 @@ def workspace_shapes( K: int, topk: int, num_experts: int, - ) -> tuple[int, int, torch.dtype]: + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: assert a.dim() == 2 num_dp = self.world_size // self.dp_size max_num_tokens = a.size( 0) if self.max_num_tokens is None else self.max_num_tokens - workspace13 = num_experts * max_num_tokens * num_dp * max(K, N) - workspace2 = num_experts * max_num_tokens * num_dp * (N // 2) - return (workspace13, workspace2, a.dtype) + workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N)) + workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2)) + output = (num_experts, max_num_tokens * num_dp, K) + return (workspace13, workspace2, output, a.dtype) def apply( self, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 642d770bde48..70d47e0c8eb1 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1517,16 +1517,6 @@ def fused_moe( block_shape=block_shape) -def _chunk_scales(scales: Optional[torch.Tensor], start: int, - end: int) -> Optional[torch.Tensor]: - if scales is not None: - if scales.numel() == 1: - return scales - else: - return scales[start:end] - return None - - class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( @@ -1552,6 +1542,9 @@ def __init__( use_int4_w4a16=use_int4_w4a16) self.per_channel_quant = per_channel_quant + def supports_chunking(self) -> bool: + return True + def workspace_shapes( self, a: torch.Tensor, @@ -1561,18 +1554,11 @@ def workspace_shapes( K: int, topk: int, num_experts: int, - ) -> tuple[int, int, torch.dtype]: - factor = num_experts if a.dim() == 3 else 1 - - CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE - if M <= CHUNK_SIZE: - workspace1 = M * topk * max(N * 2, K) * factor - else: - workspace1 = (M + CHUNK_SIZE) * topk * max(N * 2, K) * factor - - workspace2 = min(M, CHUNK_SIZE) * topk * N * factor - - return (workspace1, workspace2, a.dtype) + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + workspace1 = (M, topk, max(N * 2, K)) + workspace2 = (M, topk, N) + output = (M, topk, K) + return (workspace1, workspace2, output, a.dtype) def apply( self, @@ -1617,27 +1603,20 @@ def apply( if global_num_experts == -1: global_num_experts = E - # We execute the fused_moe kernel in chunks to circumvent this issue: - # https://github.com/vllm-project/vllm/issues/5938 - CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE - M = min(num_tokens, CHUNK_SIZE) - config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8, use_int8_w8a16=self.use_int8_w8a16, use_int4_w4a16=self.use_int4_w4a16, dtype=hidden_states.dtype) - get_config_func = functools.partial( - try_get_optimal_moe_config, + config = try_get_optimal_moe_config( w1.shape, w2.shape, top_k_num, config_dtype, + num_tokens, block_shape=self.block_shape, ) - config = get_config_func(M) - if hidden_states.dtype == torch.bfloat16: compute_type = tl.bfloat16 elif hidden_states.dtype == torch.float16: @@ -1652,106 +1631,67 @@ def apply( # We can reuse the memory between these because by the time we need # cache3, we're done with cache1 + intermediate_cache1 = _resize_cache(workspace13, + (num_tokens, top_k_num, N)) intermediate_cache2 = _resize_cache(workspace2, - (M * top_k_num, N // 2)) + (num_tokens * top_k_num, N // 2)) + intermediate_cache3 = _resize_cache(workspace13, + (num_tokens, top_k_num, K)) - num_chunks = num_tokens // CHUNK_SIZE + sorted_token_ids, expert_ids, num_tokens_post_padded = ( + moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'], + global_num_experts, expert_map)) - if num_chunks <= 1: - intermediate_cache1 = _resize_cache(workspace13, (M, top_k_num, N)) - intermediate_cache3 = _resize_cache(workspace13, - (num_tokens, top_k_num, K)) - else: - ws_numel = workspace13.numel() - result_numel = num_tokens * top_k_num * K - ws1 = workspace13[-(ws_numel - result_numel):] - ws3 = workspace13[:result_numel] - intermediate_cache1 = _resize_cache(ws1, - (CHUNK_SIZE, top_k_num, N)) - intermediate_cache3 = _resize_cache(ws3, - (num_tokens, top_k_num, K)) - - for chunk in range(num_chunks + 1): - begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, - min((chunk + 1) * CHUNK_SIZE, - num_tokens)) - curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] - curr_a1q_scale = _chunk_scales(a1q_scale, begin_chunk_idx, - end_chunk_idx) - curr_a2_scale = _chunk_scales(a2_scale, begin_chunk_idx, - end_chunk_idx) - curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] - tokens_in_chunk, _ = curr_hidden_states.shape - - if tokens_in_chunk == 0: - break + invoke_fused_moe_kernel(hidden_states, + w1, + intermediate_cache1, + a1q_scale, + w1_scale, + w1_zp, + None, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + top_k_num, + config, + compute_type=compute_type, + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a8=self.use_int8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + per_channel_quant=self.per_channel_quant, + block_shape=self.block_shape) + + self.activation(activation, intermediate_cache2, + intermediate_cache1.view(-1, N)) - if tokens_in_chunk < CHUNK_SIZE and chunk > 0: - # Adjust the intermediate cache size and config for the last - # chunk. Note that in most cases we only have one chunk - # so the cache size and config are already set correctly and - # do not need to be adjusted. - intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] - intermediate_cache2 = intermediate_cache2[:tokens_in_chunk * - topk_ids.shape[1]] - #intermdiate_cache_3 = intermediate_cache3[:tokens_in_chunk] - config = get_config_func(tokens_in_chunk) - - sorted_token_ids, expert_ids, num_tokens_post_padded = ( - moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], - global_num_experts, expert_map)) - - invoke_fused_moe_kernel(curr_hidden_states, - w1, - intermediate_cache1, - curr_a1q_scale, - w1_scale, - w1_zp, - None, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - False, - top_k_num, - config, - compute_type=compute_type, - use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a8=self.use_int8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, - per_channel_quant=self.per_channel_quant, - block_shape=self.block_shape) - - self.activation(activation, intermediate_cache2, - intermediate_cache1.view(-1, N)) - - a2q_scale: Optional[torch.Tensor] = None - - qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( - intermediate_cache2, curr_a2_scale, self.qtype, - self.per_channel_quant, self.block_shape) - - invoke_fused_moe_kernel( - qintermediate_cache2, - w2, - intermediate_cache3[begin_chunk_idx:end_chunk_idx], - a2q_scale, - w2_scale, - w2_zp, - None, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - False, - 1, - config, - compute_type=compute_type, - use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a8=self.use_int8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, - per_channel_quant=self.per_channel_quant, - block_shape=self.block_shape) + a2q_scale: Optional[torch.Tensor] = None + + qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( + intermediate_cache2, a2_scale, self.qtype, self.per_channel_quant, + self.block_shape) + + invoke_fused_moe_kernel(qintermediate_cache2, + w2, + intermediate_cache3, + a2q_scale, + w2_scale, + w2_zp, + None, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + 1, + config, + compute_type=compute_type, + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a8=self.use_int8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + per_channel_quant=self.per_channel_quant, + block_shape=self.block_shape) return intermediate_cache3 diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index e7aaf62fb340..3c6414432f16 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -2,9 +2,13 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod from typing import Optional +from math import prod import torch +import vllm.envs as envs +from vllm.model_executor.layers.fused_moe.utils import _resize_cache + # # This file defines a set of base classes used to make MoE kernels more modular. # The goal is to be able to utilize different communication mechanisms with @@ -115,9 +119,9 @@ def prepare( - quantized + dispatched a. - quantized + dispatched a1_scales. - Optional tensor as big as number of local experts that contains the - number of tokens assigned to each local expert. + number of tokens assigned to each local expert. - Optional dispatched expert topk IDs - - Optional dispatched expert topk weight + - Optional dispatched expert topk weight """ raise NotImplementedError @@ -159,7 +163,7 @@ def max_num_tokens_per_rank(self) -> Optional[int]: Some PrepareFinalize All2All implementations are batched. Meaning, they can processes only as set of tokens at a time. This function returns the batch size i.e the maximum number of tokens - the implementation can process at a time. + the implementation can process at a time. Return None if there are no such restrictions. """ raise NotImplementedError @@ -171,6 +175,15 @@ class FusedMoEPermuteExpertsUnpermute(ABC): above. """ + # TODO (bnell): make this return a CHUNK_SIZE or None instead? + @abstractmethod + def supports_chunking(self) -> bool: + """ + A flag indicating whether or not this class supports activation + chunking. + """ + raise NotImplementedError + @abstractmethod def workspace_shapes( self, @@ -181,19 +194,22 @@ def workspace_shapes( K: int, topk: int, num_experts: int, - ) -> tuple[int, int, torch.dtype]: + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: """ - Compute the number of elements for the temporary outputs of the two - gemms and activation in the fused expert function. Since the - gemms are independent, the workspace for the first gemm can be shared - with the workspace for the last gemm. + Compute the shapes for the temporary and final outputs of the two gemms + and activation in the fused expert function. Since the gemms are + independent, the workspace for the first gemm can be shared with the + workspace for the last gemm. Returns a tuple of: - - Number of workspace13 elements: must be large enough to hold the + - workspace13 shape tuple: must be large enough to hold the result of either expert gemm. - - Number of workspace2 elements: must be large enough to hold the + - workspace2 shape tuple: must be large enough to hold the result of the activation function. + - output shape tuple: must be exact size of the final gemm output. - Workspace type: The dtype to use for the workspace tensors. + - Note: in order for activation chunking to work, the first dimension + of each tuple must be the number of tokens. """ raise NotImplementedError @@ -266,6 +282,16 @@ def apply( raise NotImplementedError +def _chunk_scales(scales: Optional[torch.Tensor], start: int, + end: int) -> Optional[torch.Tensor]: + if scales is not None: + if scales.numel() == 1: + return scales + else: + return scales[start:end] + return None + + class FusedMoEModularKernel(torch.nn.Module): """ This class combines a FusedMoEPrepareAndFinalize instance and @@ -288,61 +314,6 @@ def __init__( self.prepare_finalize = prepare_finalize self.fused_experts = fused_experts - def _do_fused_experts( - self, - a1: torch.Tensor, # input to forward fn - a1q: torch.Tensor, # output of prepare fn - w1: torch.Tensor, - w2: torch.Tensor, - topk_ids: torch.Tensor, - expert_num_tokens: torch.Tensor, - activation: str, - global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor]) -> torch.Tensor: - - _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) - - # Use a1 here to decipher the correct workspace datatype - workspace13_shape, workspace2_shape, workspace_dtype = ( - self.fused_experts.workspace_shapes(a1, a1q, M, N, K, top_k, - global_num_experts)) - - # We can reuse the memory between cache1 and cache3 because by the time - # we need cache3, we're done with cache1 - workspace13 = torch.zeros(workspace13_shape, - device=a1.device, - dtype=workspace_dtype) - workspace2 = torch.zeros(workspace2_shape, - device=a1.device, - dtype=workspace_dtype) - - fused_out = self.fused_experts.apply( - a1q, - w1, - w2, - topk_ids, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, - a1q_scale=a1q_scale, - a2_scale=a2_scale, - workspace13=workspace13, - workspace2=workspace2, - expert_num_tokens=expert_num_tokens, - ) - - return fused_out - def forward( self, hidden_states: torch.Tensor, @@ -408,12 +379,14 @@ def forward( _expert_topk_weights) = self.prepare_finalize.prepare( a1, a1_scale, a2_scale, topk_weights, topk_ids, global_num_experts, expert_map, apply_router_weight_on_input) + # Maybe prepare gathered topk_ids and topk_weights from other EP ranks. topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids topk_weights = (topk_weights if _expert_topk_weights is None else _expert_topk_weights) fused_out = None + if a1q.numel() == 0: # This happens when none of the tokens from the all2all reach this # EP rank. Also, note that this is only relevant for CUDAGraph @@ -423,22 +396,83 @@ def forward( # and can never run into the tensor.numel() == 0 case. fused_out = torch.empty_like(a1q).to(dtype=a1.dtype) else: - fused_out = self._do_fused_experts( - a1=a1, - a1q=a1q, - w1=w1, - w2=w2, - topk_ids=topk_ids, - expert_num_tokens=expert_num_tokens, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, - a1q_scale=a1q_scale, - a2_scale=a2_scale) + _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) + + if self.fused_experts.supports_chunking(): + CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE + num_chunks = (M // CHUNK_SIZE) + 1 + else: + num_chunks = 1 + + if num_chunks == 1: + workspace13_shape, workspace2_shape, fused_out_shape, workspace_dtype = ( + self.fused_experts.workspace_shapes( + a1, M, N, K, top_k, + global_num_experts) + ) + else: + # Use the full M to get the final output shape. + _, _, fused_out_shape, workspace_dtype = ( + self.fused_experts.workspace_shapes( + a1, M, N, K, top_k, + global_num_experts) + ) + # Use the CHUNK_SIZE to get the workspace shapes. + workspace13_shape, workspace2_shape, _, workspace_dtype = ( + self.fused_experts.workspace_shapes( + a1, CHUNK_SIZE, N, K, top_k, + global_num_experts) + ) + + # We can reuse the memory between cache1 and cache3 because by the time + # we need cache3, we're done with cache1 + workspace13 = torch.zeros(prod(workspace13_shape), + device=a1.device, + dtype=workspace_dtype) + workspace2 = torch.zeros(prod(workspace2_shape), + device=a1.device, + dtype=workspace_dtype) + + if num_chunks == 1: + fused_out = _resize_cache(workspace13, fused_out_shape) + else: + fused_out = torch.empty(fused_out_shape, device=a1q.device, + dtype=workspace_dtype) + + for chunk in range(num_chunks): + begin_chunk_idx = chunk * CHUNK_SIZE + end_chunk_idx = min((chunk + 1) * CHUNK_SIZE, M) + curr_a1q = a1q[begin_chunk_idx:end_chunk_idx] + curr_a1q_scale = _chunk_scales(a1q_scale, begin_chunk_idx, + end_chunk_idx) + curr_a2_scale = _chunk_scales(a2_scale, begin_chunk_idx, + end_chunk_idx) + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] + tokens_in_chunk = end_chunk_idx - begin_chunk_idx + + #print(f"CHUNK {chunk}/{num_chunks}: {a1q.shape}/{fused_out_shape} {begin_chunk_idx}:{end_chunk_idx}") + + if tokens_in_chunk == 0: + break + + fused_out[begin_chunk_idx:end_chunk_idx] = self.fused_experts.apply( + curr_a1q, + w1, + w2, + curr_topk_ids, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1q_scale=curr_a1q_scale, + a2_scale=curr_a2_scale, + workspace13=workspace13, + workspace2=workspace2, + expert_num_tokens=expert_num_tokens, + ) self.prepare_finalize.finalize(output, fused_out, topk_weights, topk_ids, apply_router_weight_on_input) diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 87de29444c01..6935827eedaf 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -34,6 +34,12 @@ def __init__(self, self.deep_gemm_expert = DeepGemmExperts( ) if self.allow_deep_gemm else None + def supports_chunking(self) -> bool: + dge = self.deep_gemm_expert + te = self.triton_expert + return ((dge is None or dge.supports_chunking()) and + (te is None or te.supports_chunking())) + def workspace_shapes( self, a: torch.Tensor, @@ -43,7 +49,7 @@ def workspace_shapes( K: int, topk: int, num_experts: int, - ) -> tuple[int, int, torch.dtype]: + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm # even if we fall back to triton later, e.g. if expert maps are set. @@ -52,8 +58,8 @@ def workspace_shapes( return self.deep_gemm_expert.workspace_shapes( a, aq, M, N, K, topk, num_experts) else: - return self.triton_expert.workspace_shapes(a, aq, M, N, K, topk, - num_experts) + return self.triton_expert.workspace_shapes( + a, aq, M, N, K, topk, num_experts) def apply( self, From f1540e4e4ee9f5d2101fcf727b2086fef4640a3a Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 5 Jun 2025 19:27:22 +0000 Subject: [PATCH 04/17] lint Signed-off-by: Bill Nell --- .../layers/fused_moe/batched_deep_gemm_moe.py | 4 +- .../batched_triton_or_deep_gemm_moe.py | 10 +-- .../layers/fused_moe/modular_kernel.py | 64 +++++++++---------- .../layers/fused_moe/triton_deep_gemm_moe.py | 6 +- 4 files changed, 39 insertions(+), 45 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 8db924c4ea74..c1b9bed46474 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -54,9 +54,9 @@ def workspace_shapes( max_num_tokens = a.size( 0) if self.max_num_tokens is None else self.max_num_tokens workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N)) - workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2)) + workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2)) output = (num_experts, max_num_tokens * num_dp, K) - return (workspace13, workspace2, K, a.dtype) + return (workspace13, workspace2, output, a.dtype) def apply( self, diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index ceb4a02ff98f..8ceba310fab7 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -64,14 +64,14 @@ def __init__(self, block_shape=self.block_shape, # type: ignore[arg-type] ) if (self.allow_deep_gemm and is_fp8_128_block_quantized) else None - assert (self.batched_deep_gemm_experts is not None or - self.batched_triton_experts is not None) + assert (self.batched_deep_gemm_experts is not None + or self.batched_triton_experts is not None) def supports_chunking(self) -> bool: bdge = self.batched_deep_gemm_experts - bte = self.batched_triton_experts - return ((bdge is None or bdge.supports_chunking()) and - (bte is None or bte.supports_chunking())) + bte = self.batched_triton_experts + return ((bdge is None or bdge.supports_chunking()) + and (bte is None or bte.supports_chunking())) def workspace_shapes( self, diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 3c6414432f16..394610826fab 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from typing import Optional from math import prod +from typing import Optional import torch @@ -405,27 +405,21 @@ def forward( num_chunks = 1 if num_chunks == 1: - workspace13_shape, workspace2_shape, fused_out_shape, workspace_dtype = ( - self.fused_experts.workspace_shapes( - a1, M, N, K, top_k, - global_num_experts) - ) + (workspace13_shape, workspace2_shape, fused_out_shape, + workspace_dtype) = self.fused_experts.workspace_shapes( + a1, M, N, K, top_k, global_num_experts) else: # Use the full M to get the final output shape. _, _, fused_out_shape, workspace_dtype = ( self.fused_experts.workspace_shapes( - a1, M, N, K, top_k, - global_num_experts) - ) + a1, M, N, K, top_k, global_num_experts)) # Use the CHUNK_SIZE to get the workspace shapes. workspace13_shape, workspace2_shape, _, workspace_dtype = ( self.fused_experts.workspace_shapes( - a1, CHUNK_SIZE, N, K, top_k, - global_num_experts) - ) + a1, CHUNK_SIZE, N, K, top_k, global_num_experts)) - # We can reuse the memory between cache1 and cache3 because by the time - # we need cache3, we're done with cache1 + # We can reuse the memory between cache1 and cache3 because by the + # time we need cache3, we're done with cache1. workspace13 = torch.zeros(prod(workspace13_shape), device=a1.device, dtype=workspace_dtype) @@ -436,7 +430,8 @@ def forward( if num_chunks == 1: fused_out = _resize_cache(workspace13, fused_out_shape) else: - fused_out = torch.empty(fused_out_shape, device=a1q.device, + fused_out = torch.empty(fused_out_shape, + device=a1q.device, dtype=workspace_dtype) for chunk in range(num_chunks): @@ -450,29 +445,28 @@ def forward( curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] tokens_in_chunk = end_chunk_idx - begin_chunk_idx - #print(f"CHUNK {chunk}/{num_chunks}: {a1q.shape}/{fused_out_shape} {begin_chunk_idx}:{end_chunk_idx}") - if tokens_in_chunk == 0: break - fused_out[begin_chunk_idx:end_chunk_idx] = self.fused_experts.apply( - curr_a1q, - w1, - w2, - curr_topk_ids, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, - a1q_scale=curr_a1q_scale, - a2_scale=curr_a2_scale, - workspace13=workspace13, - workspace2=workspace2, - expert_num_tokens=expert_num_tokens, - ) + fused_out[ + begin_chunk_idx:end_chunk_idx] = self.fused_experts.apply( + curr_a1q, + w1, + w2, + curr_topk_ids, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1q_scale=curr_a1q_scale, + a2_scale=curr_a2_scale, + workspace13=workspace13, + workspace2=workspace2, + expert_num_tokens=expert_num_tokens, + ) self.prepare_finalize.finalize(output, fused_out, topk_weights, topk_ids, apply_router_weight_on_input) diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 6935827eedaf..6e717f277a91 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -36,9 +36,9 @@ def __init__(self, def supports_chunking(self) -> bool: dge = self.deep_gemm_expert - te = self.triton_expert - return ((dge is None or dge.supports_chunking()) and - (te is None or te.supports_chunking())) + te = self.triton_expert + return ((dge is None or dge.supports_chunking()) + and (te is None or te.supports_chunking())) def workspace_shapes( self, From f7aa67cf39e404e44870f1d8f00131aa13b61fd1 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 5 Jun 2025 20:00:40 +0000 Subject: [PATCH 05/17] lint Signed-off-by: Bill Nell --- .../layers/fused_moe/modular_kernel.py | 82 +++++++++++-------- 1 file changed, 49 insertions(+), 33 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 394610826fab..8e3e4fb0f27f 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -402,6 +402,7 @@ def forward( CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE num_chunks = (M // CHUNK_SIZE) + 1 else: + CHUNK_SIZE = M num_chunks = 1 if num_chunks == 1: @@ -429,44 +430,59 @@ def forward( if num_chunks == 1: fused_out = _resize_cache(workspace13, fused_out_shape) + + fused_out = self.fused_experts.apply( + a1q, + w1, + w2, + topk_ids, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1q_scale=a1q_scale, + a2_scale=a2_scale, + workspace13=workspace13, + workspace2=workspace2, + expert_num_tokens=expert_num_tokens, + ) else: fused_out = torch.empty(fused_out_shape, device=a1q.device, dtype=workspace_dtype) - for chunk in range(num_chunks): - begin_chunk_idx = chunk * CHUNK_SIZE - end_chunk_idx = min((chunk + 1) * CHUNK_SIZE, M) - curr_a1q = a1q[begin_chunk_idx:end_chunk_idx] - curr_a1q_scale = _chunk_scales(a1q_scale, begin_chunk_idx, - end_chunk_idx) - curr_a2_scale = _chunk_scales(a2_scale, begin_chunk_idx, - end_chunk_idx) - curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] - tokens_in_chunk = end_chunk_idx - begin_chunk_idx - - if tokens_in_chunk == 0: - break - - fused_out[ - begin_chunk_idx:end_chunk_idx] = self.fused_experts.apply( - curr_a1q, - w1, - w2, - curr_topk_ids, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, - a1q_scale=curr_a1q_scale, - a2_scale=curr_a2_scale, - workspace13=workspace13, - workspace2=workspace2, - expert_num_tokens=expert_num_tokens, - ) + for chunk in range(num_chunks): + begin_chunk_idx = chunk * CHUNK_SIZE + end_chunk_idx = min((chunk + 1) * CHUNK_SIZE, M) + curr_a1q = a1q[begin_chunk_idx:end_chunk_idx] + curr_a1q_scale = _chunk_scales(a1q_scale, begin_chunk_idx, + end_chunk_idx) + curr_a2_scale = _chunk_scales(a2_scale, begin_chunk_idx, + end_chunk_idx) + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] + + fused_out[begin_chunk_idx: + end_chunk_idx] = self.fused_experts.apply( + curr_a1q, + w1, + w2, + curr_topk_ids, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1q_scale=curr_a1q_scale, + a2_scale=curr_a2_scale, + workspace13=workspace13, + workspace2=workspace2, + expert_num_tokens=expert_num_tokens, + ) self.prepare_finalize.finalize(output, fused_out, topk_weights, topk_ids, apply_router_weight_on_input) From ed37b5bb43e27c5020064d2c362d33522161d9c1 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 5 Jun 2025 21:41:32 +0000 Subject: [PATCH 06/17] support activation chunking for cutlass + deep gemm kernels Signed-off-by: Bill Nell --- .../layers/fused_moe/cutlass_moe.py | 16 ++++ .../layers/fused_moe/deep_gemm_moe.py | 5 +- .../layers/fused_moe/modular_kernel.py | 93 +++++++++---------- 3 files changed, 62 insertions(+), 52 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index bb52046da79f..38b3f613dc62 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -216,6 +216,22 @@ def __init__( def supports_chunking(self) -> bool: return True + def workspace_shapes( + self, + a: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + num_experts: int, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + # Note that K, N are transposed + N, K = K, N + workspace1 = (M, topk, max(2 * N, K)) + workspace2 = (M, topk, N) + output = (M * topk, K) + return (workspace1, workspace2, output, self.out_dtype) + def workspace_shapes( self, a: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 084b4fb01e20..e9a442040e4a 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -71,8 +71,7 @@ def __init__(self): self.block_shape = deep_gemm_block_shape() def supports_chunking(self) -> bool: - # TODO: for now - return False + return True def workspace_shapes( self, @@ -89,7 +88,7 @@ def workspace_shapes( M_sum = round_up(M_sum, block_m) workspace1 = (M_sum, max(N * 2, K)) workspace2 = (M_sum, max(N, K)) - output = (M_sum, K) + output = (M * topk, K) return (workspace1, workspace2, output, a.dtype) def apply( diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 8e3e4fb0f27f..306ce8e3608a 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -8,6 +8,7 @@ import vllm.envs as envs from vllm.model_executor.layers.fused_moe.utils import _resize_cache +from vllm.utils import cdiv # # This file defines a set of base classes used to make MoE kernels more modular. @@ -400,7 +401,7 @@ def forward( if self.fused_experts.supports_chunking(): CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE - num_chunks = (M // CHUNK_SIZE) + 1 + num_chunks = cdiv(M, CHUNK_SIZE) else: CHUNK_SIZE = M num_chunks = 1 @@ -428,61 +429,55 @@ def forward( device=a1.device, dtype=workspace_dtype) + # The leading output dimension may not be equal to M, so + # we compute output indices separately. + M_out = fused_out_shape[0] + assert M_out >= M + factor = M_out // M + assert factor > 0 + OUT_CHUNK_SIZE = CHUNK_SIZE * factor + + assert cdiv(M_out, OUT_CHUNK_SIZE) == num_chunks, ( + f"{cdiv(M_out, OUT_CHUNK_SIZE)} == {num_chunks}") + if num_chunks == 1: fused_out = _resize_cache(workspace13, fused_out_shape) - - fused_out = self.fused_experts.apply( - a1q, - w1, - w2, - topk_ids, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, - a1q_scale=a1q_scale, - a2_scale=a2_scale, - workspace13=workspace13, - workspace2=workspace2, - expert_num_tokens=expert_num_tokens, - ) else: fused_out = torch.empty(fused_out_shape, device=a1q.device, dtype=workspace_dtype) - for chunk in range(num_chunks): - begin_chunk_idx = chunk * CHUNK_SIZE - end_chunk_idx = min((chunk + 1) * CHUNK_SIZE, M) - curr_a1q = a1q[begin_chunk_idx:end_chunk_idx] - curr_a1q_scale = _chunk_scales(a1q_scale, begin_chunk_idx, - end_chunk_idx) - curr_a2_scale = _chunk_scales(a2_scale, begin_chunk_idx, - end_chunk_idx) - curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] - - fused_out[begin_chunk_idx: - end_chunk_idx] = self.fused_experts.apply( - curr_a1q, - w1, - w2, - curr_topk_ids, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, - a1q_scale=curr_a1q_scale, - a2_scale=curr_a2_scale, - workspace13=workspace13, - workspace2=workspace2, - expert_num_tokens=expert_num_tokens, - ) + for chunk in range(num_chunks): + begin_chunk_idx = chunk * CHUNK_SIZE + end_chunk_idx = min((chunk + 1) * CHUNK_SIZE, M) + begin_out_idx = chunk * OUT_CHUNK_SIZE + end_out_idx = min((chunk + 1) * OUT_CHUNK_SIZE, M_out) + curr_a1q = a1q[begin_chunk_idx:end_chunk_idx] + curr_a1q_scale = _chunk_scales(a1q_scale, begin_chunk_idx, + end_chunk_idx) + curr_a2_scale = _chunk_scales(a2_scale, begin_chunk_idx, + end_chunk_idx) + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] + + fused_out[begin_out_idx:end_out_idx] = ( + self.fused_experts.apply( + curr_a1q, + w1, + w2, + curr_topk_ids, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1q_scale=curr_a1q_scale, + a2_scale=curr_a2_scale, + workspace13=workspace13, + workspace2=workspace2, + expert_num_tokens=expert_num_tokens, + )) self.prepare_finalize.finalize(output, fused_out, topk_weights, topk_ids, apply_router_weight_on_input) From 398e8d40396ef15231f76d005413e2d02da041ac Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 5 Jun 2025 21:49:11 +0000 Subject: [PATCH 07/17] add chunking sized tests for cutlass + deep gemm Signed-off-by: Bill Nell --- tests/kernels/moe/test_cutlass_moe.py | 1 + tests/kernels/quantization/test_block_fp8.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index 474745f94815..ce420901e317 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -29,6 +29,7 @@ (224, 1024, 1536), (224, 3072, 1024), (224, 3072, 1536), + (1024 * 128, 1024, 1024), ] vllm_config = VllmConfig(parallel_config=ParallelConfig( diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index eec59573792d..fa3c4e80967f 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -47,7 +47,7 @@ # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # and its hidden size is 7168. M_moe = [1, 2, 7, 83, 128, 2048, 1024 * 128] -M_moe_dg = [128, 192, 1335, 2048] +M_moe_dg = [128, 192, 1335, 2048, 1024 * 128] N_moe = [128, 256, 1024, 4608] # [13824] K_moe = [256, 512, 7168] # [13824] BLOCK_SIZE = [[128, 128]] From d7c5c7aa46b964d4009bef6afeb4d4d5a163141b Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 5 Jun 2025 21:58:47 +0000 Subject: [PATCH 08/17] revert deep gemm chunking test since it triggers cuda error in quantization Signed-off-by: Bill Nell --- tests/kernels/quantization/test_block_fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index fa3c4e80967f..eec59573792d 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -47,7 +47,7 @@ # Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 # and its hidden size is 7168. M_moe = [1, 2, 7, 83, 128, 2048, 1024 * 128] -M_moe_dg = [128, 192, 1335, 2048, 1024 * 128] +M_moe_dg = [128, 192, 1335, 2048] N_moe = [128, 256, 1024, 4608] # [13824] K_moe = [256, 512, 7168] # [13824] BLOCK_SIZE = [[128, 128]] From b73a409317537003981bf08c0114d5881964664e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sun, 8 Jun 2025 00:37:23 +0000 Subject: [PATCH 09/17] review comments Signed-off-by: Bill Nell --- .../layers/fused_moe/batched_deep_gemm_moe.py | 8 +-- .../batched_triton_or_deep_gemm_moe.py | 11 +-- .../layers/fused_moe/deep_gemm_moe.py | 7 +- .../layers/fused_moe/fused_batched_moe.py | 17 ++--- .../layers/fused_moe/fused_moe.py | 9 +-- .../layers/fused_moe/modular_kernel.py | 69 ++++++++++--------- .../layers/fused_moe/triton_deep_gemm_moe.py | 68 ++++++++---------- 7 files changed, 83 insertions(+), 106 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index c1b9bed46474..30b74165657e 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -60,6 +60,7 @@ def workspace_shapes( def apply( self, + output: torch.Tensor, hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -76,7 +77,7 @@ def apply( workspace13: torch.Tensor, workspace2: torch.Tensor, expert_num_tokens: Optional[torch.Tensor], - ) -> torch.Tensor: + ): import deep_gemm as dg assert hidden_states.ndim == 3 @@ -93,7 +94,6 @@ def apply( workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N)) workspace2 = _resize_cache(workspace2, (E, max_num_tokens, N // 2)) - workspace3 = _resize_cache(workspace13, (E, max_num_tokens, K)) # (from deepgemm docs) : A value hint (which is a value on CPU) # for the M expectation of each batch, correctly setting this value @@ -122,8 +122,6 @@ def apply( dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a2q, a2q_scale), (w2, w2_scale), - out=workspace3, + out=output, masked_m=expert_num_tokens, expected_m=expected_m) - - return workspace3 diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index 8ceba310fab7..e58cee888deb 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -96,6 +96,7 @@ def workspace_shapes( def apply( self, + output: torch.Tensor, hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -112,7 +113,7 @@ def apply( workspace13: torch.Tensor, workspace2: torch.Tensor, expert_num_tokens: Optional[torch.Tensor], - ) -> torch.Tensor: + ): use_batched_deep_gemm_experts = (self.allow_deep_gemm and self.batched_deep_gemm_experts is not None) @@ -120,7 +121,7 @@ def apply( if use_batched_deep_gemm_experts else self.batched_triton_experts) assert experts is not None - return experts.apply(hidden_states, w1, w2, topk_ids, activation, - global_num_experts, expert_map, w1_scale, - w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, - workspace13, workspace2, expert_num_tokens) + experts.apply(output, hidden_states, w1, w2, topk_ids, activation, + global_num_experts, expert_map, w1_scale, + w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, + workspace13, workspace2, expert_num_tokens) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index e9a442040e4a..f83d9593c138 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -93,6 +93,7 @@ def workspace_shapes( def apply( self, + output: torch.Tensor, hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -109,7 +110,7 @@ def apply( workspace13: torch.Tensor, workspace2: torch.Tensor, expert_num_tokens: Optional[torch.Tensor], - ) -> torch.Tensor: + ): import deep_gemm as dg a1q = hidden_states @@ -161,9 +162,7 @@ def apply( dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids) - torch.index_select(mm2_out, 0, inv_perm, out=out) - - return out + torch.index_select(mm2_out, 0, inv_perm, out=output) def deep_gemm_moe_fp8( diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 546d441f76a0..4488c85fc9fb 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -541,6 +541,7 @@ def workspace_shapes( def apply( self, + output: torch.Tensor, hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -557,7 +558,7 @@ def apply( workspace13: torch.Tensor, workspace2: torch.Tensor, expert_num_tokens: Optional[torch.Tensor], - ) -> torch.Tensor: + ): assert hidden_states.dim() == 3 assert expert_num_tokens is not None hidden_dim = hidden_states.size(-1) @@ -569,8 +570,6 @@ def apply( num_dp = self.world_size // self.dp_size num_experts = global_num_experts - out = _resize_cache(workspace13, - (num_experts, max_num_tokens * num_dp, hidden_dim)) num_local_experts = w1.size(0) assert num_local_experts == w1.size(0), ( f"{num_local_experts} == {w1.size(0)}") @@ -593,9 +592,7 @@ def apply( tmp = _resize_cache(workspace2, (num, N)) input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1) self.activation(activation, tmp, input) - out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1) - - return out + output[expert, :num, :] = tmp @ w2[expert].transpose(0, 1) class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): @@ -656,6 +653,7 @@ def workspace_shapes( def apply( self, + output: torch.Tensor, hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -672,7 +670,7 @@ def apply( workspace13: torch.Tensor, workspace2: torch.Tensor, expert_num_tokens: Optional[torch.Tensor], - ) -> torch.Tensor: + ): # Check constraints. if self.use_int4_w4a16: assert hidden_states.size(-1) // 2 == w1.size(2), ( @@ -729,8 +727,6 @@ def apply( (E, max_num_tokens, N)) intermediate_cache2 = _resize_cache(workspace2, (E, max_num_tokens, N // 2)) - intermediate_cache3 = _resize_cache(workspace13, - (E, max_num_tokens, K)) # MM1 invoke_moe_batched_triton_kernel(A=hidden_states, @@ -767,7 +763,7 @@ def apply( invoke_moe_batched_triton_kernel(A=qintermediate_cache2, B=w2, - C=intermediate_cache3, + C=output, expert_num_tokens=expert_num_tokens, compute_type=compute_type, A_scale=a2q_scale, @@ -778,4 +774,3 @@ def apply( use_int4_w4a16=self.use_int4_w4a16, config=config, block_shape=self.block_shape) - return intermediate_cache3 diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 70d47e0c8eb1..d9b1ba132671 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1562,6 +1562,7 @@ def workspace_shapes( def apply( self, + output: torch.Tensor, hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -1578,7 +1579,7 @@ def apply( workspace13: torch.Tensor, workspace2: torch.Tensor, expert_num_tokens: Optional[torch.Tensor], - ) -> torch.Tensor: + ): # Check constraints. if self.use_int4_w4a16: assert hidden_states.size(-1) // 2 == w1.size(2), ( @@ -1635,8 +1636,6 @@ def apply( (num_tokens, top_k_num, N)) intermediate_cache2 = _resize_cache(workspace2, (num_tokens * top_k_num, N // 2)) - intermediate_cache3 = _resize_cache(workspace13, - (num_tokens, top_k_num, K)) sorted_token_ids, expert_ids, num_tokens_post_padded = ( moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'], @@ -1674,7 +1673,7 @@ def apply( invoke_fused_moe_kernel(qintermediate_cache2, w2, - intermediate_cache3, + output, a2q_scale, w2_scale, w2_zp, @@ -1693,8 +1692,6 @@ def apply( per_channel_quant=self.per_channel_quant, block_shape=self.block_shape) - return intermediate_cache3 - def modular_triton_fused_moe( use_fp8_w8a8: bool, diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 306ce8e3608a..08c2409cb44f 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -227,6 +227,7 @@ def activation(self, activation: str, output: torch.Tensor, @abstractmethod def apply( self, + output: torch.Tensor, hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -243,12 +244,13 @@ def apply( workspace13: torch.Tensor, workspace2: torch.Tensor, expert_num_tokens: Optional[torch.Tensor], - ) -> torch.Tensor: + ): """ This function computes the intermediate result of a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2. Parameters: + - output: (torch.Tensor): The unweighted, unreduced output tensor. - hidden_states: (torch.Tensor): The (quantized) input tensor to the MoE layer. - w1 (torch.Tensor): The first set of expert weights. @@ -276,9 +278,6 @@ def apply( function. - expert_num_tokens: An optional tensor containing the number of tokens assigned to each expert when using batched experts format input. - - Returns: - - torch.Tensor: The unweighted, unreduced output tensor """ raise NotImplementedError @@ -412,7 +411,7 @@ def forward( a1, M, N, K, top_k, global_num_experts) else: # Use the full M to get the final output shape. - _, _, fused_out_shape, workspace_dtype = ( + _, _, fused_out_shape, _ = ( self.fused_experts.workspace_shapes( a1, M, N, K, top_k, global_num_experts)) # Use the CHUNK_SIZE to get the workspace shapes. @@ -429,24 +428,26 @@ def forward( device=a1.device, dtype=workspace_dtype) - # The leading output dimension may not be equal to M, so - # we compute output indices separately. M_out = fused_out_shape[0] - assert M_out >= M - factor = M_out // M - assert factor > 0 - OUT_CHUNK_SIZE = CHUNK_SIZE * factor - - assert cdiv(M_out, OUT_CHUNK_SIZE) == num_chunks, ( - f"{cdiv(M_out, OUT_CHUNK_SIZE)} == {num_chunks}") if num_chunks == 1: + OUT_CHUNK_SIZE = M_out fused_out = _resize_cache(workspace13, fused_out_shape) else: + # The leading output dimension may not be equal to M, so + # we compute output indices separately. + assert M_out >= M + factor = M_out // M + assert factor > 0 + OUT_CHUNK_SIZE = CHUNK_SIZE * factor + fused_out = torch.empty(fused_out_shape, device=a1q.device, dtype=workspace_dtype) + assert cdiv(M_out, OUT_CHUNK_SIZE) == num_chunks, ( + f"{cdiv(M_out, OUT_CHUNK_SIZE)} == {num_chunks}") + for chunk in range(num_chunks): begin_chunk_idx = chunk * CHUNK_SIZE end_chunk_idx = min((chunk + 1) * CHUNK_SIZE, M) @@ -456,28 +457,28 @@ def forward( curr_a1q_scale = _chunk_scales(a1q_scale, begin_chunk_idx, end_chunk_idx) curr_a2_scale = _chunk_scales(a2_scale, begin_chunk_idx, - end_chunk_idx) + end_chunk_idx) curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] - fused_out[begin_out_idx:end_out_idx] = ( - self.fused_experts.apply( - curr_a1q, - w1, - w2, - curr_topk_ids, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, - a1q_scale=curr_a1q_scale, - a2_scale=curr_a2_scale, - workspace13=workspace13, - workspace2=workspace2, - expert_num_tokens=expert_num_tokens, - )) + self.fused_experts.apply( + fused_out[begin_out_idx:end_out_idx], + curr_a1q, + w1, + w2, + curr_topk_ids, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1q_scale=curr_a1q_scale, + a2_scale=curr_a2_scale, + workspace13=workspace13, + workspace2=workspace2, + expert_num_tokens=expert_num_tokens, + ) self.prepare_finalize.finalize(output, fused_out, topk_weights, topk_ids, apply_router_weight_on_input) diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 6e717f277a91..ae454dc30772 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -63,6 +63,7 @@ def workspace_shapes( def apply( self, + output: torch.Tensor, hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -79,45 +80,30 @@ def apply( workspace13: torch.Tensor, workspace2: torch.Tensor, expert_num_tokens: Optional[torch.Tensor], - ) -> torch.Tensor: + ): N = w1.size(1) - if (self.allow_deep_gemm and self.use_fp8_w8a8 and N > 512 - and _valid_deep_gemm(hidden_states, w1, w2)): - assert self.deep_gemm_expert is not None - return self.deep_gemm_expert.apply( - hidden_states, - w1, - w2, - topk_ids, - activation, - global_num_experts, - expert_map, - w1_scale, - w2_scale, - w1_zp, - w2_zp, - a1q_scale, - a2_scale, - workspace13, - workspace2, - expert_num_tokens, - ) - else: - return self.triton_expert.apply( - hidden_states, - w1, - w2, - topk_ids, - activation, - global_num_experts, - expert_map, - w1_scale, - w2_scale, - w1_zp, - w2_zp, - a1q_scale, - a2_scale, - workspace13, - workspace2, - expert_num_tokens, - ) + + use_deep_gemm =(self.allow_deep_gemm and self.use_fp8_w8a8 and N > 512 + and _valid_deep_gemm(hidden_states, w1, w2)) + + experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert + + experts.apply( + output, + hidden_states, + w1, + w2, + topk_ids, + activation, + global_num_experts, + expert_map, + w1_scale, + w2_scale, + w1_zp, + w2_zp, + a1q_scale, + a2_scale, + workspace13, + workspace2, + expert_num_tokens, + ) From 906534bb83b7d56bb9c7e451f31ce78784701fa1 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sun, 8 Jun 2025 02:35:52 +0000 Subject: [PATCH 10/17] fix merge Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_cutlass_moe.py | 5 +- .../layers/fused_moe/cutlass_moe.py | 82 +++++++++---------- .../layers/fused_moe/modular_kernel.py | 8 +- .../compressed_tensors_moe.py | 7 +- 4 files changed, 52 insertions(+), 50 deletions(-) diff --git a/tests/kernels/moe/test_pplx_cutlass_moe.py b/tests/kernels/moe/test_pplx_cutlass_moe.py index ef3e6adcfa36..38973e5dc5d8 100644 --- a/tests/kernels/moe/test_pplx_cutlass_moe.py +++ b/tests/kernels/moe/test_pplx_cutlass_moe.py @@ -113,7 +113,10 @@ def pplx_cutlass_moe( ) experts = CutlassExpertsFp8((num_experts + world_size - 1) // world_size, - out_dtype, per_act_token, per_out_ch) + out_dtype, + per_act_token, + per_out_ch, + use_batched_format=True) fused_cutlass_experts = FusedMoEModularKernel( prepare_finalize, diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 38b3f613dc62..2ed77d4f6b06 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -14,6 +14,7 @@ def run_cutlass_moe_fp8( + output: torch.Tensor, hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -31,6 +32,7 @@ def run_cutlass_moe_fp8( out_dtype: torch.dtype, per_act_token: bool, per_out_ch: bool, + use_batched_format: bool, ) -> torch.Tensor: a1q = hidden_states @@ -61,23 +63,20 @@ def run_cutlass_moe_fp8( if expert_map is not None: assert expert_num_tokens is None - # We have two modes: PPLX and non-PPLX. We differentiate them by checking - # if expert_num_tokens is None (expert_num_tokens is a tensor which PPLX - # uses to track the number of tokens per expert). - # In the non-PPLX mode, the input tokens are not padded: thus, the shape + # We have two modes: batched experts and non-batched experts. + # In the non-batched mode, the input tokens are not padded: thus, the shape # of the input is [total_num_tokens, hidden_size]. The input and output # require shuffling by a_map and c_map such that the tokens assigned to # each expert are contiguous. - # In the PPLX mode, the input tokens are padded per expert to ensure that - # the PPLX dispatch and combine functions work correctly: thus, the shape + # In the batched mode, the input tokens are padded per expert to ensure that + # the batched dispatch and combine functions work correctly: thus, the shape # of the input is [num_experts, max_num_tokens_per_expert, hidden_size]. - # The PPLX input and output require no shuffling by a_map and c_map since + # The batched input and output require no shuffling by a_map and c_map since # their tokens are already contiguous for each expert as a result of # the dispatch function. - is_pplx = expert_num_tokens is not None - M = a1q.shape[0] # no pplx - padded_M = a1q.shape[1] # pplx + M = a1q.shape[0] # non batched expert M + padded_M = a1q.shape[1] # batched expert M _, K, N = w2.shape device = a1q.device @@ -95,7 +94,7 @@ def run_cutlass_moe_fp8( topk = local_topk_ids.shape[1] local_E = w1.shape[0] - if is_pplx: + if use_batched_format: expert_offsets = torch.empty((local_E), dtype=torch.int32, device=device) @@ -167,7 +166,7 @@ def run_cutlass_moe_fp8( device=device, dtype=torch.int64) - if is_pplx: + if use_batched_format: c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2)) c2 = _resize_cache(workspace2, (local_E * padded_M, N)) c3 = _resize_cache(workspace13, (local_E * padded_M, K)) @@ -192,12 +191,15 @@ def run_cutlass_moe_fp8( problem_sizes2, ab_strides2, ab_strides2, c_strides2, per_act_token, per_out_ch) - if is_pplx: - return c3.reshape(local_E, padded_M, K) + if use_batched_format: + output.copy_(c3.reshape(local_E, padded_M, K), non_blocking=True) else: - return c3[c_map].view(M, topk, K) + # We can't do this inplace because output may point to the same tensor + # as c3. + output.copy_(c3[c_map].view(M * topk, K), non_blocking=True) +# TODO (bnell): split class batched vs. non-batched? class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): def __init__( @@ -206,31 +208,17 @@ def __init__( out_dtype: torch.dtype, per_act_token: bool, per_out_ch: bool, + use_batched_format: bool = False, ): super().__init__() self.max_experts_per_worker = max_experts_per_worker self.out_dtype = out_dtype self.per_act_token = per_act_token self.per_out_ch = per_out_ch + self.use_batched_format = use_batched_format def supports_chunking(self) -> bool: - return True - - def workspace_shapes( - self, - a: torch.Tensor, - M: int, - N: int, - K: int, - topk: int, - num_experts: int, - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: - # Note that K, N are transposed - N, K = K, N - workspace1 = (M, topk, max(2 * N, K)) - workspace2 = (M, topk, N) - output = (M * topk, K) - return (workspace1, workspace2, output, self.out_dtype) + return not self.use_batched_format def workspace_shapes( self, @@ -242,14 +230,20 @@ def workspace_shapes( topk: int, num_experts: int, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: - padded_M = aq.shape[1] - workspace1 = self.max_experts_per_worker * padded_M * max(N, K) - workspace2 = self.max_experts_per_worker * padded_M * (N // 2) - output = (padded_M, topk, K) + if self.use_batched_format: + padded_M = aq.shape[1] + workspace1 = self.max_experts_per_worker * padded_M * max(N, K) + workspace2 = self.max_experts_per_worker * padded_M * (N // 2) + output = (self.max_experts_per_worker, padded_M, K) + else: + workspace1 = (M * topk, max(2 * N, K)) + workspace2 = (M * topk, N) + output = (M * topk, K) return (workspace1, workspace2, output, self.out_dtype) def apply( self, + output: torch.Tensor, hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -266,16 +260,17 @@ def apply( workspace13: torch.Tensor, workspace2: torch.Tensor, expert_num_tokens: Optional[torch.Tensor], - ) -> torch.Tensor: + ): assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE" assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE" activation_callable = lambda i, o: self.activation(activation, i, o) - return run_cutlass_moe_fp8(hidden_states, w1, w2, topk_ids, - activation_callable, global_num_experts, - expert_map, w1_scale, w2_scale, a1q_scale, - a2_scale, workspace13, workspace2, - expert_num_tokens, self.out_dtype, - self.per_act_token, self.per_out_ch) + run_cutlass_moe_fp8(output, hidden_states, w1, w2, topk_ids, + activation_callable, global_num_experts, + expert_map, w1_scale, w2_scale, a1q_scale, + a2_scale, workspace13, workspace2, + expert_num_tokens, self.out_dtype, + self.per_act_token, self.per_out_ch, + self.use_batched_format) def cutlass_moe_fp8( @@ -345,6 +340,7 @@ def cutlass_moe_fp8( out_dtype=out_dtype, per_act_token=per_act_token, per_out_ch=per_out_ch, + use_batched_format=False, ), ) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 08c2409cb44f..9758ad12280d 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -408,16 +408,16 @@ def forward( if num_chunks == 1: (workspace13_shape, workspace2_shape, fused_out_shape, workspace_dtype) = self.fused_experts.workspace_shapes( - a1, M, N, K, top_k, global_num_experts) + a1, a1q, M, N, K, top_k, global_num_experts) else: # Use the full M to get the final output shape. _, _, fused_out_shape, _ = ( self.fused_experts.workspace_shapes( - a1, M, N, K, top_k, global_num_experts)) + a1, a1q, M, N, K, top_k, global_num_experts)) # Use the CHUNK_SIZE to get the workspace shapes. workspace13_shape, workspace2_shape, _, workspace_dtype = ( self.fused_experts.workspace_shapes( - a1, CHUNK_SIZE, N, K, top_k, global_num_experts)) + a1, a1q, CHUNK_SIZE, N, K, top_k, global_num_experts)) # We can reuse the memory between cache1 and cache3 because by the # time we need cache3, we're done with cache1. @@ -457,7 +457,7 @@ def forward( curr_a1q_scale = _chunk_scales(a1q_scale, begin_chunk_idx, end_chunk_idx) curr_a2_scale = _chunk_scales(a2_scale, begin_chunk_idx, - end_chunk_idx) + end_chunk_idx) curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] self.fused_experts.apply( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index bc9d399cf135..f14131c5f05b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -562,9 +562,12 @@ def select_gemm_impl(self, prepare_finalize, moe): (moe.num_experts + prepare_finalize.world_size - 1) // prepare_finalize.world_size) experts = CutlassExpertsFp8( - max_experts_per_worker, moe.in_dtype, + max_experts_per_worker, + moe.in_dtype, self.input_quant.strategy == QuantizationStrategy.TOKEN, - self.weight_quant.strategy == QuantizationStrategy.CHANNEL) + self.weight_quant.strategy == QuantizationStrategy.CHANNEL, + use_batched_format=True, + ) if has_pplx and isinstance( prepare_finalize, From 8a25be33973111087304071462768870dcd2502e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sun, 8 Jun 2025 22:34:13 +0000 Subject: [PATCH 11/17] lint Signed-off-by: Bill Nell --- .../layers/fused_moe/batched_triton_or_deep_gemm_moe.py | 6 +++--- vllm/model_executor/layers/fused_moe/fused_batched_moe.py | 2 -- .../layers/fused_moe/triton_deep_gemm_moe.py | 8 ++++---- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index e58cee888deb..d0ce59ba1e62 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -122,6 +122,6 @@ def apply( self.batched_triton_experts) assert experts is not None experts.apply(output, hidden_states, w1, w2, topk_ids, activation, - global_num_experts, expert_map, w1_scale, - w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, - workspace13, workspace2, expert_num_tokens) + global_num_experts, expert_map, w1_scale, w2_scale, + w1_zp, w2_zp, a1q_scale, a2_scale, workspace13, + workspace2, expert_num_tokens) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 4488c85fc9fb..c8184969282e 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -561,7 +561,6 @@ def apply( ): assert hidden_states.dim() == 3 assert expert_num_tokens is not None - hidden_dim = hidden_states.size(-1) if self.max_num_tokens is None: max_num_tokens = hidden_states.size(1) @@ -569,7 +568,6 @@ def apply( max_num_tokens = self.max_num_tokens num_dp = self.world_size // self.dp_size - num_experts = global_num_experts num_local_experts = w1.size(0) assert num_local_experts == w1.size(0), ( f"{num_local_experts} == {w1.size(0)}") diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index ae454dc30772..2b19bdab4e32 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -58,8 +58,8 @@ def workspace_shapes( return self.deep_gemm_expert.workspace_shapes( a, aq, M, N, K, topk, num_experts) else: - return self.triton_expert.workspace_shapes( - a, aq, M, N, K, topk, num_experts) + return self.triton_expert.workspace_shapes(a, aq, M, N, K, topk, + num_experts) def apply( self, @@ -83,8 +83,8 @@ def apply( ): N = w1.size(1) - use_deep_gemm =(self.allow_deep_gemm and self.use_fp8_w8a8 and N > 512 - and _valid_deep_gemm(hidden_states, w1, w2)) + use_deep_gemm = (self.allow_deep_gemm and self.use_fp8_w8a8 and N > 512 + and _valid_deep_gemm(hidden_states, w1, w2)) experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert From 32704dc8e44be57037611c9d009c304d9b5c8973 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 9 Jun 2025 19:40:41 +0000 Subject: [PATCH 12/17] lint Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/cutlass_moe.py | 3 ++- vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 2ed77d4f6b06..eff76bfb57ba 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -234,7 +234,8 @@ def workspace_shapes( padded_M = aq.shape[1] workspace1 = self.max_experts_per_worker * padded_M * max(N, K) workspace2 = self.max_experts_per_worker * padded_M * (N // 2) - output = (self.max_experts_per_worker, padded_M, K) + output: tuple[int, + ...] = (self.max_experts_per_worker, padded_M, K) else: workspace1 = (M * topk, max(2 * N, K)) workspace2 = (M * topk, N) diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 2b19bdab4e32..d4233c23f531 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -87,6 +87,7 @@ def apply( and _valid_deep_gemm(hidden_states, w1, w2)) experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert + assert experts is not None experts.apply( output, From 21a1eca26cbec06224375bcfc60f5154e65f204c Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 9 Jun 2025 20:36:55 +0000 Subject: [PATCH 13/17] fix merge lint Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/deep_gemm_moe.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index f83d9593c138..595e8c99514d 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -146,7 +146,6 @@ def apply( quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, N // 2)) mm2_out = _resize_cache(workspace2, (M_sum, K)) - out = _resize_cache(workspace13, (inv_perm.size(0), K)) dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (a1q, a1q_scale), (w1, w1_scale), mm1_out, expert_ids) From ee33a2364b9fcca130637dabc61857322c6511eb Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 10 Jun 2025 21:00:36 +0000 Subject: [PATCH 14/17] switch tests to use intranode by default. fix batched format. consolidate test utils Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_cutlass_moe.py | 38 ++++-- tests/kernels/moe/test_pplx_moe.py | 83 +++++++++--- tests/pplx_utils.py | 123 ------------------ .../layers/fused_moe/cutlass_moe.py | 4 +- .../layers/fused_moe/fused_batched_moe.py | 32 ++--- .../layers/fused_moe/modular_kernel.py | 87 ++++++++----- 6 files changed, 155 insertions(+), 212 deletions(-) delete mode 100644 tests/pplx_utils.py diff --git a/tests/kernels/moe/test_pplx_cutlass_moe.py b/tests/kernels/moe/test_pplx_cutlass_moe.py index 38973e5dc5d8..1429caf95df9 100644 --- a/tests/kernels/moe/test_pplx_cutlass_moe.py +++ b/tests/kernels/moe/test_pplx_cutlass_moe.py @@ -4,7 +4,9 @@ import pytest import torch -from tests.pplx_utils import ProcessGroupInfo, parallel_launch +from typing import Optional + +from .deepep_utils import ProcessGroupInfo, parallel_launch from vllm import _custom_ops as ops from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul @@ -64,6 +66,7 @@ def pplx_cutlass_moe( out_dtype, per_act_token: bool, per_out_ch: bool, + group_name: Optional[str], ): from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( PplxPrepareAndFinalize) @@ -84,7 +87,8 @@ def pplx_cutlass_moe( else: scale_elems = (hidden_dim + block_size - 1) // block_size - ata = AllToAll.internode( + + args = dict( max_num_tokens=max_num_tokens, num_experts=num_experts, experts_per_token=topk, @@ -96,6 +100,12 @@ def pplx_cutlass_moe( hidden_dim_scale_bytes=scale_elems * torch.float32.itemsize, ) + if group_name is None: + ata = AllToAll.internode(**args) + else: + args["group_name"] = group_name + ata = AllToAll.intranode(**args) + w1 = w1.to(device) w2 = w2.to(device) w1_scale = w1_scale.to(device) @@ -187,11 +197,17 @@ def _pplx_moe( w2_full: torch.Tensor, per_act_token: bool, per_out_ch: bool, + use_internode: bool, ): - uid = nvshmem_get_unique_id( - ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() - torch.distributed.broadcast(uid, src=0) - nvshmem_init(uid, pgi.rank, pgi.world_size) + if use_internode: + uid = nvshmem_get_unique_id( + ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + torch.distributed.broadcast(uid, src=0) + nvshmem_init(uid, pgi.rank, pgi.world_size) + else: + group_ranks = list(range(pgi.world_size)) + cpu_group = torch.distributed.new_group(group_ranks, backend="gloo") + group_name = cpu_group.group_name with set_current_vllm_config(vllm_config): torch_output = torch_moe2(a_full, w1_full, w2_full, topk_weights, @@ -199,7 +215,7 @@ def _pplx_moe( pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale, w2_scale, topk_weights, topk_ids, a1_scale, out_dtype, per_act_token, - per_out_ch) + per_out_ch, group_name) torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) @@ -210,7 +226,8 @@ def _pplx_moe( torch.testing.assert_close(pplx_output, torch_output, atol=0.05, rtol=0) - nvshmem_finalize() + if use_internode: + nvshmem_finalize() @pytest.mark.parametrize("m", [2, 224]) @@ -221,6 +238,7 @@ def _pplx_moe( @pytest.mark.parametrize("per_act_token", [True, False]) @pytest.mark.parametrize("per_out_ch", [True, False]) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]]) +@pytest.mark.parametrize("use_internode", [False]) @pytest.mark.skipif( (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( current_platform.get_device_capability()), @@ -235,6 +253,7 @@ def test_cutlass_moe_pplx( per_act_token: bool, per_out_ch: bool, world_dp_size: tuple[int, int], + use_internode: bool, ): current_platform.seed_everything(7) @@ -287,4 +306,5 @@ def test_cutlass_moe_pplx( parallel_launch(world_size, _pplx_moe, dp_size, a, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, a_scale1, - dtype, a, w1_d, w2_d, per_act_token, per_out_ch) + dtype, a, w1_d, w2_d, per_act_token, per_out_ch, + use_internode) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 0b48bbef6ceb..1f55232feb4e 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -18,7 +18,7 @@ except ImportError: has_pplx = False -from tests.pplx_utils import ProcessGroupInfo, parallel_launch +from .deepep_utils import ProcessGroupInfo, parallel_launch from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import override_config @@ -153,7 +153,8 @@ def batched_moe( num_experts = w1.shape[0] fused_experts = FusedMoEModularKernel( - BatchedPrepareAndFinalize(a.shape[0], world_size=1, dp_size=1, rank=0), + BatchedPrepareAndFinalize(max_num_tokens=a.shape[0], + world_size=1, dp_size=1, rank=0), BatchedExperts(max_num_tokens=a.shape[0], dp_size=1, world_size=1)) return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts) @@ -229,9 +230,15 @@ def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor: return t[(r * chunk):(r + 1) * chunk] -def pplx_prepare_finalize(pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor, - topk_weight: torch.Tensor, topk_ids: torch.Tensor, - num_experts: int) -> torch.Tensor: +def pplx_prepare_finalize( + pgi: ProcessGroupInfo, + dp_size: int, + a: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + group_name: Optional[str], +) -> torch.Tensor: from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( PplxPrepareAndFinalize) @@ -245,7 +252,7 @@ def pplx_prepare_finalize(pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor, world_size = pgi.world_size max_num_tokens = rank_chunk(num_tokens, 0, world_size) - ata = AllToAll.internode( + args = dict( max_num_tokens=max_num_tokens, num_experts=num_experts, experts_per_token=topk, @@ -259,6 +266,12 @@ def pplx_prepare_finalize(pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor, torch.float32.itemsize)), ) + if group_name is None: + ata = AllToAll.internode(**args) + else: + args["group_name"] = group_name + ata = AllToAll.intranode(**args) + topk_ids = topk_ids.to(dtype=torch.uint32) prepare_finalize = PplxPrepareAndFinalize( @@ -318,11 +331,19 @@ def _pplx_prepare_finalize( score: torch.Tensor, topk: torch.Tensor, num_experts: int, + use_internode: bool, ): - uid = nvshmem_get_unique_id( - ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() - torch.distributed.broadcast(uid, src=0) - nvshmem_init(uid, pgi.rank, pgi.world_size) + if use_internode: + uid = nvshmem_get_unique_id( + ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + torch.distributed.broadcast(uid, src=0) + nvshmem_init(uid, pgi.rank, pgi.world_size) + group_name = None + else: + group_ranks = list(range(pgi.world_size)) + cpu_group = torch.distributed.new_group(group_ranks, backend="gloo") + group_name = cpu_group.group_name + device = pgi.device topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) @@ -335,14 +356,15 @@ def _pplx_prepare_finalize( a.dtype) pplx_output = pplx_prepare_finalize(pgi, dp_size, a, topk_weight, topk_ids, - num_experts) + num_experts, group_name) torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device) torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) - nvshmem_finalize() + if use_internode: + nvshmem_finalize() # TODO (bnell): this test point does not work for odd M due to how the test is @@ -353,6 +375,7 @@ def _pplx_prepare_finalize( @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) +@pytest.mark.parametrize("use_internode", [False]) @requires_pplx def test_pplx_prepare_finalize( mnk: tuple[int, int, int], @@ -360,6 +383,7 @@ def test_pplx_prepare_finalize( topk: int, dtype: torch.dtype, world_dp_size: tuple[int, int], + use_internode: bool, ): current_platform.seed_everything(7) m, n, k = mnk @@ -369,10 +393,11 @@ def test_pplx_prepare_finalize( score = torch.randn((m, e), device=device, dtype=dtype) parallel_launch(world_size, _pplx_prepare_finalize, dp_size, a, score, - topk, e) + topk, e, use_internode) def pplx_moe( + group_name: Optional[str], rank: int, world_size: int, dp_size: int, @@ -394,7 +419,7 @@ def pplx_moe( topk = topk_ids.shape[1] max_num_tokens = rank_chunk(a.shape[0], 0, world_size) - ata = AllToAll.internode( + args = dict( max_num_tokens=max_num_tokens, num_experts=num_experts, experts_per_token=topk, @@ -408,6 +433,12 @@ def pplx_moe( torch.float32.itemsize)), ) + if group_name is None: + ata = AllToAll.internode(**args) + else: + args["group_name"] = group_name + ata = AllToAll.intranode(**args) + topk_ids = topk_ids.to(dtype=torch.uint32) prepare_finalize = PplxPrepareAndFinalize( @@ -522,11 +553,18 @@ def _pplx_moe( w2: torch.Tensor, score: torch.Tensor, topk: int, + use_internode: bool, ): - uid = nvshmem_get_unique_id( - ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() - torch.distributed.broadcast(uid, src=0) - nvshmem_init(uid, pgi.rank, pgi.world_size) + if use_internode: + uid = nvshmem_get_unique_id( + ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + torch.distributed.broadcast(uid, src=0) + nvshmem_init(uid, pgi.rank, pgi.world_size) + group_name = None + else: + group_ranks = list(range(pgi.world_size)) + cpu_group = torch.distributed.new_group(group_ranks, backend="gloo") + group_name = cpu_group.group_name m, k = a.shape e, _, n = w2.shape @@ -536,7 +574,7 @@ def _pplx_moe( with set_current_vllm_config(vllm_config), override_config(moe_config): topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) - pplx_output = pplx_moe(pgi.rank, pgi.world_size, dp_size, a, w1, w2, + pplx_output = pplx_moe(group_name, pgi.rank, pgi.world_size, dp_size, a, w1, w2, topk_weight, topk_ids) # TODO (bnell): fix + re-enable #batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, @@ -548,7 +586,8 @@ def _pplx_moe( torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) #torch.testing.assert_close(batched_output, torch_output, atol=2e-2, rtol=0) - nvshmem_finalize() + if use_internode: + nvshmem_finalize() @pytest.mark.parametrize("mnk", PPLX_MOE_COMBOS) @@ -556,6 +595,7 @@ def _pplx_moe( @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("world_dp_size", [[2, 1]]) +@pytest.mark.parametrize("use_internode", [False]) @requires_pplx def test_pplx_moe( mnk: tuple[int, int, int], @@ -563,6 +603,7 @@ def test_pplx_moe( topk: int, dtype: torch.dtype, world_dp_size: tuple[int, int], + use_internode: bool, ): current_platform.seed_everything(7) m, n, k = mnk @@ -572,4 +613,4 @@ def test_pplx_moe( w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 score = torch.randn((m, e), device="cuda", dtype=dtype) - parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk) + parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, use_internode) diff --git a/tests/pplx_utils.py b/tests/pplx_utils.py deleted file mode 100644 index 2d5d5be80c3f..000000000000 --- a/tests/pplx_utils.py +++ /dev/null @@ -1,123 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import dataclasses -import os -import traceback -from typing import Callable - -import torch -from torch.multiprocessing import ( - spawn) # pyright: ignore[reportPrivateImportUsage] -from typing_extensions import Concatenate, ParamSpec - -P = ParamSpec("P") - - -@dataclasses.dataclass -class ProcessGroupInfo: - world_size: int - world_local_size: int - rank: int - node_rank: int - local_rank: int - device: torch.device - - -def _worker_parallel_launch( - local_rank: int, - world_size: int, - world_local_size: int, - node_rank: int, - init_method: str, - worker: Callable[Concatenate[ProcessGroupInfo, P], None], - *args: P.args, - **kwargs: P.kwargs, -) -> None: - rank = node_rank * world_local_size + local_rank - torch.cuda.set_device(local_rank) - device = torch.device("cuda", local_rank) - torch.distributed.init_process_group( - backend="cpu:gloo,cuda:nccl", - init_method=init_method, - rank=rank, - world_size=world_size, - device_id=device, - ) - barrier = torch.tensor([rank], device=device) - torch.distributed.all_reduce(barrier) - - try: - worker( - ProcessGroupInfo( - world_size=world_size, - world_local_size=world_local_size, - rank=rank, - node_rank=node_rank, - local_rank=local_rank, - device=device, - ), - *args, - **kwargs, - ) - except Exception as ex: - print(ex) - traceback.print_exc() - raise - finally: - torch.distributed.destroy_process_group() - - -def parallel_launch( - world_size: int, - worker: Callable[Concatenate[ProcessGroupInfo, P], None], - *args: P.args, - **kwargs: P.kwargs, -) -> None: - assert not kwargs - spawn( - _worker_parallel_launch, - args=( - world_size, - world_size, - 0, - "tcp://localhost:29500", - worker, - ) + args, - nprocs=world_size, - join=True, - ) - - -def parallel_launch_from_env( - worker: Callable[Concatenate[ProcessGroupInfo, P], None], - *args: P.args, - **kwargs: P.kwargs, -) -> None: - """ - Launches a worker function in parallel across all processes in the current - environment. The environment must have the following variables set: - - WORLD_SIZE: The total number of processes. - - WORLD_LOCAL_SIZE: The number of processes on the current node. - - NODE_RANK: The rank of the current - - MASTER_ADDR: The address of the master process. - - MASTER_PORT: The port of the master process. - """ - assert not kwargs - world_size = int(os.environ["WORLD_SIZE"]) - world_local_size = int(os.environ["WORLD_LOCAL_SIZE"]) - node_rank = int(os.environ["NODE_RANK"]) - assert "MASTER_ADDR" in os.environ - assert "MASTER_PORT" in os.environ - spawn( - _worker_parallel_launch, - args=( - world_size, - world_local_size, - node_rank, - "env://", - worker, - ) + args, - nprocs=world_local_size, - join=True, - ) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index eff76bfb57ba..8bb0d3d11694 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -232,8 +232,8 @@ def workspace_shapes( ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: if self.use_batched_format: padded_M = aq.shape[1] - workspace1 = self.max_experts_per_worker * padded_M * max(N, K) - workspace2 = self.max_experts_per_worker * padded_M * (N // 2) + workspace1 = (self.max_experts_per_worker, padded_M, max(N, K)) + workspace2 = (self.max_experts_per_worker, padded_M, (N // 2)) output: tuple[int, ...] = (self.max_experts_per_worker, padded_M, K) else: diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index c8184969282e..e1f613a4bf82 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -335,9 +335,6 @@ def invoke_moe_batched_triton_kernel( BLOCK_M = config['BLOCK_SIZE_M'] BLOCK_N = config['BLOCK_SIZE_N'] BLOCK_K = config['BLOCK_SIZE_K'] - assert (torch.compiler.is_compiling() - or torch.cuda.is_current_stream_capturing() - or max_num_tokens % BLOCK_M == 0) grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) * triton.cdiv(B.size(1), BLOCK_N)) @@ -390,7 +387,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): that the PPLX dispatch/combine kernels use. """ - def __init__(self, max_num_tokens: Optional[int], world_size: int, + def __init__(self, max_num_tokens: int, world_size: int, dp_size: int, rank: int): super().__init__() self.world_size = world_size @@ -430,14 +427,9 @@ def prepare( num_tokens, hidden_dim = a1.size() topk = topk_ids.size(1) - if self.max_num_tokens is None: - tokens_per_expert = torch.bincount(topk_ids.view(-1), - minlength=num_experts) - self.max_num_tokens = int(tokens_per_expert.max().item()) - else: - tokens_per_expert = torch.zeros(num_experts, - dtype=torch.int, - device=a1.device) + tokens_per_expert = torch.zeros(num_experts, + dtype=torch.int, + device=a1.device) assert num_experts % self.world_size == 0 @@ -497,9 +489,9 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, + max_num_tokens: int, world_size: int, dp_size: int, - max_num_tokens: Optional[int] = None, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, @@ -533,10 +525,8 @@ def workspace_shapes( ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: assert a.dim() == 2 num_dp = self.world_size // self.dp_size - max_num_tokens = a.size( - 0) if self.max_num_tokens is None else self.max_num_tokens - workspace13 = (num_experts, max_num_tokens * num_dp, K) - workspace2 = (max_num_tokens * num_dp, N) + workspace13 = (num_experts, self.max_num_tokens * num_dp, K) + workspace2 = (self.max_num_tokens * num_dp, N) return (workspace13, workspace2, workspace13, a.dtype) def apply( @@ -562,11 +552,7 @@ def apply( assert hidden_states.dim() == 3 assert expert_num_tokens is not None - if self.max_num_tokens is None: - max_num_tokens = hidden_states.size(1) - else: - max_num_tokens = self.max_num_tokens - + max_num_tokens = self.max_num_tokens num_dp = self.world_size // self.dp_size num_local_experts = w1.size(0) assert num_local_experts == w1.size(0), ( @@ -584,7 +570,7 @@ def apply( # Indexing expert_num_tokens doesn't work w/cudagraphs or inductor if (torch.compiler.is_compiling() or torch.cuda.is_current_stream_capturing()): - num = max_num_tokens * num_dp + num = hidden_states.shape[1] else: num = int(expert_num_tokens[expert].item()) tmp = _resize_cache(workspace2, (num, N)) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 9758ad12280d..3e3d92357a6c 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -428,44 +428,15 @@ def forward( device=a1.device, dtype=workspace_dtype) - M_out = fused_out_shape[0] - if num_chunks == 1: - OUT_CHUNK_SIZE = M_out fused_out = _resize_cache(workspace13, fused_out_shape) - else: - # The leading output dimension may not be equal to M, so - # we compute output indices separately. - assert M_out >= M - factor = M_out // M - assert factor > 0 - OUT_CHUNK_SIZE = CHUNK_SIZE * factor - - fused_out = torch.empty(fused_out_shape, - device=a1q.device, - dtype=workspace_dtype) - - assert cdiv(M_out, OUT_CHUNK_SIZE) == num_chunks, ( - f"{cdiv(M_out, OUT_CHUNK_SIZE)} == {num_chunks}") - - for chunk in range(num_chunks): - begin_chunk_idx = chunk * CHUNK_SIZE - end_chunk_idx = min((chunk + 1) * CHUNK_SIZE, M) - begin_out_idx = chunk * OUT_CHUNK_SIZE - end_out_idx = min((chunk + 1) * OUT_CHUNK_SIZE, M_out) - curr_a1q = a1q[begin_chunk_idx:end_chunk_idx] - curr_a1q_scale = _chunk_scales(a1q_scale, begin_chunk_idx, - end_chunk_idx) - curr_a2_scale = _chunk_scales(a2_scale, begin_chunk_idx, - end_chunk_idx) - curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] self.fused_experts.apply( - fused_out[begin_out_idx:end_out_idx], - curr_a1q, + fused_out, + a1q, w1, w2, - curr_topk_ids, + topk_ids, activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, @@ -473,12 +444,60 @@ def forward( w2_scale=w2_scale, w1_zp=w1_zp, w2_zp=w2_zp, - a1q_scale=curr_a1q_scale, - a2_scale=curr_a2_scale, + a1q_scale=a1q_scale, + a2_scale=a2_scale, workspace13=workspace13, workspace2=workspace2, expert_num_tokens=expert_num_tokens, ) + else: + # The leading output dimension may not be equal to M, so + # we compute output indices separately. + M_out = fused_out_shape[0] + assert M_out >= M + factor = M_out // M + assert factor > 0 + OUT_CHUNK_SIZE = CHUNK_SIZE * factor + + fused_out = torch.empty(fused_out_shape, + device=a1q.device, + dtype=workspace_dtype) + + assert cdiv(M_out, OUT_CHUNK_SIZE) == num_chunks, ( + f"{cdiv(M_out, OUT_CHUNK_SIZE)} == {num_chunks}") + + + for chunk in range(num_chunks): + begin_chunk_idx = chunk * CHUNK_SIZE + end_chunk_idx = min((chunk + 1) * CHUNK_SIZE, M) + begin_out_idx = chunk * OUT_CHUNK_SIZE + end_out_idx = min((chunk + 1) * OUT_CHUNK_SIZE, M_out) + curr_a1q = a1q[begin_chunk_idx:end_chunk_idx] + curr_a1q_scale = _chunk_scales(a1q_scale, begin_chunk_idx, + end_chunk_idx) + curr_a2_scale = _chunk_scales(a2_scale, begin_chunk_idx, + end_chunk_idx) + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] + + self.fused_experts.apply( + fused_out[begin_out_idx:end_out_idx], + curr_a1q, + w1, + w2, + curr_topk_ids, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1q_scale=curr_a1q_scale, + a2_scale=curr_a2_scale, + workspace13=workspace13, + workspace2=workspace2, + expert_num_tokens=expert_num_tokens, + ) self.prepare_finalize.finalize(output, fused_out, topk_weights, topk_ids, apply_router_weight_on_input) From 726920c9eb4095e1ea3f032d6d861dc37423493f Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 10 Jun 2025 21:13:08 +0000 Subject: [PATCH 15/17] lint Signed-off-by: Bill Nell --- tests/kernels/moe/test_pplx_cutlass_moe.py | 8 ++++---- tests/kernels/moe/test_pplx_moe.py | 16 ++++++++++------ .../layers/fused_moe/fused_batched_moe.py | 4 ++-- .../layers/fused_moe/modular_kernel.py | 1 - 4 files changed, 16 insertions(+), 13 deletions(-) diff --git a/tests/kernels/moe/test_pplx_cutlass_moe.py b/tests/kernels/moe/test_pplx_cutlass_moe.py index 1429caf95df9..d90202dfcb3b 100644 --- a/tests/kernels/moe/test_pplx_cutlass_moe.py +++ b/tests/kernels/moe/test_pplx_cutlass_moe.py @@ -1,12 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + import pytest import torch -from typing import Optional - -from .deepep_utils import ProcessGroupInfo, parallel_launch from vllm import _custom_ops as ops from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul @@ -16,6 +15,8 @@ FusedMoEModularKernel) from vllm.platforms import current_platform +from .deepep_utils import ProcessGroupInfo, parallel_launch + try: from pplx_kernels import AllToAll from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, @@ -87,7 +88,6 @@ def pplx_cutlass_moe( else: scale_elems = (hidden_dim + block_size - 1) // block_size - args = dict( max_num_tokens=max_num_tokens, num_experts=num_experts, diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 1f55232feb4e..2d6a8f39cec5 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -18,7 +18,6 @@ except ImportError: has_pplx = False -from .deepep_utils import ProcessGroupInfo, parallel_launch from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import override_config @@ -30,6 +29,8 @@ FusedMoEModularKernel) from vllm.platforms import current_platform +from .deepep_utils import ProcessGroupInfo, parallel_launch + requires_pplx = pytest.mark.skipif( not has_pplx, reason="Requires PPLX kernels", @@ -154,7 +155,9 @@ def batched_moe( fused_experts = FusedMoEModularKernel( BatchedPrepareAndFinalize(max_num_tokens=a.shape[0], - world_size=1, dp_size=1, rank=0), + world_size=1, + dp_size=1, + rank=0), BatchedExperts(max_num_tokens=a.shape[0], dp_size=1, world_size=1)) return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts) @@ -557,7 +560,7 @@ def _pplx_moe( ): if use_internode: uid = nvshmem_get_unique_id( - ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() torch.distributed.broadcast(uid, src=0) nvshmem_init(uid, pgi.rank, pgi.world_size) group_name = None @@ -574,8 +577,8 @@ def _pplx_moe( with set_current_vllm_config(vllm_config), override_config(moe_config): topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) - pplx_output = pplx_moe(group_name, pgi.rank, pgi.world_size, dp_size, a, w1, w2, - topk_weight, topk_ids) + pplx_output = pplx_moe(group_name, pgi.rank, pgi.world_size, dp_size, + a, w1, w2, topk_weight, topk_ids) # TODO (bnell): fix + re-enable #batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, # topk_ids) @@ -613,4 +616,5 @@ def test_pplx_moe( w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 score = torch.randn((m, e), device="cuda", dtype=dtype) - parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, use_internode) + parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, + use_internode) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index e1f613a4bf82..fb66e96c7946 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -387,8 +387,8 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): that the PPLX dispatch/combine kernels use. """ - def __init__(self, max_num_tokens: int, world_size: int, - dp_size: int, rank: int): + def __init__(self, max_num_tokens: int, world_size: int, dp_size: int, + rank: int): super().__init__() self.world_size = world_size self.dp_size = dp_size diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 3e3d92357a6c..9ef6a126680c 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -466,7 +466,6 @@ def forward( assert cdiv(M_out, OUT_CHUNK_SIZE) == num_chunks, ( f"{cdiv(M_out, OUT_CHUNK_SIZE)} == {num_chunks}") - for chunk in range(num_chunks): begin_chunk_idx = chunk * CHUNK_SIZE end_chunk_idx = min((chunk + 1) * CHUNK_SIZE, M) From a4fb363398b0e7f3c24ab11bea9dae52bb2fb87a Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 11 Jun 2025 01:41:06 +0000 Subject: [PATCH 16/17] try to fix lint Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/cutlass_moe.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 8bb0d3d11694..745174e40746 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -230,12 +230,14 @@ def workspace_shapes( topk: int, num_experts: int, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + workspace1: tuple[int, ...] = () + workspace2: tuple[int, ...] = () + output: tuple[int, ...] = () if self.use_batched_format: padded_M = aq.shape[1] workspace1 = (self.max_experts_per_worker, padded_M, max(N, K)) workspace2 = (self.max_experts_per_worker, padded_M, (N // 2)) - output: tuple[int, - ...] = (self.max_experts_per_worker, padded_M, K) + output = (self.max_experts_per_worker, padded_M, K) else: workspace1 = (M * topk, max(2 * N, K)) workspace2 = (M * topk, N) From a5d4ba1fe8629617ab188cf06ecdbf5ffa30b130 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 11 Jun 2025 01:44:15 +0000 Subject: [PATCH 17/17] fix more lint Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/cutlass_moe.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 745174e40746..f380cb77c7e8 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -33,7 +33,7 @@ def run_cutlass_moe_fp8( per_act_token: bool, per_out_ch: bool, use_batched_format: bool, -) -> torch.Tensor: +): a1q = hidden_states assert w1_scale is not None @@ -95,6 +95,8 @@ def run_cutlass_moe_fp8( local_E = w1.shape[0] if use_batched_format: + assert expert_num_tokens is not None + expert_offsets = torch.empty((local_E), dtype=torch.int32, device=device)