Skip to content

Commit 3a21d83

Browse files
committed
add tests
Signed-off-by: SzymonOzog <[email protected]>
1 parent 8d57355 commit 3a21d83

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

tests/kernels/test_gguf.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def get_gguf_sample_tensors(
2323
return GGUFReader(sample_file).tensors
2424

2525

26-
DTYPES = [torch.half]
27-
# Hidden_size for testing, must match the sample file in HF repo,
26+
DTYPES = [torch.half, torch.bfloat16, torch.float32
27+
] # Hidden_size for testing, must match the sample file in HF repo,
2828
# we have `hidden_size = 256, 1024` for test in HF repo currently.
2929
HIDDEN_SIZES = [256, 1024]
3030
NUM_TOKENS = [7, 83, 128, 2048] # Arbitrary values for testing
@@ -53,7 +53,7 @@ def get_gguf_sample_tensors(
5353

5454

5555
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
56-
@pytest.mark.parametrize("dtype", DTYPES)
56+
@pytest.mark.parametrize("dtype", [torch.half])
5757
@pytest.mark.parametrize("quant_type", QUANT_TYPES)
5858
@torch.inference_mode()
5959
def test_dequantize(hidden_size: int, dtype: torch.dtype,
@@ -123,7 +123,13 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype,
123123
ref_output = x @ weight.T
124124

125125
qweight = torch.tensor(tensor.data, device="cuda")
126-
output = ops.ggml_mul_mat_a8(qweight, x, quant_type,
127-
qweight.shape[0]).to(dtype)
128-
129-
torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1)
126+
output = ops.ggml_mul_mat_a8(qweight, x, quant_type, qweight.shape[0])
127+
atols = {torch.half: 1, torch.bfloat16: 1.5, torch.float: 1.2}
128+
# test matrix has inputs centered around 0 and lower precision from
129+
# bfloat16 tends to accumulate and can greatly inflate rtol
130+
# since outputs are also very close to 0
131+
rtols = {torch.half: 1e-1, torch.bfloat16: 1e4, torch.float: 2e1}
132+
torch.testing.assert_close(output,
133+
ref_output,
134+
atol=atols[dtype],
135+
rtol=rtols[dtype])

0 commit comments

Comments
 (0)