Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions examples/qat_nvfp4/Gemma3-12B_baseline.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
base_model: google/gemma-3-12b-it
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name

load_in_8bit: false
load_in_4bit: false
strict: false

plugins:
- axolotl.integrations.liger.LigerPlugin

liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
seed: 42
chat_template: gemma3
datasets:
- path: tatsu-lab/alpaca
type: alpaca

output_dir: ./outputs/qat_out_gemma/

sequence_len: 8096
sample_packing: true
flash_attention: true

qat:
activation_dtype: nvfp4
weight_dtype: nvfp4
group_size: 16 # only group_size of 16 is supported with nvfp4

wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 1
micro_batch_size: 16

num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 4e-5

bf16: true
tf32: true

resume_from_checkpoint:
logging_steps: 1
save_strategy: "no"

# evals_per_epoch: 1
# saves_per_epoch: 1

warmup_ratio: 0.1
weight_decay: 0.0
fsdp_version: 2

fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true

special_tokens:

# save_first_step: true # uncomment this to validate checkpoint saving works with your config
72 changes: 72 additions & 0 deletions examples/qat_nvfp4/Gemma3-12B_qat.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
base_model: google/gemma-3-12b-it
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name

load_in_8bit: false
load_in_4bit: false
strict: false

plugins:
- axolotl.integrations.liger.LigerPlugin

liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
seed: 42
chat_template: gemma3
datasets:
- path: tatsu-lab/alpaca
type: alpaca

output_dir: ./outputs/qat_out_gemma/

sequence_len: 8096
sample_packing: true
flash_attention: true

qat:
activation_dtype: nvfp4
weight_dtype: nvfp4
group_size: 16 # only group_size of 16 is supported with nvfp4

wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 1
micro_batch_size: 16

num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 4e-5

bf16: true
tf32: true

resume_from_checkpoint:
logging_steps: 1

evals_per_epoch: 1
saves_per_epoch: 1

warmup_ratio: 0.1
weight_decay: 0.0
fsdp_version: 2

fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true

special_tokens:

# save_first_step: true # uncomment this to validate checkpoint saving works with your config
73 changes: 73 additions & 0 deletions examples/qat_nvfp4/Math-Gemma3-12B_baseline.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
base_model: google/gemma-3-12b-it
# Math finetuning configuration for Gemma3-12B
# hub_model_id: username/custom_model_name

load_in_8bit: false
load_in_4bit: false
strict: false

plugins:
- axolotl.integrations.liger.LigerPlugin

liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
seed: 42
chat_template: gemma3
datasets:
- path: AI-MO/NuminaMath-CoT
type: chat_template

output_dir: ./outputs/qat_out_math_gemma/

sequence_len: 4096
sample_packing: true
flash_attention: true

qat:
activation_dtype: nvfp4
weight_dtype: nvfp4
group_size: 16 # only group_size of 16 is supported with nvfp4

wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 1
micro_batch_size: 8

num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 3e-5

bf16: true
tf32: true

resume_from_checkpoint:
logging_steps: 1
save_strategy: "no"

# evals_per_epoch: 1
# saves_per_epoch: 1

warmup_ratio: 0.1
weight_decay: 0.0
fsdp_version: 2

fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true

special_tokens:

# save_first_step: true # uncomment this to validate checkpoint saving works with your config
73 changes: 73 additions & 0 deletions examples/qat_nvfp4/Math-Gemma3-12B_qat.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
base_model: google/gemma-3-12b-it
# Math finetuning configuration for Gemma3-12B
# hub_model_id: username/custom_model_name

load_in_8bit: false
load_in_4bit: false
strict: false

plugins:
- axolotl.integrations.liger.LigerPlugin

liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
seed: 42
chat_template: gemma3
datasets:
- path: AI-MO/NuminaMath-CoT
type: chat_template

output_dir: ./outputs/qat_out_math_gemma/

sequence_len: 4096
sample_packing: true
flash_attention: true

qat:
activation_dtype: nvfp4
weight_dtype: nvfp4
group_size: 16 # only group_size of 16 is supported with nvfp4

wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 1
micro_batch_size: 8

num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 3e-5

bf16: true
tf32: true

resume_from_checkpoint:
logging_steps: 1
save_strategy: "no"

# evals_per_epoch: 1
# saves_per_epoch: 1

warmup_ratio: 0.1
weight_decay: 0.0
fsdp_version: 2

fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true

special_tokens:

# save_first_step: true # uncomment this to validate checkpoint saving works with your config
74 changes: 74 additions & 0 deletions examples/qat_nvfp4/Math-Gemma3-27B_baseline.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
base_model: google/gemma-3-27b-it
# Math finetuning configuration for Gemma3-27B
# hub_model_id: username/custom_model_name

load_in_8bit: false
load_in_4bit: false
strict: false

plugins:
- axolotl.integrations.liger.LigerPlugin

liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
seed: 42
chat_template: gemma3
datasets:
- path: AI-MO/NuminaMath-CoT
type: chat_template

output_dir: ./outputs/qat_out_math_gemma27/

sequence_len: 4096
sample_packing: true
flash_attention: true

qat:
activation_dtype: nvfp4
weight_dtype: nvfp4
group_size: 16 # only group_size of 16 is supported with nvfp4

wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 1
micro_batch_size: 16

num_epochs: 1
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 5e-6
eta_min: 7e-7

bf16: true
tf32: true

resume_from_checkpoint:
logging_steps: 1
save_strategy: "no"

# evals_per_epoch: 1
# saves_per_epoch: 1

warmup_ratio: 0.1
weight_decay: 0.0
fsdp_version: 2

fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Gemma3DecoderLayer
state_dict_type: FULL_STATE_DICT
sharding_strategy: FULL_SHARD
reshard_after_forward: true
activation_checkpointing: true

special_tokens:

# save_first_step: true # uncomment this to validate checkpoint saving works with your config
Loading