diff --git a/README.md b/README.md index 71fb25fa24..336bacd669 100644 --- a/README.md +++ b/README.md @@ -59,7 +59,7 @@ In practice these features alongside int4 weight only quantization allow us to * Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization Aware Training (QAT) to overcome this limitation. In collaboration with Torchtune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering **96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext** for Llama3 compared to post-training quantization (PTQ). And we've provided a full recipe [here](https://pytorch.org/blog/quantization-aware-training/) ```python -from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer +from torchao.prototype.quantization.qat import Int8DynActInt4WeightQATQuantizer qat_quantizer = Int8DynActInt4WeightQATQuantizer() diff --git a/test/quantization/test_mixed_precision.py b/test/prototype/test_mixed_precision.py similarity index 95% rename from test/quantization/test_mixed_precision.py rename to test/prototype/test_mixed_precision.py index 8afd022d3c..bfcd7bed2b 100644 --- a/test/quantization/test_mixed_precision.py +++ b/test/prototype/test_mixed_precision.py @@ -4,7 +4,7 @@ import torch.nn as nn from torchao.quantization import quantize_, int8_weight_only, int4_weight_only from torchao.quantization.utils import compute_error -from torchao.quantization.prototype.mixed_precision.scripts.naive_intNwo import intN_weight_only +from torchao.prototype.quantization.mixed_precision.scripts.naive_intNwo import intN_weight_only _CUDA_IS_AVAILABLE = torch.cuda.is_available() diff --git a/test/quantization/test_qat.py b/test/prototype/test_qat.py similarity index 97% rename from test/quantization/test_qat.py rename to test/prototype/test_qat.py index b00d0ad40b..1bf5c3b7b8 100644 --- a/test/quantization/test_qat.py +++ b/test/prototype/test_qat.py @@ -22,17 +22,17 @@ PerRow, PerToken, ) -from torchao.quantization.prototype.qat.api import ( +from torchao.prototype.quantization.qat.api import ( ComposableQATQuantizer, FakeQuantizeConfig, ) -from torchao.quantization.prototype.qat.fake_quantizer import ( +from torchao.prototype.quantization.qat.fake_quantizer import ( FakeQuantizer, ) -from torchao.quantization.prototype.qat.linear import ( +from torchao.prototype.quantization.qat.linear import ( FakeQuantizedLinear, ) -from torchao.quantization.prototype.qat.utils import ( +from torchao.prototype.quantization.qat.utils import ( _choose_qparams_per_token_asymmetric, _fake_quantize_per_channel_group, _fake_quantize_per_token, @@ -172,7 +172,7 @@ def _set_ptq_weight( Int8DynActInt4WeightLinear, WeightOnlyInt4Linear, ) - from torchao.quantization.prototype.qat.linear import ( + from torchao.prototype.quantization.qat.linear import ( Int8DynActInt4WeightQATLinear, Int4WeightOnlyQATLinear, ) @@ -204,7 +204,7 @@ def _set_ptq_weight( @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_linear(self): - from torchao.quantization.prototype.qat.linear import Int8DynActInt4WeightQATLinear + from torchao.prototype.quantization.qat.linear import Int8DynActInt4WeightQATLinear from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear group_size = 128 @@ -229,7 +229,7 @@ def test_qat_8da4w_linear(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer(self): - from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer + from torchao.prototype.quantization.qat import Int8DynActInt4WeightQATQuantizer from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer group_size = 16 @@ -263,7 +263,7 @@ def test_qat_8da4w_quantizer(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer_meta_weights(self): - from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer + from torchao.prototype.quantization.qat import Int8DynActInt4WeightQATQuantizer with torch.device("meta"): m = M() @@ -278,7 +278,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self): """ Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward. """ - from torchao.quantization.prototype.qat import ( + from torchao.prototype.quantization.qat import ( Int8DynActInt4WeightQATQuantizer, disable_8da4w_fake_quant, enable_8da4w_fake_quant, @@ -337,7 +337,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): """ Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward. """ - from torchao.quantization.prototype.qat import ( + from torchao.prototype.quantization.qat import ( Int8DynActInt4WeightQATQuantizer, disable_8da4w_fake_quant, ) @@ -419,7 +419,7 @@ def _test_qat_quantized_gradients(self, quantizer): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer_gradients(self): - from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer + from torchao.prototype.quantization.qat import Int8DynActInt4WeightQATQuantizer quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=16) self._test_qat_quantized_gradients(quantizer) @@ -509,7 +509,7 @@ def test_qat_4w_primitives(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") def test_qat_4w_linear(self): - from torchao.quantization.prototype.qat.linear import Int4WeightOnlyQATLinear + from torchao.prototype.quantization.qat.linear import Int4WeightOnlyQATLinear from torchao.quantization.GPTQ import WeightOnlyInt4Linear group_size = 128 @@ -536,14 +536,14 @@ def test_qat_4w_linear(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_4w_quantizer_gradients(self): - from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer + from torchao.prototype.quantization.qat import Int4WeightOnlyQATQuantizer quantizer = Int4WeightOnlyQATQuantizer(groupsize=32, inner_k_tiles=8) self._test_qat_quantized_gradients(quantizer) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") def test_qat_4w_quantizer(self): - from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer + from torchao.prototype.quantization.qat import Int4WeightOnlyQATQuantizer from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer group_size = 32 @@ -621,7 +621,7 @@ def test_composable_qat_quantizer(self): @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower") def test_qat_4w_embedding(self): - from torchao.quantization.prototype.qat import Int4WeightOnlyEmbeddingQATQuantizer + from torchao.prototype.quantization.qat import Int4WeightOnlyEmbeddingQATQuantizer model = M2() x = model.example_inputs() out = model(*x) diff --git a/torchao/quantization/prototype/__init__.py b/torchao/prototype/quantization/__init__.py similarity index 100% rename from torchao/quantization/prototype/__init__.py rename to torchao/prototype/quantization/__init__.py diff --git a/torchao/quantization/prototype/mixed_precision/README.md b/torchao/prototype/quantization/mixed_precision/README.md similarity index 100% rename from torchao/quantization/prototype/mixed_precision/README.md rename to torchao/prototype/quantization/mixed_precision/README.md diff --git a/torchao/quantization/prototype/mixed_precision/__init__.py b/torchao/prototype/quantization/mixed_precision/__init__.py similarity index 100% rename from torchao/quantization/prototype/mixed_precision/__init__.py rename to torchao/prototype/quantization/mixed_precision/__init__.py diff --git a/torchao/quantization/prototype/mixed_precision/scripts/BO_acc_modelsize.py b/torchao/prototype/quantization/mixed_precision/scripts/BO_acc_modelsize.py similarity index 100% rename from torchao/quantization/prototype/mixed_precision/scripts/BO_acc_modelsize.py rename to torchao/prototype/quantization/mixed_precision/scripts/BO_acc_modelsize.py diff --git a/torchao/quantization/prototype/mixed_precision/scripts/BO_acc_throughput.py b/torchao/prototype/quantization/mixed_precision/scripts/BO_acc_throughput.py similarity index 100% rename from torchao/quantization/prototype/mixed_precision/scripts/BO_acc_throughput.py rename to torchao/prototype/quantization/mixed_precision/scripts/BO_acc_throughput.py diff --git a/torchao/quantization/prototype/mixed_precision/scripts/Llama3-8B_initial_samples.json b/torchao/prototype/quantization/mixed_precision/scripts/Llama3-8B_initial_samples.json similarity index 100% rename from torchao/quantization/prototype/mixed_precision/scripts/Llama3-8B_initial_samples.json rename to torchao/prototype/quantization/mixed_precision/scripts/Llama3-8B_initial_samples.json diff --git a/torchao/quantization/prototype/mixed_precision/scripts/Llama3-8B_parameters.json b/torchao/prototype/quantization/mixed_precision/scripts/Llama3-8B_parameters.json similarity index 100% rename from torchao/quantization/prototype/mixed_precision/scripts/Llama3-8B_parameters.json rename to torchao/prototype/quantization/mixed_precision/scripts/Llama3-8B_parameters.json diff --git a/torchao/quantization/prototype/mixed_precision/scripts/Mistral-7B_initial_samples.json b/torchao/prototype/quantization/mixed_precision/scripts/Mistral-7B_initial_samples.json similarity index 100% rename from torchao/quantization/prototype/mixed_precision/scripts/Mistral-7B_initial_samples.json rename to torchao/prototype/quantization/mixed_precision/scripts/Mistral-7B_initial_samples.json diff --git a/torchao/quantization/prototype/mixed_precision/scripts/Mistral-7B_parameters.json b/torchao/prototype/quantization/mixed_precision/scripts/Mistral-7B_parameters.json similarity index 100% rename from torchao/quantization/prototype/mixed_precision/scripts/Mistral-7B_parameters.json rename to torchao/prototype/quantization/mixed_precision/scripts/Mistral-7B_parameters.json diff --git a/torchao/quantization/prototype/mixed_precision/scripts/__init__.py b/torchao/prototype/quantization/mixed_precision/scripts/__init__.py similarity index 100% rename from torchao/quantization/prototype/mixed_precision/scripts/__init__.py rename to torchao/prototype/quantization/mixed_precision/scripts/__init__.py diff --git a/torchao/quantization/prototype/mixed_precision/scripts/fit.py b/torchao/prototype/quantization/mixed_precision/scripts/fit.py similarity index 100% rename from torchao/quantization/prototype/mixed_precision/scripts/fit.py rename to torchao/prototype/quantization/mixed_precision/scripts/fit.py diff --git a/torchao/quantization/prototype/mixed_precision/scripts/hessian_grad.py b/torchao/prototype/quantization/mixed_precision/scripts/hessian_grad.py similarity index 100% rename from torchao/quantization/prototype/mixed_precision/scripts/hessian_grad.py rename to torchao/prototype/quantization/mixed_precision/scripts/hessian_grad.py diff --git a/torchao/quantization/prototype/mixed_precision/scripts/hessian_vhp.py b/torchao/prototype/quantization/mixed_precision/scripts/hessian_vhp.py similarity index 100% rename from torchao/quantization/prototype/mixed_precision/scripts/hessian_vhp.py rename to torchao/prototype/quantization/mixed_precision/scripts/hessian_vhp.py diff --git a/torchao/quantization/prototype/mixed_precision/scripts/mp_quant_eval.py b/torchao/prototype/quantization/mixed_precision/scripts/mp_quant_eval.py similarity index 100% rename from torchao/quantization/prototype/mixed_precision/scripts/mp_quant_eval.py rename to torchao/prototype/quantization/mixed_precision/scripts/mp_quant_eval.py diff --git a/torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py b/torchao/prototype/quantization/mixed_precision/scripts/naive_intNwo.py similarity index 100% rename from torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py rename to torchao/prototype/quantization/mixed_precision/scripts/naive_intNwo.py diff --git a/torchao/quantization/prototype/mixed_precision/scripts/utils.py b/torchao/prototype/quantization/mixed_precision/scripts/utils.py similarity index 100% rename from torchao/quantization/prototype/mixed_precision/scripts/utils.py rename to torchao/prototype/quantization/mixed_precision/scripts/utils.py diff --git a/torchao/quantization/prototype/qat/README.md b/torchao/prototype/quantization/qat/README.md similarity index 98% rename from torchao/quantization/prototype/qat/README.md rename to torchao/prototype/quantization/qat/README.md index 2869322297..8ef5d14fdb 100644 --- a/torchao/quantization/prototype/qat/README.md +++ b/torchao/prototype/quantization/qat/README.md @@ -41,7 +41,7 @@ For example, on a single GPU: ```python import torch from torchtune.models.llama3 import llama3 -from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer +from torchao.prototype.quantization.qat import Int8DynActInt4WeightQATQuantizer # Smaller version of llama3 to fit in a single GPU model = llama3( diff --git a/torchao/quantization/prototype/qat/__init__.py b/torchao/prototype/quantization/qat/__init__.py similarity index 100% rename from torchao/quantization/prototype/qat/__init__.py rename to torchao/prototype/quantization/qat/__init__.py diff --git a/torchao/quantization/prototype/qat/_module_swap_api.py b/torchao/prototype/quantization/qat/_module_swap_api.py similarity index 100% rename from torchao/quantization/prototype/qat/_module_swap_api.py rename to torchao/prototype/quantization/qat/_module_swap_api.py diff --git a/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py b/torchao/prototype/quantization/qat/affine_fake_quantized_tensor.py similarity index 100% rename from torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py rename to torchao/prototype/quantization/qat/affine_fake_quantized_tensor.py diff --git a/torchao/quantization/prototype/qat/api.py b/torchao/prototype/quantization/qat/api.py similarity index 100% rename from torchao/quantization/prototype/qat/api.py rename to torchao/prototype/quantization/qat/api.py diff --git a/torchao/quantization/prototype/qat/embedding.py b/torchao/prototype/quantization/qat/embedding.py similarity index 100% rename from torchao/quantization/prototype/qat/embedding.py rename to torchao/prototype/quantization/qat/embedding.py diff --git a/torchao/quantization/prototype/qat/fake_quantizer.py b/torchao/prototype/quantization/qat/fake_quantizer.py similarity index 100% rename from torchao/quantization/prototype/qat/fake_quantizer.py rename to torchao/prototype/quantization/qat/fake_quantizer.py diff --git a/torchao/quantization/prototype/qat/images/qat_diagram.png b/torchao/prototype/quantization/qat/images/qat_diagram.png similarity index 100% rename from torchao/quantization/prototype/qat/images/qat_diagram.png rename to torchao/prototype/quantization/qat/images/qat_diagram.png diff --git a/torchao/quantization/prototype/qat/linear.py b/torchao/prototype/quantization/qat/linear.py similarity index 100% rename from torchao/quantization/prototype/qat/linear.py rename to torchao/prototype/quantization/qat/linear.py diff --git a/torchao/quantization/prototype/qat/utils.py b/torchao/prototype/quantization/qat/utils.py similarity index 98% rename from torchao/quantization/prototype/qat/utils.py rename to torchao/prototype/quantization/qat/utils.py index 8f2dd9d13f..a25524a210 100644 --- a/torchao/quantization/prototype/qat/utils.py +++ b/torchao/prototype/quantization/qat/utils.py @@ -46,7 +46,7 @@ def forward( zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, ) -> torch.Tensor: # avoid circular dependencies - from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( + from torchao.prototype.quantization.qat.affine_fake_quantized_tensor import ( AffineFakeQuantizedTensor, ) @@ -88,7 +88,7 @@ def forward( input: torch.Tensor, ) -> torch.Tensor: # avoid circular dependencies - from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( + from torchao.prototype.quantization.qat.affine_fake_quantized_tensor import ( AffineFakeQuantizedTensor, ) assert isinstance(input, AffineFakeQuantizedTensor) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 91803fe3f7..111ef308bb 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -220,7 +220,7 @@ def _replace_with_custom_fn_if_matches_filter( def _is_linear(mod, *args): # avoid circular dependencies - from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( + from torchao.prototype.quantization.qat.affine_fake_quantized_tensor import ( AffineFakeQuantizedTensor, )