diff --git a/tests/kernels/moe/test_count_expert_num_tokens.py b/tests/kernels/moe/test_count_expert_num_tokens.py new file mode 100644 index 000000000000..0872836b6064 --- /dev/null +++ b/tests/kernels/moe/test_count_expert_num_tokens.py @@ -0,0 +1,140 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Tests compute_expert_num_tokens kernels +""" + +import dataclasses +from typing import Optional + +import pytest +import torch + +from vllm.model_executor.layers.fused_moe.utils import count_expert_num_tokens + + +@dataclasses.dataclass +class TestTensors: + + topk_ids: torch.Tensor + expert_map: Optional[torch.Tensor] = None + + def to_device(self, device: str): + self.topk_ids = self.topk_ids.to(device=device) + if self.expert_map is not None: + self.expert_map = self.expert_map.to(device=device) + + @staticmethod + def make(num_tokens: int, num_topk: int, num_experts: int, device: str, + topk_ids_dtype: torch.dtype) -> "TestTensors": + + # make topk ids + topk_ids = torch.empty((num_tokens, num_topk), + device=device, + dtype=torch.int64) + for x in range(num_tokens): + topk_ids[x] = torch.randperm(num_experts)[:num_topk] + topk_ids = topk_ids.to(dtype=torch.int64) + return TestTensors(topk_ids=topk_ids) + + def with_ep_rank(self, ep_rank: int, num_global_experts: int, + num_local_experts: int, device: str): + # make an expert map + expert_map = torch.empty((num_global_experts), + device=device, + dtype=torch.int32) + expert_map.fill_(-1) + s = ep_rank * num_local_experts + e = s + num_local_experts + expert_map[s:e] = torch.tensor(list(range(num_local_experts)), + device=device) + + return TestTensors(topk_ids=self.topk_ids.clone(), + expert_map=expert_map) + + +def ref_impl(tt: TestTensors, expert_num_tokens: torch.Tensor): + # do the reference in cpu + tt.to_device("cpu") + expert_ids, counts = tt.topk_ids.unique(return_counts=True) + + for eid, count in zip(expert_ids, counts): + if eid != -1 and tt.expert_map is not None: + eid = tt.expert_map[eid] + + if eid == -1: + continue + + expert_num_tokens[eid] += count + + +def do_test_compute_expert_num_tokens(num_tokens: int, num_topk: int, + num_experts: int, ep_size: int, + topk_ids_dtype: torch.dtype): + + assert num_topk <= num_experts + + tt = TestTensors.make(num_tokens, + num_topk, + num_experts, + topk_ids_dtype=topk_ids_dtype, + device="cpu") + + num_global_experts = num_experts + assert num_global_experts % ep_size == 0 + num_local_experts = num_global_experts // ep_size + for ep_rank in range(ep_size): + tt_rank = tt.with_ep_rank(ep_rank, num_global_experts, + num_local_experts, "cpu") + + ref_expert_num_tokens = torch.zeros((num_local_experts), + device="cpu", + dtype=torch.int32) + ref_impl(tt_rank, ref_expert_num_tokens) + ref_expert_num_tokens = ref_expert_num_tokens.to("cuda") + + tt_rank.to_device("cuda") + # Test with expert_map + triton_expert_num_tokens_w_emap = count_expert_num_tokens( + tt_rank.topk_ids, num_local_experts, tt_rank.expert_map) + + # Test without expert map + topk_ids = tt_rank.expert_map[tt_rank.topk_ids].to(topk_ids_dtype) + triton_expert_num_tokens_wo_emap = count_expert_num_tokens( + topk_ids, num_local_experts, expert_map=None) + + torch.testing.assert_close(ref_expert_num_tokens, + triton_expert_num_tokens_w_emap, + atol=0, + rtol=0) + torch.testing.assert_close(ref_expert_num_tokens, + triton_expert_num_tokens_wo_emap, + atol=0, + rtol=0) + + +@pytest.mark.parametrize( + "num_tokens", [1, 4, 8, 11, 19, 128, 127, 405, 1024, 3333, 6666, 7317]) +@pytest.mark.parametrize("num_topk", [2, 6, 8]) +@pytest.mark.parametrize("num_experts", [64]) +@pytest.mark.parametrize("ep_size", [1, 2, 4]) +@pytest.mark.parametrize("topk_ids_dtype", [torch.int64]) +def test_compute_expert_num_tokens(num_tokens: int, num_topk: int, + num_experts: int, ep_size: int, + topk_ids_dtype: torch.dtype): + do_test_compute_expert_num_tokens(num_tokens, num_topk, num_experts, + ep_size, topk_ids_dtype) + + +@pytest.mark.parametrize("numel", list(range(1, 8192, 11))) +@pytest.mark.parametrize("num_experts", [32]) +@pytest.mark.parametrize("ep_size", [2]) +@pytest.mark.parametrize("topk_ids_dtype", [torch.int64]) +def test_compute_expert_num_tokens_from_numel(numel: int, num_experts: int, + ep_size: int, + topk_ids_dtype: torch.dtype): + do_test_compute_expert_num_tokens(num_tokens=numel, + num_topk=1, + num_experts=num_experts, + ep_size=ep_size, + topk_ids_dtype=topk_ids_dtype) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index c8c02497bb1f..40b58f1a4ad9 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -98,7 +98,7 @@ def workspace_shapes( M_sum = round_up(M_sum, block_m) workspace1 = (M_sum, max(N * 2, K)) workspace2 = (M_sum, max(N, K)) - output = (M * topk, K) + output = (M, topk, K) return (workspace1, workspace2, output, a.dtype) def apply( @@ -172,7 +172,7 @@ def apply( dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids) - torch.index_select(mm2_out, 0, inv_perm, out=output) + torch.index_select(mm2_out, 0, inv_perm, out=output.view((-1, K))) def deep_gemm_moe_fp8( diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 29c232afd65e..8453ab0dc951 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -10,7 +10,8 @@ import vllm.envs as envs from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.utils import _resize_cache +from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable + _resize_cache, count_expert_num_tokens) from vllm.utils import cdiv # @@ -421,6 +422,177 @@ def __init__( f"{fused_experts.__class__.__name__}." f"{fused_experts.activation_formats[0]}") + def _do_fused_experts( + self, fused_out: Optional[torch.Tensor], a1: torch.Tensor, + a1q: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, + topk_ids: torch.Tensor, activation: str, global_num_experts: int, + local_num_experts: int, expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + expert_tokens_meta: Optional[ExpertTokensMetadata] + ) -> torch.Tensor: + + _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) + + (workspace13_shape, workspace2_shape, fused_out_shape, + workspace_dtype) = self.fused_experts.workspace_shapes( + a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts) + + # We can reuse the memory between cache1 and cache3 because by the + # time we need cache3, we're done with cache1. + workspace13 = torch.empty(prod(workspace13_shape), + device=a1.device, + dtype=workspace_dtype) + workspace2 = torch.empty(prod(workspace2_shape), + device=a1.device, + dtype=workspace_dtype) + + assert fused_out is None or fused_out.shape == fused_out_shape, ( + f"fused_out {fused_out.shape} but expected {fused_out_shape}") + if fused_out is None: + # reuse workspace13 for the output + fused_out = _resize_cache(workspace13, fused_out_shape) + + self.fused_experts.apply(fused_out, + a1q, + w1, + w2, + topk_ids=topk_ids, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1q_scale=a1q_scale, + a2_scale=a2_scale, + workspace13=workspace13, + workspace2=workspace2, + expert_tokens_meta=expert_tokens_meta) + + return fused_out + + def _maybe_chunk_fused_experts( + self, a1: torch.Tensor, a1q: torch.Tensor, w1: torch.Tensor, + w2: torch.Tensor, topk_ids: torch.Tensor, activation: str, + global_num_experts: int, local_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + expert_tokens_meta: Optional[ExpertTokensMetadata] + ) -> torch.Tensor: + + _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) + + CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE + num_chunks = cdiv(M, CHUNK_SIZE) + + if not self.fused_experts.supports_chunking() or num_chunks == 1: + return self._do_fused_experts( + fused_out=None, + a1=a1, + a1q=a1q, + w1=w1, + w2=w2, + topk_ids=topk_ids, + activation=activation, + global_num_experts=global_num_experts, + local_num_experts=local_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1q_scale=a1q_scale, + a2_scale=a2_scale, + expert_tokens_meta=expert_tokens_meta) + + # Chunking required case + assert num_chunks > 1 + + # Construct the entire output that can then be processed in chunks. + (_, _, fused_out_shape, + _) = self.fused_experts.workspace_shapes(a1, a1q, M, N, K, top_k, + global_num_experts, + local_num_experts) + fused_out = torch.empty(fused_out_shape, + device=a1q.device, + dtype=a1.dtype) + + def slice_input_tensors( + chunk_idx: int + ) -> tuple[torch.Tensor, Optional[torch.Tensor], + Optional[torch.Tensor], torch.Tensor]: + s = chunk_idx * CHUNK_SIZE + e = min(s + CHUNK_SIZE, M) + return (a1q[s:e], _chunk_scales(a1q_scale, s, e), + _chunk_scales(a2_scale, s, e), topk_ids[s:e]) + + def slice_output_tensor(chunk_idx: int) -> torch.Tensor: + assert fused_out.size(0) % M == 0, ( + f"fused_out shape {fused_out.shape} vs M {M}") + factor = fused_out.size(0) // M + out_chunk_size = CHUNK_SIZE * factor + s = chunk_idx * out_chunk_size + e = min(s + out_chunk_size, fused_out.size(0)) + return fused_out[s:e] + + def slice_expert_tokens_metadata( + full_expert_tokens_meta: ExpertTokensMetadata, + chunk_topk_ids: torch.Tensor, local_num_experts: int, + expert_map: Optional[torch.Tensor]) -> ExpertTokensMetadata: + # The existing expert_num_tokens is for the entire a1q + # input. Chunking forces recomputation of the number + # of tokens assigned to each expert. + c_expert_num_tokens = count_expert_num_tokens( + chunk_topk_ids, local_num_experts, expert_map) + + c_expert_num_tokens_cpu = None + need_expert_num_tokens_cpu = ( + full_expert_tokens_meta.expert_num_tokens_cpu is not None) + if need_expert_num_tokens_cpu: + c_expert_num_tokens_cpu = c_expert_num_tokens.to( + "cpu", non_blocking=True) + + return ExpertTokensMetadata( + expert_num_tokens=c_expert_num_tokens, + expert_num_tokens_cpu=c_expert_num_tokens_cpu) + + for chunk_idx in range(num_chunks): + c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids = ( + slice_input_tensors(chunk_idx)) + + c_expert_tokens_meta = None + if expert_tokens_meta is not None: + c_expert_tokens_meta = slice_expert_tokens_metadata( + expert_tokens_meta, c_topk_ids, local_num_experts, + expert_map) + + self._do_fused_experts(fused_out=slice_output_tensor(chunk_idx), + a1=a1, + a1q=c_a1q, + w1=w1, + w2=w2, + topk_ids=c_topk_ids, + activation=activation, + global_num_experts=global_num_experts, + local_num_experts=local_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1q_scale=c_a1q_scale, + a2_scale=c_a2_scale, + expert_tokens_meta=c_expert_tokens_meta) + + return fused_out + def forward( self, hidden_states: torch.Tensor, @@ -512,110 +684,23 @@ def forward( # and can never run into the tensor.numel() == 0 case. fused_out = torch.empty_like(a1q).to(dtype=a1.dtype) else: - _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) - - if self.fused_experts.enable_chunking(): - CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE - num_chunks = cdiv(M, CHUNK_SIZE) - else: - CHUNK_SIZE = M - num_chunks = 1 - - if num_chunks == 1: - (workspace13_shape, workspace2_shape, fused_out_shape, - workspace_dtype) = self.fused_experts.workspace_shapes( - a1, a1q, M, N, K, top_k, global_num_experts, - local_num_experts) - else: - # Use the full M to get the final output shape. - _, _, fused_out_shape, _ = ( - self.fused_experts.workspace_shapes( - a1, a1q, M, N, K, top_k, global_num_experts, - local_num_experts)) - # Use the CHUNK_SIZE to get the workspace shapes. - workspace13_shape, workspace2_shape, _, workspace_dtype = ( - self.fused_experts.workspace_shapes( - a1, a1q, CHUNK_SIZE, N, K, top_k, global_num_experts, - local_num_experts)) - - # We can reuse the memory between cache1 and cache3 because by the - # time we need cache3, we're done with cache1. - workspace13 = torch.empty(prod(workspace13_shape), - device=a1.device, - dtype=workspace_dtype) - workspace2 = torch.empty(prod(workspace2_shape), - device=a1.device, - dtype=workspace_dtype) - - if num_chunks == 1: - fused_out = _resize_cache(workspace13, fused_out_shape) - - self.fused_experts.apply( - fused_out, - a1q, - w1, - w2, - topk_ids, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, - a1q_scale=a1q_scale, - a2_scale=a2_scale, - workspace13=workspace13, - workspace2=workspace2, - expert_tokens_meta=expert_tokens_meta, - ) - else: - # The leading output dimension may not be equal to M, so - # we compute output indices separately. - M_out = fused_out_shape[0] - assert M_out >= M - factor = M_out // M - assert factor > 0 - OUT_CHUNK_SIZE = CHUNK_SIZE * factor - - fused_out = torch.empty(fused_out_shape, - device=a1q.device, - dtype=workspace_dtype) - - assert cdiv(M_out, OUT_CHUNK_SIZE) == num_chunks, ( - f"{cdiv(M_out, OUT_CHUNK_SIZE)} == {num_chunks}") - - for chunk in range(num_chunks): - begin_chunk_idx = chunk * CHUNK_SIZE - end_chunk_idx = min((chunk + 1) * CHUNK_SIZE, M) - begin_out_idx = chunk * OUT_CHUNK_SIZE - end_out_idx = min((chunk + 1) * OUT_CHUNK_SIZE, M_out) - curr_a1q = a1q[begin_chunk_idx:end_chunk_idx] - curr_a1q_scale = _chunk_scales(a1q_scale, begin_chunk_idx, - end_chunk_idx) - curr_a2_scale = _chunk_scales(a2_scale, begin_chunk_idx, - end_chunk_idx) - curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] - - self.fused_experts.apply( - fused_out[begin_out_idx:end_out_idx], - curr_a1q, - w1, - w2, - curr_topk_ids, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, - a1q_scale=curr_a1q_scale, - a2_scale=curr_a2_scale, - workspace13=workspace13, - workspace2=workspace2, - expert_tokens_meta=expert_tokens_meta, - ) + fused_out = self._maybe_chunk_fused_experts( + a1=a1, + a1q=a1q, + w1=w1, + w2=w2, + topk_ids=topk_ids, + activation=activation, + global_num_experts=global_num_experts, + local_num_experts=local_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1q_scale=a1q_scale, + a2_scale=a2_scale, + expert_tokens_meta=expert_tokens_meta) self.prepare_finalize.finalize(output, fused_out, topk_weights, topk_ids, apply_router_weight_on_input) diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index 1eb949790060..b27e99150541 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -13,9 +13,81 @@ from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( quant_dequant_mxfp4) from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton from vllm.utils import cdiv +@triton.jit +def _count_expert_num_tokens(topk_ids_ptr, expert_num_tokens_ptr, num_experts, + topk_numel, expert_map, + HAS_EXPERT_MAP: tl.constexpr, + BLOCK_SIZE: tl.constexpr): + + curr_expert = tl.program_id(0) + + offsets = tl.arange(0, BLOCK_SIZE) + topk_ids_ptrs = topk_ids_ptr + offsets + + acc = tl.zeros((BLOCK_SIZE, ), dtype=tl.int32) + for x in range(tl.cdiv(topk_numel, BLOCK_SIZE)): + mask = offsets < (topk_numel - x * BLOCK_SIZE) + expert_ids = tl.load(topk_ids_ptrs, mask=mask, other=-1) + if HAS_EXPERT_MAP: + expert_map_ptrs = expert_map + expert_ids + expert_map_mask = expert_ids >= 0 + expert_ids = tl.load(expert_map_ptrs, + mask=expert_map_mask, + other=-1) + + has_curr_expert = tl.where(expert_ids == curr_expert, 1, 0) + acc = acc + has_curr_expert + topk_ids_ptrs += BLOCK_SIZE + + if curr_expert < num_experts: + tl.store(expert_num_tokens_ptr + curr_expert, tl.sum(acc)) + + +def count_expert_num_tokens( + topk_ids: torch.Tensor, num_local_experts: int, + expert_map: Optional[torch.Tensor]) -> torch.Tensor: + """ + Count the number to tokens assigned to each expert. + + Parameters: + - topk_ids (torch.Tensor): Tensor mapping each token to its + list of experts. + - num_local_experts (int): Number of experts in this rank. + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert + parallel shard. + + Returns: + A tensor of size num_local_experts, where tensor[i] holds the number + of tokens assigned to the ith expert. + """ + assert topk_ids.dtype.is_signed, ( + "The kernel uses -1 to represent invalid topk_ids") + expert_num_tokens = torch.empty((num_local_experts), + device=topk_ids.device, + dtype=torch.int32) + + grid = num_local_experts + BLOCK_SIZE = min(topk_ids.numel(), 1024) + BLOCK_SIZE = triton.next_power_of_2(BLOCK_SIZE) + + _count_expert_num_tokens[(grid, )]( + topk_ids, + expert_num_tokens, + num_local_experts, + topk_ids.numel(), + expert_map, + HAS_EXPERT_MAP=expert_map is not None, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return expert_num_tokens + + def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor: """ Shrink the given tensor and apply the given view to it. This is