Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions tests/kernels/moe/test_count_expert_num_tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests compute_expert_num_tokens kernels
"""

import dataclasses
from typing import Optional

import pytest
import torch

from vllm.model_executor.layers.fused_moe.utils import count_expert_num_tokens


@dataclasses.dataclass
class TestTensors:

topk_ids: torch.Tensor
expert_map: Optional[torch.Tensor] = None

def to_device(self, device: str):
self.topk_ids = self.topk_ids.to(device=device)
if self.expert_map is not None:
self.expert_map = self.expert_map.to(device=device)

@staticmethod
def make(num_tokens: int, num_topk: int, num_experts: int, device: str,
topk_ids_dtype: torch.dtype) -> "TestTensors":

# make topk ids
topk_ids = torch.empty((num_tokens, num_topk),
device=device,
dtype=torch.int64)
for x in range(num_tokens):
topk_ids[x] = torch.randperm(num_experts)[:num_topk]
topk_ids = topk_ids.to(dtype=torch.int64)
return TestTensors(topk_ids=topk_ids)

def with_ep_rank(self, ep_rank: int, num_global_experts: int,
num_local_experts: int, device: str):
# make an expert map
expert_map = torch.empty((num_global_experts),
device=device,
dtype=torch.int32)
expert_map.fill_(-1)
s = ep_rank * num_local_experts
e = s + num_local_experts
expert_map[s:e] = torch.tensor(list(range(num_local_experts)),
device=device)

return TestTensors(topk_ids=self.topk_ids.clone(),
expert_map=expert_map)


def ref_impl(tt: TestTensors, expert_num_tokens: torch.Tensor):
# do the reference in cpu
tt.to_device("cpu")
expert_ids, counts = tt.topk_ids.unique(return_counts=True)

for eid, count in zip(expert_ids, counts):
if eid != -1 and tt.expert_map is not None:
eid = tt.expert_map[eid]

if eid == -1:
continue

expert_num_tokens[eid] += count


def do_test_compute_expert_num_tokens(num_tokens: int, num_topk: int,
num_experts: int, ep_size: int,
topk_ids_dtype: torch.dtype):

assert num_topk <= num_experts

tt = TestTensors.make(num_tokens,
num_topk,
num_experts,
topk_ids_dtype=topk_ids_dtype,
device="cpu")

num_global_experts = num_experts
assert num_global_experts % ep_size == 0
num_local_experts = num_global_experts // ep_size
for ep_rank in range(ep_size):
tt_rank = tt.with_ep_rank(ep_rank, num_global_experts,
num_local_experts, "cpu")

ref_expert_num_tokens = torch.zeros((num_local_experts),
device="cpu",
dtype=torch.int32)
ref_impl(tt_rank, ref_expert_num_tokens)
ref_expert_num_tokens = ref_expert_num_tokens.to("cuda")

tt_rank.to_device("cuda")
# Test with expert_map
triton_expert_num_tokens_w_emap = count_expert_num_tokens(
tt_rank.topk_ids, num_local_experts, tt_rank.expert_map)

# Test without expert map
topk_ids = tt_rank.expert_map[tt_rank.topk_ids].to(topk_ids_dtype)
triton_expert_num_tokens_wo_emap = count_expert_num_tokens(
topk_ids, num_local_experts, expert_map=None)

torch.testing.assert_close(ref_expert_num_tokens,
triton_expert_num_tokens_w_emap,
atol=0,
rtol=0)
torch.testing.assert_close(ref_expert_num_tokens,
triton_expert_num_tokens_wo_emap,
atol=0,
rtol=0)


@pytest.mark.parametrize(
"num_tokens", [1, 4, 8, 11, 19, 128, 127, 405, 1024, 3333, 6666, 7317])
@pytest.mark.parametrize("num_topk", [2, 6, 8])
@pytest.mark.parametrize("num_experts", [64])
@pytest.mark.parametrize("ep_size", [1, 2, 4])
@pytest.mark.parametrize("topk_ids_dtype", [torch.int64])
def test_compute_expert_num_tokens(num_tokens: int, num_topk: int,
num_experts: int, ep_size: int,
topk_ids_dtype: torch.dtype):
do_test_compute_expert_num_tokens(num_tokens, num_topk, num_experts,
ep_size, topk_ids_dtype)


@pytest.mark.parametrize("numel", list(range(1, 8192, 11)))
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("ep_size", [2])
@pytest.mark.parametrize("topk_ids_dtype", [torch.int64])
def test_compute_expert_num_tokens_from_numel(numel: int, num_experts: int,
ep_size: int,
topk_ids_dtype: torch.dtype):
do_test_compute_expert_num_tokens(num_tokens=numel,
num_topk=1,
num_experts=num_experts,
ep_size=ep_size,
topk_ids_dtype=topk_ids_dtype)
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def workspace_shapes(
M_sum = round_up(M_sum, block_m)
workspace1 = (M_sum, max(N * 2, K))
workspace2 = (M_sum, max(N, K))
output = (M * topk, K)
output = (M, topk, K)
return (workspace1, workspace2, output, a.dtype)

def apply(
Expand Down Expand Up @@ -172,7 +172,7 @@ def apply(
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids)

torch.index_select(mm2_out, 0, inv_perm, out=output)
torch.index_select(mm2_out, 0, inv_perm, out=output.view((-1, K)))


def deep_gemm_moe_fp8(
Expand Down
Loading