44import pytest
55import torch
66from sgl_kernel import sgl_per_tensor_quant_fp8
7- from vllm import _custom_ops as ops
87
98from sglang .srt .utils import is_hip
109
1110is_hip_ = is_hip ()
1211fp8_type_ = torch .float8_e4m3fnuz if is_hip_ else torch .float8_e4m3fn
1312
1413
15- def vllm_scaled_fp8_quant (
16- input : torch .Tensor ,
17- scale : Optional [torch .Tensor ] = None ,
18- ) -> Tuple [torch .Tensor , torch .Tensor ]:
19- return ops .scaled_fp8_quant (input , scale )
20-
21-
2214def sglang_scaled_fp8_quant (
2315 input : torch .Tensor ,
2416 scale : Optional [torch .Tensor ] = None ,
@@ -34,6 +26,16 @@ def sglang_scaled_fp8_quant(
3426 return output , scale
3527
3628
29+ def torch_scaled_fp8_quant (tensor , inv_scale ):
30+ # The reference implementation that fully aligns to
31+ # the kernel being tested.
32+ finfo = torch .finfo (torch .float8_e4m3fn )
33+ scale = inv_scale .reciprocal ()
34+ qweight = (tensor .to (torch .float32 ) * scale ).clamp (min = finfo .min , max = finfo .max )
35+ qweight = qweight .to (torch .float8_e4m3fn )
36+ return qweight
37+
38+
3739@pytest .mark .parametrize (
3840 "num_tokens,hidden_dim" ,
3941 list (itertools .product ([128 , 256 , 512 ], [512 , 2048 , 4096 ])),
@@ -45,21 +47,19 @@ def test_per_tensor_quant_compare_implementations(
4547 device = torch .device ("cuda" )
4648 x = torch .rand ((num_tokens , hidden_dim ), dtype = torch .float16 , device = device )
4749
48- vllm_out , vllm_scale = vllm_scaled_fp8_quant (x )
4950 sglang_out , sglang_scale = sglang_scaled_fp8_quant (x )
51+ torch_out = torch_scaled_fp8_quant (x , sglang_scale )
5052
51- torch .testing .assert_close (vllm_scale , sglang_scale , rtol = 1e-3 , atol = 1e-3 )
5253 torch .testing .assert_close (
53- vllm_out .float (), sglang_out .float (), rtol = 1e-3 , atol = 1e-3
54+ sglang_out .float (), torch_out .float (), rtol = 1e-3 , atol = 1e-3
5455 )
5556
5657 scale = torch .rand (1 , dtype = torch .float32 , device = device )
57- vllm_out , vllm_scale = vllm_scaled_fp8_quant (x , scale )
5858 sglang_out , sglang_scale = sglang_scaled_fp8_quant (x , scale )
59+ torch_out = torch_scaled_fp8_quant (x , scale )
5960
60- torch .testing .assert_close (vllm_scale , sglang_scale , rtol = 1e-3 , atol = 1e-3 )
6161 torch .testing .assert_close (
62- vllm_out .float (), sglang_out .float (), rtol = 1e-3 , atol = 1e-3
62+ sglang_out .float (), torch_out .float (), rtol = 1e-3 , atol = 1e-3
6363 )
6464
6565
0 commit comments