Skip to content

Commit 3beb410

Browse files
authored
Update QAT tutorial (#2396)
1 parent a65d23c commit 3beb410

File tree

2 files changed

+89
-48
lines changed

2 files changed

+89
-48
lines changed

docs/source/tutorials/qat_finetune.rst

Lines changed: 85 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -64,32 +64,47 @@ Between these two steps, training can proceed exactly as before.
6464
Applying QAT to Llama3 models
6565
-----------------------------
6666

67-
We can easily apply the above QAT transformations to Llama3 in torchtune for fine-tuning:
67+
We can easily apply the above QAT transformations to Llama3 for fine-tuning,
68+
leveraging the APIs in torchao as follows:
6869

6970
.. code-block:: python
7071
71-
from torchtune.training.quantization import Int8DynActInt4WeightQATQuantizer
72+
import copy
73+
import torch
74+
from torchao.quantization import quantize_
75+
from torchao.quantization.qat import (
76+
FakeQuantizeConfig,
77+
IntXQuantizationAwareTrainingConfig,
78+
)
7279
from torchtune.models.llama3 import llama3_8b
7380
7481
model = llama3_8b()
82+
original_model = copy.deepcopy(model)
83+
84+
# Config for int8 dynamic asymmetric per token activations +
85+
# int4 symmetric per group weights, only for linear layers
86+
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
87+
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
88+
qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config)
7589
76-
# Quantizer for int8 dynamic per token activations +
77-
# int4 grouped per channel weights, only for linear layers
78-
quantizer = Int8DynActInt4WeightQATQuantizer()
90+
# Prepare the model for quantization-aware fine-tuning.
91+
#
92+
# This step inserts "fake quantize" ops that simulate
93+
# quantization numerics during fine-tuning without
94+
# actually casting the activations/weights to lower-bit
95+
# dtypes like in "real" quantization.
96+
quantize_(model, qat_config)
7997
80-
# Insert "fake quantize" operations into linear layers.
81-
# These operations simulate quantization numerics during
82-
# fine-tuning without performing any dtype casting
83-
prepared_model = quantizer.prepare(model)
98+
prepared_model = model
8499
85-
If we print the model we’ll see that all linear layers have been swapped with
86-
:code:`Int8DynActInt4WeightQATLinear`, which simulates the numerics of int8
87-
dynamic per token activations + int4 grouped per channel weights. Now the model
88-
is ready for fine-tuning.
100+
The model is now ready for QAT fine-tuning! If we print the model we’ll see that
101+
all linear layers have been swapped with :code:`FakeQuantizedLinear`, which simulates
102+
the numerics of int8 dynamic asymmetric per token activations + int4 symmetric
103+
per group weights:
89104

90105
.. code-block:: bash
91106
92-
>>> print(model.layers[0].attn)
107+
>>> original_model.layers[0].attn
93108
MultiHeadAttention(
94109
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
95110
(k_proj): Linear(in_features=4096, out_features=1024, bias=False)
@@ -98,37 +113,71 @@ is ready for fine-tuning.
98113
(pos_embeddings): RotaryPositionalEmbeddings()
99114
)
100115
101-
>>> print(prepared_model.layers[0].attn)
116+
.. code-block:: bash
117+
118+
>>> prepared_model.layers[0].attn
102119
MultiHeadAttention(
103-
(q_proj): Int8DynActInt4WeightQATLinear(in_features=4096, out_features=4096, bias=False)
104-
(k_proj): Int8DynActInt4WeightQATLinear(in_features=4096, out_features=1024, bias=False)
105-
(v_proj): Int8DynActInt4WeightQATLinear(in_features=4096, out_features=1024, bias=False)
106-
(output_proj): Int8DynActInt4WeightQATLinear(in_features=4096, out_features=4096, bias=False)
120+
(q_proj): FakeQuantizedLinear(
121+
in_features=4096, out_features=4096, bias=False
122+
(activation_fake_quantizer): FakeQuantizer(FakeQuantizeConfig(dtype=torch.int8, granularity=PerToken(), mapping_type=<MappingType.ASYMMETRIC: 3>, scale_precision=torch.float32, zero_point_precision=torch.int32, zero_point_domain=<ZeroPointDomain.INT: 1>, is_dynamic=True, range_learning=False))
123+
(weight_fake_quantizer): FakeQuantizer(FakeQuantizeConfig(dtype=torch.int4, granularity=PerGroup(group_size=32), mapping_type=<MappingType.SYMMETRIC: 1>, scale_precision=torch.float32, zero_point_precision=torch.int32, zero_point_domain=<ZeroPointDomain.INT: 1>, is_dynamic=True, range_learning=False))
124+
)
125+
(k_proj): FakeQuantizedLinear(
126+
in_features=4096, out_features=1024, bias=False
127+
(activation_fake_quantizer): FakeQuantizer(FakeQuantizeConfig(dtype=torch.int8, granularity=PerToken(), mapping_type=<MappingType.ASYMMETRIC: 3>, scale_precision=torch.float32, zero_point_precision=torch.int32, zero_point_domain=<ZeroPointDomain.INT: 1>, is_dynamic=True, range_learning=False))
128+
(weight_fake_quantizer): FakeQuantizer(FakeQuantizeConfig(dtype=torch.int4, granularity=PerGroup(group_size=32), mapping_type=<MappingType.SYMMETRIC: 1>, scale_precision=torch.float32, zero_point_precision=torch.int32, zero_point_domain=<ZeroPointDomain.INT: 1>, is_dynamic=True, range_learning=False))
129+
)
130+
(v_proj): FakeQuantizedLinear(
131+
in_features=4096, out_features=1024, bias=False
132+
(activation_fake_quantizer): FakeQuantizer(FakeQuantizeConfig(dtype=torch.int8, granularity=PerToken(), mapping_type=<MappingType.ASYMMETRIC: 3>, scale_precision=torch.float32, zero_point_precision=torch.int32, zero_point_domain=<ZeroPointDomain.INT: 1>, is_dynamic=True, range_learning=False))
133+
(weight_fake_quantizer): FakeQuantizer(FakeQuantizeConfig(dtype=torch.int4, granularity=PerGroup(group_size=32), mapping_type=<MappingType.SYMMETRIC: 1>, scale_precision=torch.float32, zero_point_precision=torch.int32, zero_point_domain=<ZeroPointDomain.INT: 1>, is_dynamic=True, range_learning=False))
134+
)
135+
(output_proj): FakeQuantizedLinear(
136+
in_features=4096, out_features=4096, bias=False
137+
(activation_fake_quantizer): FakeQuantizer(FakeQuantizeConfig(dtype=torch.int8, granularity=PerToken(), mapping_type=<MappingType.ASYMMETRIC: 3>, scale_precision=torch.float32, zero_point_precision=torch.int32, zero_point_domain=<ZeroPointDomain.INT: 1>, is_dynamic=True, range_learning=False))
138+
(weight_fake_quantizer): FakeQuantizer(FakeQuantizeConfig(dtype=torch.int4, granularity=PerGroup(group_size=32), mapping_type=<MappingType.SYMMETRIC: 1>, scale_precision=torch.float32, zero_point_precision=torch.int32, zero_point_domain=<ZeroPointDomain.INT: 1>, is_dynamic=True, range_learning=False))
139+
)
107140
(pos_embeddings): RotaryPositionalEmbeddings()
108141
)
109142
110-
After fine-tuning, we can convert the model to get an actual quantized model.
111-
If we print the converted model, we’ll see that the QAT linears have been
112-
swapped with `Int8DynActInt4WeightLinear <https://github.com/pytorch/ao/blob/428084356ace4ea94c22a3a9b3d74cff8ee41db3/torchao/quantization/prototype/qat.py#L38>`_, which are the quantized versions
113-
of the linear layers. This quantized model can then be saved to checkpoint and
114-
used for inference or generation.
143+
After fine-tuning, we can convert the model to get an actual quantized model:
115144

116145
.. code-block:: python
117146
147+
from torchao.quantization.qat import (
148+
FromIntXQuantizationAwareTrainingConfig,
149+
)
150+
from torchao.quantization import (
151+
Int8DynamicActivationInt4WeightConfig,
152+
)
153+
118154
# Fine-tune as before
119155
train_loop(prepared_model)
120156
121-
# Convert fake quantize to actual quantize operations
122-
converted_model = quantizer.convert(prepared_model)
157+
# Convert the fake quantized model into an actual quantized model
158+
#
159+
# First, we swap `FakeQuantizedLinear` back to `torch.nn.Linear`
160+
# while keeping the QAT fine-tuned weights. Then, we perform standard
161+
# post-training quantization (PTQ), which inserts quantized activation
162+
# and weight tensor subclasses
163+
quantize_(prepared_model, FromIntXQuantizationAwareTrainingConfig())
164+
quantize_(prepared_model, Int8DynamicActivationInt4WeightConfig(group_size=32))
165+
166+
converted_model = prepared_model
167+
168+
The model is now fully quantized to int8 and int4 and ready for inference
169+
or generation. If we print the model now, we will see the linear layers
170+
are now swapped back to :code:`torch.nn.Linear`, but with quantized tensor
171+
activations and weights:
123172

124173
.. code-block:: bash
125174
126-
>>> print(converted_model.layers[0].attn)
175+
>>> converted_model.layers[0].attn
127176
MultiHeadAttention(
128-
(q_proj): Int8DynActInt4WeightLinear()
129-
(k_proj): Int8DynActInt4WeightLinear()
130-
(v_proj): Int8DynActInt4WeightLinear()
131-
(output_proj): Int8DynActInt4WeightLinear()
177+
(q_proj): Linear(in_features=4096, out_features=4096, weight=LinearActivationQuantizedTensor(activation=<function _int8_asymm_per_token_quant at 0x7f801ce08790>, weight=AffineQuantizedTensor(shape=torch.Size([4096, 4096]), block_size=(1, 32), device=cpu, _layout=PlainLayout(), tensor_impl_dtype=torch.int8, quant_min=-8, quant_max=7)))
178+
(k_proj): Linear(in_features=4096, out_features=1024, weight=LinearActivationQuantizedTensor(activation=<function _int8_asymm_per_token_quant at 0x7f801ce08790>, weight=AffineQuantizedTensor(shape=torch.Size([1024, 4096]), block_size=(1, 32), device=cpu, _layout=PlainLayout(), tensor_impl_dtype=torch.int8, quant_min=-8, quant_max=7)))
179+
(v_proj): Linear(in_features=4096, out_features=1024, weight=LinearActivationQuantizedTensor(activation=<function _int8_asymm_per_token_quant at 0x7f801ce08790>, weight=AffineQuantizedTensor(shape=torch.Size([1024, 4096]), block_size=(1, 32), device=cpu, _layout=PlainLayout(), tensor_impl_dtype=torch.int8, quant_min=-8, quant_max=7)))
180+
(output_proj): Linear(in_features=4096, out_features=4096, weight=LinearActivationQuantizedTensor(activation=<function _int8_asymm_per_token_quant at 0x7f801ce08790>, weight=AffineQuantizedTensor(shape=torch.Size([4096, 4096]), block_size=(1, 32), device=cpu, _layout=PlainLayout(), tensor_impl_dtype=torch.int8, quant_min=-8, quant_max=7)))
132181
(pos_embeddings): RotaryPositionalEmbeddings()
133182
)
134183
@@ -150,23 +199,21 @@ modifications accordingly:
150199
151200
.. code-block:: yaml
152201
153-
# Dataset
154202
dataset:
155203
_component_: torchtune.datasets.text_completion_dataset
156204
source: allenai/c4
157-
max_seq_len: 8192
158205
column: text
159206
name: en
160207
split: train
161-
seed: null
162-
shuffle: True
163208
164209
...
165210
166211
epochs: 1
167212
max_steps_per_epoch: 2000
168213
fake_quant_after_n_steps: 1000
169-
memory_efficient_fsdp_wrap: False
214+
215+
By default, this uses the :code:`torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer`,
216+
which uses the same fake quantization configurations as the example above.
170217

171218
Empirically, we observed that disabling fake quantization for the first N steps
172219
led to better results, presumably because doing so allows the weights to stabilize
@@ -213,15 +260,13 @@ copy and make the following modifications to the quantization config:
213260
214261
.. code-block:: yaml
215262
216-
# Model arguments
217263
model:
218264
_component_: torchtune.models.llama3.llama3_8b
219265
220266
checkpointer:
221267
_component_: torchtune.training.FullModelMetaCheckpointer
222268
checkpoint_dir: <your QAT checkpoint dir>
223-
checkpoint_files: [meta_model_0.pt]
224-
recipe_checkpoint: null
269+
checkpoint_files: [ft-model-00001-of-00001.bin]
225270
output_dir: <your QAT checkpoint dir>
226271
model_type: LLAMA3
227272
@@ -259,25 +304,19 @@ integrated in torchtune. First, copy the evaluation config and make the followin
259304
260305
.. code-block:: yaml
261306
262-
# Model arguments
263307
model:
264308
_component_: torchtune.models.llama3.llama3_8b
265309
266310
checkpointer:
267311
_component_: torchtune.training.FullModelTorchTuneCheckpointer
268312
checkpoint_dir: <your quantized model checkpoint dir>
269-
checkpoint_files: [meta_model_0-8da4w.pt]
270-
recipe_checkpoint: null
313+
checkpoint_files: [ft-model-00001-of-00001-8da4w.bin]
271314
output_dir: <your quantized model checkpoint dir>
272315
model_type: LLAMA3
273316
274317
...
275318
276-
# EleutherAI specific eval args
277319
tasks: ["hellaswag", "wikitext"]
278-
limit: null
279-
max_seq_length: 8192
280-
batch_size: 8
281320
282321
quantizer:
283322
_component_: torchtune.training.quantization.Int8DynActInt4WeightQuantizer

recipes/quantize.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,15 @@ def quantize(self, cfg: DictConfig):
102102

103103
def save_checkpoint(self, cfg: DictConfig):
104104
ckpt_dict = self._model.state_dict()
105-
file_name = cfg.checkpointer.checkpoint_files[0].split(".")[0]
105+
split = cfg.checkpointer.checkpoint_files[0].split(".")
106+
file_name = split[0]
107+
suffix = split[-1]
106108

107109
output_dir = Path(cfg.checkpointer.output_dir)
108110
output_dir.mkdir(exist_ok=True)
109111
checkpoint_file = Path.joinpath(
110112
output_dir, f"{file_name}-{self._quantization_mode}".rstrip("-qat")
111-
).with_suffix(".pt")
113+
).with_suffix(suffix)
112114

113115
torch.save(ckpt_dict, checkpoint_file)
114116
logger.info(

0 commit comments

Comments
 (0)