Skip to content
29 changes: 29 additions & 0 deletions test/srt/cpu/test_activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import unittest

import torch
import torch.nn.functional as F
from sgl_kernel.common_ops import silu_and_mul_cpu as silu_and_mul
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please change this

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

@yanbing-j yanbing-j May 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhyncs Thanks, has updated according to #6404 (comment).


from sglang.test.test_utils import CustomTestCase


class TestActivation(CustomTestCase):
def _forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]

def _run_single_test(self, shape, dtype, device):
x = torch.randn(shape, dtype=dtype).to(device=device)

out = silu_and_mul(x)
ref_out = self._forward_native(x)

torch.testing.assert_close(out, ref_out)

def test_activation(self):
self._run_single_test([128, 22016], torch.bfloat16, "cpu")
self._run_single_test([129, 22016], torch.float16, "cpu")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test case won't take too much time, suggest to use itertools.product to cover more combinations.



if __name__ == "__main__":
unittest.main()
69 changes: 69 additions & 0 deletions test/srt/cpu/test_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import unittest
from typing import Optional, Tuple, Union

import torch
from sgl_kernel.common_ops import fused_add_rmsnorm_cpu as fused_add_rmsnorm
from sgl_kernel.common_ops import rmsnorm_cpu as rmsnorm

from sglang.test.test_utils import CustomTestCase


class TestNorm(CustomTestCase):
def _forward_native(
self,
x: torch.Tensor,
weight: torch.Tensor,
variance_epsilon: float = 1e-6,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)

variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + variance_epsilon)
x = x.to(orig_dtype) * weight
if residual is None:
return x
else:
return x, residual

def _run_single_test(self, shape, dtype, device="cuda"):

x = torch.randn(shape, dtype=dtype).to(device=device)
hidden_size = x.size(-1)
weight = torch.randn(hidden_size, dtype=dtype).to(device=device)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _run_single_test(self, shape, dtype, device="cuda"):
x = torch.randn(shape, dtype=dtype).to(device=device)
hidden_size = x.size(-1)
weight = torch.randn(hidden_size, dtype=dtype).to(device=device)
def _run_single_test(self, shape, dtype):
x = torch.randn(shape, dtype=dtype)
hidden_size = x.size(-1)
weight = torch.randn(hidden_size, dtype=dtype)

variance_epsilon = 1e-6

# TEST: rmsnorm
out = rmsnorm(x, weight, variance_epsilon)
ref_out = self._forward_native(x, weight, variance_epsilon)

torch.testing.assert_close(out, ref_out)

# TEST: fused_add_rmsnorm
# flashinfer writes x and residual inplaced
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# TEST: fused_add_rmsnorm
# flashinfer writes x and residual inplaced

ref_x = x.clone()

residual = torch.randn(shape, dtype=dtype).to(device=device)
ref_residual = residual.clone()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
residual = torch.randn(shape, dtype=dtype).to(device=device)
ref_residual = residual.clone()
residual = torch.randn(shape, dtype=dtype)
ref_residual = residual.clone()


fused_add_rmsnorm(x, residual, weight, variance_epsilon)

ref_x, ref_residual = self._forward_native(
ref_x, weight, variance_epsilon, ref_residual
)

torch.testing.assert_close(x, ref_x)
torch.testing.assert_close(residual, ref_residual)

def test_norm(self):
self._run_single_test([4096, 4096], torch.bfloat16, "cpu")
self._run_single_test([1024, 4096], torch.bfloat16, "cpu")
self._run_single_test([1024, 4096 + 13], torch.float16, "cpu")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use itertools.product and remove 'cpu'



if __name__ == "__main__":
unittest.main()
101 changes: 101 additions & 0 deletions test/srt/cpu/test_rope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import unittest

import torch
from sgl_kernel.common_ops import (
rotary_position_embedding_cpu as rotary_position_embedding,
)

from sglang.test.test_utils import CustomTestCase


class TestROPE(CustomTestCase):
def _rotate_neox(self, x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)

def _rotate_gptj(self, x: torch.Tensor) -> torch.Tensor:
x1 = x[..., ::2]
x2 = x[..., 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2)

def _forward_ref(self, positions, query, key, cos_sin_cache, offsets=None):
self.rotary_dim = 64
self.head_size = 64
self.is_neox_style = False
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we support is_neox in the C++ kernels?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make these input parameters, rotary_dim, head_size and is_neox_stype

query_rot = query[..., : self.rotary_dim]
key_rot = key[..., : self.rotary_dim]
if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim :]
key_pass = key[..., self.rotary_dim :]

cos_sin = cos_sin_cache[
torch.add(positions, offsets) if offsets is not None else positions
]
cos, sin = cos_sin.chunk(2, dim=-1)
if self.is_neox_style:
# shape [batch_size, seq_len].
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)

rotate_fn = self._rotate_neox if self.is_neox_style else self._rotate_gptj
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
key_rot = key_rot * cos + rotate_fn(key_rot) * sin

if self.rotary_dim < self.head_size:
query = torch.cat((query_rot, query_pass), dim=-1)
key = torch.cat((key_rot, key_pass), dim=-1)
else:
query = query_rot
key = key_rot
return query, key

def test_deepseek_v2_rope(self):
num_head = 16
seq_len = 1024
q_head_dim = 192
qk_nope_head_dim = 128
qk_rope_head_dim = 64
max_pos = 256
k_dim = 576

# Create cos_sin_cache
freqs = torch.rand(max_pos, qk_rope_head_dim // 2)
cos = freqs.cos() * 0.7
sin = freqs.sin() * 0.7
cos_sin_cache = torch.cat((cos, sin), dim=-1).to(torch.bfloat16)
positions = torch.randint(0, max_pos, (seq_len,))

for dtype in [torch.bfloat16]:
enable_autocast = True

with torch.no_grad(), torch.cpu.amp.autocast(enabled=enable_autocast):
q = torch.randn(seq_len, num_head, q_head_dim, dtype=dtype)
q_clone = q.clone()
k = torch.randn(seq_len, 1, k_dim, dtype=dtype)
k_clone = k.clone()
_, q_pe = q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
_, q_pe_clone = q_clone.split(
[qk_nope_head_dim, qk_rope_head_dim], dim=-1
)
k_pe = k[:, :, k_dim - qk_rope_head_dim :]
k_pe_clone = k_clone[:, :, k_dim - qk_rope_head_dim :]

# ref kernel
q_pe, k_pe = self._forward_ref(positions, q_pe, k_pe, cos_sin_cache)

# fused rope kernel
q_pe_clone, k_pe_clone = rotary_position_embedding(
positions, q_pe_clone, k_pe_clone, cos_sin_cache
)

torch.testing.assert_close(q_pe, q_pe_clone)
torch.testing.assert_close(k_pe, k_pe_clone)


if __name__ == "__main__":
unittest.main()
173 changes: 173 additions & 0 deletions test/srt/cpu/test_topk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import unittest

import torch
from sgl_kernel.common_ops import biased_grouped_topk_cpu as biased_grouped_topk
from sgl_kernel.common_ops import grouped_topk_cpu as grouped_topk

from sglang.test.test_utils import CustomTestCase


# This is used by the Deepseek-V2 model
class TestGroupedTopK(CustomTestCase):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def _grouped_topk_native(
self,
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
):

assert (
hidden_states.shape[0] == gating_output.shape[0]
), "Number of tokens mismatch"

scores = torch.softmax(gating_output, dim=-1)
num_token = scores.shape[0]
group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
1
] # [n, top_k_group]

group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
.reshape(num_token, -1)
) # [n, e]

tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)

if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)

return topk_weights.to(torch.float32), topk_ids.to(torch.int32)

def _run_single_test(self, M, E, G, topk, topk_group, renormalize, dtype):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you may set a seed here as topk is unstable sort.

# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
hidden_states = torch.randn(M, 100, dtype=dtype)
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M

ref_topk_weights, ref_topk_ids = self._grouped_topk_native(
hidden_states.float(),
gating_output.float(),
topk,
renormalize,
G,
topk_group,
)

# fused version
topk_weights, topk_ids = grouped_topk(
hidden_states, gating_output, topk, renormalize, G, topk_group
)

res = torch.zeros(M, E, dtype=torch.float)
ref = torch.zeros(M, E, dtype=torch.float)
res.scatter_(1, topk_ids.long(), topk_weights)
ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights)
torch.testing.assert_close(res, ref)

def test_grouped_topk(self):
for renormalize in [True, False]:
self._run_single_test(123, 8, 2, 2, 1, renormalize, torch.bfloat16)
self._run_single_test(123, 16, 4, 3, 2, renormalize, torch.bfloat16)
self._run_single_test(123, 32, 4, 3, 2, renormalize, torch.bfloat16)
self._run_single_test(1123, 32, 4, 3, 2, renormalize, torch.bfloat16)
self._run_single_test(123, 64, 1, 6, 1, renormalize, torch.bfloat16)
self._run_single_test(123, 256, 8, 4, 8, renormalize, torch.bfloat16)
self._run_single_test(123, 160, 8, 6, 2, renormalize, torch.bfloat16)


# DeepSeek V2/V3/R1 uses biased_grouped_top
class TestBiasedGroupedTopK(CustomTestCase):
def _biased_grouped_topk(
self,
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert (
hidden_states.shape[0] == gating_output.shape[0]
), "Number of tokens mismatch"

scores = gating_output.sigmoid()
num_token = scores.shape[0]
scores_for_choice = scores.view(num_token, -1) + correction_bias.unsqueeze(0)
group_scores = (
scores_for_choice.view(num_token, num_expert_group, -1)
.topk(2, dim=-1)[0]
.sum(dim=-1)
) # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
1
] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
.reshape(num_token, -1)
) # [n, e]
tmp_scores = scores_for_choice.masked_fill(
~score_mask.bool(), float("-inf")
) # [n, e]
_, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
topk_weights = scores.gather(1, topk_ids)

if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)

return topk_weights.to(torch.float32), topk_ids.to(torch.int32)

def _run_single_test(self, M, E, G, topk, topk_group, renormalize, dtype):

# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
hidden_states = torch.randn(M, 100, dtype=dtype)
gating_output = torch.randn(M, E, dtype=dtype) * 2 * M
correction_bias = torch.randn(E, dtype=dtype)

ref_topk_weights, ref_topk_ids = self._biased_grouped_topk(
hidden_states.float(),
gating_output.float(),
correction_bias.float(),
topk,
renormalize,
G,
topk_group,
)

# fused version
topk_weights, topk_ids = biased_grouped_topk(
hidden_states,
gating_output,
correction_bias,
topk,
renormalize,
G,
topk_group,
)

res = torch.zeros(M, E, dtype=torch.float)
ref = torch.zeros(M, E, dtype=torch.float)
res.scatter_(1, topk_ids.long(), topk_weights)
ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights)
torch.testing.assert_close(res, ref)

def test_biased_grouped_topk(self):
for renormalize in [True, False]:
self._run_single_test(122, 256, 8, 8, 2, renormalize, torch.bfloat16)


if __name__ == "__main__":
unittest.main()