Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 165 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import torch
from torch.testing._internal.common_utils import TestCase, IS_FBCODE
from torch.testing._internal.optests import opcheck
Expand All @@ -6,6 +7,14 @@
import unittest
from parameterized import parameterized
import pytest
from torchao.quantization.utils import (
get_groupwise_affine_qparams,
groupwise_affine_quantize_tensor_from_qparams,
groupwise_affine_dequantize_tensor_from_qparams,
pack_tinygemm_scales_and_zeros,
unpack_tinygemm_scales_and_zeros
)
import torchao.quantization

try:
import torchao.ops
Expand Down Expand Up @@ -55,6 +64,162 @@ def test_fp6_llm_linear_correctness(self, BS, OC, IC, splitK):
relative_error = error / results_fp16.abs()
assert relative_error.mean() < 1e-2

## Tests for `unpack_int4_packed`
kTileSizeN = 8
kTileSizeK = 16

SHAPES = [
(4096, 4096),
# Llama 2 GEMM shapes
(4096, 11008),
(11008, 4096),
# Llama 3 GEMM shapes
(4096, 14336),
(14336, 4096),
]
INNERKTILES = [2, 4, 8]
QGROUP_SIZES = [32, 64, 128, 256]
TEST_CONFIGS_UNPACK = list(itertools.product(SHAPES, INNERKTILES))
TEST_CONFIGS_DEQUANT = list(itertools.product(SHAPES, INNERKTILES, QGROUP_SIZES))

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels")
@pytest.mark.parametrize("shape, innerKTiles", TEST_CONFIGS_UNPACK, ids=str)
def test_unpack_tensor_core_tiled_layout_correctness(shape, innerKTiles):
N, K = shape
assert K % (innerKTiles * kTileSizeK) == 0 and N % kTileSizeN == 0

t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda")
packed_w = torch.ops.aten._convert_weight_to_int4pack(t, innerKTiles)
unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed_w, innerKTiles)
assert torch.allclose(t, unpacked)
Comment thread
jerryzh168 marked this conversation as resolved.
Outdated

# TODO: Fix "test_aot_dispatch_dynamic" test failure
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels")
@pytest.mark.parametrize("shape, innerKTiles", TEST_CONFIGS_UNPACK , ids=str)
def test_unpack_tensor_core_tiled_layout_op(shape, innerKTiles):
test_utils = [
"test_schema",
"test_autograd_registration",
"test_faketensor",
"test_aot_dispatch_dynamic",
]
t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda")
packed_w = torch.ops.aten._convert_weight_to_int4pack(t, innerKTiles)

opcheck(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which pytorch version are you using? it seems this opcheck is moved to torch.library.opcheck: https://github.com/pytorch/pytorch/pull/124496/files

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch 2.5.0.dev20240624+cu121

torch.ops.torchao.unpack_tensor_core_tiled_layout,
(packed_w, innerKTiles),
test_utils=test_utils,
)

def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16):
n, k = q.shape
assert q.dtype == torch.int

n_groups = k // group_size
assert scales.shape[0] == n and scales.shape[1] == n_groups
assert scales.shape == zeros.shape

midpoint = 2 ** (nbits - 1)

#Convert fron u4 -> s4 and upcast to bfloat16
q = q.sub(midpoint).to(dtype)

# Dequantize
q = q.reshape(-1, group_size)
dq = q * scales.reshape(-1, 1) + zeros.reshape(-1, 1)

return dq.reshape(n, k)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels")
@pytest.mark.parametrize("shape, innerKTiles, group_size", TEST_CONFIGS_DEQUANT, ids=str)
def test_dequantize_tensor_core_tiled_layout_correctness(shape, innerKTiles, group_size):
Comment thread
jerryzh168 marked this conversation as resolved.
Outdated
n, k = shape
dtype = torch.bfloat16

# tinygemm params
nTileSize = 8
kTileSize = 16
nTiles = n // nTileSize
kTiles = k // (innerKTiles * kTileSize)
numThreads = 32

device = "cuda"

t = torch.randn(n, k, dtype=dtype, device=device)
scales, zeros = get_groupwise_affine_qparams(t, n_bit=4, groupsize=group_size, dtype=dtype)

# Quantize
q = groupwise_affine_quantize_tensor_from_qparams(
t, scales, zeros, n_bit=4, groupsize=group_size
)

# Pack to tensor core layout
packed = torch.ops.aten._convert_weight_to_int4pack(q, innerKTiles)
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)
q_groups = k // group_size
assert scales_and_zeros.shape == torch.Size([q_groups, n, 2])

dq_ao = groupwise_affine_dequantize_tensor_from_qparams(
q, scales, zeros, n_bit=4, groupsize=group_size
)

# Dequantize by passing in an identity matrix as the activation
a_eye = torch.eye(k, device=device, dtype=dtype)
dq_id = torch.ops.aten._weight_int4pack_mm(
a_eye,
packed,
group_size,
scales_and_zeros,
).t()

# Actual operation to test
dq_op = torchao.ops.dequantize_tensor_core_tiled_layout(packed, scales_and_zeros, group_size, innerKTiles)

# Compare results
diff_ao_id = (dq_id - dq_ao).abs().max()
diff_op_id = (dq_op - dq_id).abs().max()
diff_op_ao = (dq_op - dq_ao).abs().max()

# There are slight numerical differences when dequantizing with an identity matrix
# Since the `dequantize_int4` kernel relies on same underlying numerical conversions, this gives same
Comment thread
jerryzh168 marked this conversation as resolved.
Outdated
# numerical differences when compared to the `groupwise_affine_dequantize`

# Test that the `dequant` kernel gives same results as identity matrix-based dequant
assert diff_op_id == 0

# Test that the `dequant` kernel gives same numerical diffs as the `groupwise_affine_dequantize` when compared against the identity matrix
assert diff_op_ao == diff_ao_id
Comment thread
jerryzh168 marked this conversation as resolved.

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels")
@pytest.mark.parametrize("shape, innerKTiles, group_size", TEST_CONFIGS_DEQUANT, ids=str)
def test_dequantize_tensor_core_tiled_layout_op(shape, innerKTiles, group_size):
Comment thread
jerryzh168 marked this conversation as resolved.
Outdated
n, k = shape
device = "cuda"

q = torch.randint(0, 16, shape, dtype=torch.int, device=device)
packed_w = torch._convert_weight_to_int4pack(q, innerKTiles)
q_groups = k // group_size
scales = torch.randn(n, q_groups, dtype=torch.bfloat16, device=device)
zeros = torch.randn_like(scales)
scales_and_zeros = torchao.quantization.utils.pack_tinygemm_scales_and_zeros(scales, zeros)

test_utils = [
"test_schema",
"test_autograd_registration",
"test_faketensor",
"test_aot_dispatch_dynamic",
]
opcheck(
torch.ops.torchao.dequantize_tensor_core_tiled_layout,
(packed_w, scales_and_zeros, group_size, innerKTiles),
test_utils=test_utils,
)

if __name__ == "__main__":
unittest.main()
Loading