Skip to content

Commit c8649b6

Browse files
chunyuan-wchenxijun1029
authored andcommitted
[CPU] refine CPU integration code (sgl-project#7647)
1 parent d67fdef commit c8649b6

9 files changed

Lines changed: 141 additions & 116 deletions

File tree

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import logging
2+
3+
import torch
4+
5+
from sglang.srt.utils import cpu_has_amx_support
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
def amx_process_weight_after_loading(weight):
11+
if weight.device != torch.device("cpu"):
12+
return weight
13+
if not cpu_has_amx_support():
14+
return weight
15+
16+
return torch.ops.sgl_kernel.convert_weight_packed(weight)
17+
18+
19+
# TODO: currently gemm kernel has the below requirements:
20+
# OC % TILE_N == 0, where TILE_N = 16
21+
# IC % TILE_K == 0, where TILE_K = 32
22+
def dim_is_supported(weight):
23+
TILE_N = 16
24+
TILE_K = 32
25+
ndim = weight.ndim
26+
OC = weight.size(1) if ndim == 3 else weight.size(0)
27+
IC = weight.size(2) if ndim == 3 else weight.size(1)
28+
return OC % TILE_N == 0 and IC % TILE_K == 0
29+
30+
31+
def _amx_process_weight_after_loading(
32+
module, weight_names, transpose_dims=None
33+
) -> None:
34+
# Pack weight for get better performance on CPU
35+
devices = {getattr(module, weight_name).device for weight_name in weight_names}
36+
assert len(devices) == 1, f"Expects all weights to be on the same device"
37+
device = devices.pop()
38+
39+
if transpose_dims:
40+
assert len(weight_names) == len(
41+
transpose_dims
42+
), "len(weight_names) should be equal to len(transpose_dims)"
43+
44+
for i, weight_name in enumerate(weight_names):
45+
weight_tensor = getattr(module, weight_name)
46+
47+
if transpose_dims and transpose_dims[i]:
48+
weight_tensor = weight_tensor.transpose(*transpose_dims[i])
49+
50+
# We don't pack weight or use intel amx backend if any weight of this module has unsupported dim.
51+
if not dim_is_supported(weight_tensor):
52+
logger.warning(
53+
f"Unsupported dimension for prepacking for weight '{weight_name}' with shape {weight_tensor.shape} in {module}. "
54+
f"The derived (OC, IC) dimensions must be divisible by (16, 32). "
55+
)
56+
module.use_intel_amx_backend = False
57+
return
58+
59+
packed_weight = torch.nn.Parameter(
60+
amx_process_weight_after_loading(weight_tensor),
61+
requires_grad=False,
62+
)
63+
packed_weight.__dict__ = weight_tensor.__dict__
64+
setattr(module, weight_name, packed_weight)
65+
66+
module.use_intel_amx_backend = (
67+
device == torch.device("cpu") and cpu_has_amx_support()
68+
)
69+
70+
if (
71+
module.use_intel_amx_backend
72+
and hasattr(module, "bias")
73+
and module.bias is not None
74+
):
75+
module.bias = torch.nn.Parameter(module.bias.data.float(), requires_grad=False)
76+
77+
78+
class PackWeightMethod:
79+
def __init__(self, weight_names, transpose_dims=None):
80+
self.weight_names = weight_names
81+
self.transpose_dims = transpose_dims
82+
83+
def process_weights_after_loading(self, module) -> None:
84+
_amx_process_weight_after_loading(
85+
module, self.weight_names, self.transpose_dims
86+
)

python/sglang/srt/layers/linear.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
tensor_model_parallel_all_gather,
1818
tensor_model_parallel_all_reduce,
1919
)
20+
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
2021
from sglang.srt.layers.parameter import (
2122
BasevLLMParameter,
2223
BlockQuantScaleParameter,
@@ -31,10 +32,10 @@
3132
QuantizeMethodBase,
3233
)
3334
from sglang.srt.utils import (
34-
_process_weight_after_loading,
3535
cpu_has_amx_support,
3636
is_cpu,
3737
set_weight_attrs,
38+
use_intel_amx_backend,
3839
)
3940

4041
logger = logging.getLogger(__name__)
@@ -175,7 +176,7 @@ def create_weights(
175176

176177
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
177178
if _is_cpu and _is_cpu_amx_available:
178-
_process_weight_after_loading(layer, ["weight"])
179+
_amx_process_weight_after_loading(layer, ["weight"])
179180

180181
def apply(
181182
self,
@@ -184,7 +185,7 @@ def apply(
184185
bias: Optional[torch.Tensor] = None,
185186
) -> torch.Tensor:
186187

187-
if getattr(layer, "use_intel_amx_backend", False):
188+
if use_intel_amx_backend(layer):
188189
return torch.ops.sgl_kernel.weight_packed_linear(
189190
x, layer.weight, bias, True # is_vnni
190191
)

python/sglang/srt/layers/logits_processor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
ForwardBatch,
4343
ForwardMode,
4444
)
45-
from sglang.srt.utils import dump_to_file
45+
from sglang.srt.utils import dump_to_file, use_intel_amx_backend
4646

4747
logger = logging.getLogger(__name__)
4848

@@ -442,7 +442,7 @@ def _get_logits(
442442
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
443443

444444
if hasattr(lm_head, "weight"):
445-
if getattr(lm_head, "use_intel_amx_backend", False):
445+
if use_intel_amx_backend(lm_head):
446446
logits = torch.ops.sgl_kernel.weight_packed_linear(
447447
hidden_states.to(lm_head.weight.dtype),
448448
lm_head.weight,

python/sglang/srt/layers/moe/fused_moe_triton/layer.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,20 @@
1212
get_tensor_model_parallel_world_size,
1313
tensor_model_parallel_all_reduce,
1414
)
15+
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
1516
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
1617
from sglang.srt.layers.moe.topk import select_experts
1718
from sglang.srt.layers.quantization.base_config import (
1819
QuantizationConfig,
1920
QuantizeMethodBase,
2021
)
2122
from sglang.srt.utils import (
22-
_process_weight_after_loading,
2323
cpu_has_amx_support,
2424
get_bool_env_var,
2525
is_cpu,
2626
is_hip,
2727
set_weight_attrs,
28+
use_intel_amx_backend,
2829
)
2930

3031
if torch.cuda.is_available():
@@ -129,7 +130,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
129130

130131
# Pack weight for get better performance on CPU
131132
if _is_cpu and _is_cpu_amx_available:
132-
_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
133+
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
133134

134135
return
135136

@@ -264,10 +265,7 @@ def forward_cpu(
264265
) -> torch.Tensor:
265266
assert activation == "silu", f"activation = {activation} is not supported."
266267

267-
if (
268-
getattr(layer, "use_intel_amx_backend", False)
269-
and not apply_router_weight_on_input
270-
):
268+
if use_intel_amx_backend(layer) and not apply_router_weight_on_input:
271269
topk_weights, topk_ids = select_experts(
272270
hidden_states=x,
273271
router_logits=router_logits,

python/sglang/srt/layers/quantization/fp8.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def dummy_func(*args, **kwargs):
2727

2828

2929
from sglang.srt.distributed import get_tensor_model_parallel_world_size
30+
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
3031
from sglang.srt.layers.linear import (
3132
LinearBase,
3233
LinearMethodBase,
@@ -64,7 +65,6 @@ def dummy_func(*args, **kwargs):
6465
)
6566
from sglang.srt.layers.utils import is_sm100_supported
6667
from sglang.srt.utils import (
67-
_process_weight_after_loading,
6868
cpu_has_amx_support,
6969
get_bool_env_var,
7070
is_cpu,
@@ -74,6 +74,7 @@ def dummy_func(*args, **kwargs):
7474
log_info_on_rank0,
7575
print_warning_once,
7676
set_weight_attrs,
77+
use_intel_amx_backend,
7778
)
7879

7980
_is_hip = is_hip()
@@ -335,7 +336,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
335336
assert (
336337
_is_cpu_amx_available
337338
), "Fp8LinearMethod on CPU requires that CPU has AMX support"
338-
_process_weight_after_loading(layer, ["weight"])
339+
_amx_process_weight_after_loading(layer, ["weight"])
339340
return
340341
else:
341342
weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
@@ -433,7 +434,7 @@ def apply(
433434
)
434435

435436
if self.block_quant:
436-
if getattr(layer, "use_intel_amx_backend", False):
437+
if use_intel_amx_backend(layer):
437438
return torch.ops.sgl_kernel.fp8_scaled_mm_cpu(
438439
x,
439440
layer.weight,
@@ -769,7 +770,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
769770
assert (
770771
_is_cpu_amx_available
771772
), "Fp8MoEMethod on CPU requires that CPU has AMX support"
772-
_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
773+
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
773774

774775
return
775776

@@ -996,7 +997,7 @@ def apply(
996997
routed_scaling_factor=routed_scaling_factor,
997998
)
998999

999-
if getattr(layer, "use_intel_amx_backend", False):
1000+
if use_intel_amx_backend(layer):
10001001
return torch.ops.sgl_kernel.fused_experts_cpu(
10011002
x,
10021003
layer.w13_weight,

python/sglang/srt/layers/quantization/w8a8_int8.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from torch.nn.parameter import Parameter
55

66
from sglang.srt.distributed import get_tensor_model_parallel_world_size
7+
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
78
from sglang.srt.layers.linear import LinearMethodBase
89
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
910
from sglang.srt.layers.quantization.base_config import (
@@ -12,11 +13,11 @@
1213
)
1314
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
1415
from sglang.srt.utils import (
15-
_process_weight_after_loading,
1616
cpu_has_amx_support,
1717
is_cpu,
1818
is_cuda,
1919
set_weight_attrs,
20+
use_intel_amx_backend,
2021
)
2122

2223
_is_cuda = is_cuda()
@@ -84,7 +85,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
8485
assert (
8586
_is_cpu_amx_available
8687
), "W8A8Int8LinearMethod on CPU requires that CPU has AMX support"
87-
_process_weight_after_loading(layer, ["weight"])
88+
_amx_process_weight_after_loading(layer, ["weight"])
8889
return
8990

9091
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
@@ -127,7 +128,7 @@ def apply(
127128
x: torch.Tensor,
128129
bias: Optional[torch.Tensor] = None,
129130
):
130-
if getattr(layer, "use_intel_amx_backend", False):
131+
if use_intel_amx_backend(layer):
131132
return torch.ops.sgl_kernel.int8_scaled_mm_with_quant(
132133
x,
133134
layer.weight,
@@ -235,7 +236,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
235236
assert (
236237
_is_cpu_amx_available
237238
), "W8A8Int8MoEMethod on CPU requires that CPU has AMX support"
238-
_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
239+
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
239240
return
240241

241242
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
@@ -284,7 +285,7 @@ def apply(
284285
routed_scaling_factor=routed_scaling_factor,
285286
)
286287

287-
if getattr(layer, "use_intel_amx_backend", False):
288+
if use_intel_amx_backend(layer):
288289
return torch.ops.sgl_kernel.fused_experts_cpu(
289290
x,
290291
layer.w13_weight,

python/sglang/srt/layers/vocab_parallel_embedding.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,15 @@
1313
get_tensor_model_parallel_world_size,
1414
tensor_model_parallel_all_reduce,
1515
)
16+
from sglang.srt.layers.amx_utils import PackWeightMethod
1617
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
1718
from sglang.srt.layers.parameter import BasevLLMParameter
1819
from sglang.srt.layers.quantization.base_config import (
1920
QuantizationConfig,
2021
QuantizeMethodBase,
2122
method_has_implemented_embedding,
2223
)
23-
from sglang.srt.utils import (
24-
PackWeightMethod,
25-
cpu_has_amx_support,
26-
is_cpu,
27-
set_weight_attrs,
28-
)
24+
from sglang.srt.utils import cpu_has_amx_support, is_cpu, set_weight_attrs
2925

3026
DEFAULT_VOCAB_PADDING_SIZE = 64
3127

0 commit comments

Comments
 (0)