Skip to content
33 changes: 33 additions & 0 deletions test/srt/cpu/test_activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import itertools
import unittest

import torch
import torch.nn.functional as F
from sgl_kernel.common_ops import silu_and_mul_cpu as silu_and_mul
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please change this

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

@yanbing-j yanbing-j May 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhyncs Thanks, has updated according to #6404 (comment).

from utils import SiluAndMul, precision

from sglang.test.test_utils import CustomTestCase


class TestActivation(CustomTestCase):
M = [128, 129, 257]
N = [22016, 22018]
dtype = [torch.float16, torch.bfloat16]

def _activation_test(self, m, n, dtype):
x = torch.randn([m, n], dtype=dtype)

out = silu_and_mul(x)
ref_out = SiluAndMul(x)

atol = rtol = precision[ref_out.dtype]
self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol))

def test_activation(self):
for params in itertools.product(self.M, self.N, self.dtype):
with self.subTest(m=params[0], n=params[1], dtype=params[2]):
self._activation_test(*params)


if __name__ == "__main__":
unittest.main()
72 changes: 72 additions & 0 deletions test/srt/cpu/test_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import itertools
import unittest
from typing import Optional, Tuple, Union

import torch
from sgl_kernel.common_ops import fused_add_rmsnorm_cpu as fused_add_rmsnorm
from sgl_kernel.common_ops import rmsnorm_cpu as rmsnorm
from utils import precision

from sglang.test.test_utils import CustomTestCase


class TestNorm(CustomTestCase):
M = [4096, 1024]
N = [4096, 4096 + 13]
dtype = [torch.float16, torch.bfloat16]

def _forward_native(
self,
x: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float = 1e-6,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)

variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + variance_epsilon)
x = x.to(orig_dtype) * weight
if residual is None:
return x
else:
return x, residual

def _norm_test(self, m, n, dtype):

x = torch.randn([m, n], dtype=dtype)
hidden_size = x.size(-1)
weight = torch.randn(hidden_size, dtype=dtype)
variance_epsilon = 1e-6

out = rmsnorm(x, weight, variance_epsilon)
ref_out = self._forward_native(x, weight, variance_epsilon)

atol = rtol = precision[ref_out.dtype]
self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol))

ref_x = x.clone()
residual = torch.randn([m, n], dtype=dtype)
ref_residual = residual.clone()

fused_add_rmsnorm(x, residual, weight, variance_epsilon)

ref_x, ref_residual = self._forward_native(
ref_x, weight, variance_epsilon, ref_residual
)

self.assertTrue(torch.allclose(x, ref_x, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(residual, ref_residual, atol=atol, rtol=rtol))

def test_norm(self):
for params in itertools.product(self.M, self.N, self.dtype):
with self.subTest(m=params[0], n=params[1], dtype=params[2]):
self._norm_test(*params)


if __name__ == "__main__":
unittest.main()
103 changes: 103 additions & 0 deletions test/srt/cpu/test_rope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import unittest

import torch
from sgl_kernel.common_ops import (
rotary_position_embedding_cpu as rotary_position_embedding,
)
from utils import _rotate_gptj, _rotate_neox, precision

from sglang.test.test_utils import CustomTestCase


class TestROPE(CustomTestCase):

def _forward_ref(
self,
positions,
query,
key,
cos_sin_cache,
rotary_dim,
head_size,
is_neox_style,
offsets=None,
):
query_rot = query[..., :rotary_dim]
key_rot = key[..., :rotary_dim]
if rotary_dim < head_size:
query_pass = query[..., rotary_dim:]
key_pass = key[..., rotary_dim:]

cos_sin = cos_sin_cache[
torch.add(positions, offsets) if offsets is not None else positions
]
cos, sin = cos_sin.chunk(2, dim=-1)
if is_neox_style:
# shape [batch_size, seq_len].
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)

rotate_fn = _rotate_neox if is_neox_style else _rotate_gptj
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
key_rot = key_rot * cos + rotate_fn(key_rot) * sin

if rotary_dim < head_size:
query = torch.cat((query_rot, query_pass), dim=-1)
key = torch.cat((key_rot, key_pass), dim=-1)
else:
query = query_rot
key = key_rot
return query, key

def test_deepseek_v2_rope(self):
num_head = 16
seq_len = 1024
q_head_dim = 192
qk_nope_head_dim = 128
qk_rope_head_dim = 64
max_pos = 256
k_dim = 576

# Create cos_sin_cache
freqs = torch.rand(max_pos, qk_rope_head_dim // 2)
cos = freqs.cos() * 0.7
sin = freqs.sin() * 0.7
cos_sin_cache = torch.cat((cos, sin), dim=-1).to(torch.bfloat16)
positions = torch.randint(0, max_pos, (seq_len,))

for dtype in [torch.bfloat16]:
enable_autocast = True

with torch.no_grad(), torch.cpu.amp.autocast(enabled=enable_autocast):
q = torch.randn(seq_len, num_head, q_head_dim, dtype=dtype)
q_clone = q.clone()
k = torch.randn(seq_len, 1, k_dim, dtype=dtype)
k_clone = k.clone()
_, q_pe = q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
_, q_pe_clone = q_clone.split(
[qk_nope_head_dim, qk_rope_head_dim], dim=-1
)
k_pe = k[:, :, k_dim - qk_rope_head_dim :]
k_pe_clone = k_clone[:, :, k_dim - qk_rope_head_dim :]

# ref kernel
q_pe, k_pe = self._forward_ref(
positions, q_pe, k_pe, cos_sin_cache, 64, 64, False
)

# fused rope kernel
q_pe_clone, k_pe_clone = rotary_position_embedding(
positions, q_pe_clone, k_pe_clone, cos_sin_cache
)

atol = rtol = precision[q_pe.dtype]
self.assertTrue(torch.allclose(q_pe, q_pe_clone, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(k_pe, k_pe_clone, atol=atol, rtol=rtol))
torch.testing.assert_close(k_pe, k_pe_clone)


if __name__ == "__main__":
unittest.main()
97 changes: 97 additions & 0 deletions test/srt/cpu/test_topk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import itertools
import unittest

import torch
from sgl_kernel.common_ops import biased_grouped_topk_cpu as biased_grouped_topk
from sgl_kernel.common_ops import grouped_topk_cpu as grouped_topk
from utils import precision

from sglang.srt.layers.moe.topk import (
biased_grouped_topk_impl as native_biased_grouped_topk,
)
from sglang.srt.layers.moe.topk import grouped_topk as native_grouped_topk
from sglang.test.test_utils import CustomTestCase


# This is used by the Deepseek-V2 model
class TestGroupedTopK(CustomTestCase):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def _run_single_test(self, M, E, G, topk, topk_group, renormalize, dtype):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you may set a seed here as topk is unstable sort.

# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
hidden_states = torch.randn(M, 100, dtype=dtype)
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M

ref_topk_weights, ref_topk_ids = native_grouped_topk(
hidden_states.float(),
gating_output.float(),
topk,
renormalize,
G,
topk_group,
)

# fused version
topk_weights, topk_ids = grouped_topk(
hidden_states, gating_output, topk, renormalize, G, topk_group
)

res = torch.zeros(M, E, dtype=torch.float)
ref = torch.zeros(M, E, dtype=torch.float)
res.scatter_(1, topk_ids.long(), topk_weights)
ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights)
torch.testing.assert_close(res, ref)

def test_grouped_topk(self):
for renormalize in [True, False]:
self._run_single_test(123, 8, 2, 2, 1, renormalize, torch.bfloat16)
self._run_single_test(123, 16, 4, 3, 2, renormalize, torch.bfloat16)
self._run_single_test(123, 32, 4, 3, 2, renormalize, torch.bfloat16)
self._run_single_test(1123, 32, 4, 3, 2, renormalize, torch.bfloat16)
self._run_single_test(123, 64, 1, 6, 1, renormalize, torch.bfloat16)
self._run_single_test(123, 256, 8, 4, 8, renormalize, torch.bfloat16)
self._run_single_test(123, 160, 8, 6, 2, renormalize, torch.bfloat16)


# DeepSeek V2/V3/R1 uses biased_grouped_top
class TestBiasedGroupedTopK(CustomTestCase):
def _run_single_test(self, M, E, G, topk, topk_group, renormalize, dtype):

# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
hidden_states = torch.randn(M, 100, dtype=dtype)
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
correction_bias = torch.randn(E, dtype=dtype)

ref_topk_weights, ref_topk_ids = native_biased_grouped_topk(
hidden_states.float(),
gating_output.float(),
correction_bias.float(),
topk,
renormalize,
G,
topk_group,
)

# fused version
topk_weights, topk_ids = biased_grouped_topk(
hidden_states,
gating_output,
correction_bias,
topk,
renormalize,
G,
topk_group,
)

res = torch.zeros(M, E, dtype=torch.float)
ref = torch.zeros(M, E, dtype=torch.float)
res.scatter_(1, topk_ids.long(), topk_weights)
ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights)
torch.testing.assert_close(res, ref)

def test_biased_grouped_topk(self):
for renormalize in [True, False]:
self._run_single_test(122, 256, 8, 8, 2, renormalize, torch.bfloat16)


if __name__ == "__main__":
unittest.main()
13 changes: 13 additions & 0 deletions test/srt/cpu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,16 @@ def scaled_weight(weight, scales):
.contiguous()
.view(E, N, K)
)


def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., ::2]
x2 = x[..., 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this.