-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
QAT #2590
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
QAT #2590
Conversation
Codecov ReportAttention: Patch coverage is 📢 Thoughts on this report? Let us know! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
tests/e2e/test_quantization.py (1)
196-206: Duplicate issue: Limited scope of Linear layer checking.Same issue as mentioned previously - this only checks direct children for Linear layers.
Consider applying the same recursive approach suggested for lines 150-164 to ensure all Linear layers in the model hierarchy are properly quantized.
🧹 Nitpick comments (2)
tests/e2e/test_quantization.py (2)
40-51: Consider the necessity of recreating the embedding layer.The fixture recreates the embedding layer instead of using the original pretrained one. While this may be intentional for testing purposes, it changes the model structure from the pretrained version.
If the embedding layer recreation is not essential for the tests, consider using the original pretrained embedding:
- with torch.device(dummy_model.device): - dummy_model.model.embed_tokens = torch.nn.Embedding( - dummy_model.model.embed_tokens.weight.shape[0], - dummy_model.model.embed_tokens.weight.shape[1], - dtype=dummy_model.model.embed_tokens.weight.dtype, - )Alternatively, add a comment explaining why this recreation is necessary for the tests.
150-164: Limited scope of Linear layer checking may miss nested modules.The test only checks direct children of the model for Linear layers, which may miss nested Linear layers in deeper parts of the model hierarchy.
Consider using a recursive approach to find all Linear layers:
- for child in list(model.children()): - if isinstance(child, torch.nn.Linear): + def check_linear_layers(module): + for child in module.children(): + if isinstance(child, torch.nn.Linear): + assert isinstance(child, FakeQuantizedLinear) + assert hasattr(child, "weight_fake_quantizer") + assert child.weight_fake_quantizer.config.dtype == weight_dtype.value + assert child.weight_fake_quantizer.config.group_size == group_size + if activation_dtype: + assert hasattr(child, "activation_fake_quantizer") + assert ( + child.activation_fake_quantizer.config.dtype + == activation_dtype.value + ) + else: + assert child.activation_fake_quantizer is None + else: + check_linear_layers(child) + + check_linear_layers(model)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/axolotl/utils/quantization.py(1 hunks)src/axolotl/utils/schemas/enums.py(1 hunks)tests/e2e/test_quantization.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- src/axolotl/utils/schemas/enums.py
- src/axolotl/utils/quantization.py
⏰ Context from checks skipped due to timeout of 90000ms (9)
- GitHub Check: preview
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: pre-commit
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: pre-commit
🔇 Additional comments (4)
tests/e2e/test_quantization.py (4)
88-90: Good error case coverage for invalid quantization configurations.The test cases correctly expect ValueError exceptions for int8 and int4 quantization without group_size, which is a requirement for these quantization schemes.
219-309: Comprehensive testing of QAT callback functionality.The tests effectively validate both scenarios: delayed fake quantization activation after N steps and immediate activation when no delay is specified. The assertions properly check the enabled/disabled state of quantizers.
316-351: Thorough testing of QAT to PTQ conversion.The test validates that FakeQuantized modules are properly replaced with standard modules and that weights remain as nn.Parameter objects after conversion, which is the expected behavior for PTQ.
54-79: Well-structured parametrized test data.The test cases provide comprehensive coverage of different quantization configurations including various dtypes (uint4, int4, int8), activation dtypes, and group sizes. The expected types and parameters are correctly specified for each configuration.
NanoCode012
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome stuff, just some minor doc comments
| """ | ||
| print_axolotl_text_art() | ||
|
|
||
| cfg = load_cfg(config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need to pass cli_args through load_cfg to overwrite?
| quantize_embedding = ( | ||
| cli_args.get("quantize_embedding") or quantize_cfg.quantize_embedding | ||
| ) | ||
| output_dir = cli_args.get("output_dir") or cfg.output_dir |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should not be needed if cli args is passed into load_cfg
| fsdp_cpu_ram_efficient_loading: true | ||
| fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP | ||
| fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer | ||
| fsdp_state_dict_type: FULL_STATE_DICT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you this need to be sharded?
ValueError: FSDP2 only supports SHARDED_STATE_DICT for now. Please set fsdp_state_dict_type to SHARDED_STATE_DICT.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same for other config
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've patched in full state dict support in this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's great! Any advice when one would be used over the other?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess you might want to use sharded state dicts for very large models. Sharded state dicts use each rank to separately save a portion of the total model. This means we don't need to gather on rank 0, which may be slow, or require significant amounts of CPU VRAM. However, you would need to manually merge the sharded checkpoints after - so distributed checkpoints may be suitable for checkpointing during training when you need to recover from training easily.
Full state dicts will gather on rank 0 and save the full model like normal. It's more convenient as you don't need to merge the weights.
Co-authored-by: NanoCode012 <[email protected]>
Co-authored-by: NanoCode012 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🔭 Outside diff range comments (1)
examples/qwen3/8b-qat-fsdp2.yml (1)
78-79:⚠️ Potential issueInitialize
special_tokensor remove the field
An empty mapping will break parsing. Set it to an empty list or provide actual tokens.Example fix:
- special_tokens: + special_tokens: []
🧹 Nitpick comments (6)
tests/e2e/multigpu/solo/test_grpo.py (1)
120-123: Good bug fix for proper file cleanup.The change correctly replaces
shutil.rmtree(which is for directories) withos.remove(for files) when cleaning up/tmp/vllm.log. The try-except block appropriately handles the case where the file might not exist.Consider using
contextlib.suppressfor cleaner code as suggested by static analysis:- try: - os.remove("/tmp/vllm.log") - except FileNotFoundError: - pass + from contextlib import suppress + with suppress(FileNotFoundError): + os.remove("/tmp/vllm.log")This provides the same functionality with more concise and idiomatic Python code.
🧰 Tools
🪛 Ruff (0.11.9)
120-123: Use
contextlib.suppress(FileNotFoundError)instead oftry-except-passReplace with
contextlib.suppress(FileNotFoundError)(SIM105)
examples/qwen3/8b-qat-fsdp2.yml (1)
45-54: Training hyperparameters are well-tuned
Batch sizes, optimizer, scheduler, learning rate, and mixed-precision flags (bf16/tf32) look appropriate for QAT.Consider adding a fixed
seedparameter for reproducibility:learning_rate: 2e-5 + seed: 42 bf16: trueexamples/llama-3/3b-qat-fsdp2.yaml (4)
9-17: Consider nesting plugin-specific flags for clarity
Theliger_*options are top-level. To improve readability and maintainability, you could move these under aplugin_configsection or within thepluginsblock, e.g.:plugins: - name: axolotl.integrations.liger.LigerPlugin config: rope: true rms_norm: true ...
38-43: Populate or remove empty wandb fields
The WandB entries (project,entity,watch, etc.) are blank. Either fill these with your project details or omit them to avoid runtime errors in the logging callback.
61-62: Weight decay and warmup settings
warmup_steps: 10andweight_decay: 0.0are clear. If you need regularization, consider a non-zeroweight_decay.
78-80: Quote special token to ensure YAML parsing
Thepad_tokenvalue contains YAML-special characters and should be quoted:- pad_token: <|end_of_text|> + pad_token: "<|end_of_text|>"
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
docs/quantize.qmd(1 hunks)examples/llama-3/3b-qat-fsdp2.yaml(1 hunks)examples/qwen3/8b-qat-fsdp2.yml(1 hunks)src/axolotl/cli/quantize.py(1 hunks)src/axolotl/utils/distributed.py(1 hunks)tests/e2e/multigpu/solo/test_grpo.py(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- src/axolotl/utils/distributed.py
🚧 Files skipped from review as they are similar to previous changes (2)
- src/axolotl/cli/quantize.py
- docs/quantize.qmd
🧰 Additional context used
🪛 Ruff (0.11.9)
tests/e2e/multigpu/solo/test_grpo.py
120-123: Use contextlib.suppress(FileNotFoundError) instead of try-except-pass
Replace with contextlib.suppress(FileNotFoundError)
(SIM105)
⏰ Context from checks skipped due to timeout of 90000ms (9)
- GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.5.1, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2, true)
- GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.6.0, vllm, 2, true)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
🔇 Additional comments (18)
examples/qwen3/8b-qat-fsdp2.yml (9)
1-8: Base configuration and loading flags are correctly set
Thebase_modelpoints to the intended HF checkpoint, and disabling 8-bit/4-bit loading upfront is appropriate for QAT workflows.
9-17: Liger plugin integration looks good
All necessary Liger optimizations (rope, RMSNorm, GLU, layer norm, fused X-ent) are enabled. This aligns with the new plugin schema.
18-21: Dataset specification is valid
Using the Alpaca dataset from tatsu-lab for fine-tuning matches recommendations and the schema’stype: alpaca.
22-32: Output, sequence, and attention configurations are correct
output_dir,sequence_len,sample_packing,flex_attention, and compile kwargs are consistent with other example configs.
33-38: QAT parameters are properly defined
Activation/weight dtypes, group size, and fake-quant step threshold align with intended quantization-aware training defaults.
39-44: Validate or remove empty WandB configuration entries
Emptywandb_project,wandb_entity, etc., may trigger validation errors or unintended behavior. Remove unused keys or populate them with valid strings.Apply one of the following diffs as appropriate:
- wandb_project: - wandb_entity: - wandb_watch: - wandb_name: - wandb_log_model: # If you intend to use WandB, replace blank values: + wandb_project: your_project_name + wandb_entity: your_entity_name + wandb_watch: all + wandb_name: qwen3-8b-qat + wandb_log_model: true
57-60: Evaluation and save frequencies are reasonable
Running one eval and one save per epoch will give timely insights without overloading storage.
61-63: Warmup and weight decay settings look good
A short warmup and zero weight decay is standard for fine-tuning with QAT.
64-77:❓ Verification inconclusive
Verify
Qwen3DecoderLayermatches the model’s layer class
The auto-wrap policy targetsQwen3DecoderLayer. Please confirm that this identifier exactly matches the transformer layer class in the Qwen3-8B implementation to avoid FSDP wrapping errors.You can grep for the correct class name in your local checkout of the model code, for example:
🏁 Script executed:
rg "class .*DecoderLayer" -A2 path/to/qwen3Length of output: 99
🏁 Script executed:
#!/bin/bash # Search for a direct definition of Qwen3DecoderLayer rg --no-heading --color never "class .*Qwen3DecoderLayer" -A2 . # If not found, look for any DecoderLayer subclasses rg --no-heading --color never "class .*DecoderLayer" -A2 . # Locate any directories or files related to 'qwen3' in the repo find . -type d | grep -i qwen3 || echo "No qwen3 directory found"Length of output: 1311
Verify
Qwen3DecoderLayermatches the transformer layer class
I wasn’t able to locate anyQwen3DecoderLayerdefinition in this repo. Please confirm that this identifier exactly matches the class name used by the Qwen3-8B model (typically in the Hugging Facetransformerscode). For example, you can inspect the Qwen3 implementation with:# Find where Qwen3DecoderLayer is defined in your transformers install grep -R "class Qwen3DecoderLayer" \ "$(python3 -c 'import transformers; print(transformers.__path__[0])')/models/qwen3"Or check the
_no_split_modulessetting inmodeling_qwen3.pyto see the exact layer class name used for auto-wrapping.examples/llama-3/3b-qat-fsdp2.yaml (9)
5-7: Correctly disabling quantized loading for QAT
Disabling bothload_in_8bitandload_in_4bitensures the model is loaded in full precision before applying fake quantization, which is essential for QAT.
18-21: Dataset configuration looks good
Pointing toyahma/alpaca-cleanedwithtype: alpacaaligns with the tutorial recommendations.
22-27: Output and sampling settings are clear
output_dirand sample packing options (sample_packing,pad_to_sequence_len,sequence_len) are correctly set for efficient data processing.
28-32: Flex attention configuration is valid
Enablingflex_attentionwith the provided compile arguments is appropriate for performance.
44-53: Verify learning rate scheduler parameters
Settingcosine_constant_lr_ratio: 0andcosine_min_lr_ratio: 1.0results in no LR decay (min LR = base LR). Confirm this matches your intent; oftenmin_lr_ratio< 1.0 is desired for a decay schedule.
58-60: Evaluation and save frequencies are explicit
evals_per_epochandsaves_per_epochset to1is a reasonable starting point for quick feedback loops.
63-77: FSDP v2 configuration is comprehensive
The combination offull_shard,auto_wrap, activation checkpointing, offload, and transformer wrap policies aligns with best practices for large-model training.
1-4:✅ Verification successful
Verify the base_model identifier
The model IDmeta-llama/Llama-3.2-3Bmay not match an existing Hugging Face repo. Confirm the exact repository name (e.g.,meta-llama/Llama-3-2-3B) to avoid model loading errors.Run:
🏁 Script executed:
gh api "repos/meta-llama/Llama-3.2-3B" >/dev/null && echo "Exists" || echo "Missing"Length of output: 86
🏁 Script executed:
#!/bin/bash # Install Hugging Face Hub SDK to list models under the meta-llama namespace pip install huggingface-hub >/dev/null 2>&1 python - <<'EOF' from huggingface_hub import list_models # List all model IDs in the meta-llama namespace for model in list_models(author="meta-llama", use_auth_token=False): print(model.modelId) EOFLength of output: 2431
No change needed:
base_modelidentifier is valid
The modelmeta-llama/Llama-3.2-3Bexists in the Hugging Face “meta-llama” namespace, so no update is required.• examples/llama-3/3b-qat-fsdp2.yaml (line 1):
base_model: meta-llama/Llama-3.2-3B[exists on HF Hub]
33-37:✅ Verification successful
Validate QAT schema coverage
Theqatsection defines dtypes andgroup_size, but I don't see a parameter controlling when fake quantization starts (e.g.,fake_quant_start_step). Verify against the schema to ensure all required fields are present and defaults align with your intended workflow.
🏁 Script executed:
#!/bin/bash # Check QAT config schema for required keys ast-grep --pattern 'struct QuantizationAwareTrainingConfig { $$$ fake_quant_start_$_ ; $$$ }' src/axolotl/utils/schemas/quantization.pyLength of output: 137
🏁 Script executed:
#!/bin/bash # Locate the QuantizationAwareTrainingConfig definition echo "Locating QuantizationAwareTrainingConfig..." rg -l "QuantizationAwareTrainingConfig" -g "*.py" # If found, display its first 200 lines to inspect defined fields SCHEMA_FILE=$(rg -l "QuantizationAwareTrainingConfig" -g "*.py" | head -n1) if [[ -n "$SCHEMA_FILE" ]]; then echo -e "\n--- Schema content (first 200 lines) ---" sed -n '1,200p' "$SCHEMA_FILE" else echo "No schema file found for QuantizationAwareTrainingConfig." fi # Search globally for any fake_quant-related parameters echo -e "\nSearching for fake_quant parameters..." rg -n "fake_quant" -g "*.py"Length of output: 10520
🏁 Script executed:
#!/bin/bash # Display the QuantizationAwareTrainingConfig schema to verify fake_quant_after_n_steps field echo "Displaying src/axolotl/utils/schemas/quantization.py (lines 1-100):" sed -n '1,100p' src/axolotl/utils/schemas/quantization.pyLength of output: 2137
Fake quant start step exists in schema
TheQATConfigmodel already includes afake_quant_after_n_steps: int | None = Nonefield (defaultNone⇒ QAT begins immediately). If you want to delay when fake-quant kicks in, add it under yourqat:block:qat: activation_dtype: int8 weight_dtype: int4 group_size: 32 fake_quant_after_n_steps: <step>Otherwise, omitting it is valid and QAT will start from step 0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
src/axolotl/utils/schemas/config.py (1)
1488-1525: Comprehensive QAT validation logic with good error handling.The validator correctly checks all QAT incompatibilities and version requirements. The torch version extraction with fallback is well-implemented.
Consider combining the nested if statements for better readability:
- if ( - data.get("fsdp") - and data.get("fsdp_config") - and str(data["fsdp_config"].get("fsdp_version")) == "2" - ): - if version.parse(torch_version) < version.parse("2.7.0"): - raise ValueError( - "FSDP2 and QAT are not supported on torch version < 2.7.0" - ) + if ( + data.get("fsdp") + and data.get("fsdp_config") + and str(data["fsdp_config"].get("fsdp_version")) == "2" + and version.parse(torch_version) < version.parse("2.7.0") + ): + raise ValueError( + "FSDP2 and QAT are not supported on torch version < 2.7.0" + )🧰 Tools
🪛 Ruff (0.11.9)
1512-1517: Use a single
ifstatement instead of nestedifstatementsCombine
ifstatements usingand(SIM102)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
_quarto.yml(4 hunks)docs/config.qmd(1 hunks)src/axolotl/utils/schemas/config.py(4 hunks)src/axolotl/utils/schemas/enums.py(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- src/axolotl/utils/schemas/enums.py
🚧 Files skipped from review as they are similar to previous changes (2)
- _quarto.yml
- docs/config.qmd
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/utils/schemas/config.py (1)
src/axolotl/utils/schemas/quantization.py (2)
PTQConfig(41-64)QATConfig(12-38)
🪛 Ruff (0.11.9)
src/axolotl/utils/schemas/config.py
1512-1517: Use a single if statement instead of nested if statements
Combine if statements using and
(SIM102)
⏰ Context from checks skipped due to timeout of 90000ms (9)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.5.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.5.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, 2, true)
- GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.5.1, 2, true)
- GitHub Check: test-axolotl-multigpu (124, 12.4.1, 3.11, 2.6.0, vllm, 2, true)
🔇 Additional comments (3)
src/axolotl/utils/schemas/config.py (3)
47-47: LGTM!The import is correctly added to support the new QAT and PTQ configuration schemas.
95-96: LGTM!The new configuration fields for QAT and PTQ are properly defined as optional fields with appropriate types and defaults.
132-132: LGTM!Good defensive programming to handle cases where
os.cpu_count()returnsNone. The fallback to 1 and cap at 32 are reasonable choices.
This PR adds support for training with QAT, and also supports applying quantization techniques using torchao through a separate
quantizeCLI.Evals:

Training run:

Summary by CodeRabbit
New Features
Bug Fixes
Documentation
Tests
Chores