Skip to content

Commit e50a1f1

Browse files
authored
[TPU] Add kernel test for moe_pallas (#17496)
Signed-off-by: Michael Goin <[email protected]>
1 parent a17cef7 commit e50a1f1

File tree

4 files changed

+96
-3
lines changed

4 files changed

+96
-3
lines changed

.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ docker run --privileged --net host --shm-size=16G -it \
4747
&& echo TEST_10 \
4848
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py \
4949
&& echo TEST_11 \
50-
&& pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py" \
50+
&& pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py \
51+
&& echo TEST_12 \
52+
&& pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py" \
5153

5254

5355
# TODO: This test fails because it uses RANDOM_SEED sampling

tests/tpu/test_moe_pallas.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Tests for the Pallas MOE implementation.
3+
4+
Run `pytest tests/kernels/moe/test_moe_pallas.py`.
5+
"""
6+
import pytest
7+
import torch
8+
9+
# yapf conflicts with isort for this block
10+
# yapf: disable
11+
from vllm.model_executor.layers.fused_moe.moe_pallas import (
12+
fused_moe as pallas_moe)
13+
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
14+
fused_moe as torch_moe)
15+
# yapf: enable
16+
from vllm.platforms import current_platform
17+
18+
if not current_platform.is_tpu():
19+
pytest.skip("This test needs a TPU.", allow_module_level=True)
20+
21+
NUM_EXPERTS = [8, 64]
22+
EP_SIZE = [1]
23+
TOP_KS = [2, 6]
24+
25+
26+
# The Pallas GMM kernel requires num_tokens * topk to be a multiple of 16
27+
@pytest.mark.parametrize("m", [8, 16, 64, 2048])
28+
@pytest.mark.parametrize("n", [128, 1024, 2048])
29+
@pytest.mark.parametrize("k", [128, 511, 1024])
30+
@pytest.mark.parametrize("e", NUM_EXPERTS)
31+
@pytest.mark.parametrize("topk", TOP_KS)
32+
@pytest.mark.parametrize("ep_size", EP_SIZE)
33+
@pytest.mark.parametrize("dtype", [torch.bfloat16])
34+
def test_pallas_moe(
35+
m: int,
36+
n: int,
37+
k: int,
38+
e: int,
39+
topk: int,
40+
ep_size: int,
41+
dtype: torch.dtype,
42+
):
43+
import torch_xla.core.xla_model as xm
44+
with torch.device(xm.xla_device()):
45+
a = torch.randn((m, k), dtype=dtype) / 10
46+
w1 = torch.randn((e, 2 * n, k), dtype=dtype) / 10
47+
w2 = torch.randn((e, k, n), dtype=dtype) / 10
48+
49+
score = torch.randn((m, e), dtype=dtype)
50+
51+
# TODO: Support ep
52+
if ep_size > 1:
53+
pytest.skip("No support for ep_size > 1 yet")
54+
else:
55+
e_map = None
56+
57+
# Run both implementations
58+
torch_output = torch_moe(
59+
hidden_states=a,
60+
w1=w1,
61+
w2=w2,
62+
gating_output=score,
63+
topk=topk,
64+
global_num_experts=e,
65+
expert_map=e_map,
66+
renormalize=False,
67+
)
68+
69+
pallas_output = pallas_moe(
70+
hidden_states=a,
71+
w1=w1,
72+
w2=w2,
73+
gating_output=score,
74+
topk=topk,
75+
global_num_experts=e,
76+
expert_map=e_map,
77+
renormalize=False,
78+
)
79+
xm.mark_step()
80+
81+
# Compare outputs
82+
torch.testing.assert_close(
83+
pallas_output.cpu(),
84+
torch_output.cpu(),
85+
atol=2e-2,
86+
rtol=0,
87+
)

vllm/attention/backends/pallas.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ def __init__(
123123
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
124124
self.logits_soft_cap = logits_soft_cap
125125
if head_size % 128 != 0:
126-
raise NotImplementedError("Head size must be a multiple of 128.")
126+
raise NotImplementedError(
127+
f"Head size must be a multiple of 128, found {head_size}.")
127128
if alibi_slopes is not None:
128129
raise NotImplementedError("Alibi slopes is not supported.")
129130
if sliding_window is not None:

vllm/model_executor/layers/fused_moe/moe_pallas.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ def fused_moe(
1111
w2: torch.Tensor,
1212
gating_output: torch.Tensor,
1313
topk: int,
14-
renormalize: bool,
14+
global_num_experts: int,
15+
expert_map: torch.Tensor = None,
16+
renormalize: bool = False,
1517
) -> torch.Tensor:
1618
"""
1719
Args:
@@ -20,6 +22,7 @@ def fused_moe(
2022
w2: [num_experts, hidden_size, intermediate_size]
2123
gating_output: [*, num_experts]
2224
"""
25+
assert expert_map is None, "expert_map is not supported for pallas MoE."
2326
orig_shape = hidden_states.shape
2427
hidden_size = hidden_states.shape[-1]
2528
num_tokens = hidden_states.shape[:-1].numel()

0 commit comments

Comments
 (0)