@@ -23,8 +23,8 @@ def get_gguf_sample_tensors(
2323 return GGUFReader (sample_file ).tensors
2424
2525
26- DTYPES = [torch .half ]
27- # Hidden_size for testing, must match the sample file in HF repo,
26+ DTYPES = [torch .half , torch . bfloat16 , torch . float32
27+ ] # Hidden_size for testing, must match the sample file in HF repo,
2828# we have `hidden_size = 256, 1024` for test in HF repo currently.
2929HIDDEN_SIZES = [256 , 1024 ]
3030NUM_TOKENS = [7 , 83 , 128 , 2048 ] # Arbitrary values for testing
@@ -53,7 +53,7 @@ def get_gguf_sample_tensors(
5353
5454
5555@pytest .mark .parametrize ("hidden_size" , HIDDEN_SIZES )
56- @pytest .mark .parametrize ("dtype" , DTYPES )
56+ @pytest .mark .parametrize ("dtype" , [ torch . half ] )
5757@pytest .mark .parametrize ("quant_type" , QUANT_TYPES )
5858@torch .inference_mode ()
5959def test_dequantize (hidden_size : int , dtype : torch .dtype ,
@@ -123,7 +123,13 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype,
123123 ref_output = x @ weight .T
124124
125125 qweight = torch .tensor (tensor .data , device = "cuda" )
126- output = ops .ggml_mul_mat_a8 (qweight , x , quant_type ,
127- qweight .shape [0 ]).to (dtype )
128-
129- torch .testing .assert_close (output , ref_output , atol = 1 , rtol = 1e-1 )
126+ output = ops .ggml_mul_mat_a8 (qweight , x , quant_type , qweight .shape [0 ])
127+ atols = {torch .half : 1 , torch .bfloat16 : 1.5 , torch .float : 1.2 }
128+ # test matrix has inputs centered around 0 and lower precision from
129+ # bfloat16 tends to accumulate and can greatly inflate rtol
130+ # since outputs are also very close to 0
131+ rtols = {torch .half : 1e-1 , torch .bfloat16 : 1e4 , torch .float : 2e1 }
132+ torch .testing .assert_close (output ,
133+ ref_output ,
134+ atol = atols [dtype ],
135+ rtol = rtols [dtype ])
0 commit comments