diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index 82cc623f9364..9ad97f2fc419 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -671,9 +671,9 @@ def forward( q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) # [b, s, embed_dim] --> [b * s, head, head_size] - q = q.reshape(bsz * s, head, -1).contiguous() - k = k.reshape(bsz * s, kv_head, -1).contiguous() - v = v.reshape(bsz * s, kv_head, -1).contiguous() + q = q.reshape(bsz * s, head, -1) + k = k.reshape(bsz * s, kv_head, -1) + v = v.reshape(bsz * s, kv_head, -1) else: # [b, s, embed_dim] --> [s, b, embed_dim] x = rearrange(x, "b s ... -> s b ...") @@ -692,7 +692,7 @@ def forward( # [s, b, head, head_size] --> [b, s, head, head_size] q, k, v = [ - rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v) + rearrange(x, "s b ... -> b s ...") for x in (q, k, v) ] if position_embeddings is not None: diff --git a/sgl-kernel/csrc/cpu/rope.cpp b/sgl-kernel/csrc/cpu/rope.cpp index ae92f3ce897c..1a4049b50b3d 100644 --- a/sgl-kernel/csrc/cpu/rope.cpp +++ b/sgl-kernel/csrc/cpu/rope.cpp @@ -622,8 +622,8 @@ std::tuple rotary_embedding_cpu( std::tuple apply_rotary_pos_emb_cpu(at::Tensor& query, at::Tensor& key, at::Tensor& cos, at::Tensor& sin) { RECORD_FUNCTION("sgl-kernel::apply_rotary_pos_emb_cpu", std::vector({query, key})); - CHECK_INPUT(query); - CHECK_INPUT(key); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(query); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(key); CHECK_INPUT(cos); CHECK_INPUT(sin); CHECK_DIM(3, query); diff --git a/test/srt/cpu/test_rope.py b/test/srt/cpu/test_rope.py index 02787cfe077f..b43d42528aae 100644 --- a/test/srt/cpu/test_rope.py +++ b/test/srt/cpu/test_rope.py @@ -253,15 +253,15 @@ def test_apply_rotary_pos_emb(self): num_tokens = 1024 num_heads = 8 head_size = 72 - query = torch.randn(num_tokens, num_heads, head_size).to(torch.bfloat16) - key = torch.randn(num_tokens, num_heads, head_size).to(torch.bfloat16) + qkv = torch.randn(num_tokens, num_heads * head_size * 3).to(torch.bfloat16) + query, key, _ = qkv.split([num_heads * head_size, num_heads * head_size, num_heads * head_size], dim=-1) + query = query.view(num_tokens, num_heads, head_size) + key = key.view(num_tokens, num_heads, head_size) cos = torch.rand(num_tokens, head_size).to(torch.float32) sin = torch.rand(num_tokens, head_size).to(torch.float32) - query_clone = query.clone() - key_clone = key.clone() q_out_ref, k_out_ref = apply_rotary_pos_emb_native(query, key, cos, sin) q_out_sgl, k_out_sgl = torch.ops.sgl_kernel.apply_rotary_pos_emb_cpu( - query_clone, key_clone, cos, sin + query, key, cos, sin ) torch.testing.assert_close(q_out_ref, q_out_sgl, atol=1e-2, rtol=1e-2) torch.testing.assert_close(k_out_ref, k_out_sgl, atol=1e-2, rtol=1e-2)