Skip to content

Commit d2e507d

Browse files
authored
[Misc] clean up vllm in sgl-kernel test (#5189)
1 parent 61970b0 commit d2e507d

4 files changed

Lines changed: 25 additions & 40 deletions

File tree

sgl-kernel/tests/test_awq_dequant.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import pytest
55
import torch
66
from sgl_kernel import awq_dequantize
7-
from vllm import _custom_ops as ops
87

98

109
def reverse_awq_order(t: torch.Tensor):
@@ -58,12 +57,6 @@ def awq_dequantize_torch(
5857
return (iweights - zeros) * scales
5958

6059

61-
def vllm_awq_dequantize(
62-
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
63-
) -> torch.Tensor:
64-
return ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
65-
66-
6760
def sglang_awq_dequantize(
6861
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
6962
) -> torch.Tensor:
@@ -110,21 +103,13 @@ def test_awq_dequant_compare_implementations(
110103
)
111104

112105
# Run both implementations
113-
vllm_out = vllm_awq_dequantize(qweight, scales.to(torch.float16), qzeros)
114106
torch_out = awq_dequantize_torch(qweight, scales, qzeros, group_size)
115107
sglang_out = sglang_awq_dequantize(qweight, scales, qzeros)
116108

117109
# Compare results
118110
torch.testing.assert_close(
119111
torch_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5
120112
)
121-
if not is_bf16_act:
122-
torch.testing.assert_close(
123-
vllm_out.to(torch.float32),
124-
sglang_out.to(torch.float32),
125-
rtol=1e-3,
126-
atol=1e-5,
127-
)
128113

129114

130115
if __name__ == "__main__":

sgl-kernel/tests/test_int8_gemm.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import pytest
22
import torch
33
from sgl_kernel import int8_scaled_mm
4-
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
54

65

76
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
@@ -28,9 +27,7 @@ def _test_accuracy_once(M, N, K, with_bias, out_dtype, device):
2827
bias = None
2928
o = int8_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
3029
o1 = torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
31-
o2 = vllm_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
3230
torch.testing.assert_close(o, o1)
33-
torch.testing.assert_close(o, o2)
3431
print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK")
3532

3633

sgl-kernel/tests/test_per_tensor_quant_fp8.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,13 @@
44
import pytest
55
import torch
66
from sgl_kernel import sgl_per_tensor_quant_fp8
7-
from vllm import _custom_ops as ops
87

98
from sglang.srt.utils import is_hip
109

1110
is_hip_ = is_hip()
1211
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
1312

1413

15-
def vllm_scaled_fp8_quant(
16-
input: torch.Tensor,
17-
scale: Optional[torch.Tensor] = None,
18-
) -> Tuple[torch.Tensor, torch.Tensor]:
19-
return ops.scaled_fp8_quant(input, scale)
20-
21-
2214
def sglang_scaled_fp8_quant(
2315
input: torch.Tensor,
2416
scale: Optional[torch.Tensor] = None,
@@ -34,6 +26,16 @@ def sglang_scaled_fp8_quant(
3426
return output, scale
3527

3628

29+
def torch_scaled_fp8_quant(tensor, inv_scale):
30+
# The reference implementation that fully aligns to
31+
# the kernel being tested.
32+
finfo = torch.finfo(torch.float8_e4m3fn)
33+
scale = inv_scale.reciprocal()
34+
qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max)
35+
qweight = qweight.to(torch.float8_e4m3fn)
36+
return qweight
37+
38+
3739
@pytest.mark.parametrize(
3840
"num_tokens,hidden_dim",
3941
list(itertools.product([128, 256, 512], [512, 2048, 4096])),
@@ -45,21 +47,19 @@ def test_per_tensor_quant_compare_implementations(
4547
device = torch.device("cuda")
4648
x = torch.rand((num_tokens, hidden_dim), dtype=torch.float16, device=device)
4749

48-
vllm_out, vllm_scale = vllm_scaled_fp8_quant(x)
4950
sglang_out, sglang_scale = sglang_scaled_fp8_quant(x)
51+
torch_out = torch_scaled_fp8_quant(x, sglang_scale)
5052

51-
torch.testing.assert_close(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-3)
5253
torch.testing.assert_close(
53-
vllm_out.float(), sglang_out.float(), rtol=1e-3, atol=1e-3
54+
sglang_out.float(), torch_out.float(), rtol=1e-3, atol=1e-3
5455
)
5556

5657
scale = torch.rand(1, dtype=torch.float32, device=device)
57-
vllm_out, vllm_scale = vllm_scaled_fp8_quant(x, scale)
5858
sglang_out, sglang_scale = sglang_scaled_fp8_quant(x, scale)
59+
torch_out = torch_scaled_fp8_quant(x, scale)
5960

60-
torch.testing.assert_close(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-3)
6161
torch.testing.assert_close(
62-
vllm_out.float(), sglang_out.float(), rtol=1e-3, atol=1e-3
62+
sglang_out.float(), torch_out.float(), rtol=1e-3, atol=1e-3
6363
)
6464

6565

sgl-kernel/tests/test_per_token_quant_fp8.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,22 @@
44
import pytest
55
import torch
66
from sgl_kernel import sgl_per_token_quant_fp8
7-
from vllm import _custom_ops as ops
87

98
from sglang.srt.utils import is_hip
109

1110
is_hip_ = is_hip()
1211
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
1312

1413

15-
def vllm_per_token_quant_fp8(
16-
input: torch.Tensor,
17-
) -> Tuple[torch.Tensor, torch.Tensor]:
18-
return ops.scaled_fp8_quant(input, use_per_token_if_dynamic=True)
14+
def torch_per_token_quant_fp8(tensor, inv_scale):
15+
# The reference implementation that fully aligns to
16+
# the kernel being tested.
17+
finfo = torch.finfo(torch.float8_e4m3fn)
18+
inv_scale = inv_scale.view(-1, 1)
19+
scale = inv_scale.reciprocal()
20+
qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max)
21+
qweight = qweight.to(torch.float8_e4m3fn)
22+
return qweight
1923

2024

2125
def sglang_per_token_quant_fp8(
@@ -41,12 +45,11 @@ def test_per_token_quant_compare_implementations(
4145
device = torch.device("cuda")
4246
x = torch.rand((num_tokens, hidden_dim), dtype=torch.float16, device=device)
4347

44-
vllm_out, vllm_scale = vllm_per_token_quant_fp8(x)
4548
sglang_out, sglang_scale = sglang_per_token_quant_fp8(x)
49+
torch_out = torch_per_token_quant_fp8(x, sglang_scale)
4650

47-
torch.testing.assert_close(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-3)
4851
torch.testing.assert_close(
49-
vllm_out.float(), sglang_out.float(), rtol=1e-3, atol=1e-3
52+
sglang_out.float(), torch_out.float(), rtol=1e-3, atol=1e-3
5053
)
5154

5255

0 commit comments

Comments
 (0)