2222 PerRow ,
2323 PerToken ,
2424)
25- from torchao .quantization . prototype .qat .api import (
25+ from torchao .prototype . quantization .qat .api import (
2626 ComposableQATQuantizer ,
2727 FakeQuantizeConfig ,
2828)
29- from torchao .quantization . prototype .qat .fake_quantizer import (
29+ from torchao .prototype . quantization .qat .fake_quantizer import (
3030 FakeQuantizer ,
3131)
32- from torchao .quantization . prototype .qat .linear import (
32+ from torchao .prototype . quantization .qat .linear import (
3333 FakeQuantizedLinear ,
3434)
35- from torchao .quantization . prototype .qat .utils import (
35+ from torchao .prototype . quantization .qat .utils import (
3636 _choose_qparams_per_token_asymmetric ,
3737 _fake_quantize_per_channel_group ,
3838 _fake_quantize_per_token ,
@@ -172,7 +172,7 @@ def _set_ptq_weight(
172172 Int8DynActInt4WeightLinear ,
173173 WeightOnlyInt4Linear ,
174174 )
175- from torchao .quantization . prototype .qat .linear import (
175+ from torchao .prototype . quantization .qat .linear import (
176176 Int8DynActInt4WeightQATLinear ,
177177 Int4WeightOnlyQATLinear ,
178178 )
@@ -204,7 +204,7 @@ def _set_ptq_weight(
204204
205205 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
206206 def test_qat_8da4w_linear (self ):
207- from torchao .quantization . prototype .qat .linear import Int8DynActInt4WeightQATLinear
207+ from torchao .prototype . quantization .qat .linear import Int8DynActInt4WeightQATLinear
208208 from torchao .quantization .GPTQ import Int8DynActInt4WeightLinear
209209
210210 group_size = 128
@@ -229,7 +229,7 @@ def test_qat_8da4w_linear(self):
229229
230230 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
231231 def test_qat_8da4w_quantizer (self ):
232- from torchao .quantization . prototype .qat import Int8DynActInt4WeightQATQuantizer
232+ from torchao .prototype . quantization .qat import Int8DynActInt4WeightQATQuantizer
233233 from torchao .quantization .GPTQ import Int8DynActInt4WeightQuantizer
234234
235235 group_size = 16
@@ -263,7 +263,7 @@ def test_qat_8da4w_quantizer(self):
263263
264264 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
265265 def test_qat_8da4w_quantizer_meta_weights (self ):
266- from torchao .quantization . prototype .qat import Int8DynActInt4WeightQATQuantizer
266+ from torchao .prototype . quantization .qat import Int8DynActInt4WeightQATQuantizer
267267
268268 with torch .device ("meta" ):
269269 m = M ()
@@ -278,7 +278,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
278278 """
279279 Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward.
280280 """
281- from torchao .quantization . prototype .qat import (
281+ from torchao .prototype . quantization .qat import (
282282 Int8DynActInt4WeightQATQuantizer ,
283283 disable_8da4w_fake_quant ,
284284 enable_8da4w_fake_quant ,
@@ -337,7 +337,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
337337 """
338338 Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward.
339339 """
340- from torchao .quantization . prototype .qat import (
340+ from torchao .prototype . quantization .qat import (
341341 Int8DynActInt4WeightQATQuantizer ,
342342 disable_8da4w_fake_quant ,
343343 )
@@ -419,7 +419,7 @@ def _test_qat_quantized_gradients(self, quantizer):
419419
420420 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
421421 def test_qat_8da4w_quantizer_gradients (self ):
422- from torchao .quantization . prototype .qat import Int8DynActInt4WeightQATQuantizer
422+ from torchao .prototype . quantization .qat import Int8DynActInt4WeightQATQuantizer
423423 quantizer = Int8DynActInt4WeightQATQuantizer (groupsize = 16 )
424424 self ._test_qat_quantized_gradients (quantizer )
425425
@@ -509,7 +509,7 @@ def test_qat_4w_primitives(self):
509509 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
510510 @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
511511 def test_qat_4w_linear (self ):
512- from torchao .quantization . prototype .qat .linear import Int4WeightOnlyQATLinear
512+ from torchao .prototype . quantization .qat .linear import Int4WeightOnlyQATLinear
513513 from torchao .quantization .GPTQ import WeightOnlyInt4Linear
514514
515515 group_size = 128
@@ -536,14 +536,14 @@ def test_qat_4w_linear(self):
536536
537537 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
538538 def test_qat_4w_quantizer_gradients (self ):
539- from torchao .quantization . prototype .qat import Int4WeightOnlyQATQuantizer
539+ from torchao .prototype . quantization .qat import Int4WeightOnlyQATQuantizer
540540 quantizer = Int4WeightOnlyQATQuantizer (groupsize = 32 , inner_k_tiles = 8 )
541541 self ._test_qat_quantized_gradients (quantizer )
542542
543543 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
544544 @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
545545 def test_qat_4w_quantizer (self ):
546- from torchao .quantization . prototype .qat import Int4WeightOnlyQATQuantizer
546+ from torchao .prototype . quantization .qat import Int4WeightOnlyQATQuantizer
547547 from torchao .quantization .GPTQ import Int4WeightOnlyQuantizer
548548
549549 group_size = 32
@@ -621,7 +621,7 @@ def test_composable_qat_quantizer(self):
621621
622622 @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
623623 def test_qat_4w_embedding (self ):
624- from torchao .quantization . prototype .qat import Int4WeightOnlyEmbeddingQATQuantizer
624+ from torchao .prototype . quantization .qat import Int4WeightOnlyEmbeddingQATQuantizer
625625 model = M2 ()
626626 x = model .example_inputs ()
627627 out = model (* x )
0 commit comments