Skip to content
22 changes: 15 additions & 7 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,15 +885,23 @@ def __init__(
)

self.weight_block_size = None
if self.qkv_proj_with_rope_is_fp8:
assert (
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
== self.q_b_proj.quant_method.quant_config.weight_block_size
)
self.weight_block_size = (
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
if self.qkv_proj_with_rope_is_fp8 and _is_cpu and _is_cpu_amx_available:
assert getattr(
self.fused_qkv_a_proj_with_mqa.quant_method, "block_quant", False
) == getattr(self.q_b_proj.quant_method, "block_quant", False)
use_block_quant = getattr(
self.fused_qkv_a_proj_with_mqa.quant_method, "block_quant", False
)

if use_block_quant:
assert (
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
== self.q_b_proj.quant_method.quant_config.weight_block_size
)
self.weight_block_size = (
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
)
Comment on lines +932 to +946
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve readability and avoid redundant calls, store the result of the first getattr call in a variable and reuse it for both the assertion and the if condition. This makes the logic slightly cleaner and more efficient.

Suggested change
assert getattr(
self.fused_qkv_a_proj_with_mqa.quant_method, "block_quant", False
) == getattr(self.q_b_proj.quant_method, "block_quant", False)
use_block_quant = getattr(
self.fused_qkv_a_proj_with_mqa.quant_method, "block_quant", False
)
if use_block_quant:
assert (
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
== self.q_b_proj.quant_method.quant_config.weight_block_size
)
self.weight_block_size = (
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
)
use_block_quant = getattr(
self.fused_qkv_a_proj_with_mqa.quant_method, "block_quant", False
)
assert use_block_quant == getattr(
self.q_b_proj.quant_method, "block_quant", False
), "block_quant setting must be consistent for fused_qkv_a_proj_with_mqa and q_b_proj"
if use_block_quant:
assert (
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
== self.q_b_proj.quant_method.quant_config.weight_block_size
)
self.weight_block_size = (
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
)


def dispatch_attn_forward_method(
self, forward_batch: ForwardBatch
) -> AttnForwardMethod:
Expand Down
15 changes: 10 additions & 5 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2428,7 +2428,12 @@ def prepack_weight_if_needed(weight):
# OC % TILE_N == 0, where TILE_N = 16
# IC % TILE_K == 0, where TILE_K = 32
def dim_is_supported(weight):
return weight.size(0) % 16 == 0 and weight.size(1) % 32 == 0
TILE_N = 16
TILE_K = 32
ndim = weight.ndim
OC = weight.size(1) if ndim == 3 else weight.size(0)
IC = weight.size(2) if ndim == 3 else weight.size(1)
return OC % TILE_N == 0 and IC % TILE_K == 0


def _process_weight_after_loading(module, weight_names, transpose_dims=None) -> None:
Expand All @@ -2445,19 +2450,19 @@ def _process_weight_after_loading(module, weight_names, transpose_dims=None) ->
for i, weight_name in enumerate(weight_names):
weight_tensor = getattr(module, weight_name)

if transpose_dims and transpose_dims[i]:
weight_tensor = weight_tensor.transpose(*transpose_dims[i])

# We don't pack weight or use intel amx backend if any weight of this module has unsupported dim.
if not dim_is_supported(weight_tensor):
logger.warning(
f"Expects weight.size(0) % 16 == 0 and weight.size(1) % 32 == 0 "
f"but {weight_tensor.size(0)=} and {weight_tensor.size(1)=} in {module}. "
f"but {weight_name} {weight_tensor.size(0)=} and {weight_tensor.size(1)=} in {module}. "
f"{module} won't use intel amx backend."
)
module.use_intel_amx_backend = False
return

if transpose_dims and transpose_dims[i]:
weight_tensor = weight_tensor.transpose(*transpose_dims[i])

packed_weight = torch.nn.Parameter(
prepack_weight_if_needed(weight_tensor),
requires_grad=False,
Expand Down