Skip to content

Commit 5fca214

Browse files
QAT (#2590)
QAT and quantization w/torchao
1 parent 20fda75 commit 5fca214

26 files changed

Lines changed: 1372 additions & 13 deletions

_quarto.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ quartodoc:
4343
- cli.vllm_serve
4444
- cli.cloud.base
4545
- cli.cloud.modal_
46+
- cli.quantize
4647
- title: Trainers
4748
desc: Training implementations
4849
contents:
@@ -147,6 +148,7 @@ quartodoc:
147148
- utils.optimizers.adopt
148149
- utils.data.pretraining
149150
- utils.data.sft
151+
- utils.quantization
150152
- title: Schemas
151153
desc: Pydantic data models for Axolotl config
152154
contents:
@@ -196,7 +198,7 @@ quartodoc:
196198
- utils.callbacks.lisa
197199
- utils.callbacks.mlflow_
198200
- utils.callbacks.comet_
199-
201+
- utils.callbacks.qat
200202
website:
201203
title: "Axolotl"
202204
description: "We make fine-tuning accessible, scalable, and fun"
@@ -256,6 +258,8 @@ website:
256258
- docs/lr_groups.qmd
257259
- docs/lora_optims.qmd
258260
- docs/dataset_loading.qmd
261+
- docs/qat.qmd
262+
- docs/quantize.qmd
259263

260264
- section: "Core Concepts"
261265
contents:

docs/cli.qmd

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,16 @@ axolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir
209209

210210
This would be necessary to use with other frameworks. If you have an adapter, merge it with the non-quantized linearized model before delinearizing.
211211

212+
### quantize
213+
214+
Quantizes a model using the quantization configuration specified in your YAML file.
215+
216+
```bash
217+
axolotl quantize config.yml
218+
```
219+
220+
See [Quantization](./quantize.qmd) for more details.
221+
212222

213223
## Legacy CLI Usage
214224

docs/config.qmd

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,20 @@ bnb_config_kwargs:
6565
bnb_4bit_quant_type: nf4
6666
bnb_4bit_use_double_quant: true
6767

68+
# quantization aware training
69+
qat:
70+
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
71+
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"
72+
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
73+
fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after
74+
75+
# post-training quantization
76+
quantization:
77+
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8
78+
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
79+
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
80+
quantize_embedding: # Optional[bool] = False. Whether to quantize the embedding layer.
81+
6882

6983
# Whether you are training a 4-bit GPTQ quantized model
7084
gptq: true

docs/qat.qmd

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
---
2+
title: "Quantization Aware Training (QAT)"
3+
back-to-top-navigation: true
4+
toc: true
5+
toc-expand: 2
6+
toc-depth: 4
7+
---
8+
9+
## Overview
10+
11+
[Quantization Aware Training](https://pytorch.org/blog/introduction-to-quantization-on-pytorch/#quantization-aware-training) (QAT) is a technique for improving the accuracy of models which are quantized
12+
by applying "fake" quantizations to the model's weights (and optionally, activations) during training. This fake
13+
quantization allows for the model to adjust for noise introduced by the quantization, so when the model is eventually
14+
quantized, the accuracy loss is minimized. We use the quantization techniques implemented in [torchao](https://github.com/pytorch/ao) to provide
15+
support for QAT and post-training quantization (PTQ) in axolotl.
16+
17+
We recommend reviewing the excellent QAT tutorial in the [torchtune library](https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#quantizing-the-qat-model),
18+
and the QAT documentation in the [torchao library](https://github.com/pytorch/ao/tree/main/torchao/quantization/qat), for more details.
19+
20+
## Configuring QAT in Axolotl
21+
22+
To enable QAT in axolotl, add the following to your configuration file:
23+
24+
```yaml
25+
qat:
26+
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
27+
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"
28+
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
29+
fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after
30+
```
31+
32+
Once you have finished training, you must quantize your model by using the same quantization configuration which you used to train the model with. You can use the [`quantize` command](./quantize.md) to do this.

docs/quantize.qmd

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
---
2+
title: "Quantization with torchao"
3+
back-to-top-navigation: true
4+
toc: true
5+
toc-expand: 2
6+
toc-depth: 4
7+
---
8+
9+
Quantization is a technique to lower the memory footprint of your model, potentially at the cost of accuracy or model performance. We support quantizing your model using the [torchao](https://github.com/pytorch/ao) library. Quantization is supported for both post-training quantization (PTQ) and quantization-aware training (QAT).
10+
11+
12+
::: {.callout-note}
13+
14+
We do not currently support quantization techniques such as GGUF/GPTQ,EXL2 at the moment.
15+
16+
:::
17+
18+
## Configuring Quantization in Axolotl
19+
20+
Quantization is configured using the `quantization` key in your configuration file.
21+
22+
```yaml
23+
base_model: # The path to the model to quantize.
24+
quantization:
25+
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8
26+
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
27+
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
28+
quantize_embedding: # Optional[bool] = False. Whether to quantize the embedding layer.
29+
30+
output_dir: # The path to the output directory.
31+
```
32+
33+
Once quantization is complete, your quantized model will be saved in the `{output_dir}/quantized` directory.
34+
35+
You may also use the `quantize` command to quantize a model which has been trained with [QAT](./qat.md) - you can do this by using the existing QAT configuration file which
36+
you used to train the model:
37+
38+
```yaml
39+
# qat.yml
40+
qat:
41+
activation_dtype: int8
42+
weight_dtype: int8
43+
group_size: 256
44+
quantize_embedding: true
45+
46+
output_dir: # The path to the output directory used during training where the final checkpoint has been saved.
47+
```
48+
49+
```bash
50+
axolotl quantize qat.yml
51+
```
52+
53+
This ensures that an identical quantization configuration is used to quantize the model as was used to train it.

examples/llama-3/3b-qat-fsdp2.yaml

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
base_model: meta-llama/Llama-3.2-3B
2+
# Automatically upload checkpoint and final model to HF
3+
# hub_model_id: username/custom_model_name
4+
5+
load_in_8bit: false
6+
load_in_4bit: false
7+
strict: false
8+
9+
plugins:
10+
- axolotl.integrations.liger.LigerPlugin
11+
12+
liger_rope: true
13+
liger_rms_norm: true
14+
liger_glu_activation: true
15+
liger_layer_norm: true
16+
liger_fused_linear_cross_entropy: true
17+
18+
datasets:
19+
- path: yahma/alpaca-cleaned
20+
type: alpaca
21+
22+
output_dir: ./outputs/qat_out/
23+
24+
sample_packing: true
25+
pad_to_sequence_len: true
26+
sequence_len: 512
27+
28+
flex_attention: true
29+
flex_attn_compile_kwargs:
30+
dynamic: false
31+
mode: max-autotune-no-cudagraphs
32+
33+
qat:
34+
activation_dtype: int8
35+
weight_dtype: int4
36+
group_size: 32
37+
38+
wandb_project:
39+
wandb_entity:
40+
wandb_watch:
41+
wandb_name:
42+
wandb_log_model:
43+
44+
gradient_accumulation_steps: 1
45+
micro_batch_size: 16
46+
num_epochs: 1
47+
optimizer: adamw_torch_fused
48+
49+
cosine_constant_lr_ratio: 0
50+
cosine_min_lr_ratio: 1.0
51+
learning_rate: 2e-5
52+
save_only_model: true
53+
bf16: true
54+
55+
resume_from_checkpoint:
56+
logging_steps: 1
57+
58+
evals_per_epoch: 1
59+
saves_per_epoch: 1
60+
61+
warmup_steps: 10
62+
weight_decay: 0.0
63+
fsdp:
64+
- full_shard
65+
- auto_wrap
66+
67+
fsdp_config:
68+
fsdp_version: 2
69+
fsdp_offload_params: false
70+
fsdp_cpu_ram_efficient_loading: true
71+
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
72+
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
73+
fsdp_state_dict_type: FULL_STATE_DICT
74+
fsdp_sharding_strategy: FULL_SHARD
75+
fsdp_reshard_after_forward: true
76+
fsdp_activation_checkpointing: true
77+
78+
special_tokens:
79+
pad_token: <|end_of_text|>

examples/qwen3/8b-qat-fsdp2.yml

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
base_model: Qwen/Qwen3-8B
2+
# Automatically upload checkpoint and final model to HF
3+
# hub_model_id: username/custom_model_name
4+
5+
load_in_8bit: false
6+
load_in_4bit: false
7+
strict: false
8+
9+
plugins:
10+
- axolotl.integrations.liger.LigerPlugin
11+
12+
liger_rope: true
13+
liger_rms_norm: true
14+
liger_glu_activation: true
15+
liger_layer_norm: true
16+
liger_fused_linear_cross_entropy: true
17+
18+
datasets:
19+
- path: tatsu-lab/alpaca
20+
type: alpaca
21+
22+
output_dir: ./outputs/qat_out/
23+
24+
sequence_len: 2048
25+
sample_packing: true
26+
flex_attention: true
27+
pad_to_sequence_len: true
28+
29+
flex_attn_compile_kwargs:
30+
dynamic: false
31+
mode: max-autotune-no-cudagraphs
32+
33+
qat:
34+
activation_dtype: int8
35+
weight_dtype: int4
36+
group_size: 256
37+
fake_quant_after_n_steps: 1000
38+
39+
wandb_project:
40+
wandb_entity:
41+
wandb_watch:
42+
wandb_name:
43+
wandb_log_model:
44+
45+
gradient_accumulation_steps: 1
46+
micro_batch_size: 2
47+
max_steps: 2000
48+
optimizer: adamw_torch_fused
49+
lr_scheduler: cosine
50+
learning_rate: 2e-5
51+
52+
bf16: true
53+
tf32: true
54+
55+
resume_from_checkpoint:
56+
logging_steps: 1
57+
58+
evals_per_epoch: 1
59+
saves_per_epoch: 1
60+
61+
warmup_steps: 10
62+
weight_decay: 0.0
63+
fsdp:
64+
- full_shard
65+
- auto_wrap
66+
67+
fsdp_config:
68+
fsdp_version: 2
69+
fsdp_offload_params: false
70+
fsdp_cpu_ram_efficient_loading: true
71+
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
72+
fsdp_transformer_layer_cls_to_wrap: Qwen3DecoderLayer
73+
fsdp_state_dict_type: FULL_STATE_DICT
74+
fsdp_sharding_strategy: FULL_SHARD
75+
fsdp_reshard_after_forward: true
76+
fsdp_activation_checkpointing: true
77+
78+
special_tokens:

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ langdetect==1.0.9
6363
immutabledict==4.2.0
6464
antlr4-python3-runtime==4.13.2
6565

66-
torchao==0.9.0
66+
torchao==0.10.0
6767
schedulefree==1.4.1
6868

6969
axolotl-contribs-lgpl==0.0.6

src/axolotl/cli/args.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,18 @@ class VllmServeCliArgs:
9090
)
9191

9292

93+
@dataclass
94+
class QuantizeCliArgs:
95+
"""Dataclass with CLI arguments for `axolotl quantize` command."""
96+
97+
base_model: Optional[str] = field(default=None)
98+
weight_dtype: Optional[str] = field(default=None)
99+
activation_dtype: Optional[str] = field(default=None)
100+
quantize_embedding: Optional[bool] = field(default=None)
101+
group_size: Optional[int] = field(default=None)
102+
output_dir: Optional[str] = field(default=None)
103+
104+
93105
@dataclass
94106
class EvaluateCliArgs:
95107
"""Dataclass with CLI arguments for `axolotl evaluate` command."""

src/axolotl/cli/main.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from axolotl.cli.args import (
1818
EvaluateCliArgs,
1919
PreprocessCliArgs,
20+
QuantizeCliArgs,
2021
TrainerCliArgs,
2122
VllmServeCliArgs,
2223
)
@@ -333,6 +334,16 @@ def vllm_serve(config: str, **cli_args: VllmServeCliArgs):
333334
do_vllm_serve(config, cli_args)
334335

335336

337+
@cli.command()
338+
@click.argument("config", type=click.Path(exists=True, path_type=str))
339+
@add_options_from_dataclass(QuantizeCliArgs)
340+
@filter_none_kwargs
341+
def quantize(config: str, **cli_args: QuantizeCliArgs):
342+
from axolotl.cli.quantize import do_quantize
343+
344+
do_quantize(config, cli_args)
345+
346+
336347
@cli.command()
337348
@click.argument("model", type=click.Path(exists=True, path_type=str))
338349
@click.argument("output", type=click.Path(exists=False, path_type=str))

0 commit comments

Comments
 (0)