Skip to content

Conversation

@SalmanMohammadi
Copy link
Contributor

@SalmanMohammadi SalmanMohammadi commented Aug 26, 2025

The QAT API in torchao has recently been updated pytorch/ao#2629.

Summary by CodeRabbit

  • New Features

    • Unified quantization workflow with expanded dtypes (int4, int8, FP8, NVFP4) and streamlined QAT/PTQ APIs.
    • New CLI flag to push quantized models to the Hub with an auto-suffixed model name.
    • Baseten training template now auto-selects GPU type (H100/B200).
  • Documentation

    • Added note explaining Hub name suffixing for quantized models.
  • Examples

    • New and updated Llama 3 QAT configs (e.g., longer context, flash attention).
  • Behavior Changes

    • Tokens-per-second metric disabled by default.
  • Chores

    • Dependency bump to torchao 0.13.0; optional fbgemm-gpu extra and CI usage.
  • Tests

    • E2E tests migrated to new quantization APIs and dtypes; added environment guards.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Aug 26, 2025

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

📝 Walkthrough

Walkthrough

Updates CI matrices to include fbgemm-gpu extras. Adds packaging extra and bumps torchao. Overhauls quantization/QAT API, enums, schemas, and CLI quantize flow with hub push support. Adjusts training defaults and cloud accelerator selection. Expands docs and examples (including NVFP4). Updates and gates e2e tests accordingly.

Changes

Cohort / File(s) Summary
CI workflows
.github/workflows/multi-gpu-e2e.yml, .github/workflows/tests.yml
Set axolotl_extras to fbgemm-gpu for one CUDA 12.8.1/PyTorch 2.8.0 matrix entry; no other matrix changes.
Packaging and dependencies
requirements.txt, setup.py
Bump torchao 0.12.0 -> 0.13.0. Add extras_require entry "fbgemm-gpu": ["fbgemm-gpu-genai>=1.2.0"].
CLI quantization
src/axolotl/cli/quantize.py, src/axolotl/cli/args.py
Add QuantizeCliArgs.hub_model_id. Switch to TorchAO-based get_quantization_config/quantize_model, attach config, load with AutoConfig torch_dtype, support hub push with suffixed model id.
Quantization core and schemas
src/axolotl/utils/quantization.py, src/axolotl/utils/schemas/enums.py, src/axolotl/utils/schemas/quantization.py
Replace TorchIntDType with TorchAOQuantDType; introduce get_quantization_config, quantize_model, convert_qat_model; add quantization_config_to_str; add validate_ao_dtype; NVFP4/Float8 support under version gates; update signatures and behaviors accordingly.
Training and cloud templates
src/axolotl/train.py, src/axolotl/core/training_args_base.py, src/axolotl/cli/cloud/baseten/template/train_sft.py
Change include_tkps default to False. Use convert_qat_model in save path and adjusted logging; import reformat. Cloud template selects accelerator dynamically (h100/b200) with validation.
Docs
docs/quantize.qmd
Add note: when pushing to hub, hub_model_id is suffixed with quantization schema; include example mapping.
Examples
examples/llama-3/3b-qat-fsdp2.yaml, examples/llama-3/3b-qat-fsdp2-nvfp4.yaml
Update QAT example: dataset split, prepared path, disable sample_packing, sequence_len 8192, enable flash attention, adjust fsdp setting, pad token change, add weight_dtype/group_size. Add new NVFP4 QAT config for Llama-3.2-3B with Liger, sequence_len 8192, optimizer/training params.
Tests
tests/e2e/test_quantization.py, tests/e2e/test_qat.py, tests/e2e/utils.py
Migrate tests to new API/enums; add NVFP4 and QAT conversion tests; change QAT weight_dtype to int4 in two tests; add decorators to gate by torch/CUDA capabilities.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Possibly related PRs

Suggested labels

ready to merge

Suggested reviewers

  • djsaunde
  • winglian
  • NanoCode012
✨ Finishing touches
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch qat_migration

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@SalmanMohammadi SalmanMohammadi marked this pull request as ready for review August 29, 2025 16:30
@github-actions
Copy link
Contributor

github-actions bot commented Aug 29, 2025

📖 Documentation Preview: https://68c2a4da72a47b10ec2e66d6--resonant-treacle-0fd729.netlify.app

Deployed on Netlify from commit c4f8e26

@SalmanMohammadi SalmanMohammadi changed the title [WIP] Fix/update QAT Migrate QAT API; fix axolotl quantize for QAT-ed models Aug 29, 2025
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 11

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
src/axolotl/cli/quantize.py (2)

50-63: Prefer cfg.base_model over cfg.output_dir when base_model is not passed.

Falling back to cfg.output_dir can point at a training directory instead of a source model. Use cfg.base_model if available, else cfg.output_dir.

-    model_path = cli_args.get("base_model") or cfg.output_dir
+    model_path = (
+        cli_args.get("base_model")
+        or getattr(cfg, "base_model", None)
+        or cfg.output_dir
+    )

97-102: Remove unsupported kwargs safe_serialization and progressbar.
PreTrainedTokenizerBase.save_pretrained only accepts
(save_directory: str, legacy_format: bool = True, filename_prefix: Optional[str] = None) and will error on unexpected keywords (huggingface.co).

-    tokenizer.save_pretrained(
-        str(Path(output_dir) / "quantized"),
-        safe_serialization=False,
-        progressbar=True,
-        save_jinja_files=cfg.tokenizer_save_jinja_files,
-    )
+    tokenizer.save_pretrained(
+        str(Path(output_dir) / "quantized"),
+        save_jinja_files=cfg.tokenizer_save_jinja_files,
+    )
tests/e2e/test_quantization.py (1)

38-49: Invalid context manager: use device=... when creating Embedding.

with torch.device(dummy_model.device): is not a valid context manager. Pass device= to the layer instead to keep tensors on the model’s device.

-    with torch.device(dummy_model.device):
-        dummy_model.model.embed_tokens = torch.nn.Embedding(
-            dummy_model.model.embed_tokens.weight.shape[0],
-            dummy_model.model.embed_tokens.weight.shape[1],
-            dtype=dummy_model.model.embed_tokens.weight.dtype,
-        )
+    dummy_model.model.embed_tokens = torch.nn.Embedding(
+        dummy_model.model.embed_tokens.weight.shape[0],
+        dummy_model.model.embed_tokens.weight.shape[1],
+        device=dummy_model.device,
+        dtype=dummy_model.model.embed_tokens.weight.dtype,
+    )
🧹 Nitpick comments (6)
setup.py (1)

165-167: Scope the fbgemm extra to Linux.

Avoid accidental installs on macOS/Windows when users opt into the extra.

-    "fbgemm": [
-        "fbgemm-gpu-genai>=1.2.0",
-    ],
+    "fbgemm": [
+        "fbgemm-gpu-genai>=1.2.0 ; sys_platform == 'linux'",
+    ],
src/axolotl/train.py (1)

240-242: Polish the user-facing log message (grammar).

-            "QAT usage note: please ensure you quantize your model fine-tuned using QAT by running `axolotl quantize`"
-            " with the same config which you used for training."
+            "QAT note: after fine-tuning with QAT, run `axolotl quantize` using the same training config."
src/axolotl/cli/quantize.py (1)

26-33: Fix docstring typo.

"Quantizes a model's model's weights" → "Quantizes a model's weights".

-    Quantizes a model's model's weights
+    Quantizes a model's weights
src/axolotl/utils/schemas/quantization.py (2)

17-24: Docstrings outdated; include float8.

Descriptions still list only int4/int8. Update to include float8_e4m3fn to match the enum and validators.

-        description='Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"',
+        description='Fake quantization layout for activations. Valid options: "int4", "int8", "float8_e4m3fn".',
@@
-        description='Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"',
+        description='Fake quantization layout for weights. Valid options: "int4", "int8", "float8_e4m3fn".',

55-63: PTQConfig description mentions deprecated uintX variants.

Remove uintX from the description to reflect the current TorchAOQuantDType.

-        description="Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8",
+        description='Fake quantization layout for weights. Valid options: "int4", "int8", "float8_e4m3fn".',
@@
-        description='Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"',
+        description='Fake quantization layout for activations. Valid options: "int4", "int8", "float8_e4m3fn".',
src/axolotl/utils/quantization.py (1)

53-64: Neutralize “QAT” wording in PTQ builder errors.

This helper is used for both PTQ and QAT; error strings should not suggest QAT-only.

-            "Int4DynamicActivationInt4WeightConfig is not supported by torchao QAT. "
+            "Int4DynamicActivationInt4WeightConfig is not supported by torchao. "
@@
-            "Int8DynamicActivationInt8WeightConfig is not supported by torchao QAT. "
+            "Int8DynamicActivationInt8WeightConfig is not supported by torchao. "
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 5b6ec28 and 1cada45.

📒 Files selected for processing (9)
  • requirements-dev.txt (1 hunks)
  • requirements.txt (1 hunks)
  • setup.py (1 hunks)
  • src/axolotl/cli/quantize.py (4 hunks)
  • src/axolotl/train.py (3 hunks)
  • src/axolotl/utils/quantization.py (5 hunks)
  • src/axolotl/utils/schemas/enums.py (1 hunks)
  • src/axolotl/utils/schemas/quantization.py (3 hunks)
  • tests/e2e/test_quantization.py (8 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/axolotl/utils/schemas/quantization.py (1)
src/axolotl/utils/schemas/enums.py (1)
  • TorchAOQuantDType (8-11)
tests/e2e/test_quantization.py (4)
src/axolotl/utils/callbacks/qat.py (1)
  • QATCallback (33-48)
src/axolotl/utils/schemas/enums.py (1)
  • TorchAOQuantDType (8-11)
src/axolotl/utils/schemas/quantization.py (1)
  • QATConfig (12-47)
tests/e2e/utils.py (1)
  • require_torch_2_6_0 (69-78)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
  • GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
  • GitHub Check: PyTest (3.11, 2.8.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, vllm, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
  • GitHub Check: test-axolotl-multigpu (128, 12.8.1, 3.11, 2.8.0, 2, true)
🔇 Additional comments (5)
src/axolotl/train.py (4)

33-33: Import reorder is fine.


233-233: Rename alignment to convert_qat_model looks correct.


235-238: Ensure we convert the exact model instance that is saved (trainer.model) or confirm aliasing.

If trainer.model wraps/aliases model, this is fine; otherwise convert the trainer’s model to avoid divergence with the saved state.

Suggested change if they’re not the same object:

-        convert_qat_model(
-            model,
+        convert_qat_model(
+            trainer.model,
             quantize_embedding=cfg.qat.quantize_embedding,
         )

336-336: Import consolidation LGTM.

src/axolotl/utils/quantization.py (1)

142-151: Propagate optional group_size through to config builder.

No code change needed if you adopt the signature change above; this is a heads-up to ensure None is accepted for float8 combos.

@SalmanMohammadi SalmanMohammadi changed the title Migrate QAT API; fix axolotl quantize for QAT-ed models [WIP] Migrate QAT API; fix axolotl quantize for QAT-ed models Aug 30, 2025
@SalmanMohammadi SalmanMohammadi marked this pull request as draft August 30, 2025 10:34
@SalmanMohammadi SalmanMohammadi changed the title [WIP] Migrate QAT API; fix axolotl quantize for QAT-ed models; add NVFP4 Migrate QAT API; fix axolotl quantize for QAT-ed models; add NVFP4 Sep 8, 2025
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
docs/quantize.qmd (1)

12-16: Fix punctuation and wording in the callout

Minor doc nit: add a space after the comma and simplify the sentence.

-We do not currently support quantization techniques such as GGUF/GPTQ,EXL2 at the moment.
+We do not currently support quantization formats such as GGUF, GPTQ, or EXL2.
src/axolotl/cli/quantize.py (2)

98-103: Invalid kwargs to save_pretrained

progressbar isn’t a supported kwarg for PreTrainedModel.save_pretrained on many releases. Remove it to prevent runtime errors.

-    model.save_pretrained(
-        str(Path(output_dir) / "quantized"),
-        safe_serialization=False,
-        progressbar=True,
-    )
+    model.save_pretrained(
+        str(Path(output_dir) / "quantized"),
+        safe_serialization=False,
+    )

103-108: Tokenizer.save_pretrained: drop unsupported args

Tokenizer save_pretrained does not accept safe_serialization or progressbar.

-    tokenizer.save_pretrained(
-        str(Path(output_dir) / "quantized"),
-        safe_serialization=False,
-        progressbar=True,
-        save_jinja_files=cfg.tokenizer_save_jinja_files,
-    )
+    tokenizer.save_pretrained(
+        str(Path(output_dir) / "quantized"),
+        save_jinja_files=cfg.tokenizer_save_jinja_files,
+    )
tests/e2e/test_quantization.py (1)

47-53: Fix invalid device context manager; ensure replacement embedding is created on the right device.

torch.device(...) is not a context manager; this will raise at runtime. Create the new Embedding directly on the original weight’s device.

-    with torch.device(dummy_model.device):
-        dummy_model.model.embed_tokens = torch.nn.Embedding(
-            dummy_model.model.embed_tokens.weight.shape[0],
-            dummy_model.model.embed_tokens.weight.shape[1],
-            dtype=dummy_model.model.embed_tokens.weight.dtype,
-        )
+    embed_w = dummy_model.model.embed_tokens.weight
+    dummy_model.model.embed_tokens = torch.nn.Embedding(
+        embed_w.shape[0],
+        embed_w.shape[1],
+        dtype=embed_w.dtype,
+        device=embed_w.device,
+    )
♻️ Duplicate comments (5)
requirements.txt (1)

67-67: torchao 0.13.0 bump LGTM; resolves prior VCS URL concern.

Good move switching from a VCS ref to the released package and pinning to 0.13.0.

Given CI tests Torch 2.6.0/2.7.1/2.8.0, double-check torchao 0.13.0 compatibility across that matrix:

#!/bin/bash
# Verify torchao 0.13.0 imports with the three torch versions.
set -euo pipefail
for TORCH in 2.6.0 2.7.1 2.8.0; do
  python - <<PY
import subprocess, sys
subprocess.check_call([sys.executable, "-m", "pip", "install", f"torch==${TORCH}", "torchao==0.13.0", "-q"])
__import__("torchao")
print("OK torch", "${TORCH}")
PY
done
src/axolotl/utils/schemas/enums.py (1)

8-12: Make dtype enum string-backed with a resolver; avoid mixing torch.dtypes and strings.

Storing torch.int4/float8 directly can be None on some builds; mixing with "nvfp4" (str) also breaks serialization/validation. Prefer a string-backed enum + resolve().

-class TorchAOQuantDType(Enum):
-    int4 = torch.int4
-    int8 = torch.int8
-    float8_e4m3fn = torch.float8_e4m3fn
-    nvfp4 = "nvfp4"
+class TorchAOQuantDType(str, Enum):
+    int4 = "int4"
+    int8 = "int8"
+    float8_e4m3fn = "float8_e4m3fn"
+    nvfp4 = "nvfp4"
+
+    @classmethod
+    def from_string(cls, s: str) -> "TorchAOQuantDType":
+        try:
+            return cls(s)
+        except ValueError as e:
+            raise ValueError(f"Unknown TorchAOQuantDType: {s!r}") from e
+
+    def resolve(self) -> torch.dtype:
+        dt = getattr(torch, self.value, None)
+        if dt is None:
+            raise RuntimeError(
+                f"Requested quant dtype {self.value} is not supported by this torch build."
+            )
+        return dt

Follow-up: ensure downstream code calls .resolve() right before interacting with torch APIs, and never persists raw torch.dtypes in configs.

src/axolotl/cli/quantize.py (2)

67-73: Normalize hf config torch_dtype before passing to from_pretrained

AutoConfig.torch_dtype can be a string (e.g., "bfloat16"); pass an actual torch.dtype.

+    import torch
@@
-    config = AutoConfig.from_pretrained(model_path)
-    torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None
+    hf_config = AutoConfig.from_pretrained(model_path)
+    torch_dtype = getattr(hf_config, "torch_dtype", None)
+    if isinstance(torch_dtype, str):
+        torch_dtype = getattr(torch, torch_dtype, None)

91-96: TorchAoConfig kwargs: drop include_input_output_embeddings

Embedding quantization is already controlled in quantize_model. This kwarg isn’t universally supported and can break on some transformers versions.

-    ao_config = TorchAoConfig(
-        quant_type=quantization_config,
-        include_input_output_embeddings=quantize_embedding,
-    )
+    ao_config = TorchAoConfig(quant_type=quantization_config)
src/axolotl/utils/quantization.py (1)

64-71: Require explicit, positive group_size for int4 weight-only; don’t default to -1.

Tests expect a ValueError when group_size is missing for int4 WO. Defaulting to -1 defers the error and is ambiguous.

-        if weight_dtype == TorchAOQuantDType.int4:
-            from torchao.quantization.quant_api import Int4WeightOnlyConfig
-
-            return Int4WeightOnlyConfig(group_size=group_size or -1, version=2)
+        if weight_dtype == TorchAOQuantDType.int4:
+            from torchao.quantization.quant_api import Int4WeightOnlyConfig
+            if not group_size or group_size <= 0:
+                raise ValueError("group_size must be a positive int for int4 weight-only quantization.")
+            return Int4WeightOnlyConfig(group_size=group_size, version=2)
🧹 Nitpick comments (18)
src/axolotl/core/training_args_base.py (1)

52-57: Confirm intent: include_tkps default flipped to False.

This changes default training metrics output. If intentional, note it in release notes and CLI/docs to avoid surprises in dashboards.

src/axolotl/cli/args.py (1)

118-119: Add help metadata for hub_model_id.

Keep CLI help consistent with other fields and explain the suffixing behavior.

-    hub_model_id: Optional[str] = field(default=None)
+    hub_model_id: Optional[str] = field(
+        default=None,
+        metadata={
+            "help": (
+                "If set, push quantized artifacts to the Hub under this model ID. "
+                "A suffix indicating the quantization scheme is appended automatically."
+            )
+        },
+    )
tests/e2e/utils.py (3)

93-103: Fix docstring to match the check (2.8.0).

The docstring says 2.7.0 but the predicate enforces 2.8.0.

-    Decorator marking a test that requires torch >= 2.7.0
+    Decorator marking a test that requires torch >= 2.8.0

143-150: Message clarity: target is compute capability (SM) ≥ 10.0.

Skip reason is clear; consider guarding get_device_capability() behind is_available (already done) — good.

-    return unittest.skipUnless(is_sm_ge_100, "test requires sm>=100")(test_case)
+    return unittest.skipUnless(is_sm_ge_100, "test requires compute capability (SM) >= 10.0")(test_case)

152-159: Name/message mismatch: function checks SM 8.9, not CUDA version.

Either rename to requires_sm_ge_89 or adjust the message to avoid “cuda>=8.9” ambiguity.

-    return unittest.skipUnless(is_cuda_ge_8_9, "test requires cuda>=8.9")(test_case)
+    return unittest.skipUnless(
+        is_cuda_ge_8_9, "test requires compute capability (SM) >= 8.9"
+    )(test_case)
tests/e2e/test_qat.py (2)

43-49: Switch to int4: confirm group_size=8 is valid for your torchao path

Changing qat.weight_dtype to "int4" is fine; please confirm group size 8 is supported for this config to avoid hidden dequant drift vs training. If not, bump to 32 (or the library’s required value).

You may also set a fixed seed in cfg (e.g., "seed": 1) to reduce flakiness across CI runs.


111-116: DPO QAT config mirrors the above—same group_size validation applies

Same note as above re: "int4" + group_size: 8. Ensure this combo is accepted by torchao for DPO.

docs/quantize.qmd (1)

56-61: Clarify suffix mapping and scope

Good callout. Consider adding a brief note that the suffix reflects the quantization schema (e.g., weight dtype and layout) and is only applied when pushing to the Hub, not to local directories.

examples/llama-3/3b-qat-fsdp2-nvfp4.yaml (2)

26-28: Consider explicitly disabling sample packing

To mirror the other example and avoid surprises with long context + FA, explicitly set sample_packing: false.

 sequence_len: 8192
 flash_attention: true
+sample_packing: false

42-51: Batch size at seq_len 8192: validate memory headroom

micro_batch_size: 64 at 8k context is aggressive even on large GPUs. If this example is meant to “just work,” consider a safer default (e.g., 4–8) and annotate that users can scale up.

examples/llama-3/3b-qat-fsdp2.yaml (1)

68-69: Toggle of fsdp_cpu_ram_efficient_loading: double-check tradeoff

You flipped this to false. Confirm this doesn’t regress load-time memory on CPU-constrained runners; if it was to fix perf, consider a short comment explaining why.

src/axolotl/cli/quantize.py (2)

97-97: Ruff RUF010: prefer conversion flag in f-strings

Use !s instead of wrapping in str(...).

-    LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}.")
+    LOG.info(f"Saving quantized model to: {(Path(output_dir) / 'quantized')!s}.")

119-119: Ruff RUF010: same nit here

-    LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}.")
+    LOG.info(f"Quantized model saved to: {(Path(output_dir) / 'quantized')!s}.")
tests/e2e/test_quantization.py (2)

300-303: Assertion doesn’t verify “quantized weights” as the comment claims.

isinstance(..., nn.Parameter) is always true pre/post, so it doesn’t prove conversion or quantization.

Consider asserting the post-conversion tensor type (e.g., AffineQuantizedTensor/LinearActivationQuantizedTensor or NVFP4 types if applicable), or explicitly verify fake-quant modules were removed and expected PTQ tensor subclasses are present. If the intended behavior after convert_qat_model is “de-fake-quant to standard modules with quantized tensors,” align the assertion accordingly.


180-191: Optional: broaden Linear checks beyond top-level children.

Iterating model.children() only catches top-level modules (often just lm_head). If you intend to validate deeper linears inside the backbone, iterate model.modules() and skip containers.

-for child in list(model.children()):
+for child in model.modules():
     if isinstance(child, torch.nn.Linear):
         ...
src/axolotl/utils/quantization.py (3)

26-42: Avoid bare except; narrow the import guards.

Use specific exceptions for optional imports to satisfy linting and avoid masking unrelated errors.

-    try:
+    try:
         from torchao.prototype.mx_formats import NVFP4InferenceConfig
-
-        quantization_config_to_str[NVFP4InferenceConfig] = "nvfp4"
-    except:
-        pass
+        quantization_config_to_str[NVFP4InferenceConfig] = "nvfp4"
+    except (ImportError, OSError):
+        pass
@@
-    try:
+    try:
         from torchao.quantization.quant_api import Int4WeightOnlyConfig
-
-        quantization_config_to_str[Int4WeightOnlyConfig] = "int4"
-    except:
-        pass
+        quantization_config_to_str[Int4WeightOnlyConfig] = "int4"
+    except (ImportError, OSError):
+        pass

66-66: Clarify error message; avoid QAT-specific wording in a shared builder.

get_quantization_config is used for PTQ and as a base for QAT. Rephrase to avoid confusion.

-            raise ValueError("Int8WeightOnlyConfig is not supported by torchao QAT.")
+            raise ValueError("Int8 weight-only quantization is not supported in this flow.")

49-63: Docstring scope nit: mention both PTQ and QAT base configs.

This function also returns base configs wrapped by QAT later; adjust wording for accuracy.

-    This function is used to build a post-training quantization config.
+    Build a quantization config (used directly for PTQ, or as the base for QAT).
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b5d4c7f and d223aeb.

📒 Files selected for processing (18)
  • .github/workflows/multi-gpu-e2e.yml (1 hunks)
  • .github/workflows/tests.yml (1 hunks)
  • docs/quantize.qmd (1 hunks)
  • examples/llama-3/3b-qat-fsdp2-nvfp4.yaml (1 hunks)
  • examples/llama-3/3b-qat-fsdp2.yaml (3 hunks)
  • requirements.txt (1 hunks)
  • setup.py (1 hunks)
  • src/axolotl/cli/args.py (1 hunks)
  • src/axolotl/cli/cloud/baseten/template/train_sft.py (1 hunks)
  • src/axolotl/cli/quantize.py (4 hunks)
  • src/axolotl/core/training_args_base.py (1 hunks)
  • src/axolotl/train.py (3 hunks)
  • src/axolotl/utils/quantization.py (5 hunks)
  • src/axolotl/utils/schemas/enums.py (1 hunks)
  • src/axolotl/utils/schemas/quantization.py (3 hunks)
  • tests/e2e/test_qat.py (2 hunks)
  • tests/e2e/test_quantization.py (7 hunks)
  • tests/e2e/utils.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (5)
src/axolotl/cli/quantize.py (4)
src/axolotl/loaders/tokenizer.py (1)
  • load_tokenizer (122-308)
src/axolotl/utils/quantization.py (2)
  • get_quantization_config (44-108)
  • quantize_model (111-146)
src/axolotl/utils/schemas/enums.py (1)
  • TorchAOQuantDType (8-12)
src/axolotl/core/trainers/base.py (1)
  • push_to_hub (527-537)
src/axolotl/train.py (4)
src/axolotl/loaders/processor.py (1)
  • load_processor (17-56)
src/axolotl/loaders/model.py (1)
  • ModelLoader (67-889)
src/axolotl/utils/quantization.py (1)
  • convert_qat_model (193-207)
src/axolotl/integrations/llm_compressor/utils.py (1)
  • save_compressed_model (11-40)
src/axolotl/utils/schemas/quantization.py (1)
src/axolotl/utils/schemas/enums.py (1)
  • TorchAOQuantDType (8-12)
tests/e2e/test_quantization.py (5)
src/axolotl/utils/callbacks/qat.py (1)
  • QATCallback (33-48)
src/axolotl/utils/quantization.py (4)
  • get_quantization_config (44-108)
  • prepare_model_for_qat (149-190)
  • convert_qat_model (193-207)
  • quantize_model (111-146)
src/axolotl/utils/schemas/enums.py (1)
  • TorchAOQuantDType (8-12)
src/axolotl/utils/schemas/quantization.py (1)
  • QATConfig (26-53)
tests/e2e/utils.py (3)
  • require_torch_2_8_0 (93-102)
  • requires_sm_ge_100 (143-149)
  • requires_cuda_ge_8_9 (152-158)
src/axolotl/utils/quantization.py (3)
src/axolotl/utils/schemas/quantization.py (1)
  • QATConfig (26-53)
src/axolotl/utils/schemas/enums.py (1)
  • TorchAOQuantDType (8-12)
tests/e2e/test_quantization.py (1)
  • model (41-54)
🪛 Ruff (0.12.2)
src/axolotl/cli/cloud/baseten/template/train_sft.py

53-55: Avoid specifying long messages outside the exception class

(TRY003)

src/axolotl/cli/quantize.py

97-97: Use explicit conversion flag

Replace with conversion flag

(RUF010)


119-119: Use explicit conversion flag

Replace with conversion flag

(RUF010)

src/axolotl/utils/schemas/quantization.py

21-23: Avoid specifying long messages outside the exception class

(TRY003)

tests/e2e/test_quantization.py

138-138: Use of assert detected

(S101)


146-146: Use of assert detected

(S101)


207-207: Use of assert detected

(S101)


208-208: Use of assert detected

(S101)


241-241: Use of assert detected

(S101)


242-242: Use of assert detected

(S101)


243-243: Use of assert detected

(S101)


248-248: Use of assert detected

(S101)


255-255: Use of assert detected

(S101)


256-256: Use of assert detected

(S101)


257-257: Use of assert detected

(S101)


259-259: Use of assert detected

(S101)


261-261: Use of assert detected

(S101)


262-262: Use of assert detected

(S101)


267-267: Use of assert detected

(S101)


288-288: Use of assert detected

(S101)


289-289: Use of assert detected

(S101)


297-297: Use of assert detected

(S101)


298-298: Use of assert detected

(S101)


301-301: Use of assert detected

(S101)


302-302: Use of assert detected

(S101)

src/axolotl/utils/quantization.py

31-31: Do not use bare except

(E722)


31-32: try-except-pass detected, consider logging the exception

(S110)


40-40: Do not use bare except

(E722)


40-41: try-except-pass detected, consider logging the exception

(S110)


66-66: Avoid specifying long messages outside the exception class

(TRY003)


75-77: Avoid specifying long messages outside the exception class

(TRY003)


82-84: Avoid specifying long messages outside the exception class

(TRY003)


104-104: Avoid specifying long messages outside the exception class

(TRY003)


106-108: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (8)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
  • GitHub Check: preview
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.8.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.6.0)
🔇 Additional comments (17)
setup.py (1)

165-166: Extras entry looks good; verify runner compatibility.

The new "fbgemm-gpu" extra maps to fbgemm-gpu-genai>=1.2.0. Ensure the GPU runners used in CI (B200/H100) meet the runtime constraints for that wheel.

src/axolotl/cli/cloud/baseten/template/train_sft.py (1)

60-60: LGTM after accelerator normalization.

Using the resolved accelerator in AcceleratorSpec is correct.

tests/e2e/utils.py (1)

16-17: LGTM: import relocation is harmless.

.github/workflows/tests.yml (1)

307-307: No references to AXOLOTL_EXTRAS or install logic found in cicd/e2e_tests.py. Verify if this file delegates to Docker builds (which already handle extras) or manually installs axolotl[$AXOLOTL_EXTRAS]; otherwise, the extras value may never be applied.

.github/workflows/multi-gpu-e2e.yml (1)

47-49: fbgemm-gpu extra: check runner compatibility and Modal wiring

Nice addition. Please verify the Modal image/env actually installs the fbgemm-gpu extra for CUDA 12.8 / PyTorch 2.8.0 and that cicd.multigpu respects AXOLOTL_EXTRAS; otherwise this matrix leg may silently run without the intended kernels.

Optionally add a quick echo for visibility:

       - name: Update env vars
         run: |
@@
           echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
           echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
+      - name: Show selected extras
+        run: echo "Using AXOLOTL_EXTRAS=$AXOLOTL_EXTRAS"
src/axolotl/train.py (3)

233-242: QAT conversion path looks correct

Switch to convert_qat_model(...) and the updated user-facing log is appropriate for the new flow.


335-336: Import style cleanup—LGTM

No functional change; matches surrounding style.


33-33: Re-export of loaders verified
src/axolotl/loaders/init.py already imports ModelLoader, load_processor, and load_tokenizer; no changes needed.

examples/llama-3/3b-qat-fsdp2-nvfp4.yaml (1)

29-33: NVFP4 config: LGTM; group_size=16 constraint captured

Config matches NVFP4 expectations.

examples/llama-3/3b-qat-fsdp2.yaml (3)

27-30: Flash attention + long context: good; packing disabled explicitly

Settings look consistent for 8k context training.


31-35: QAT block alignment

weight_dtype: int4 with group_size: 32 is a sensible default for common 4-bit paths.


76-78: Pad token change: ensure tokenizer contains it or will be added

Switching to <|finetune_right_pad_id|> is fine; verify load_tokenizer adds it and that embeddings are resized before training starts (it usually does).

src/axolotl/cli/quantize.py (1)

110-118: Hub suffixing: good; ensure mapping covers all AO configs

Suffix construction via quantization_config_to_str[type(quantization_config)] looks fine. Please ensure the mapping contains every config your CLI can emit, otherwise you’ll KeyError at push time.

src/axolotl/utils/schemas/quantization.py (2)

50-54: LGTM: single, centralized dtype validator.

Consolidating per-field validators through validate_ao_dtype is clean and reduces duplication.


77-81: LGTM: mirrored validation for PTQConfig.

Consistent behavior across QAT/PTQ schemas.

src/axolotl/utils/quantization.py (2)

149-191: LGTM: QAT prep embeds base config inside torchao.QATConfig and filters Embedding separately.

The signature and embedding path align with tests and the newer API.


193-208: LGTM: simple conversion step using QATConfig(step="convert").

Matches the intended swap-out of fake-quant modules.

@winglian
Copy link
Collaborator

winglian commented Sep 9, 2025

@SalmanMohammadi does this supersede #3119 ?

@andrewor14
Copy link

@SalmanMohammadi does this supersede #3119 ?

Yep just closed #3119

Copy link
Collaborator

@winglian winglian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mostly have a bunch of follow ups on previous comments to make sure they are addressed. Otherwise seems good to me.

Float8DynamicActivationInt4WeightConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt8WeightConfig,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we get a resolution on this?

and weight_dtype == TorchAOQuantDType.int8
):
raise ValueError(
"Int8DynamicActivationInt8WeightConfig is not supported by torchao QAT."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any updates on this or will it need to be addressed in an ao patch release?

@winglian
Copy link
Collaborator

2.8.0 still failing on:

>       to_quant = torch.split(x.to(torch.float), group_size, dim=-1)
                               ^^^^^^^^^^^^^^^^^
E       torch.AcceleratorError: CUDA error: no kernel image is available for execution on the device
E       CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
E       For debugging consider passing CUDA_LAUNCH_BLOCKING=1
E       Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

@SalmanMohammadi
Copy link
Contributor Author

2.8.0 still failing on:

>       to_quant = torch.split(x.to(torch.float), group_size, dim=-1)
                               ^^^^^^^^^^^^^^^^^
E       torch.AcceleratorError: CUDA error: no kernel image is available for execution on the device
E       CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
E       For debugging consider passing CUDA_LAUNCH_BLOCKING=1
E       Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

I'm able to manually run this on a B200 instance. I'm going to exclude this test for now and come back to it.

root@9a09ddf52b98:/workspace/axolotl# python
Python 3.11.13 (main, Jun  5 2025, 13:12:00) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from axolotl.utils.schemas.enums import TorchAOQuantDType
>>> import torch
>>> weight_dtype = TorchAOQuantDType.int4
>>> act_dtype = TorchAOQuantDType.float8_e4m3fn
>>> from axolotl.utils.quantization import quantize_model
TMA benchmarks will be running without grid constant TMA descriptor.
>>> model = torch.nn.Linear(2048, 2048)
>>> quantize_model(model, weight_dtype, None, act_dtype)

@SalmanMohammadi SalmanMohammadi merged commit 58d67bf into main Sep 12, 2025
24 of 25 checks passed
@SalmanMohammadi SalmanMohammadi deleted the qat_migration branch September 12, 2025 09:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants