-
Notifications
You must be signed in to change notification settings - Fork 502
[FEAT] Add custom CUDA tinygemm unpacker
#415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
dc5b10f
fff3e8a
39f23cf
e41b682
3a3d788
a2ca149
052d482
18c505f
d831a5e
48a8062
279b79a
b6ad9f7
612d8e3
9afa73e
f05c720
c666a18
e8ca817
e089ffb
d1bd61b
75df5f5
e90e280
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| import itertools | ||
| import torch | ||
| from torch.testing._internal.common_utils import TestCase, IS_FBCODE | ||
| from torch.testing._internal.optests import opcheck | ||
|
|
@@ -6,6 +7,14 @@ | |
| import unittest | ||
| from parameterized import parameterized | ||
| import pytest | ||
| from torchao.quantization.utils import ( | ||
| get_groupwise_affine_qparams, | ||
| groupwise_affine_quantize_tensor_from_qparams, | ||
| groupwise_affine_dequantize_tensor_from_qparams, | ||
| pack_tinygemm_scales_and_zeros, | ||
| unpack_tinygemm_scales_and_zeros | ||
| ) | ||
| import torchao.quantization | ||
|
|
||
| try: | ||
| import torchao.ops | ||
|
|
@@ -55,6 +64,162 @@ def test_fp6_llm_linear_correctness(self, BS, OC, IC, splitK): | |
| relative_error = error / results_fp16.abs() | ||
| assert relative_error.mean() < 1e-2 | ||
|
|
||
| ## Tests for `unpack_int4_packed` | ||
| kTileSizeN = 8 | ||
| kTileSizeK = 16 | ||
|
|
||
| SHAPES = [ | ||
| (4096, 4096), | ||
| # Llama 2 GEMM shapes | ||
| (4096, 11008), | ||
| (11008, 4096), | ||
| # Llama 3 GEMM shapes | ||
| (4096, 14336), | ||
| (14336, 4096), | ||
| ] | ||
| INNERKTILES = [2, 4, 8] | ||
| QGROUP_SIZES = [32, 64, 128, 256] | ||
| TEST_CONFIGS_UNPACK = list(itertools.product(SHAPES, INNERKTILES)) | ||
| TEST_CONFIGS_DEQUANT = list(itertools.product(SHAPES, INNERKTILES, QGROUP_SIZES)) | ||
|
|
||
| @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
| @pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") | ||
| @pytest.mark.parametrize("shape, innerKTiles", TEST_CONFIGS_UNPACK, ids=str) | ||
| def test_unpack_tensor_core_tiled_layout_correctness(shape, innerKTiles): | ||
| N, K = shape | ||
| assert K % (innerKTiles * kTileSizeK) == 0 and N % kTileSizeN == 0 | ||
|
|
||
| t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda") | ||
| packed_w = torch.ops.aten._convert_weight_to_int4pack(t, innerKTiles) | ||
| unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed_w, innerKTiles) | ||
| assert torch.allclose(t, unpacked) | ||
|
|
||
| # TODO: Fix "test_aot_dispatch_dynamic" test failure | ||
| @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
| @pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") | ||
| @pytest.mark.parametrize("shape, innerKTiles", TEST_CONFIGS_UNPACK , ids=str) | ||
| def test_unpack_tensor_core_tiled_layout_op(shape, innerKTiles): | ||
| test_utils = [ | ||
| "test_schema", | ||
| "test_autograd_registration", | ||
| "test_faketensor", | ||
| "test_aot_dispatch_dynamic", | ||
| ] | ||
| t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda") | ||
| packed_w = torch.ops.aten._convert_weight_to_int4pack(t, innerKTiles) | ||
|
|
||
| opcheck( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. which pytorch version are you using? it seems this
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. torch |
||
| torch.ops.torchao.unpack_tensor_core_tiled_layout, | ||
| (packed_w, innerKTiles), | ||
| test_utils=test_utils, | ||
| ) | ||
|
|
||
| def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16): | ||
| n, k = q.shape | ||
| assert q.dtype == torch.int | ||
|
|
||
| n_groups = k // group_size | ||
| assert scales.shape[0] == n and scales.shape[1] == n_groups | ||
| assert scales.shape == zeros.shape | ||
|
|
||
| midpoint = 2 ** (nbits - 1) | ||
|
|
||
| #Convert fron u4 -> s4 and upcast to bfloat16 | ||
| q = q.sub(midpoint).to(dtype) | ||
|
|
||
| # Dequantize | ||
| q = q.reshape(-1, group_size) | ||
| dq = q * scales.reshape(-1, 1) + zeros.reshape(-1, 1) | ||
|
|
||
| return dq.reshape(n, k) | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
| @pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") | ||
| @pytest.mark.parametrize("shape, innerKTiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) | ||
| def test_dequantize_tensor_core_tiled_layout_correctness(shape, innerKTiles, group_size): | ||
|
jerryzh168 marked this conversation as resolved.
Outdated
|
||
| n, k = shape | ||
| dtype = torch.bfloat16 | ||
|
|
||
| # tinygemm params | ||
| nTileSize = 8 | ||
| kTileSize = 16 | ||
| nTiles = n // nTileSize | ||
| kTiles = k // (innerKTiles * kTileSize) | ||
| numThreads = 32 | ||
|
|
||
| device = "cuda" | ||
|
|
||
| t = torch.randn(n, k, dtype=dtype, device=device) | ||
| scales, zeros = get_groupwise_affine_qparams(t, n_bit=4, groupsize=group_size, dtype=dtype) | ||
|
|
||
| # Quantize | ||
| q = groupwise_affine_quantize_tensor_from_qparams( | ||
| t, scales, zeros, n_bit=4, groupsize=group_size | ||
| ) | ||
|
|
||
| # Pack to tensor core layout | ||
| packed = torch.ops.aten._convert_weight_to_int4pack(q, innerKTiles) | ||
| scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) | ||
| q_groups = k // group_size | ||
| assert scales_and_zeros.shape == torch.Size([q_groups, n, 2]) | ||
|
|
||
| dq_ao = groupwise_affine_dequantize_tensor_from_qparams( | ||
| q, scales, zeros, n_bit=4, groupsize=group_size | ||
| ) | ||
|
|
||
| # Dequantize by passing in an identity matrix as the activation | ||
| a_eye = torch.eye(k, device=device, dtype=dtype) | ||
| dq_id = torch.ops.aten._weight_int4pack_mm( | ||
| a_eye, | ||
| packed, | ||
| group_size, | ||
| scales_and_zeros, | ||
| ).t() | ||
|
|
||
| # Actual operation to test | ||
| dq_op = torchao.ops.dequantize_tensor_core_tiled_layout(packed, scales_and_zeros, group_size, innerKTiles) | ||
|
|
||
| # Compare results | ||
| diff_ao_id = (dq_id - dq_ao).abs().max() | ||
| diff_op_id = (dq_op - dq_id).abs().max() | ||
| diff_op_ao = (dq_op - dq_ao).abs().max() | ||
|
|
||
| # There are slight numerical differences when dequantizing with an identity matrix | ||
| # Since the `dequantize_int4` kernel relies on same underlying numerical conversions, this gives same | ||
|
jerryzh168 marked this conversation as resolved.
Outdated
|
||
| # numerical differences when compared to the `groupwise_affine_dequantize` | ||
|
|
||
| # Test that the `dequant` kernel gives same results as identity matrix-based dequant | ||
| assert diff_op_id == 0 | ||
|
|
||
| # Test that the `dequant` kernel gives same numerical diffs as the `groupwise_affine_dequantize` when compared against the identity matrix | ||
| assert diff_op_ao == diff_ao_id | ||
|
jerryzh168 marked this conversation as resolved.
|
||
|
|
||
| @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
| @pytest.mark.skipif(IS_FBCODE, reason="Skipping the test in fbcode since we don't have TARGET file for kernels") | ||
| @pytest.mark.parametrize("shape, innerKTiles, group_size", TEST_CONFIGS_DEQUANT, ids=str) | ||
| def test_dequantize_tensor_core_tiled_layout_op(shape, innerKTiles, group_size): | ||
|
jerryzh168 marked this conversation as resolved.
Outdated
|
||
| n, k = shape | ||
| device = "cuda" | ||
|
|
||
| q = torch.randint(0, 16, shape, dtype=torch.int, device=device) | ||
| packed_w = torch._convert_weight_to_int4pack(q, innerKTiles) | ||
| q_groups = k // group_size | ||
| scales = torch.randn(n, q_groups, dtype=torch.bfloat16, device=device) | ||
| zeros = torch.randn_like(scales) | ||
| scales_and_zeros = torchao.quantization.utils.pack_tinygemm_scales_and_zeros(scales, zeros) | ||
|
|
||
| test_utils = [ | ||
| "test_schema", | ||
| "test_autograd_registration", | ||
| "test_faketensor", | ||
| "test_aot_dispatch_dynamic", | ||
| ] | ||
| opcheck( | ||
| torch.ops.torchao.dequantize_tensor_core_tiled_layout, | ||
| (packed_w, scales_and_zeros, group_size, innerKTiles), | ||
| test_utils=test_utils, | ||
| ) | ||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
Uh oh!
There was an error while loading. Please reload this page.