-
Notifications
You must be signed in to change notification settings - Fork 5.1k
Update sgl-kernel UTs for activation/topk/norm/rope kernels #6452
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
3e0ecb0
c30005e
b3198dd
50edfdc
065d695
ca936e0
fb82a99
813cd22
2addf9d
95d0dde
fc105fc
88aafc5
59da74c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| import itertools | ||
| import unittest | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
| from sgl_kernel.common_ops import silu_and_mul_cpu as silu_and_mul | ||
| from utils import SiluAndMul, precision | ||
|
|
||
| from sglang.test.test_utils import CustomTestCase | ||
|
|
||
|
|
||
| class TestActivation(CustomTestCase): | ||
| M = [128, 129, 257] | ||
| N = [22016, 22018] | ||
| dtype = [torch.float16, torch.bfloat16] | ||
|
|
||
| def _activation_test(self, m, n, dtype): | ||
| x = torch.randn([m, n], dtype=dtype) | ||
|
|
||
| out = silu_and_mul(x) | ||
| ref_out = SiluAndMul(x) | ||
|
|
||
| atol = rtol = precision[ref_out.dtype] | ||
| self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol)) | ||
|
|
||
| def test_activation(self): | ||
| for params in itertools.product(self.M, self.N, self.dtype): | ||
| with self.subTest(m=params[0], n=params[1], dtype=params[2]): | ||
| self._activation_test(*params) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| import itertools | ||
| import unittest | ||
| from typing import Optional, Tuple, Union | ||
|
|
||
| import torch | ||
| from sgl_kernel.common_ops import fused_add_rmsnorm_cpu as fused_add_rmsnorm | ||
| from sgl_kernel.common_ops import rmsnorm_cpu as rmsnorm | ||
| from utils import precision | ||
|
|
||
| from sglang.test.test_utils import CustomTestCase | ||
|
|
||
|
|
||
| class TestNorm(CustomTestCase): | ||
| M = [4096, 1024] | ||
| N = [4096, 4096 + 13] | ||
| dtype = [torch.float16, torch.bfloat16] | ||
|
|
||
| def _forward_native( | ||
| self, | ||
| x: torch.Tensor, | ||
| weight: torch.Tensor, | ||
| variance_epsilon: float = 1e-6, | ||
| residual: Optional[torch.Tensor] = None, | ||
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | ||
| orig_dtype = x.dtype | ||
| x = x.to(torch.float32) | ||
| if residual is not None: | ||
| x = x + residual.to(torch.float32) | ||
| residual = x.to(orig_dtype) | ||
|
|
||
| variance = x.pow(2).mean(dim=-1, keepdim=True) | ||
| x = x * torch.rsqrt(variance + variance_epsilon) | ||
| x = x.to(orig_dtype) * weight | ||
| if residual is None: | ||
| return x | ||
| else: | ||
| return x, residual | ||
|
|
||
| def _norm_test(self, m, n, dtype): | ||
|
|
||
| x = torch.randn([m, n], dtype=dtype) | ||
| hidden_size = x.size(-1) | ||
| weight = torch.randn(hidden_size, dtype=dtype) | ||
| variance_epsilon = 1e-6 | ||
|
|
||
| out = rmsnorm(x, weight, variance_epsilon) | ||
| ref_out = self._forward_native(x, weight, variance_epsilon) | ||
|
|
||
| atol = rtol = precision[ref_out.dtype] | ||
| self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol)) | ||
|
|
||
| ref_x = x.clone() | ||
| residual = torch.randn([m, n], dtype=dtype) | ||
| ref_residual = residual.clone() | ||
|
|
||
| fused_add_rmsnorm(x, residual, weight, variance_epsilon) | ||
|
|
||
| ref_x, ref_residual = self._forward_native( | ||
| ref_x, weight, variance_epsilon, ref_residual | ||
| ) | ||
|
|
||
| self.assertTrue(torch.allclose(x, ref_x, atol=atol, rtol=rtol)) | ||
| self.assertTrue(torch.allclose(residual, ref_residual, atol=atol, rtol=rtol)) | ||
|
|
||
| def test_norm(self): | ||
| for params in itertools.product(self.M, self.N, self.dtype): | ||
| with self.subTest(m=params[0], n=params[1], dtype=params[2]): | ||
| self._norm_test(*params) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,103 @@ | ||
| import unittest | ||
|
|
||
| import torch | ||
| from sgl_kernel.common_ops import ( | ||
| rotary_position_embedding_cpu as rotary_position_embedding, | ||
| ) | ||
| from utils import _rotate_gptj, _rotate_neox, precision | ||
|
|
||
| from sglang.test.test_utils import CustomTestCase | ||
|
|
||
|
|
||
| class TestROPE(CustomTestCase): | ||
|
|
||
| def _forward_ref( | ||
| self, | ||
| positions, | ||
| query, | ||
| key, | ||
| cos_sin_cache, | ||
| rotary_dim, | ||
| head_size, | ||
| is_neox_style, | ||
| offsets=None, | ||
| ): | ||
| query_rot = query[..., :rotary_dim] | ||
| key_rot = key[..., :rotary_dim] | ||
| if rotary_dim < head_size: | ||
| query_pass = query[..., rotary_dim:] | ||
| key_pass = key[..., rotary_dim:] | ||
|
|
||
| cos_sin = cos_sin_cache[ | ||
| torch.add(positions, offsets) if offsets is not None else positions | ||
| ] | ||
| cos, sin = cos_sin.chunk(2, dim=-1) | ||
| if is_neox_style: | ||
| # shape [batch_size, seq_len]. | ||
| cos = cos.repeat(1, 1, 2).unsqueeze(-2) | ||
| sin = sin.repeat(1, 1, 2).unsqueeze(-2) | ||
|
||
| else: | ||
| cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) | ||
| sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) | ||
|
|
||
| rotate_fn = _rotate_neox if is_neox_style else _rotate_gptj | ||
| query_rot = query_rot * cos + rotate_fn(query_rot) * sin | ||
| key_rot = key_rot * cos + rotate_fn(key_rot) * sin | ||
|
|
||
| if rotary_dim < head_size: | ||
| query = torch.cat((query_rot, query_pass), dim=-1) | ||
| key = torch.cat((key_rot, key_pass), dim=-1) | ||
| else: | ||
| query = query_rot | ||
| key = key_rot | ||
| return query, key | ||
|
|
||
| def test_deepseek_v2_rope(self): | ||
| num_head = 16 | ||
| seq_len = 1024 | ||
| q_head_dim = 192 | ||
| qk_nope_head_dim = 128 | ||
| qk_rope_head_dim = 64 | ||
| max_pos = 256 | ||
| k_dim = 576 | ||
|
|
||
| # Create cos_sin_cache | ||
| freqs = torch.rand(max_pos, qk_rope_head_dim // 2) | ||
| cos = freqs.cos() * 0.7 | ||
| sin = freqs.sin() * 0.7 | ||
| cos_sin_cache = torch.cat((cos, sin), dim=-1).to(torch.bfloat16) | ||
| positions = torch.randint(0, max_pos, (seq_len,)) | ||
|
|
||
| for dtype in [torch.bfloat16]: | ||
| enable_autocast = True | ||
|
|
||
| with torch.no_grad(), torch.cpu.amp.autocast(enabled=enable_autocast): | ||
| q = torch.randn(seq_len, num_head, q_head_dim, dtype=dtype) | ||
| q_clone = q.clone() | ||
| k = torch.randn(seq_len, 1, k_dim, dtype=dtype) | ||
| k_clone = k.clone() | ||
| _, q_pe = q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1) | ||
| _, q_pe_clone = q_clone.split( | ||
| [qk_nope_head_dim, qk_rope_head_dim], dim=-1 | ||
| ) | ||
| k_pe = k[:, :, k_dim - qk_rope_head_dim :] | ||
| k_pe_clone = k_clone[:, :, k_dim - qk_rope_head_dim :] | ||
|
|
||
| # ref kernel | ||
| q_pe, k_pe = self._forward_ref( | ||
| positions, q_pe, k_pe, cos_sin_cache, 64, 64, False | ||
| ) | ||
|
|
||
| # fused rope kernel | ||
| q_pe_clone, k_pe_clone = rotary_position_embedding( | ||
| positions, q_pe_clone, k_pe_clone, cos_sin_cache | ||
| ) | ||
|
|
||
| atol = rtol = precision[q_pe.dtype] | ||
| self.assertTrue(torch.allclose(q_pe, q_pe_clone, atol=atol, rtol=rtol)) | ||
| self.assertTrue(torch.allclose(k_pe, k_pe_clone, atol=atol, rtol=rtol)) | ||
| torch.testing.assert_close(k_pe, k_pe_clone) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,97 @@ | ||
| import itertools | ||
| import unittest | ||
|
|
||
| import torch | ||
| from sgl_kernel.common_ops import biased_grouped_topk_cpu as biased_grouped_topk | ||
| from sgl_kernel.common_ops import grouped_topk_cpu as grouped_topk | ||
| from utils import precision | ||
|
|
||
| from sglang.srt.layers.moe.topk import ( | ||
| biased_grouped_topk_impl as native_biased_grouped_topk, | ||
| ) | ||
| from sglang.srt.layers.moe.topk import grouped_topk as native_grouped_topk | ||
| from sglang.test.test_utils import CustomTestCase | ||
|
|
||
|
|
||
| # This is used by the Deepseek-V2 model | ||
| class TestGroupedTopK(CustomTestCase): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we import that native impl from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/moe/topk.py |
||
| def _run_single_test(self, M, E, G, topk, topk_group, renormalize, dtype): | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you may set a seed here as topk is unstable sort. |
||
| # expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating | ||
| hidden_states = torch.randn(M, 100, dtype=dtype) | ||
| gating_output = torch.randn(M, E, dtype=dtype) * 2 * M | ||
|
|
||
| ref_topk_weights, ref_topk_ids = native_grouped_topk( | ||
| hidden_states.float(), | ||
| gating_output.float(), | ||
| topk, | ||
| renormalize, | ||
| G, | ||
| topk_group, | ||
| ) | ||
|
|
||
| # fused version | ||
| topk_weights, topk_ids = grouped_topk( | ||
| hidden_states, gating_output, topk, renormalize, G, topk_group | ||
| ) | ||
|
|
||
| res = torch.zeros(M, E, dtype=torch.float) | ||
| ref = torch.zeros(M, E, dtype=torch.float) | ||
| res.scatter_(1, topk_ids.long(), topk_weights) | ||
| ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights) | ||
| torch.testing.assert_close(res, ref) | ||
|
|
||
| def test_grouped_topk(self): | ||
| for renormalize in [True, False]: | ||
| self._run_single_test(123, 8, 2, 2, 1, renormalize, torch.bfloat16) | ||
| self._run_single_test(123, 16, 4, 3, 2, renormalize, torch.bfloat16) | ||
| self._run_single_test(123, 32, 4, 3, 2, renormalize, torch.bfloat16) | ||
| self._run_single_test(1123, 32, 4, 3, 2, renormalize, torch.bfloat16) | ||
| self._run_single_test(123, 64, 1, 6, 1, renormalize, torch.bfloat16) | ||
| self._run_single_test(123, 256, 8, 4, 8, renormalize, torch.bfloat16) | ||
| self._run_single_test(123, 160, 8, 6, 2, renormalize, torch.bfloat16) | ||
|
|
||
|
|
||
| # DeepSeek V2/V3/R1 uses biased_grouped_top | ||
| class TestBiasedGroupedTopK(CustomTestCase): | ||
| def _run_single_test(self, M, E, G, topk, topk_group, renormalize, dtype): | ||
|
|
||
| # expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating | ||
| hidden_states = torch.randn(M, 100, dtype=dtype) | ||
| gating_output = torch.randn(M, E, dtype=dtype) * 2 * M | ||
| correction_bias = torch.randn(E, dtype=dtype) | ||
|
|
||
| ref_topk_weights, ref_topk_ids = native_biased_grouped_topk( | ||
| hidden_states.float(), | ||
| gating_output.float(), | ||
| correction_bias.float(), | ||
| topk, | ||
| renormalize, | ||
| G, | ||
| topk_group, | ||
| ) | ||
|
|
||
| # fused version | ||
| topk_weights, topk_ids = biased_grouped_topk( | ||
| hidden_states, | ||
| gating_output, | ||
| correction_bias, | ||
| topk, | ||
| renormalize, | ||
| G, | ||
| topk_group, | ||
| ) | ||
|
|
||
| res = torch.zeros(M, E, dtype=torch.float) | ||
| ref = torch.zeros(M, E, dtype=torch.float) | ||
| res.scatter_(1, topk_ids.long(), topk_weights) | ||
| ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights) | ||
| torch.testing.assert_close(res, ref) | ||
|
|
||
| def test_biased_grouped_topk(self): | ||
| for renormalize in [True, False]: | ||
| self._run_single_test(122, 256, 8, 8, 2, renormalize, torch.bfloat16) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -148,3 +148,16 @@ def scaled_weight(weight, scales): | |
| .contiguous() | ||
| .view(E, N, K) | ||
| ) | ||
|
|
||
|
|
||
| def _rotate_neox(x: torch.Tensor) -> torch.Tensor: | ||
| x1 = x[..., : x.shape[-1] // 2] | ||
| x2 = x[..., x.shape[-1] // 2 :] | ||
| return torch.cat((-x2, x1), dim=-1) | ||
|
|
||
|
|
||
| def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: | ||
| x1 = x[..., ::2] | ||
| x2 = x[..., 1::2] | ||
| x = torch.stack((-x2, x1), dim=-1) | ||
| return x.flatten(-2) | ||
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please change this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ref #6404 (comment) @yanbing-j
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhyncs Thanks, has updated according to #6404 (comment).