Skip to content

Commit b2a189d

Browse files
strgrbZhang Kaihong
andauthored
use sglang_per_token_group_quant_fp8 from sgl-kernel instead of trion kernel (#5473)
Co-authored-by: Zhang Kaihong <[email protected]>
1 parent f28d829 commit b2a189d

2 files changed

Lines changed: 25 additions & 6 deletions

File tree

python/sglang/srt/layers/quantization/fp8_kernel.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -275,18 +275,37 @@ def sglang_per_token_group_quant_fp8(
275275
x: torch.Tensor,
276276
group_size: int,
277277
eps: float = 1e-10,
278+
column_major_scales: bool = False,
279+
scale_tma_aligned: bool = False,
278280
):
279281
assert (
280282
x.shape[-1] % group_size == 0
281283
), "the last dimension of `x` cannot be divisible by `group_size`"
282284
assert x.is_contiguous(), "`x` is not contiguous"
283285

284286
x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
285-
x_s = torch.empty(
286-
x.shape[:-1] + (x.shape[-1] // group_size,),
287-
device=x.device,
288-
dtype=torch.float32,
289-
)
287+
if column_major_scales:
288+
if scale_tma_aligned:
289+
# aligned to 4 * sizeof(float)
290+
aligned_size = (x.shape[-2] + 3) // 4 * 4
291+
x_s = torch.empty(
292+
x.shape[:-2] + (x.shape[-1] // group_size, aligned_size),
293+
device=x.device,
294+
dtype=torch.float32,
295+
).permute(-1, -2)[: x.shape[-2], :]
296+
else:
297+
x_s = torch.empty(
298+
(x.shape[-1] // group_size,) + x.shape[:-1],
299+
device=x.device,
300+
dtype=torch.float32,
301+
).permute(-1, -2)
302+
else:
303+
x_s = torch.empty(
304+
x.shape[:-1] + (x.shape[-1] // group_size,),
305+
device=x.device,
306+
dtype=torch.float32,
307+
)
308+
290309
sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
291310

292311
return x_q, x_s

python/sglang/srt/layers/quantization/fp8_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def apply_w8a8_block_fp8_linear(
141141
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
142142
else:
143143
if _enable_jit_deepgemm:
144-
q_input, x_scale = per_token_group_quant_fp8(
144+
q_input, x_scale = sglang_per_token_group_quant_fp8(
145145
input_2d,
146146
block_size[1],
147147
column_major_scales=True,

0 commit comments

Comments
 (0)