Skip to content

Update sgl-kernel UTs for activation/topk/norm/rope kernels#6452

Merged
zhyncs merged 13 commits intosgl-project:mainfrom
yanbing-j:yanbing/pointwise_ops
May 23, 2025
Merged

Update sgl-kernel UTs for activation/topk/norm/rope kernels#6452
zhyncs merged 13 commits intosgl-project:mainfrom
yanbing-j:yanbing/pointwise_ops

Conversation

@yanbing-j
Copy link
Copy Markdown
Contributor

Motivation

This PR is a follow-up on #2807 to add UTs for activation/topk/morm/rope kernels.

Modifications

Checklist

Copy link
Copy Markdown
Collaborator

@mingfeima mingfeima left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merge test_grouped_topk.py and test_biased_grouped_topk.py into one, name it after test_topk.py. We only focus on DS v2 and v3 right now, later on we need to add more types of topk kernels.

def _forward_ref(self, positions, query, key, cos_sin_cache, offsets=None):
self.rotary_dim = 64
self.head_size = 64
self.is_neox_style = False
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we support is_neox in the C++ kernels?

@mingfeima mingfeima added sgl-kernel intel cpu cpu backend performance optimization labels May 21, 2025
Comment on lines +23 to +25
def test_activation(self):
self._run_single_test([128, 22016], torch.bfloat16, "cpu")
self._run_single_test([129, 22016], torch.float16, "cpu")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test case won't take too much time, suggest to use itertools.product to cover more combinations.

Comment on lines +33 to +37
def _run_single_test(self, shape, dtype, device="cuda"):

x = torch.randn(shape, dtype=dtype).to(device=device)
hidden_size = x.size(-1)
weight = torch.randn(hidden_size, dtype=dtype).to(device=device)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _run_single_test(self, shape, dtype, device="cuda"):
x = torch.randn(shape, dtype=dtype).to(device=device)
hidden_size = x.size(-1)
weight = torch.randn(hidden_size, dtype=dtype).to(device=device)
def _run_single_test(self, shape, dtype):
x = torch.randn(shape, dtype=dtype)
hidden_size = x.size(-1)
weight = torch.randn(hidden_size, dtype=dtype)

Comment on lines +46 to +47
# TEST: fused_add_rmsnorm
# flashinfer writes x and residual inplaced
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# TEST: fused_add_rmsnorm
# flashinfer writes x and residual inplaced

Comment on lines +50 to +51
residual = torch.randn(shape, dtype=dtype).to(device=device)
ref_residual = residual.clone()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
residual = torch.randn(shape, dtype=dtype).to(device=device)
ref_residual = residual.clone()
residual = torch.randn(shape, dtype=dtype)
ref_residual = residual.clone()

Comment on lines +62 to +65
def test_norm(self):
self._run_single_test([4096, 4096], torch.bfloat16, "cpu")
self._run_single_test([1024, 4096], torch.bfloat16, "cpu")
self._run_single_test([1024, 4096 + 13], torch.float16, "cpu")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use itertools.product and remove 'cpu'

Comment on lines +24 to +26
self.rotary_dim = 64
self.head_size = 64
self.is_neox_style = False
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make these input parameters, rotary_dim, head_size and is_neox_stype



# This is used by the Deepseek-V2 model
class TestGroupedTopK(CustomTestCase):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +90 to +99
def _biased_grouped_topk(
self,
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +14 to +38
def _forward_ref(
self,
positions,
query,
key,
cos_sin_cache,
rotary_dim,
head_size,
is_neox_style,
offsets=None,
):
query_rot = query[..., :rotary_dim]
key_rot = key[..., :rotary_dim]
if rotary_dim < head_size:
query_pass = query[..., rotary_dim:]
key_pass = key[..., rotary_dim:]

cos_sin = cos_sin_cache[
torch.add(positions, offsets) if offsets is not None else positions
]
cos, sin = cos_sin.chunk(2, dim=-1)
if is_neox_style:
# shape [batch_size, seq_len].
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# This is used by the Deepseek-V2 model
class TestGroupedTopK(CustomTestCase):
def _run_single_test(self, M, E, G, topk, topk_group, renormalize, dtype):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you may set a seed here as topk is unstable sort.

Comment on lines +151 to +163


def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., ::2]
x2 = x[..., 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this.

@yanbing-j yanbing-j requested a review from mingfeima May 23, 2025 05:54
@mingfeima mingfeima marked this pull request as ready for review May 23, 2025 06:21

import torch
import torch.nn.functional as F
from sgl_kernel.common_ops import silu_and_mul_cpu as silu_and_mul
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please change this

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

@yanbing-j yanbing-j May 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhyncs Thanks, has updated according to #6404 (comment).

@zhyncs
Copy link
Copy Markdown
Collaborator

zhyncs commented May 23, 2025

setup_cpu.py has been removed in #6404

@yanbing-j yanbing-j requested a review from zhyncs May 23, 2025 08:23
@zhyncs zhyncs merged commit d818966 into sgl-project:main May 23, 2025
@zhyncs
Copy link
Copy Markdown
Collaborator

zhyncs commented May 23, 2025

sgl-kernel git:(main) cp pyproject_cpu.toml pyproject.toml
uv build --wheel -Cbuild-dir=build . --color=always --no-build-isolation
pip3 install dist/sgl_kernel-0.1.4-cp310-cp310-linux_x86_64.whl --force-reinstall
pytest -v -srf ../test/srt/cpu/test_*
Building wheel...
*** scikit-build-core 0.11.3 using CMake 3.31.1 (wheel)
*** Configuring CMake...
loading initial cache file build/CMakeInit.txt
-- The C compiler identification is GNU 11.4.0
-- The CXX compiler identification is GNU 11.4.0
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Check for working C compiler: /usr/bin/x86_64-linux-gnu-gcc - skipped
-- Detecting C compile features
-- Detecting C compile features - done
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Check for working CXX compiler: /usr/bin/x86_64-linux-gnu-g++ - skipped
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Found Python: /usr/bin/python3 (found version "3.10.12") found components: Interpreter Development.Module
-- /usr/local/lib/python3.10/dist-packages/torch/share/cmake
-- Found CUDA: /usr/local/cuda (found version "12.4")
-- The CUDA compiler identification is NVIDIA 12.4.131 with host compiler GNU 11.4.0
-- Detecting CUDA compiler ABI info
-- Detecting CUDA compiler ABI info - done
-- Check for working CUDA compiler: /usr/local/cuda/bin/nvcc - skipped
-- Detecting CUDA compile features
-- Detecting CUDA compile features - done
-- Found CUDAToolkit: /usr/local/cuda/include (found version "12.4.131")
-- Performing Test CMAKE_HAVE_LIBC_PTHREAD
-- Performing Test CMAKE_HAVE_LIBC_PTHREAD - Success
-- Found Threads: TRUE
-- PyTorch: CUDA detected: 12.4
-- PyTorch: CUDA nvcc is: /usr/local/cuda/bin/nvcc
-- PyTorch: CUDA toolkit directory: /usr/local/cuda
-- PyTorch: Header version is: 12.4
-- Found Python: /usr/bin/python3 (found version "3.10.12") found components: Interpreter
CMake Warning at /usr/local/lib/python3.10/dist-packages/torch/share/cmake/Caffe2/public/cuda.cmake:140 (message):
  Failed to compute shorthash for libnvrtc.so
Call Stack (most recent call first):
  /usr/local/lib/python3.10/dist-packages/torch/share/cmake/Caffe2/Caffe2Config.cmake:86 (include)
  /usr/local/lib/python3.10/dist-packages/torch/share/cmake/Torch/TorchConfig.cmake:68 (find_package)
  CMakeLists.txt:19 (find_package)


CMake Warning (dev) at /usr/local/share/cmake-3.31/Modules/FindPackageHandleStandardArgs.cmake:441 (message):
  The package name passed to `find_package_handle_standard_args` (nvtx3) does
  not match the name of the calling package (Caffe2).  This can lead to
  problems in calling code that expects `find_package` result variables
  (e.g., `_FOUND`) to follow a certain pattern.
Call Stack (most recent call first):
  /usr/local/lib/python3.10/dist-packages/torch/share/cmake/Caffe2/public/cuda.cmake:178 (find_package_handle_standard_args)
  /usr/local/lib/python3.10/dist-packages/torch/share/cmake/Caffe2/Caffe2Config.cmake:86 (include)
  /usr/local/lib/python3.10/dist-packages/torch/share/cmake/Torch/TorchConfig.cmake:68 (find_package)
  CMakeLists.txt:19 (find_package)
This warning is for project developers.  Use -Wno-dev to suppress it.

-- Could NOT find nvtx3 (missing: nvtx3_dir)
-- USE_CUDNN is set to 0. Compiling without cuDNN support
-- USE_CUSPARSELT is set to 0. Compiling without cuSPARSELt support
CMake Warning at /usr/local/lib/python3.10/dist-packages/torch/share/cmake/Caffe2/public/cuda.cmake:184 (message):
  Cannot find NVTX3, find old NVTX instead
Call Stack (most recent call first):
  /usr/local/lib/python3.10/dist-packages/torch/share/cmake/Caffe2/Caffe2Config.cmake:86 (include)
  /usr/local/lib/python3.10/dist-packages/torch/share/cmake/Torch/TorchConfig.cmake:68 (find_package)
  CMakeLists.txt:19 (find_package)


-- USE_CUDSS is set to 0. Compiling without cuDSS support
-- USE_CUFILE is set to 0. Compiling without cuFile support
-- Autodetected CUDA architecture(s):  9.0 9.0 9.0 9.0 9.0 9.0 9.0 9.0
-- Added CUDA NVCC flags for: -gencode;arch=compute_90,code=sm_90
CMake Warning at /usr/local/lib/python3.10/dist-packages/torch/share/cmake/Torch/TorchConfig.cmake:22 (message):
  static library kineto_LIBRARY-NOTFOUND not found.
Call Stack (most recent call first):
  /usr/local/lib/python3.10/dist-packages/torch/share/cmake/Torch/TorchConfig.cmake:121 (append_torchlib_if_found)
  CMakeLists.txt:19 (find_package)


-- Found Torch: /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch.so
CMake Warning (dev) at /usr/local/share/cmake-3.31/Modules/FindPython/Support.cmake:4178 (cmake_parse_arguments):
  The USE_SABI keyword was followed by an empty string or no value at all.
  Policy CMP0174 is not set, so cmake_parse_arguments() will unset the
  PYTHON_ADD_LIBRARY_USE_SABI variable rather than setting it to an empty
  string.
Call Stack (most recent call first):
  /usr/local/share/cmake-3.31/Modules/FindPython.cmake:692 (__Python_add_library)
  CMakeLists.txt:50 (Python_add_library)
This warning is for project developers.  Use -Wno-dev to suppress it.

-- Configuring done (22.9s)
-- Generating done (0.0s)
-- Build files have been written to: /sgl-workspace/tmp/sglang/sgl-kernel/build
*** Building project with Ninja...
[1/18] Building CXX object CMakeFiles/common_ops.dir/activation.cpp.o
[2/18] Building CXX object CMakeFiles/common_ops.dir/bmm.cpp.o
[3/18] Building CXX object CMakeFiles/common_ops.dir/gemm_int8.cpp.o
[4/18] Building CXX object CMakeFiles/common_ops.dir/rope.cpp.o
[5/18] Building CXX object CMakeFiles/common_ops.dir/moe_fp8.cpp.o
[6/18] Building CXX object CMakeFiles/common_ops.dir/norm.cpp.o
[7/18] Building CXX object CMakeFiles/common_ops.dir/moe_int8.cpp.o
[8/18] Building CXX object CMakeFiles/common_ops.dir/gemm.cpp.o
[9/18] Building CXX object CMakeFiles/common_ops.dir/qkv_proj.cpp.o
[10/18] Building CXX object CMakeFiles/common_ops.dir/gemm_fp8.cpp.o
[11/18] Building CXX object CMakeFiles/common_ops.dir/extend.cpp.o
[12/18] Building CXX object CMakeFiles/common_ops.dir/moe.cpp.o
[13/18] Building CXX object CMakeFiles/common_ops.dir/shm.cpp.o
/sgl-workspace/tmp/sglang/sgl-kernel/csrc/cpu/shm.cpp: In function ‘void shm_initialize(int, int, char*, char*)’:
/sgl-workspace/tmp/sglang/sgl-kernel/csrc/cpu/shm.cpp:412:41: warning: ‘%d’ directive output may be truncated writing between 1 and 11 bytes into a region of size between 0 and 999 [-Wformat-truncation=]
  412 |   snprintf(shm_name, NAME_BUF_SIZE, "%s_%d", shm_name_prefix, rank);
      |                                         ^~
In file included from /usr/include/stdio.h:894,
                 from /usr/include/c++/11/cstdio:42,
                 from /usr/include/c++/11/ext/string_conversions.h:43,
                 from /usr/include/c++/11/bits/basic_string.h:6608,
                 from /usr/include/c++/11/string:55,
                 from /usr/include/c++/11/bits/locale_classes.h:40,
                 from /usr/include/c++/11/bits/ios_base.h:41,
                 from /usr/include/c++/11/ios:42,
                 from /usr/include/c++/11/ostream:38,
                 from /usr/local/lib/python3.10/dist-packages/torch/include/c10/core/DeviceType.h:13,
                 from /usr/local/lib/python3.10/dist-packages/torch/include/c10/core/Device.h:3,
                 from /usr/local/lib/python3.10/dist-packages/torch/include/ATen/core/TensorBody.h:11,
                 from /usr/local/lib/python3.10/dist-packages/torch/include/ATen/core/Tensor.h:3,
                 from /usr/local/lib/python3.10/dist-packages/torch/include/ATen/Tensor.h:3,
                 from /usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/autograd/function_hook.h:3,
                 from /usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/autograd/cpp_hook.h:2,
                 from /usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/autograd/variable.h:6,
                 from /usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/autograd/autograd.h:3,
                 from /usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include/torch/autograd.h:3,
                 from /usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include/torch/all.h:7,
                 from /sgl-workspace/tmp/sglang/sgl-kernel/csrc/cpu/shm.h:1,
                 from /sgl-workspace/tmp/sglang/sgl-kernel/csrc/cpu/shm.cpp:1:
/usr/include/x86_64-linux-gnu/bits/stdio2.h:71:35: note: ‘__builtin___snprintf_chk’ output between 3 and 1012 bytes into a destination of size 1000
   71 |   return __builtin___snprintf_chk (__s, __n, __USE_FORTIFY_LEVEL - 1,
      |          ~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
   72 |                                    __glibc_objsize (__s), __fmt,
      |                                    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
   73 |                                    __va_arg_pack ());
      |                                    ~~~~~~~~~~~~~~~~~
/sgl-workspace/tmp/sglang/sgl-kernel/csrc/cpu/shm.cpp:428:45: warning: ‘%d’ directive output may be truncated writing between 1 and 10 bytes into a region of size between 0 and 999 [-Wformat-truncation=]
  428 |       snprintf(shm_name, NAME_BUF_SIZE, "%s_%d", shm_name_prefix, i);
      |                                             ^~
/sgl-workspace/tmp/sglang/sgl-kernel/csrc/cpu/shm.cpp:428:41: note: directive argument in the range [0, 2147483646]
  428 |       snprintf(shm_name, NAME_BUF_SIZE, "%s_%d", shm_name_prefix, i);
      |                                         ^~~~~~~
In file included from /usr/include/stdio.h:894,
                 from /usr/include/c++/11/cstdio:42,
                 from /usr/include/c++/11/ext/string_conversions.h:43,
                 from /usr/include/c++/11/bits/basic_string.h:6608,
                 from /usr/include/c++/11/string:55,
                 from /usr/include/c++/11/bits/locale_classes.h:40,
                 from /usr/include/c++/11/bits/ios_base.h:41,
                 from /usr/include/c++/11/ios:42,
                 from /usr/include/c++/11/ostream:38,
                 from /usr/local/lib/python3.10/dist-packages/torch/include/c10/core/DeviceType.h:13,
                 from /usr/local/lib/python3.10/dist-packages/torch/include/c10/core/Device.h:3,
                 from /usr/local/lib/python3.10/dist-packages/torch/include/ATen/core/TensorBody.h:11,
                 from /usr/local/lib/python3.10/dist-packages/torch/include/ATen/core/Tensor.h:3,
                 from /usr/local/lib/python3.10/dist-packages/torch/include/ATen/Tensor.h:3,
                 from /usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/autograd/function_hook.h:3,
                 from /usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/autograd/cpp_hook.h:2,
                 from /usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/autograd/variable.h:6,
                 from /usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/autograd/autograd.h:3,
                 from /usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include/torch/autograd.h:3,
                 from /usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include/torch/all.h:7,
                 from /sgl-workspace/tmp/sglang/sgl-kernel/csrc/cpu/shm.h:1,
                 from /sgl-workspace/tmp/sglang/sgl-kernel/csrc/cpu/shm.cpp:1:
/usr/include/x86_64-linux-gnu/bits/stdio2.h:71:35: note: ‘__builtin___snprintf_chk’ output between 3 and 1011 bytes into a destination of size 1000
   71 |   return __builtin___snprintf_chk (__s, __n, __USE_FORTIFY_LEVEL - 1,
      |          ~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
   72 |                                    __glibc_objsize (__s), __fmt,
      |                                    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
   73 |                                    __va_arg_pack ());
      |                                    ~~~~~~~~~~~~~~~~~
[14/18] Building CXX object CMakeFiles/common_ops.dir/topk.cpp.o
[15/18] Building CXX object CMakeFiles/common_ops.dir/interface.cpp.o
/sgl-workspace/tmp/sglang/sgl-kernel/csrc/cpu/interface.cpp: In function ‘void initialize(int64_t, int64_t)’:
/sgl-workspace/tmp/sglang/sgl-kernel/csrc/cpu/interface.cpp:38:19: warning: ISO C++ forbids converting a string constant to ‘char*’ [-Wwrite-strings]
   38 |     addr_string = "";
      |                   ^~
/sgl-workspace/tmp/sglang/sgl-kernel/csrc/cpu/interface.cpp:42:19: warning: ISO C++ forbids converting a string constant to ‘char*’ [-Wwrite-strings]
   42 |     port_string = "";
      |                   ^~
[16/18] Building CXX object CMakeFiles/common_ops.dir/decode.cpp.o
[17/18] Building CXX object CMakeFiles/common_ops.dir/torch_extension_cpu.cpp.o
[18/18] Linking CXX shared module common_ops.cpython-310-x86_64-linux-gnu.so
*** Installing project into wheel...
-- Install configuration: "Release"
-- Installing: /tmp/tmp_kntgk32/wheel/platlib/sgl_kernel/common_ops.cpython-310-x86_64-linux-gnu.so
-- Set non-toolchain portion of runtime path of "/tmp/tmp_kntgk32/wheel/platlib/sgl_kernel/common_ops.cpython-310-x86_64-linux-gnu.so" to ""
*** Making wheel...
*** Created sgl_kernel-0.1.4-cp310-cp310-linux_x86_64.whl
Successfully built dist/sgl_kernel-0.1.4-cp310-cp310-linux_x86_64.whl
Processing ./dist/sgl_kernel-0.1.4-cp310-cp310-linux_x86_64.whl
Installing collected packages: sgl-kernel
  Attempting uninstall: sgl-kernel
    Found existing installation: sgl-kernel 0.1.2.post1
    Uninstalling sgl-kernel-0.1.2.post1:
      Successfully uninstalled sgl-kernel-0.1.2.post1
Successfully installed sgl-kernel-0.1.4
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.
=================================================================================== test session starts ===================================================================================
platform linux -- Python 3.10.12, pytest-8.3.5, pluggy-1.6.0 -- /usr/bin/python3
cachedir: .pytest_cache
rootdir: /sgl-workspace/tmp/sglang
plugins: anyio-4.9.0, typeguard-4.4.2
collected 19 items

../test/srt/cpu/test_activation.py::TestActivation::test_activation PASSED
../test/srt/cpu/test_decode.py::TestDecodeAttention::test_grouped_decode_attention PASSED
../test/srt/cpu/test_extend.py::TestExtendAttention::test_extend_attention PASSED
../test/srt/cpu/test_gemm.py::TestGemm::test_bf16_gemm PASSED
../test/srt/cpu/test_gemm.py::TestGemm::test_fp8_gemm PASSED
../test/srt/cpu/test_gemm.py::TestGemm::test_int8_gemm PASSED
../test/srt/cpu/test_moe.py::TestFusedExperts::test_bf16_moe PASSED
../test/srt/cpu/test_moe.py::TestFusedExperts::test_fp8_moe PASSED
../test/srt/cpu/test_moe.py::TestFusedExperts::test_int8_moe PASSED
../test/srt/cpu/test_norm.py::TestNorm::test_norm PASSED
../test/srt/cpu/test_qkv_proj_with_rope.py::TestQKVProjWithROPE::test_bf16_qkv_proj_with_rope PASSED
../test/srt/cpu/test_qkv_proj_with_rope.py::TestQKVProjWithROPE::test_fp8_qkv_proj_with_rope PASSED
../test/srt/cpu/test_qkv_proj_with_rope.py::TestQKVProjWithROPE::test_int8_qkv_proj_with_rope PASSED
../test/srt/cpu/test_rope.py::TestROPE::test_deepseek_v2_rope PASSED
../test/srt/cpu/test_shared_expert.py::TestSharedExpert::test_bf16_shared_expert PASSED
../test/srt/cpu/test_shared_expert.py::TestSharedExpert::test_fp8_shared_expert PASSED
../test/srt/cpu/test_shared_expert.py::TestSharedExpert::test_int8_shared_expert PASSED
../test/srt/cpu/test_topk.py::TestGroupedTopK::test_grouped_topk PASSED
../test/srt/cpu/test_topk.py::TestBiasedGroupedTopK::test_biased_grouped_topk PASSED

=================================================================================== 19 passed in 55.36s ===================================================================================
➜  sgl-kernel git:(main) ✗ lscpu
Architecture:             x86_64
  CPU op-mode(s):         32-bit, 64-bit
  Address sizes:          52 bits physical, 57 bits virtual
  Byte Order:             Little Endian
CPU(s):                   224
  On-line CPU(s) list:    0-223
Vendor ID:                GenuineIntel
  Model name:             Intel(R) Xeon(R) Platinum 8480C
    CPU family:           6
    Model:                143
    Thread(s) per core:   2
    Core(s) per socket:   56
    Socket(s):            2
    Stepping:             8
    CPU max MHz:          3800.0000
    CPU min MHz:          800.0000
    BogoMIPS:             4000.00
    Flags:                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc
                           art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma
                           cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 c
                          dp_l3 invpcid_single intel_ppin cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi
                          2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cq
                          m_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req av
                          x512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movd
                          ir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities
Virtualization features:
  Virtualization:         VT-x
Caches (sum of all):
  L1d:                    5.3 MiB (112 instances)
  L1i:                    3.5 MiB (112 instances)
  L2:                     224 MiB (112 instances)
  L3:                     210 MiB (2 instances)
NUMA:
  NUMA node(s):           2
  NUMA node0 CPU(s):      0-55,112-167
  NUMA node1 CPU(s):      56-111,168-223
Vulnerabilities:
  Gather data sampling:   Not affected
  Itlb multihit:          Not affected
  L1tf:                   Not affected
  Mds:                    Not affected
  Meltdown:               Not affected
  Mmio stale data:        Not affected
  Reg file data sampling: Not affected
  Retbleed:               Not affected
  Spec rstack overflow:   Not affected
  Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl and seccomp
  Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
  Spectre v2:             Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
  Srbds:                  Not affected
  Tsx async abort:        Not affected

Layssy pushed a commit to Layssy/sglang-iaas that referenced this pull request Jun 9, 2025
xwu-intel pushed a commit to xwu-intel/sglang that referenced this pull request Jun 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cpu cpu backend performance optimization intel sgl-kernel

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants