Skip to content
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
d355954
migrating QAT API
SalmanMohammadi Aug 26, 2025
dd718b8
updating tests
SalmanMohammadi Aug 27, 2025
1003432
updating cli
SalmanMohammadi Aug 27, 2025
9f9ef8c
updating cli
SalmanMohammadi Aug 27, 2025
75ed197
Merge branch 'main' into qat_migration
SalmanMohammadi Aug 27, 2025
6166dbf
adding quant config
SalmanMohammadi Aug 29, 2025
ed9ef69
updating APIs
SalmanMohammadi Aug 29, 2025
c12d131
Merge branch 'main' into qat_migration
SalmanMohammadi Aug 29, 2025
563c200
linting
SalmanMohammadi Aug 29, 2025
de9b10f
Merge branch 'qat_migration' of github.com:axolotl-ai-cloud/axolotl i…
SalmanMohammadi Aug 29, 2025
218aa40
fixing tests
SalmanMohammadi Aug 29, 2025
ddeba5a
updating ptqconfig
SalmanMohammadi Aug 29, 2025
17b6051
updating quantization.py
SalmanMohammadi Aug 29, 2025
8b3e550
linting
SalmanMohammadi Aug 29, 2025
450b92f
bump ao
SalmanMohammadi Aug 29, 2025
d4f5f5a
bump ao
SalmanMohammadi Aug 29, 2025
1cada45
bump ao
SalmanMohammadi Aug 29, 2025
b218c7c
bump ao
SalmanMohammadi Aug 29, 2025
4345ae6
bump ao
SalmanMohammadi Aug 29, 2025
72f7820
fix language
SalmanMohammadi Aug 29, 2025
f4b7c26
comments
SalmanMohammadi Sep 1, 2025
38ac691
comments
SalmanMohammadi Sep 1, 2025
900d4a1
adding nvfp4
SalmanMohammadi Sep 1, 2025
5b1f478
updating tests
SalmanMohammadi Sep 1, 2025
80fb7da
fix dtype
SalmanMohammadi Sep 1, 2025
47eb791
fix dtype
SalmanMohammadi Sep 1, 2025
b0ccde1
updating tests
SalmanMohammadi Sep 2, 2025
84f6889
fixing accelerator
SalmanMohammadi Sep 2, 2025
c3f4048
fixing config
SalmanMohammadi Sep 2, 2025
029734b
adding support for push to hub in quantize
SalmanMohammadi Sep 2, 2025
78668b5
linting
SalmanMohammadi Sep 2, 2025
5c095c3
Merge branch 'main' into qat_migration
SalmanMohammadi Sep 2, 2025
e2f5dd5
updating nvfp4 config
SalmanMohammadi Sep 2, 2025
f6ec879
Merge branch 'qat_migration' of github.com:axolotl-ai-cloud/axolotl i…
SalmanMohammadi Sep 2, 2025
5bc768a
disable safetensors for push to hub
SalmanMohammadi Sep 3, 2025
154315f
force config on push to hub
SalmanMohammadi Sep 3, 2025
a5ecc05
log
SalmanMohammadi Sep 3, 2025
ae7d876
cli hub_model_id
SalmanMohammadi Sep 3, 2025
225e1d8
cli hub_model_id
SalmanMohammadi Sep 3, 2025
fac195d
adding quant strs
SalmanMohammadi Sep 3, 2025
5a49579
adding quant strs
SalmanMohammadi Sep 3, 2025
33dc44c
adding quant strs
SalmanMohammadi Sep 3, 2025
9932b4f
fix quant_type kwarg
SalmanMohammadi Sep 3, 2025
4a07a17
tkps
SalmanMohammadi Sep 3, 2025
ca1a0b7
updating conf
SalmanMohammadi Sep 3, 2025
94554a1
linting
SalmanMohammadi Sep 3, 2025
a1a3d14
Merge branch 'main' of github.com:axolotl-ai-cloud/axolotl into qat_m…
SalmanMohammadi Sep 3, 2025
d8a8c75
dont need to specify model config
SalmanMohammadi Sep 3, 2025
a0ff954
comments
SalmanMohammadi Sep 4, 2025
9a288c7
adding more aliases
SalmanMohammadi Sep 4, 2025
0ae60e1
fixing fbgemm import [skip-e2e]
SalmanMohammadi Sep 5, 2025
bfea773
updating gpu runner step
SalmanMohammadi Sep 5, 2025
c7bb62d
updating install command
SalmanMohammadi Sep 5, 2025
8da2c6b
disable default include_tkps
SalmanMohammadi Sep 5, 2025
766677c
linting
SalmanMohammadi Sep 5, 2025
c17ca49
trying extras
SalmanMohammadi Sep 5, 2025
3945aaa
2.8 only
SalmanMohammadi Sep 5, 2025
4392627
guard int4weightonly import
SalmanMohammadi Sep 8, 2025
8558ec9
guard int4weightonly import
SalmanMohammadi Sep 8, 2025
1ffac1f
only import on 2.8
SalmanMohammadi Sep 8, 2025
e62e637
stray comma
SalmanMohammadi Sep 8, 2025
6a8ed51
only attempt install on 2.8
SalmanMohammadi Sep 8, 2025
54bbc30
Merge branch 'main' into qat_migration
SalmanMohammadi Sep 8, 2025
43f7eb1
fix tests
SalmanMohammadi Sep 8, 2025
d223aeb
fix test case
SalmanMohammadi Sep 8, 2025
13b51f9
fixing tests for b200s
SalmanMohammadi Sep 8, 2025
2311bc5
comments
SalmanMohammadi Sep 9, 2025
ef53534
fixing tests
SalmanMohammadi Sep 9, 2025
835d030
Merge branch 'main' into qat_migration
SalmanMohammadi Sep 9, 2025
0e38530
fix test
SalmanMohammadi Sep 9, 2025
3faf7bd
Merge branch 'qat_migration' of github.com:axolotl-ai-cloud/axolotl i…
SalmanMohammadi Sep 9, 2025
f35c2f9
Merge branch 'main' into qat_migration
SalmanMohammadi Sep 10, 2025
0986ff0
fix group size defaults
SalmanMohammadi Sep 10, 2025
c284726
Merge branch 'qat_migration' of github.com:axolotl-ai-cloud/axolotl i…
SalmanMohammadi Sep 10, 2025
320a722
comments
SalmanMohammadi Sep 10, 2025
303439e
tests
SalmanMohammadi Sep 10, 2025
6e455bd
Merge branch 'main' into qat_migration
SalmanMohammadi Sep 11, 2025
23f0895
removing int4fp8 case
SalmanMohammadi Sep 11, 2025
c4f8e26
lint
SalmanMohammadi Sep 11, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions examples/llama-3/3b-qat-fsdp2-nfvp4.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
base_model: meta-llama/Llama-3.2-3B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name

load_in_8bit: false
load_in_4bit: false
strict: false

datasets:
- path: yahma/alpaca-cleaned
type: alpaca
split: train[:95%]

output_dir: ./outputs/qat_out/
dataset_prepared_path: ./outputs/qat_out/dataset_prepared

sample_packing: true
sequence_len: 8192

flex_attention: true
flex_attn_compile_kwargs:
dynamic: false
mode: max-autotune-no-cudagraphs

qat:
activation_dtype: nvfp4
weight_dtype: nvfp4
group_size: 16 # only group_size of 16 is supported with nvfp4

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 1
micro_batch_size: 16
num_epochs: 1
optimizer: adamw_torch_fused

cosine_constant_lr_ratio: 0
cosine_min_lr_ratio: 1.0
learning_rate: 2e-5
save_only_model: true
bf16: true

resume_from_checkpoint:
logging_steps: 1

evals_per_epoch: 1
saves_per_epoch: 1

warmup_ratio: 0.1
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap

fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_cpu_ram_efficient_loading: false
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
fsdp_reshard_after_forward: true
fsdp_activation_checkpointing: true

special_tokens:
pad_token: <|finetune_right_pad_id|>

# save_first_step: true # uncomment this to validate checkpoint saving works with your config
18 changes: 5 additions & 13 deletions examples/llama-3/3b-qat-fsdp2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,16 @@ load_in_8bit: false
load_in_4bit: false
strict: false

plugins:
- axolotl.integrations.liger.LigerPlugin

liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true

datasets:
- path: yahma/alpaca-cleaned
type: alpaca
split: train[:95%]

output_dir: ./outputs/qat_out/
dataset_prepared_path: ./outputs/qat_out/dataset_prepared

sample_packing: true

sequence_len: 512
sequence_len: 8192

flex_attention: true
flex_attn_compile_kwargs:
Expand Down Expand Up @@ -67,7 +59,7 @@ fsdp:
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_cpu_ram_efficient_loading: false
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
Expand All @@ -76,6 +68,6 @@ fsdp_config:
fsdp_activation_checkpointing: true

special_tokens:
pad_token: <|end_of_text|>
pad_token: <|finetune_right_pad_id|>

# save_first_step: true # uncomment this to validate checkpoint saving works with your config
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2

torchao==0.12.0
torchao @ git+https://github.com/pytorch/torchao.git@13029fb6855bc19ceb8215b6dab204146908464b
schedulefree==1.4.1

axolotl-contribs-lgpl==0.0.6
Expand Down
28 changes: 21 additions & 7 deletions src/axolotl/cli/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@
from pathlib import Path
from typing import Union

from transformers import AutoModelForCausalLM
from transformers import AutoModelForCausalLM, TorchAoConfig, AutoConfig

from axolotl.cli.config import load_cfg
from axolotl.loaders import load_tokenizer
from axolotl.utils.logging import get_logger
from axolotl.utils.quantization import TorchIntDType, quantize_model_for_ptq
from axolotl.utils.quantization import (
get_quantization_config,
quantize_model,
TorchAOQuantDType,
)

LOG = get_logger(__name__)

Expand Down Expand Up @@ -43,13 +47,13 @@ def do_quantize(
"No quantization configuration found. Please specify either qat or quantization in your config file."
)

model_path = cli_args.get("model_path") or cfg.output_dir
model_path = cli_args.get("base_model") or cfg.output_dir
if weight_dtype := cli_args.get("weight_dtype"):
weight_dtype = TorchIntDType[weight_dtype]
weight_dtype = TorchAOQuantDType.from_string(weight_dtype)
else:
weight_dtype = quantize_cfg.weight_dtype
if activation_dtype := cli_args.get("activation_dtype"):
activation_dtype = TorchIntDType[activation_dtype]
activation_dtype = TorchAOQuantDType.from_string(activation_dtype)
else:
activation_dtype = quantize_cfg.activation_dtype
group_size = cli_args.get("group_size") or quantize_cfg.group_size
Expand All @@ -60,7 +64,11 @@ def do_quantize(

LOG.info(f"Loading model from {model_path}...")
tokenizer = load_tokenizer(cfg)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
config = AutoConfig.from_pretrained(model_path)
torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None
model = AutoModelForCausalLM.from_pretrained(
model_path, device_map="auto", torch_dtype=torch_dtype
)

LOG.info(
f"Quantizing model with configuration: \n"
Expand All @@ -70,10 +78,16 @@ def do_quantize(
f"\tquantize_embedding: {quantize_embedding}"
)

quantize_model_for_ptq(
quantize_model(
model, weight_dtype, group_size, activation_dtype, quantize_embedding
)

quantization_config = TorchAoConfig(
get_quantization_config(weight_dtype, activation_dtype, group_size),
include_input_output_embeddings=quantize_embedding,
)
model.quantization_config = quantization_config

LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}...")
model.save_pretrained(
str(Path(output_dir) / "quantized"),
Expand Down
19 changes: 6 additions & 13 deletions src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,7 @@
fix_untrained_tokens,
)
from axolotl.integrations.base import PluginManager
from axolotl.loaders import (
ModelLoader,
load_processor,
load_tokenizer,
)
from axolotl.loaders import load_processor, load_tokenizer, ModelLoader
from axolotl.utils.ctx_managers.sequence_parallel import SequenceParallelContextManager
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import cleanup_distributed
Expand Down Expand Up @@ -234,16 +230,15 @@ def save_trained_model(

# handle QAT
if cfg.qat:
from axolotl.utils.quantization import convert_qat_model_for_ptq
from axolotl.utils.quantization import convert_qat_model

LOG.info("Processing QAT model for saving...")
convert_qat_model_for_ptq(
convert_qat_model(
model,
quantize_embedding=cfg.qat.quantize_embedding,
)
LOG.info(
"QAT modules have been converted for PTQ. Please ensure you quantize "
"your model weights with `axolotl quantize`."
"QAT usage note: please ensure you quantize your model fine-tuned using QAT by running `axolotl quantize`"
" with the same config which you used for training."
)
# Handle ReLoRA early return case
if cfg.relora:
Expand Down Expand Up @@ -337,9 +332,7 @@ def save_trained_model(

if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
# TODO: add integration support so this can be implemented completely within the plugin
from axolotl.integrations.llm_compressor.utils import (
save_compressed_model,
)
from axolotl.integrations.llm_compressor.utils import save_compressed_model

save_compressed_model(
model=model,
Expand Down
Loading