@@ -64,32 +64,47 @@ Between these two steps, training can proceed exactly as before.
6464Applying QAT to Llama3 models
6565-----------------------------
6666
67- We can easily apply the above QAT transformations to Llama3 in torchtune for fine-tuning:
67+ We can easily apply the above QAT transformations to Llama3 for fine-tuning,
68+ leveraging the APIs in torchao as follows:
6869
6970.. code-block :: python
7071
71- from torchtune.training.quantization import Int8DynActInt4WeightQATQuantizer
72+ import copy
73+ import torch
74+ from torchao.quantization import quantize_
75+ from torchao.quantization.qat import (
76+ FakeQuantizeConfig,
77+ IntXQuantizationAwareTrainingConfig,
78+ )
7279 from torchtune.models.llama3 import llama3_8b
7380
7481 model = llama3_8b()
82+ original_model = copy.deepcopy(model)
83+
84+ # Config for int8 dynamic asymmetric per token activations +
85+ # int4 symmetric per group weights, only for linear layers
86+ activation_config = FakeQuantizeConfig(torch.int8, " per_token" , is_symmetric = False )
87+ weight_config = FakeQuantizeConfig(torch.int4, group_size = 32 )
88+ qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config)
7589
76- # Quantizer for int8 dynamic per token activations +
77- # int4 grouped per channel weights, only for linear layers
78- quantizer = Int8DynActInt4WeightQATQuantizer()
90+ # Prepare the model for quantization-aware fine-tuning.
91+ #
92+ # This step inserts "fake quantize" ops that simulate
93+ # quantization numerics during fine-tuning without
94+ # actually casting the activations/weights to lower-bit
95+ # dtypes like in "real" quantization.
96+ quantize_(model, qat_config)
7997
80- # Insert "fake quantize" operations into linear layers.
81- # These operations simulate quantization numerics during
82- # fine-tuning without performing any dtype casting
83- prepared_model = quantizer.prepare(model)
98+ prepared_model = model
8499
85- If we print the model we’ll see that all linear layers have been swapped with
86- :code: `Int8DynActInt4WeightQATLinear `, which simulates the numerics of int8
87- dynamic per token activations + int4 grouped per channel weights. Now the model
88- is ready for fine-tuning.
100+ The model is now ready for QAT fine-tuning! If we print the model we’ll see that
101+ all linear layers have been swapped with :code: `FakeQuantizedLinear `, which simulates
102+ the numerics of int8 dynamic asymmetric per token activations + int4 symmetric
103+ per group weights:
89104
90105.. code-block :: bash
91106
92- >>> print(model .layers[0].attn)
107+ >>> original_model .layers[0].attn
93108 MultiHeadAttention(
94109 (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
95110 (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
@@ -98,37 +113,71 @@ is ready for fine-tuning.
98113 (pos_embeddings): RotaryPositionalEmbeddings ()
99114 )
100115
101- >>> print(prepared_model.layers[0].attn)
116+ .. code-block :: bash
117+
118+ >>> prepared_model.layers[0].attn
102119 MultiHeadAttention(
103- (q_proj): Int8DynActInt4WeightQATLinear(in_features=4096, out_features=4096, bias=False)
104- (k_proj): Int8DynActInt4WeightQATLinear(in_features=4096, out_features=1024, bias=False)
105- (v_proj): Int8DynActInt4WeightQATLinear(in_features=4096, out_features=1024, bias=False)
106- (output_proj): Int8DynActInt4WeightQATLinear(in_features=4096, out_features=4096, bias=False)
120+ (q_proj): FakeQuantizedLinear(
121+ in_features=4096, out_features=4096, bias=False
122+ (activation_fake_quantizer): FakeQuantizer(FakeQuantizeConfig(dtype=torch.int8, granularity=PerToken(), mapping_type=< MappingType.ASYMMETRIC: 3> , scale_precision=torch.float32, zero_point_precision=torch.int32, zero_point_domain=< ZeroPointDomain.INT: 1> , is_dynamic=True, range_learning=False))
123+ (weight_fake_quantizer): FakeQuantizer(FakeQuantizeConfig(dtype=torch.int4, granularity=PerGroup(group_size=32), mapping_type=< MappingType.SYMMETRIC: 1> , scale_precision=torch.float32, zero_point_precision=torch.int32, zero_point_domain=< ZeroPointDomain.INT: 1> , is_dynamic=True, range_learning=False))
124+ )
125+ (k_proj): FakeQuantizedLinear(
126+ in_features=4096, out_features=1024, bias=False
127+ (activation_fake_quantizer): FakeQuantizer(FakeQuantizeConfig(dtype=torch.int8, granularity=PerToken(), mapping_type=< MappingType.ASYMMETRIC: 3> , scale_precision=torch.float32, zero_point_precision=torch.int32, zero_point_domain=< ZeroPointDomain.INT: 1> , is_dynamic=True, range_learning=False))
128+ (weight_fake_quantizer): FakeQuantizer(FakeQuantizeConfig(dtype=torch.int4, granularity=PerGroup(group_size=32), mapping_type=< MappingType.SYMMETRIC: 1> , scale_precision=torch.float32, zero_point_precision=torch.int32, zero_point_domain=< ZeroPointDomain.INT: 1> , is_dynamic=True, range_learning=False))
129+ )
130+ (v_proj): FakeQuantizedLinear(
131+ in_features=4096, out_features=1024, bias=False
132+ (activation_fake_quantizer): FakeQuantizer(FakeQuantizeConfig(dtype=torch.int8, granularity=PerToken(), mapping_type=< MappingType.ASYMMETRIC: 3> , scale_precision=torch.float32, zero_point_precision=torch.int32, zero_point_domain=< ZeroPointDomain.INT: 1> , is_dynamic=True, range_learning=False))
133+ (weight_fake_quantizer): FakeQuantizer(FakeQuantizeConfig(dtype=torch.int4, granularity=PerGroup(group_size=32), mapping_type=< MappingType.SYMMETRIC: 1> , scale_precision=torch.float32, zero_point_precision=torch.int32, zero_point_domain=< ZeroPointDomain.INT: 1> , is_dynamic=True, range_learning=False))
134+ )
135+ (output_proj): FakeQuantizedLinear(
136+ in_features=4096, out_features=4096, bias=False
137+ (activation_fake_quantizer): FakeQuantizer(FakeQuantizeConfig(dtype=torch.int8, granularity=PerToken(), mapping_type=< MappingType.ASYMMETRIC: 3> , scale_precision=torch.float32, zero_point_precision=torch.int32, zero_point_domain=< ZeroPointDomain.INT: 1> , is_dynamic=True, range_learning=False))
138+ (weight_fake_quantizer): FakeQuantizer(FakeQuantizeConfig(dtype=torch.int4, granularity=PerGroup(group_size=32), mapping_type=< MappingType.SYMMETRIC: 1> , scale_precision=torch.float32, zero_point_precision=torch.int32, zero_point_domain=< ZeroPointDomain.INT: 1> , is_dynamic=True, range_learning=False))
139+ )
107140 (pos_embeddings): RotaryPositionalEmbeddings ()
108141 )
109142
110- After fine-tuning, we can convert the model to get an actual quantized model.
111- If we print the converted model, we’ll see that the QAT linears have been
112- swapped with `Int8DynActInt4WeightLinear <https://github.com/pytorch/ao/blob/428084356ace4ea94c22a3a9b3d74cff8ee41db3/torchao/quantization/prototype/qat.py#L38 >`_, which are the quantized versions
113- of the linear layers. This quantized model can then be saved to checkpoint and
114- used for inference or generation.
143+ After fine-tuning, we can convert the model to get an actual quantized model:
115144
116145.. code-block :: python
117146
147+ from torchao.quantization.qat import (
148+ FromIntXQuantizationAwareTrainingConfig,
149+ )
150+ from torchao.quantization import (
151+ Int8DynamicActivationInt4WeightConfig,
152+ )
153+
118154 # Fine-tune as before
119155 train_loop(prepared_model)
120156
121- # Convert fake quantize to actual quantize operations
122- converted_model = quantizer.convert(prepared_model)
157+ # Convert the fake quantized model into an actual quantized model
158+ #
159+ # First, we swap `FakeQuantizedLinear` back to `torch.nn.Linear`
160+ # while keeping the QAT fine-tuned weights. Then, we perform standard
161+ # post-training quantization (PTQ), which inserts quantized activation
162+ # and weight tensor subclasses
163+ quantize_(prepared_model, FromIntXQuantizationAwareTrainingConfig())
164+ quantize_(prepared_model, Int8DynamicActivationInt4WeightConfig(group_size = 32 ))
165+
166+ converted_model = prepared_model
167+
168+ The model is now fully quantized to int8 and int4 and ready for inference
169+ or generation. If we print the model now, we will see the linear layers
170+ are now swapped back to :code: `torch.nn.Linear `, but with quantized tensor
171+ activations and weights:
123172
124173.. code-block :: bash
125174
126- >>> print( converted_model.layers[0].attn)
175+ >>> converted_model.layers[0].attn
127176 MultiHeadAttention(
128- (q_proj): Int8DynActInt4WeightLinear ( )
129- (k_proj): Int8DynActInt4WeightLinear ( )
130- (v_proj): Int8DynActInt4WeightLinear ( )
131- (output_proj): Int8DynActInt4WeightLinear ( )
177+ (q_proj): Linear(in_features=4096, out_features=4096, weight=LinearActivationQuantizedTensor(activation= < function _int8_asymm_per_token_quant at 0x7f801ce 08790> , weight=AffineQuantizedTensor(shape=torch.Size([4096, 4096]), block_size=(1, 32), device=cpu, _layout=PlainLayout(), tensor_impl_dtype=torch.int8, quant_min=-8, quant_max=7)) )
178+ (k_proj): Linear(in_features=4096, out_features=1024, weight=LinearActivationQuantizedTensor(activation= < function _int8_asymm_per_token_quant at 0x7f801ce 08790> , weight=AffineQuantizedTensor(shape=torch.Size([1024, 4096]), block_size=(1, 32), device=cpu, _layout=PlainLayout(), tensor_impl_dtype=torch.int8, quant_min=-8, quant_max=7)) )
179+ (v_proj): Linear(in_features=4096, out_features=1024, weight=LinearActivationQuantizedTensor(activation= < function _int8_asymm_per_token_quant at 0x7f801ce 08790> , weight=AffineQuantizedTensor(shape=torch.Size([1024, 4096]), block_size=(1, 32), device=cpu, _layout=PlainLayout(), tensor_impl_dtype=torch.int8, quant_min=-8, quant_max=7)) )
180+ (output_proj): Linear(in_features=4096, out_features=4096, weight=LinearActivationQuantizedTensor(activation= < function _int8_asymm_per_token_quant at 0x7f801ce 08790> , weight=AffineQuantizedTensor(shape=torch.Size([4096, 4096]), block_size=(1, 32), device=cpu, _layout=PlainLayout(), tensor_impl_dtype=torch.int8, quant_min=-8, quant_max=7)) )
132181 (pos_embeddings): RotaryPositionalEmbeddings ()
133182 )
134183
@@ -150,23 +199,21 @@ modifications accordingly:
150199
151200 .. code-block :: yaml
152201
153- # Dataset
154202 dataset :
155203 _component_ : torchtune.datasets.text_completion_dataset
156204 source : allenai/c4
157- max_seq_len : 8192
158205 column : text
159206 name : en
160207 split : train
161- seed : null
162- shuffle : True
163208
164209 ...
165210
166211 epochs : 1
167212 max_steps_per_epoch : 2000
168213 fake_quant_after_n_steps : 1000
169- memory_efficient_fsdp_wrap : False
214+
215+ By default, this uses the :code: `torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer `,
216+ which uses the same fake quantization configurations as the example above.
170217
171218Empirically, we observed that disabling fake quantization for the first N steps
172219led to better results, presumably because doing so allows the weights to stabilize
@@ -213,15 +260,13 @@ copy and make the following modifications to the quantization config:
213260
214261 .. code-block :: yaml
215262
216- # Model arguments
217263 model :
218264 _component_ : torchtune.models.llama3.llama3_8b
219265
220266 checkpointer :
221267 _component_ : torchtune.training.FullModelMetaCheckpointer
222268 checkpoint_dir : <your QAT checkpoint dir>
223- checkpoint_files : [meta_model_0.pt]
224- recipe_checkpoint : null
269+ checkpoint_files : [ft-model-00001-of-00001.bin]
225270 output_dir : <your QAT checkpoint dir>
226271 model_type : LLAMA3
227272
@@ -259,25 +304,19 @@ integrated in torchtune. First, copy the evaluation config and make the followin
259304
260305 .. code-block :: yaml
261306
262- # Model arguments
263307 model :
264308 _component_ : torchtune.models.llama3.llama3_8b
265309
266310 checkpointer :
267311 _component_ : torchtune.training.FullModelTorchTuneCheckpointer
268312 checkpoint_dir : <your quantized model checkpoint dir>
269- checkpoint_files : [meta_model_0-8da4w.pt]
270- recipe_checkpoint : null
313+ checkpoint_files : [ft-model-00001-of-00001-8da4w.bin]
271314 output_dir : <your quantized model checkpoint dir>
272315 model_type : LLAMA3
273316
274317 ...
275318
276- # EleutherAI specific eval args
277319 tasks : ["hellaswag", "wikitext"]
278- limit : null
279- max_seq_length : 8192
280- batch_size : 8
281320
282321 quantizer :
283322 _component_ : torchtune.training.quantization.Int8DynActInt4WeightQuantizer
0 commit comments