-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Migrate QAT API; fix axolotl quantize for QAT-ed models; add NVFP4
#3107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the 📝 WalkthroughWalkthroughUpdates CI matrices to include fbgemm-gpu extras. Adds packaging extra and bumps torchao. Overhauls quantization/QAT API, enums, schemas, and CLI quantize flow with hub push support. Adjusts training defaults and cloud accelerator selection. Expands docs and examples (including NVFP4). Updates and gates e2e tests accordingly. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Suggested labels
Suggested reviewers
✨ Finishing touches🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
📖 Documentation Preview: https://68c2a4da72a47b10ec2e66d6--resonant-treacle-0fd729.netlify.app Deployed on Netlify from commit c4f8e26 |
axolotl quantize for QAT-ed models
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 11
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
src/axolotl/cli/quantize.py (2)
50-63: Prefer cfg.base_model over cfg.output_dir when base_model is not passed.Falling back to cfg.output_dir can point at a training directory instead of a source model. Use cfg.base_model if available, else cfg.output_dir.
- model_path = cli_args.get("base_model") or cfg.output_dir + model_path = ( + cli_args.get("base_model") + or getattr(cfg, "base_model", None) + or cfg.output_dir + )
97-102: Remove unsupported kwargssafe_serializationandprogressbar.
PreTrainedTokenizerBase.save_pretrainedonly accepts
(save_directory: str, legacy_format: bool = True, filename_prefix: Optional[str] = None)and will error on unexpected keywords (huggingface.co).- tokenizer.save_pretrained( - str(Path(output_dir) / "quantized"), - safe_serialization=False, - progressbar=True, - save_jinja_files=cfg.tokenizer_save_jinja_files, - ) + tokenizer.save_pretrained( + str(Path(output_dir) / "quantized"), + save_jinja_files=cfg.tokenizer_save_jinja_files, + )tests/e2e/test_quantization.py (1)
38-49: Invalid context manager: use device=... when creating Embedding.
with torch.device(dummy_model.device):is not a valid context manager. Passdevice=to the layer instead to keep tensors on the model’s device.- with torch.device(dummy_model.device): - dummy_model.model.embed_tokens = torch.nn.Embedding( - dummy_model.model.embed_tokens.weight.shape[0], - dummy_model.model.embed_tokens.weight.shape[1], - dtype=dummy_model.model.embed_tokens.weight.dtype, - ) + dummy_model.model.embed_tokens = torch.nn.Embedding( + dummy_model.model.embed_tokens.weight.shape[0], + dummy_model.model.embed_tokens.weight.shape[1], + device=dummy_model.device, + dtype=dummy_model.model.embed_tokens.weight.dtype, + )
🧹 Nitpick comments (6)
setup.py (1)
165-167: Scope the fbgemm extra to Linux.Avoid accidental installs on macOS/Windows when users opt into the extra.
- "fbgemm": [ - "fbgemm-gpu-genai>=1.2.0", - ], + "fbgemm": [ + "fbgemm-gpu-genai>=1.2.0 ; sys_platform == 'linux'", + ],src/axolotl/train.py (1)
240-242: Polish the user-facing log message (grammar).- "QAT usage note: please ensure you quantize your model fine-tuned using QAT by running `axolotl quantize`" - " with the same config which you used for training." + "QAT note: after fine-tuning with QAT, run `axolotl quantize` using the same training config."src/axolotl/cli/quantize.py (1)
26-33: Fix docstring typo."Quantizes a model's model's weights" → "Quantizes a model's weights".
- Quantizes a model's model's weights + Quantizes a model's weightssrc/axolotl/utils/schemas/quantization.py (2)
17-24: Docstrings outdated; include float8.Descriptions still list only int4/int8. Update to include float8_e4m3fn to match the enum and validators.
- description='Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"', + description='Fake quantization layout for activations. Valid options: "int4", "int8", "float8_e4m3fn".', @@ - description='Fake quantization layout to use for weight quantization. Valid options are "int4" and "int8"', + description='Fake quantization layout for weights. Valid options: "int4", "int8", "float8_e4m3fn".',
55-63: PTQConfig description mentions deprecated uintX variants.Remove uintX from the description to reflect the current TorchAOQuantDType.
- description="Fake quantization layout to use for weight quantization. Valid options are uintX for X in [1, 2, 3, 4, 5, 6, 7], or int4, or int8", + description='Fake quantization layout for weights. Valid options: "int4", "int8", "float8_e4m3fn".', @@ - description='Fake quantization layout to use for activation quantization. Valid options are "int4" and "int8"', + description='Fake quantization layout for activations. Valid options: "int4", "int8", "float8_e4m3fn".',src/axolotl/utils/quantization.py (1)
53-64: Neutralize “QAT” wording in PTQ builder errors.This helper is used for both PTQ and QAT; error strings should not suggest QAT-only.
- "Int4DynamicActivationInt4WeightConfig is not supported by torchao QAT. " + "Int4DynamicActivationInt4WeightConfig is not supported by torchao. " @@ - "Int8DynamicActivationInt8WeightConfig is not supported by torchao QAT. " + "Int8DynamicActivationInt8WeightConfig is not supported by torchao. "
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (9)
requirements-dev.txt(1 hunks)requirements.txt(1 hunks)setup.py(1 hunks)src/axolotl/cli/quantize.py(4 hunks)src/axolotl/train.py(3 hunks)src/axolotl/utils/quantization.py(5 hunks)src/axolotl/utils/schemas/enums.py(1 hunks)src/axolotl/utils/schemas/quantization.py(3 hunks)tests/e2e/test_quantization.py(8 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/axolotl/utils/schemas/quantization.py (1)
src/axolotl/utils/schemas/enums.py (1)
TorchAOQuantDType(8-11)
tests/e2e/test_quantization.py (4)
src/axolotl/utils/callbacks/qat.py (1)
QATCallback(33-48)src/axolotl/utils/schemas/enums.py (1)
TorchAOQuantDType(8-11)src/axolotl/utils/schemas/quantization.py (1)
QATConfig(12-47)tests/e2e/utils.py (1)
require_torch_2_6_0(69-78)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
- GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
- GitHub Check: PyTest (3.11, 2.8.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, vllm, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
- GitHub Check: test-axolotl-multigpu (128, 12.8.1, 3.11, 2.8.0, 2, true)
🔇 Additional comments (5)
src/axolotl/train.py (4)
33-33: Import reorder is fine.
233-233: Rename alignment to convert_qat_model looks correct.
235-238: Ensure we convert the exact model instance that is saved (trainer.model) or confirm aliasing.If
trainer.modelwraps/aliasesmodel, this is fine; otherwise convert the trainer’s model to avoid divergence with the saved state.Suggested change if they’re not the same object:
- convert_qat_model( - model, + convert_qat_model( + trainer.model, quantize_embedding=cfg.qat.quantize_embedding, )
336-336: Import consolidation LGTM.src/axolotl/utils/quantization.py (1)
142-151: Propagate optional group_size through to config builder.No code change needed if you adopt the signature change above; this is a heads-up to ensure None is accepted for float8 combos.
axolotl quantize for QAT-ed modelsaxolotl quantize for QAT-ed models
axolotl quantize for QAT-ed models; add NVFP4axolotl quantize for QAT-ed models; add NVFP4
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
docs/quantize.qmd (1)
12-16: Fix punctuation and wording in the calloutMinor doc nit: add a space after the comma and simplify the sentence.
-We do not currently support quantization techniques such as GGUF/GPTQ,EXL2 at the moment. +We do not currently support quantization formats such as GGUF, GPTQ, or EXL2.src/axolotl/cli/quantize.py (2)
98-103: Invalid kwargs to save_pretrained
progressbarisn’t a supported kwarg forPreTrainedModel.save_pretrainedon many releases. Remove it to prevent runtime errors.- model.save_pretrained( - str(Path(output_dir) / "quantized"), - safe_serialization=False, - progressbar=True, - ) + model.save_pretrained( + str(Path(output_dir) / "quantized"), + safe_serialization=False, + )
103-108: Tokenizer.save_pretrained: drop unsupported argsTokenizer
save_pretraineddoes not acceptsafe_serializationorprogressbar.- tokenizer.save_pretrained( - str(Path(output_dir) / "quantized"), - safe_serialization=False, - progressbar=True, - save_jinja_files=cfg.tokenizer_save_jinja_files, - ) + tokenizer.save_pretrained( + str(Path(output_dir) / "quantized"), + save_jinja_files=cfg.tokenizer_save_jinja_files, + )tests/e2e/test_quantization.py (1)
47-53: Fix invalid device context manager; ensure replacement embedding is created on the right device.
torch.device(...)is not a context manager; this will raise at runtime. Create the newEmbeddingdirectly on the original weight’s device.- with torch.device(dummy_model.device): - dummy_model.model.embed_tokens = torch.nn.Embedding( - dummy_model.model.embed_tokens.weight.shape[0], - dummy_model.model.embed_tokens.weight.shape[1], - dtype=dummy_model.model.embed_tokens.weight.dtype, - ) + embed_w = dummy_model.model.embed_tokens.weight + dummy_model.model.embed_tokens = torch.nn.Embedding( + embed_w.shape[0], + embed_w.shape[1], + dtype=embed_w.dtype, + device=embed_w.device, + )
♻️ Duplicate comments (5)
requirements.txt (1)
67-67: torchao 0.13.0 bump LGTM; resolves prior VCS URL concern.Good move switching from a VCS ref to the released package and pinning to 0.13.0.
Given CI tests Torch 2.6.0/2.7.1/2.8.0, double-check torchao 0.13.0 compatibility across that matrix:
#!/bin/bash # Verify torchao 0.13.0 imports with the three torch versions. set -euo pipefail for TORCH in 2.6.0 2.7.1 2.8.0; do python - <<PY import subprocess, sys subprocess.check_call([sys.executable, "-m", "pip", "install", f"torch==${TORCH}", "torchao==0.13.0", "-q"]) __import__("torchao") print("OK torch", "${TORCH}") PY donesrc/axolotl/utils/schemas/enums.py (1)
8-12: Make dtype enum string-backed with a resolver; avoid mixing torch.dtypes and strings.Storing torch.int4/float8 directly can be None on some builds; mixing with "nvfp4" (str) also breaks serialization/validation. Prefer a string-backed enum + resolve().
-class TorchAOQuantDType(Enum): - int4 = torch.int4 - int8 = torch.int8 - float8_e4m3fn = torch.float8_e4m3fn - nvfp4 = "nvfp4" +class TorchAOQuantDType(str, Enum): + int4 = "int4" + int8 = "int8" + float8_e4m3fn = "float8_e4m3fn" + nvfp4 = "nvfp4" + + @classmethod + def from_string(cls, s: str) -> "TorchAOQuantDType": + try: + return cls(s) + except ValueError as e: + raise ValueError(f"Unknown TorchAOQuantDType: {s!r}") from e + + def resolve(self) -> torch.dtype: + dt = getattr(torch, self.value, None) + if dt is None: + raise RuntimeError( + f"Requested quant dtype {self.value} is not supported by this torch build." + ) + return dtFollow-up: ensure downstream code calls
.resolve()right before interacting with torch APIs, and never persists raw torch.dtypes in configs.src/axolotl/cli/quantize.py (2)
67-73: Normalize hf config torch_dtype before passing to from_pretrained
AutoConfig.torch_dtypecan be a string (e.g., "bfloat16"); pass an actualtorch.dtype.+ import torch @@ - config = AutoConfig.from_pretrained(model_path) - torch_dtype = config.torch_dtype if hasattr(config, "torch_dtype") else None + hf_config = AutoConfig.from_pretrained(model_path) + torch_dtype = getattr(hf_config, "torch_dtype", None) + if isinstance(torch_dtype, str): + torch_dtype = getattr(torch, torch_dtype, None)
91-96: TorchAoConfig kwargs: drop include_input_output_embeddingsEmbedding quantization is already controlled in
quantize_model. This kwarg isn’t universally supported and can break on some transformers versions.- ao_config = TorchAoConfig( - quant_type=quantization_config, - include_input_output_embeddings=quantize_embedding, - ) + ao_config = TorchAoConfig(quant_type=quantization_config)src/axolotl/utils/quantization.py (1)
64-71: Require explicit, positive group_size for int4 weight-only; don’t default to -1.Tests expect a ValueError when
group_sizeis missing for int4 WO. Defaulting to-1defers the error and is ambiguous.- if weight_dtype == TorchAOQuantDType.int4: - from torchao.quantization.quant_api import Int4WeightOnlyConfig - - return Int4WeightOnlyConfig(group_size=group_size or -1, version=2) + if weight_dtype == TorchAOQuantDType.int4: + from torchao.quantization.quant_api import Int4WeightOnlyConfig + if not group_size or group_size <= 0: + raise ValueError("group_size must be a positive int for int4 weight-only quantization.") + return Int4WeightOnlyConfig(group_size=group_size, version=2)
🧹 Nitpick comments (18)
src/axolotl/core/training_args_base.py (1)
52-57: Confirm intent: include_tkps default flipped to False.This changes default training metrics output. If intentional, note it in release notes and CLI/docs to avoid surprises in dashboards.
src/axolotl/cli/args.py (1)
118-119: Add help metadata for hub_model_id.Keep CLI help consistent with other fields and explain the suffixing behavior.
- hub_model_id: Optional[str] = field(default=None) + hub_model_id: Optional[str] = field( + default=None, + metadata={ + "help": ( + "If set, push quantized artifacts to the Hub under this model ID. " + "A suffix indicating the quantization scheme is appended automatically." + ) + }, + )tests/e2e/utils.py (3)
93-103: Fix docstring to match the check (2.8.0).The docstring says 2.7.0 but the predicate enforces 2.8.0.
- Decorator marking a test that requires torch >= 2.7.0 + Decorator marking a test that requires torch >= 2.8.0
143-150: Message clarity: target is compute capability (SM) ≥ 10.0.Skip reason is clear; consider guarding get_device_capability() behind is_available (already done) — good.
- return unittest.skipUnless(is_sm_ge_100, "test requires sm>=100")(test_case) + return unittest.skipUnless(is_sm_ge_100, "test requires compute capability (SM) >= 10.0")(test_case)
152-159: Name/message mismatch: function checks SM 8.9, not CUDA version.Either rename to requires_sm_ge_89 or adjust the message to avoid “cuda>=8.9” ambiguity.
- return unittest.skipUnless(is_cuda_ge_8_9, "test requires cuda>=8.9")(test_case) + return unittest.skipUnless( + is_cuda_ge_8_9, "test requires compute capability (SM) >= 8.9" + )(test_case)tests/e2e/test_qat.py (2)
43-49: Switch to int4: confirm group_size=8 is valid for your torchao pathChanging
qat.weight_dtypeto"int4"is fine; please confirm group size 8 is supported for this config to avoid hidden dequant drift vs training. If not, bump to 32 (or the library’s required value).You may also set a fixed seed in cfg (e.g.,
"seed": 1) to reduce flakiness across CI runs.
111-116: DPO QAT config mirrors the above—same group_size validation appliesSame note as above re:
"int4"+group_size: 8. Ensure this combo is accepted by torchao for DPO.docs/quantize.qmd (1)
56-61: Clarify suffix mapping and scopeGood callout. Consider adding a brief note that the suffix reflects the quantization schema (e.g., weight dtype and layout) and is only applied when pushing to the Hub, not to local directories.
examples/llama-3/3b-qat-fsdp2-nvfp4.yaml (2)
26-28: Consider explicitly disabling sample packingTo mirror the other example and avoid surprises with long context + FA, explicitly set
sample_packing: false.sequence_len: 8192 flash_attention: true +sample_packing: false
42-51: Batch size at seq_len 8192: validate memory headroom
micro_batch_size: 64at 8k context is aggressive even on large GPUs. If this example is meant to “just work,” consider a safer default (e.g., 4–8) and annotate that users can scale up.examples/llama-3/3b-qat-fsdp2.yaml (1)
68-69: Toggle of fsdp_cpu_ram_efficient_loading: double-check tradeoffYou flipped this to
false. Confirm this doesn’t regress load-time memory on CPU-constrained runners; if it was to fix perf, consider a short comment explaining why.src/axolotl/cli/quantize.py (2)
97-97: Ruff RUF010: prefer conversion flag in f-stringsUse
!sinstead of wrapping instr(...).- LOG.info(f"Saving quantized model to: {str(Path(output_dir) / 'quantized')}.") + LOG.info(f"Saving quantized model to: {(Path(output_dir) / 'quantized')!s}.")
119-119: Ruff RUF010: same nit here- LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}.") + LOG.info(f"Quantized model saved to: {(Path(output_dir) / 'quantized')!s}.")tests/e2e/test_quantization.py (2)
300-303: Assertion doesn’t verify “quantized weights” as the comment claims.
isinstance(..., nn.Parameter)is always true pre/post, so it doesn’t prove conversion or quantization.Consider asserting the post-conversion tensor type (e.g.,
AffineQuantizedTensor/LinearActivationQuantizedTensoror NVFP4 types if applicable), or explicitly verify fake-quant modules were removed and expected PTQ tensor subclasses are present. If the intended behavior afterconvert_qat_modelis “de-fake-quant to standard modules with quantized tensors,” align the assertion accordingly.
180-191: Optional: broaden Linear checks beyond top-level children.Iterating
model.children()only catches top-level modules (often justlm_head). If you intend to validate deeper linears inside the backbone, iteratemodel.modules()and skip containers.-for child in list(model.children()): +for child in model.modules(): if isinstance(child, torch.nn.Linear): ...src/axolotl/utils/quantization.py (3)
26-42: Avoid bare except; narrow the import guards.Use specific exceptions for optional imports to satisfy linting and avoid masking unrelated errors.
- try: + try: from torchao.prototype.mx_formats import NVFP4InferenceConfig - - quantization_config_to_str[NVFP4InferenceConfig] = "nvfp4" - except: - pass + quantization_config_to_str[NVFP4InferenceConfig] = "nvfp4" + except (ImportError, OSError): + pass @@ - try: + try: from torchao.quantization.quant_api import Int4WeightOnlyConfig - - quantization_config_to_str[Int4WeightOnlyConfig] = "int4" - except: - pass + quantization_config_to_str[Int4WeightOnlyConfig] = "int4" + except (ImportError, OSError): + pass
66-66: Clarify error message; avoid QAT-specific wording in a shared builder.
get_quantization_configis used for PTQ and as a base for QAT. Rephrase to avoid confusion.- raise ValueError("Int8WeightOnlyConfig is not supported by torchao QAT.") + raise ValueError("Int8 weight-only quantization is not supported in this flow.")
49-63: Docstring scope nit: mention both PTQ and QAT base configs.This function also returns base configs wrapped by QAT later; adjust wording for accuracy.
- This function is used to build a post-training quantization config. + Build a quantization config (used directly for PTQ, or as the base for QAT).
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (18)
.github/workflows/multi-gpu-e2e.yml(1 hunks).github/workflows/tests.yml(1 hunks)docs/quantize.qmd(1 hunks)examples/llama-3/3b-qat-fsdp2-nvfp4.yaml(1 hunks)examples/llama-3/3b-qat-fsdp2.yaml(3 hunks)requirements.txt(1 hunks)setup.py(1 hunks)src/axolotl/cli/args.py(1 hunks)src/axolotl/cli/cloud/baseten/template/train_sft.py(1 hunks)src/axolotl/cli/quantize.py(4 hunks)src/axolotl/core/training_args_base.py(1 hunks)src/axolotl/train.py(3 hunks)src/axolotl/utils/quantization.py(5 hunks)src/axolotl/utils/schemas/enums.py(1 hunks)src/axolotl/utils/schemas/quantization.py(3 hunks)tests/e2e/test_qat.py(2 hunks)tests/e2e/test_quantization.py(7 hunks)tests/e2e/utils.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (5)
src/axolotl/cli/quantize.py (4)
src/axolotl/loaders/tokenizer.py (1)
load_tokenizer(122-308)src/axolotl/utils/quantization.py (2)
get_quantization_config(44-108)quantize_model(111-146)src/axolotl/utils/schemas/enums.py (1)
TorchAOQuantDType(8-12)src/axolotl/core/trainers/base.py (1)
push_to_hub(527-537)
src/axolotl/train.py (4)
src/axolotl/loaders/processor.py (1)
load_processor(17-56)src/axolotl/loaders/model.py (1)
ModelLoader(67-889)src/axolotl/utils/quantization.py (1)
convert_qat_model(193-207)src/axolotl/integrations/llm_compressor/utils.py (1)
save_compressed_model(11-40)
src/axolotl/utils/schemas/quantization.py (1)
src/axolotl/utils/schemas/enums.py (1)
TorchAOQuantDType(8-12)
tests/e2e/test_quantization.py (5)
src/axolotl/utils/callbacks/qat.py (1)
QATCallback(33-48)src/axolotl/utils/quantization.py (4)
get_quantization_config(44-108)prepare_model_for_qat(149-190)convert_qat_model(193-207)quantize_model(111-146)src/axolotl/utils/schemas/enums.py (1)
TorchAOQuantDType(8-12)src/axolotl/utils/schemas/quantization.py (1)
QATConfig(26-53)tests/e2e/utils.py (3)
require_torch_2_8_0(93-102)requires_sm_ge_100(143-149)requires_cuda_ge_8_9(152-158)
src/axolotl/utils/quantization.py (3)
src/axolotl/utils/schemas/quantization.py (1)
QATConfig(26-53)src/axolotl/utils/schemas/enums.py (1)
TorchAOQuantDType(8-12)tests/e2e/test_quantization.py (1)
model(41-54)
🪛 Ruff (0.12.2)
src/axolotl/cli/cloud/baseten/template/train_sft.py
53-55: Avoid specifying long messages outside the exception class
(TRY003)
src/axolotl/cli/quantize.py
97-97: Use explicit conversion flag
Replace with conversion flag
(RUF010)
119-119: Use explicit conversion flag
Replace with conversion flag
(RUF010)
src/axolotl/utils/schemas/quantization.py
21-23: Avoid specifying long messages outside the exception class
(TRY003)
tests/e2e/test_quantization.py
138-138: Use of assert detected
(S101)
146-146: Use of assert detected
(S101)
207-207: Use of assert detected
(S101)
208-208: Use of assert detected
(S101)
241-241: Use of assert detected
(S101)
242-242: Use of assert detected
(S101)
243-243: Use of assert detected
(S101)
248-248: Use of assert detected
(S101)
255-255: Use of assert detected
(S101)
256-256: Use of assert detected
(S101)
257-257: Use of assert detected
(S101)
259-259: Use of assert detected
(S101)
261-261: Use of assert detected
(S101)
262-262: Use of assert detected
(S101)
267-267: Use of assert detected
(S101)
288-288: Use of assert detected
(S101)
289-289: Use of assert detected
(S101)
297-297: Use of assert detected
(S101)
298-298: Use of assert detected
(S101)
301-301: Use of assert detected
(S101)
302-302: Use of assert detected
(S101)
src/axolotl/utils/quantization.py
31-31: Do not use bare except
(E722)
31-32: try-except-pass detected, consider logging the exception
(S110)
40-40: Do not use bare except
(E722)
40-41: try-except-pass detected, consider logging the exception
(S110)
66-66: Avoid specifying long messages outside the exception class
(TRY003)
75-77: Avoid specifying long messages outside the exception class
(TRY003)
82-84: Avoid specifying long messages outside the exception class
(TRY003)
104-104: Avoid specifying long messages outside the exception class
(TRY003)
106-108: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (8)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
- GitHub Check: preview
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.8.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.6.0)
🔇 Additional comments (17)
setup.py (1)
165-166: Extras entry looks good; verify runner compatibility.The new "fbgemm-gpu" extra maps to fbgemm-gpu-genai>=1.2.0. Ensure the GPU runners used in CI (B200/H100) meet the runtime constraints for that wheel.
src/axolotl/cli/cloud/baseten/template/train_sft.py (1)
60-60: LGTM after accelerator normalization.Using the resolved accelerator in AcceleratorSpec is correct.
tests/e2e/utils.py (1)
16-17: LGTM: import relocation is harmless..github/workflows/tests.yml (1)
307-307: No references to AXOLOTL_EXTRAS or install logic found incicd/e2e_tests.py. Verify if this file delegates to Docker builds (which already handle extras) or manually installsaxolotl[$AXOLOTL_EXTRAS]; otherwise, the extras value may never be applied..github/workflows/multi-gpu-e2e.yml (1)
47-49: fbgemm-gpu extra: check runner compatibility and Modal wiringNice addition. Please verify the Modal image/env actually installs the fbgemm-gpu extra for CUDA 12.8 / PyTorch 2.8.0 and that
cicd.multigpurespectsAXOLOTL_EXTRAS; otherwise this matrix leg may silently run without the intended kernels.Optionally add a quick echo for visibility:
- name: Update env vars run: | @@ echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV + - name: Show selected extras + run: echo "Using AXOLOTL_EXTRAS=$AXOLOTL_EXTRAS"src/axolotl/train.py (3)
233-242: QAT conversion path looks correctSwitch to
convert_qat_model(...)and the updated user-facing log is appropriate for the new flow.
335-336: Import style cleanup—LGTMNo functional change; matches surrounding style.
33-33: Re-export of loaders verified
src/axolotl/loaders/init.py already imports ModelLoader, load_processor, and load_tokenizer; no changes needed.examples/llama-3/3b-qat-fsdp2-nvfp4.yaml (1)
29-33: NVFP4 config: LGTM; group_size=16 constraint capturedConfig matches NVFP4 expectations.
examples/llama-3/3b-qat-fsdp2.yaml (3)
27-30: Flash attention + long context: good; packing disabled explicitlySettings look consistent for 8k context training.
31-35: QAT block alignment
weight_dtype: int4withgroup_size: 32is a sensible default for common 4-bit paths.
76-78: Pad token change: ensure tokenizer contains it or will be addedSwitching to
<|finetune_right_pad_id|>is fine; verifyload_tokenizeradds it and that embeddings are resized before training starts (it usually does).src/axolotl/cli/quantize.py (1)
110-118: Hub suffixing: good; ensure mapping covers all AO configsSuffix construction via
quantization_config_to_str[type(quantization_config)]looks fine. Please ensure the mapping contains every config your CLI can emit, otherwise you’ll KeyError at push time.src/axolotl/utils/schemas/quantization.py (2)
50-54: LGTM: single, centralized dtype validator.Consolidating per-field validators through
validate_ao_dtypeis clean and reduces duplication.
77-81: LGTM: mirrored validation for PTQConfig.Consistent behavior across QAT/PTQ schemas.
src/axolotl/utils/quantization.py (2)
149-191: LGTM: QAT prep embeds base config inside torchao.QATConfig and filters Embedding separately.The signature and embedding path align with tests and the newer API.
193-208: LGTM: simple conversion step using QATConfig(step="convert").Matches the intended swap-out of fake-quant modules.
|
@SalmanMohammadi does this supersede #3119 ? |
Yep just closed #3119 |
winglian
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mostly have a bunch of follow ups on previous comments to make sure they are addressed. Otherwise seems good to me.
| Float8DynamicActivationInt4WeightConfig, | ||
| Int4WeightOnlyConfig, | ||
| Int8DynamicActivationInt4WeightConfig, | ||
| Int8DynamicActivationInt8WeightConfig, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did we get a resolution on this?
| and weight_dtype == TorchAOQuantDType.int8 | ||
| ): | ||
| raise ValueError( | ||
| "Int8DynamicActivationInt8WeightConfig is not supported by torchao QAT." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any updates on this or will it need to be addressed in an ao patch release?
|
2.8.0 still failing on: |
I'm able to manually run this on a B200 instance. I'm going to exclude this test for now and come back to it. root@9a09ddf52b98:/workspace/axolotl# python
Python 3.11.13 (main, Jun 5 2025, 13:12:00) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from axolotl.utils.schemas.enums import TorchAOQuantDType
>>> import torch
>>> weight_dtype = TorchAOQuantDType.int4
>>> act_dtype = TorchAOQuantDType.float8_e4m3fn
>>> from axolotl.utils.quantization import quantize_model
TMA benchmarks will be running without grid constant TMA descriptor.
>>> model = torch.nn.Linear(2048, 2048)
>>> quantize_model(model, weight_dtype, None, act_dtype) |
The QAT API in torchao has recently been updated pytorch/ao#2629.
Summary by CodeRabbit
New Features
Documentation
Examples
Behavior Changes
Chores
Tests