Skip to content
Merged

QAT #2590

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
14538c9
Initial QAT implementation
SalmanMohammadi Apr 29, 2025
646ff68
adding config
SalmanMohammadi Apr 30, 2025
0e47cde
running E2E training
SalmanMohammadi Apr 30, 2025
574e3d6
updating config
SalmanMohammadi Apr 30, 2025
72d4f77
adding callback for fake quant delay
SalmanMohammadi Apr 30, 2025
e45ed55
Merge branch 'main' into qat
SalmanMohammadi May 1, 2025
00d46ae
updating conf
SalmanMohammadi May 1, 2025
99a7aee
correctly delaying quantization
SalmanMohammadi May 1, 2025
38e5b75
adding post-train quantize, bugfixes
SalmanMohammadi May 2, 2025
f1eed54
revamping quantization apis
SalmanMohammadi May 2, 2025
86291c5
WIP PTQ support
SalmanMohammadi May 2, 2025
98cb0ee
Merge branch 'main' into qat
SalmanMohammadi May 5, 2025
9d6c066
adding tests
SalmanMohammadi May 6, 2025
16d1ca7
pushing for diff
SalmanMohammadi May 7, 2025
4c3cfd1
WIP tearing everything apart and putting it back together
SalmanMohammadi May 8, 2025
d8361b7
more tests
SalmanMohammadi May 9, 2025
ba52f1a
more testing, rounding out quantize cli
SalmanMohammadi May 12, 2025
983316e
things looking good
SalmanMohammadi May 13, 2025
a102c45
working state
SalmanMohammadi May 14, 2025
92512ac
fixing docs
SalmanMohammadi May 14, 2025
47cbc27
more docs
SalmanMohammadi May 14, 2025
e73f8d5
tidying
SalmanMohammadi May 14, 2025
8452510
reverting change
SalmanMohammadi May 14, 2025
f16f99b
more docs
SalmanMohammadi May 14, 2025
adb3aed
more comments
SalmanMohammadi May 16, 2025
3e8f6f3
more comments
SalmanMohammadi May 17, 2025
dcc82a3
updating cli
SalmanMohammadi May 19, 2025
0433b0c
updating conf
SalmanMohammadi May 27, 2025
002a637
merging
SalmanMohammadi May 27, 2025
2ed234b
linting
SalmanMohammadi May 27, 2025
2ca43c6
fixing docs
SalmanMohammadi May 27, 2025
4cd9dd5
Merge branch 'main' into qat
SalmanMohammadi May 27, 2025
d41d459
fixing import
SalmanMohammadi May 27, 2025
9cb48ac
silly billy me
SalmanMohammadi May 27, 2025
2f13f42
fixing import
SalmanMohammadi May 27, 2025
df2fd43
dan suggest good
SalmanMohammadi May 27, 2025
bcc72d9
fixing test
SalmanMohammadi May 27, 2025
e0bf629
Update examples/llama-3/3b-qat.yaml
SalmanMohammadi May 28, 2025
7f548b1
Update docs/quantize.qmd
SalmanMohammadi May 28, 2025
2079592
comments
SalmanMohammadi May 28, 2025
86f564c
CI
SalmanMohammadi May 28, 2025
529cdb8
Merge branch 'main' into qat
SalmanMohammadi May 28, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion _quarto.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ quartodoc:
- cli.vllm_serve
- cli.cloud.base
- cli.cloud.modal_
- cli.quantize
- title: Trainers
desc: Training implementations
contents:
Expand Down Expand Up @@ -147,6 +148,7 @@ quartodoc:
- utils.optimizers.adopt
- utils.data.pretraining
- utils.data.sft
- utils.quantization
- title: Schemas
desc: Pydantic data models for Axolotl config
contents:
Expand Down Expand Up @@ -196,7 +198,7 @@ quartodoc:
- utils.callbacks.lisa
- utils.callbacks.mlflow_
- utils.callbacks.comet_

- utils.callbacks.qat
website:
title: "Axolotl"
description: "We make fine-tuning accessible, scalable, and fun"
Expand Down Expand Up @@ -256,6 +258,8 @@ website:
- docs/lr_groups.qmd
- docs/lora_optims.qmd
- docs/dataset_loading.qmd
- docs/qat.qmd
- docs/quantize.qmd

- section: "Core Concepts"
contents:
Expand Down
10 changes: 10 additions & 0 deletions docs/cli.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,16 @@ axolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir

This would be necessary to use with other frameworks. If you have an adapter, merge it with the non-quantized linearized model before delinearizing.

### quantize

Quantizes a model using the quantization configuration specified in your YAML file.

```bash
axolotl quantize config.yml
```

See [Quantization](./quantize.qmd) for more details.


## Legacy CLI Usage

Expand Down
14 changes: 14 additions & 0 deletions docs/config.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,20 @@ bnb_config_kwargs:
bnb_4bit_quant_type: nf4
bnb_4bit_use_double_quant: true

# quantization aware training
qat:
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after

# post-training quantization
quantization:
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
quantize_embedding: # Optional[bool] = False. Whether to quantize the embedding layer.


# Whether you are training a 4-bit GPTQ quantized model
gptq: true
Expand Down
32 changes: 32 additions & 0 deletions docs/qat.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
---
title: "Quantization Aware Training (QAT)"
back-to-top-navigation: true
toc: true
toc-expand: 2
toc-depth: 4
---

## Overview

[Quantization Aware Training](https://pytorch.org/blog/introduction-to-quantization-on-pytorch/#quantization-aware-training) (QAT) is a technique for improving the accuracy of models which are quantized
by applying "fake" quantizations to the model's weights (and optionally, activations) during training. This fake
quantization allows for the model to adjust for noise introduced by the quantization, so when the model is eventually
quantized, the accuracy loss is minimized. We use the quantization techniques implemented in [torchao](https://github.com/pytorch/ao) to provide
support for QAT and post-training quantization (PTQ) in axolotl.

We recommend reviewing the excellent QAT tutorial in the [torchtune library](https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#quantizing-the-qat-model),
and the QAT documentation in the [torchao library](https://github.com/pytorch/ao/tree/main/torchao/quantization/qat), for more details.

## Configuring QAT in Axolotl

To enable QAT in axolotl, add the following to your configuration file:

```yaml
qat:
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
fake_quant_after_n_steps: # Optional[int] = None. The number of steps to apply fake quantization after
```

Once you have finished training, you must quantize your model by using the same quantization configuration which you used to train the model with. You can use the [`quantize` command](./quantize.md) to do this.
53 changes: 53 additions & 0 deletions docs/quantize.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
---
title: "Quantization with torchao"
back-to-top-navigation: true
toc: true
toc-expand: 2
toc-depth: 4
---

Quantization is a technique to lower the memory footprint of your model, potentially at the cost of accuracy or model performance. We support quantizing your model using the [torchao](https://github.com/pytorch/ao) library. Quantization is supported for both post-training quantization (PTQ) and quantization-aware training (QAT).


::: {.callout-note}

We do not currently support quantization techniques such as GGUF/GPTQ,EXL2 at the moment.

:::

## Configuring Quantization in Axolotl

Quantization is configured using the `quantization` key in your configuration file.

```yaml
base_model: # The path to the model to quantize.
quantization:
weight_dtype: # Optional[str] = "int8". Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8
activation_dtype: # Optional[str] = "int8". Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"
group_size: # Optional[int] = 32. The number of elements in each group for per-group fake quantization
quantize_embedding: # Optional[bool] = False. Whether to quantize the embedding layer.

output_dir: # The path to the output directory.
```

Once quantization is complete, your quantized model will be saved in the `{output_dir}/quantized` directory.

You may also use the `quantize` command to quantize a model which has been trained with [QAT](./qat.md) - you can do this by using the existing QAT configuration file which
you used to train the model:

```yaml
# qat.yml
qat:
activation_dtype: int8
weight_dtype: int8
group_size: 256
quantize_embedding: true

output_dir: # The path to the output directory used during training where the final checkpoint has been saved.
```

```bash
axolotl quantize qat.yml
```

This ensures that an identical quantization configuration is used to quantize the model as was used to train it.
79 changes: 79 additions & 0 deletions examples/llama-3/3b-qat-fsdp2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
base_model: meta-llama/Llama-3.2-3B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name

load_in_8bit: false
load_in_4bit: false
strict: false

plugins:
- axolotl.integrations.liger.LigerPlugin

liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true

datasets:
- path: yahma/alpaca-cleaned
type: alpaca

output_dir: ./outputs/qat_out/

sample_packing: true
pad_to_sequence_len: true
sequence_len: 512

flex_attention: true
flex_attn_compile_kwargs:
dynamic: false
mode: max-autotune-no-cudagraphs

qat:
activation_dtype: int8
weight_dtype: int4
group_size: 32

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 1
micro_batch_size: 16
num_epochs: 1
optimizer: adamw_torch_fused

cosine_constant_lr_ratio: 0
cosine_min_lr_ratio: 1.0
learning_rate: 2e-5
save_only_model: true
bf16: true

resume_from_checkpoint:
logging_steps: 1

evals_per_epoch: 1
saves_per_epoch: 1

warmup_steps: 10
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap

fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
fsdp_reshard_after_forward: true
fsdp_activation_checkpointing: true

special_tokens:
pad_token: <|end_of_text|>
78 changes: 78 additions & 0 deletions examples/qwen3/8b-qat-fsdp2.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
base_model: Qwen/Qwen3-8B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name

load_in_8bit: false
load_in_4bit: false
strict: false

plugins:
- axolotl.integrations.liger.LigerPlugin

liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true

datasets:
- path: tatsu-lab/alpaca
type: alpaca

output_dir: ./outputs/qat_out/

sequence_len: 2048
sample_packing: true
flex_attention: true
pad_to_sequence_len: true

flex_attn_compile_kwargs:
dynamic: false
mode: max-autotune-no-cudagraphs

qat:
activation_dtype: int8
weight_dtype: int4
group_size: 256
fake_quant_after_n_steps: 1000

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 1
micro_batch_size: 2
max_steps: 2000
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 2e-5

bf16: true
tf32: true

resume_from_checkpoint:
logging_steps: 1

evals_per_epoch: 1
saves_per_epoch: 1

warmup_steps: 10
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap

fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Qwen3DecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
fsdp_reshard_after_forward: true
fsdp_activation_checkpointing: true

special_tokens:
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2

torchao==0.9.0
torchao==0.10.0
schedulefree==1.4.1

axolotl-contribs-lgpl==0.0.6
Expand Down
12 changes: 12 additions & 0 deletions src/axolotl/cli/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,18 @@ class VllmServeCliArgs:
)


@dataclass
class QuantizeCliArgs:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since these are part of the pydantic model config schema, I think we don't need to duplicate here.

As an aside, we can probably get rid of these CLI arg classes in favor of using all pydantic model fields?

"""Dataclass with CLI arguments for `axolotl quantize` command."""

base_model: Optional[str] = field(default=None)
weight_dtype: Optional[str] = field(default=None)
activation_dtype: Optional[str] = field(default=None)
quantize_embedding: Optional[bool] = field(default=None)
group_size: Optional[int] = field(default=None)
output_dir: Optional[str] = field(default=None)


@dataclass
class EvaluateCliArgs:
"""Dataclass with CLI arguments for `axolotl evaluate` command."""
Expand Down
11 changes: 11 additions & 0 deletions src/axolotl/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from axolotl.cli.args import (
EvaluateCliArgs,
PreprocessCliArgs,
QuantizeCliArgs,
TrainerCliArgs,
VllmServeCliArgs,
)
Expand Down Expand Up @@ -333,6 +334,16 @@ def vllm_serve(config: str, **cli_args: VllmServeCliArgs):
do_vllm_serve(config, cli_args)


@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(QuantizeCliArgs)
@filter_none_kwargs
def quantize(config: str, **cli_args: QuantizeCliArgs):
from axolotl.cli.quantize import do_quantize

do_quantize(config, cli_args)


@cli.command()
@click.argument("model", type=click.Path(exists=True, path_type=str))
@click.argument("output", type=click.Path(exists=False, path_type=str))
Expand Down
Loading
Loading