-
Notifications
You must be signed in to change notification settings - Fork 5.2k
Update sgl-kernel UTs for activation/topk/norm/rope kernels #6452
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
3e0ecb0
c30005e
b3198dd
50edfdc
065d695
ca936e0
fb82a99
813cd22
2addf9d
95d0dde
fc105fc
88aafc5
59da74c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| import unittest | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
| from sgl_kernel.common_ops import silu_and_mul_cpu as silu_and_mul | ||
|
|
||
| from sglang.test.test_utils import CustomTestCase | ||
|
|
||
|
|
||
| class TestActivation(CustomTestCase): | ||
| def _forward_native(self, x: torch.Tensor) -> torch.Tensor: | ||
| d = x.shape[-1] // 2 | ||
| return F.silu(x[..., :d]) * x[..., d:] | ||
|
|
||
| def _run_single_test(self, shape, dtype, device): | ||
| x = torch.randn(shape, dtype=dtype).to(device=device) | ||
|
|
||
| out = silu_and_mul(x) | ||
| ref_out = self._forward_native(x) | ||
|
|
||
| torch.testing.assert_close(out, ref_out) | ||
|
|
||
| def test_activation(self): | ||
| self._run_single_test([128, 22016], torch.bfloat16, "cpu") | ||
| self._run_single_test([129, 22016], torch.float16, "cpu") | ||
|
||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,69 @@ | ||||||||||||||||||||||
| import unittest | ||||||||||||||||||||||
| from typing import Optional, Tuple, Union | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| import torch | ||||||||||||||||||||||
| from sgl_kernel.common_ops import fused_add_rmsnorm_cpu as fused_add_rmsnorm | ||||||||||||||||||||||
| from sgl_kernel.common_ops import rmsnorm_cpu as rmsnorm | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| from sglang.test.test_utils import CustomTestCase | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| class TestNorm(CustomTestCase): | ||||||||||||||||||||||
| def _forward_native( | ||||||||||||||||||||||
| self, | ||||||||||||||||||||||
| x: torch.Tensor, | ||||||||||||||||||||||
| weight: torch.Tensor, | ||||||||||||||||||||||
| variance_epsilon: float = 1e-6, | ||||||||||||||||||||||
| residual: Optional[torch.Tensor] = None, | ||||||||||||||||||||||
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | ||||||||||||||||||||||
| orig_dtype = x.dtype | ||||||||||||||||||||||
| x = x.to(torch.float32) | ||||||||||||||||||||||
| if residual is not None: | ||||||||||||||||||||||
| x = x + residual.to(torch.float32) | ||||||||||||||||||||||
| residual = x.to(orig_dtype) | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| variance = x.pow(2).mean(dim=-1, keepdim=True) | ||||||||||||||||||||||
| x = x * torch.rsqrt(variance + variance_epsilon) | ||||||||||||||||||||||
| x = x.to(orig_dtype) * weight | ||||||||||||||||||||||
| if residual is None: | ||||||||||||||||||||||
| return x | ||||||||||||||||||||||
| else: | ||||||||||||||||||||||
| return x, residual | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def _run_single_test(self, shape, dtype, device="cuda"): | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| x = torch.randn(shape, dtype=dtype).to(device=device) | ||||||||||||||||||||||
| hidden_size = x.size(-1) | ||||||||||||||||||||||
| weight = torch.randn(hidden_size, dtype=dtype).to(device=device) | ||||||||||||||||||||||
|
||||||||||||||||||||||
| def _run_single_test(self, shape, dtype, device="cuda"): | |
| x = torch.randn(shape, dtype=dtype).to(device=device) | |
| hidden_size = x.size(-1) | |
| weight = torch.randn(hidden_size, dtype=dtype).to(device=device) | |
| def _run_single_test(self, shape, dtype): | |
| x = torch.randn(shape, dtype=dtype) | |
| hidden_size = x.size(-1) | |
| weight = torch.randn(hidden_size, dtype=dtype) |
mingfeima marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # TEST: fused_add_rmsnorm | |
| # flashinfer writes x and residual inplaced |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| residual = torch.randn(shape, dtype=dtype).to(device=device) | |
| ref_residual = residual.clone() | |
| residual = torch.randn(shape, dtype=dtype) | |
| ref_residual = residual.clone() |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use itertools.product and remove 'cpu'
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,101 @@ | ||
| import unittest | ||
|
|
||
| import torch | ||
| from sgl_kernel.common_ops import ( | ||
| rotary_position_embedding_cpu as rotary_position_embedding, | ||
| ) | ||
|
|
||
| from sglang.test.test_utils import CustomTestCase | ||
|
|
||
|
|
||
| class TestROPE(CustomTestCase): | ||
| def _rotate_neox(self, x: torch.Tensor) -> torch.Tensor: | ||
| x1 = x[..., : x.shape[-1] // 2] | ||
| x2 = x[..., x.shape[-1] // 2 :] | ||
| return torch.cat((-x2, x1), dim=-1) | ||
|
|
||
| def _rotate_gptj(self, x: torch.Tensor) -> torch.Tensor: | ||
| x1 = x[..., ::2] | ||
| x2 = x[..., 1::2] | ||
| x = torch.stack((-x2, x1), dim=-1) | ||
| return x.flatten(-2) | ||
|
|
||
| def _forward_ref(self, positions, query, key, cos_sin_cache, offsets=None): | ||
| self.rotary_dim = 64 | ||
| self.head_size = 64 | ||
| self.is_neox_style = False | ||
|
||
| query_rot = query[..., : self.rotary_dim] | ||
| key_rot = key[..., : self.rotary_dim] | ||
| if self.rotary_dim < self.head_size: | ||
| query_pass = query[..., self.rotary_dim :] | ||
| key_pass = key[..., self.rotary_dim :] | ||
|
|
||
| cos_sin = cos_sin_cache[ | ||
| torch.add(positions, offsets) if offsets is not None else positions | ||
| ] | ||
| cos, sin = cos_sin.chunk(2, dim=-1) | ||
| if self.is_neox_style: | ||
| # shape [batch_size, seq_len]. | ||
| cos = cos.repeat(1, 1, 2).unsqueeze(-2) | ||
| sin = sin.repeat(1, 1, 2).unsqueeze(-2) | ||
| else: | ||
| cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) | ||
| sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) | ||
|
|
||
| rotate_fn = self._rotate_neox if self.is_neox_style else self._rotate_gptj | ||
| query_rot = query_rot * cos + rotate_fn(query_rot) * sin | ||
| key_rot = key_rot * cos + rotate_fn(key_rot) * sin | ||
|
|
||
| if self.rotary_dim < self.head_size: | ||
| query = torch.cat((query_rot, query_pass), dim=-1) | ||
| key = torch.cat((key_rot, key_pass), dim=-1) | ||
| else: | ||
| query = query_rot | ||
| key = key_rot | ||
| return query, key | ||
|
|
||
| def test_deepseek_v2_rope(self): | ||
| num_head = 16 | ||
| seq_len = 1024 | ||
| q_head_dim = 192 | ||
| qk_nope_head_dim = 128 | ||
| qk_rope_head_dim = 64 | ||
| max_pos = 256 | ||
| k_dim = 576 | ||
|
|
||
| # Create cos_sin_cache | ||
| freqs = torch.rand(max_pos, qk_rope_head_dim // 2) | ||
| cos = freqs.cos() * 0.7 | ||
| sin = freqs.sin() * 0.7 | ||
| cos_sin_cache = torch.cat((cos, sin), dim=-1).to(torch.bfloat16) | ||
| positions = torch.randint(0, max_pos, (seq_len,)) | ||
|
|
||
| for dtype in [torch.bfloat16]: | ||
| enable_autocast = True | ||
|
|
||
| with torch.no_grad(), torch.cpu.amp.autocast(enabled=enable_autocast): | ||
| q = torch.randn(seq_len, num_head, q_head_dim, dtype=dtype) | ||
| q_clone = q.clone() | ||
| k = torch.randn(seq_len, 1, k_dim, dtype=dtype) | ||
| k_clone = k.clone() | ||
| _, q_pe = q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1) | ||
| _, q_pe_clone = q_clone.split( | ||
| [qk_nope_head_dim, qk_rope_head_dim], dim=-1 | ||
| ) | ||
| k_pe = k[:, :, k_dim - qk_rope_head_dim :] | ||
| k_pe_clone = k_clone[:, :, k_dim - qk_rope_head_dim :] | ||
|
|
||
| # ref kernel | ||
| q_pe, k_pe = self._forward_ref(positions, q_pe, k_pe, cos_sin_cache) | ||
|
|
||
| # fused rope kernel | ||
| q_pe_clone, k_pe_clone = rotary_position_embedding( | ||
| positions, q_pe_clone, k_pe_clone, cos_sin_cache | ||
| ) | ||
|
|
||
| torch.testing.assert_close(q_pe, q_pe_clone) | ||
| torch.testing.assert_close(k_pe, k_pe_clone) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,173 @@ | ||
| import unittest | ||
|
|
||
| import torch | ||
| from sgl_kernel.common_ops import biased_grouped_topk_cpu as biased_grouped_topk | ||
| from sgl_kernel.common_ops import grouped_topk_cpu as grouped_topk | ||
|
|
||
| from sglang.test.test_utils import CustomTestCase | ||
|
|
||
|
|
||
| # This is used by the Deepseek-V2 model | ||
| class TestGroupedTopK(CustomTestCase): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we import that native impl from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/moe/topk.py |
||
| def _grouped_topk_native( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| gating_output: torch.Tensor, | ||
| topk: int, | ||
| renormalize: bool, | ||
| num_expert_group: int = 0, | ||
| topk_group: int = 0, | ||
| ): | ||
|
|
||
| assert ( | ||
| hidden_states.shape[0] == gating_output.shape[0] | ||
| ), "Number of tokens mismatch" | ||
|
|
||
| scores = torch.softmax(gating_output, dim=-1) | ||
| num_token = scores.shape[0] | ||
| group_scores = ( | ||
| scores.view(num_token, num_expert_group, -1).max(dim=-1).values | ||
| ) # [n, n_group] | ||
| group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ | ||
| 1 | ||
| ] # [n, top_k_group] | ||
|
|
||
| group_mask = torch.zeros_like(group_scores) # [n, n_group] | ||
| group_mask.scatter_(1, group_idx, 1) # [n, n_group] | ||
| score_mask = ( | ||
| group_mask.unsqueeze(-1) | ||
| .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) | ||
| .reshape(num_token, -1) | ||
| ) # [n, e] | ||
|
|
||
| tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] | ||
| topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) | ||
|
|
||
| if renormalize: | ||
| topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) | ||
|
|
||
| return topk_weights.to(torch.float32), topk_ids.to(torch.int32) | ||
|
|
||
| def _run_single_test(self, M, E, G, topk, topk_group, renormalize, dtype): | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you may set a seed here as topk is unstable sort. |
||
| # expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating | ||
| hidden_states = torch.randn(M, 100, dtype=dtype) | ||
| gating_output = torch.randn(M, E, dtype=dtype) * 2 * M | ||
|
|
||
| ref_topk_weights, ref_topk_ids = self._grouped_topk_native( | ||
| hidden_states.float(), | ||
| gating_output.float(), | ||
| topk, | ||
| renormalize, | ||
| G, | ||
| topk_group, | ||
| ) | ||
|
|
||
| # fused version | ||
| topk_weights, topk_ids = grouped_topk( | ||
| hidden_states, gating_output, topk, renormalize, G, topk_group | ||
| ) | ||
|
|
||
| res = torch.zeros(M, E, dtype=torch.float) | ||
| ref = torch.zeros(M, E, dtype=torch.float) | ||
| res.scatter_(1, topk_ids.long(), topk_weights) | ||
| ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights) | ||
| torch.testing.assert_close(res, ref) | ||
|
|
||
| def test_grouped_topk(self): | ||
| for renormalize in [True, False]: | ||
| self._run_single_test(123, 8, 2, 2, 1, renormalize, torch.bfloat16) | ||
| self._run_single_test(123, 16, 4, 3, 2, renormalize, torch.bfloat16) | ||
| self._run_single_test(123, 32, 4, 3, 2, renormalize, torch.bfloat16) | ||
| self._run_single_test(1123, 32, 4, 3, 2, renormalize, torch.bfloat16) | ||
| self._run_single_test(123, 64, 1, 6, 1, renormalize, torch.bfloat16) | ||
| self._run_single_test(123, 256, 8, 4, 8, renormalize, torch.bfloat16) | ||
| self._run_single_test(123, 160, 8, 6, 2, renormalize, torch.bfloat16) | ||
|
|
||
|
|
||
| # DeepSeek V2/V3/R1 uses biased_grouped_top | ||
| class TestBiasedGroupedTopK(CustomTestCase): | ||
| def _biased_grouped_topk( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| gating_output: torch.Tensor, | ||
| correction_bias: torch.Tensor, | ||
| topk: int, | ||
| renormalize: bool, | ||
| num_expert_group: int = 0, | ||
| topk_group: int = 0, | ||
| ): | ||
|
||
| assert ( | ||
| hidden_states.shape[0] == gating_output.shape[0] | ||
| ), "Number of tokens mismatch" | ||
|
|
||
| scores = gating_output.sigmoid() | ||
| num_token = scores.shape[0] | ||
| scores_for_choice = scores.view(num_token, -1) + correction_bias.unsqueeze(0) | ||
| group_scores = ( | ||
| scores_for_choice.view(num_token, num_expert_group, -1) | ||
| .topk(2, dim=-1)[0] | ||
| .sum(dim=-1) | ||
| ) # [n, n_group] | ||
| group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ | ||
| 1 | ||
| ] # [n, top_k_group] | ||
| group_mask = torch.zeros_like(group_scores) # [n, n_group] | ||
| group_mask.scatter_(1, group_idx, 1) # [n, n_group] | ||
| score_mask = ( | ||
| group_mask.unsqueeze(-1) | ||
| .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) | ||
| .reshape(num_token, -1) | ||
| ) # [n, e] | ||
| tmp_scores = scores_for_choice.masked_fill( | ||
| ~score_mask.bool(), float("-inf") | ||
| ) # [n, e] | ||
| _, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) | ||
| topk_weights = scores.gather(1, topk_ids) | ||
|
|
||
| if renormalize: | ||
| topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) | ||
|
|
||
| return topk_weights.to(torch.float32), topk_ids.to(torch.int32) | ||
|
|
||
| def _run_single_test(self, M, E, G, topk, topk_group, renormalize, dtype): | ||
|
|
||
| # expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating | ||
| hidden_states = torch.randn(M, 100, dtype=dtype) | ||
| gating_output = torch.randn(M, E, dtype=dtype) * 2 * M | ||
| correction_bias = torch.randn(E, dtype=dtype) | ||
|
|
||
| ref_topk_weights, ref_topk_ids = self._biased_grouped_topk( | ||
| hidden_states.float(), | ||
| gating_output.float(), | ||
| correction_bias.float(), | ||
| topk, | ||
| renormalize, | ||
| G, | ||
| topk_group, | ||
| ) | ||
|
|
||
| # fused version | ||
| topk_weights, topk_ids = biased_grouped_topk( | ||
| hidden_states, | ||
| gating_output, | ||
| correction_bias, | ||
| topk, | ||
| renormalize, | ||
| G, | ||
| topk_group, | ||
| ) | ||
|
|
||
| res = torch.zeros(M, E, dtype=torch.float) | ||
| ref = torch.zeros(M, E, dtype=torch.float) | ||
| res.scatter_(1, topk_ids.long(), topk_weights) | ||
| ref.scatter_(1, ref_topk_ids.long(), ref_topk_weights) | ||
| torch.testing.assert_close(res, ref) | ||
|
|
||
| def test_biased_grouped_topk(self): | ||
| for renormalize in [True, False]: | ||
| self._run_single_test(122, 256, 8, 8, 2, renormalize, torch.bfloat16) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please change this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ref #6404 (comment) @yanbing-j
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhyncs Thanks, has updated according to #6404 (comment).