From bc962f814d15aec39f5af7ed93e17660874ce734 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Mon, 29 Sep 2025 17:09:09 -0700 Subject: [PATCH] up --- unsloth/models/_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 60abcea70..9cea44b69 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -1653,11 +1653,12 @@ def _prepare_model_for_qat(model: torch.nn.Module, qat_scheme: str) -> torch.nn. from torchao.quantization import ( Float8DynamicActivationInt4WeightConfig, Float8DynamicActivationFloat8WeightConfig, - Int8DynamicActivationInt4WeightConfig, + Int8DynamicActivationIntxWeightConfig, Int4WeightOnlyConfig, PerRow, quantize_, ) + from torchao.quantization.granularity import PerGroup from torchao.quantization.qat import QATConfig filter_fn = None if qat_scheme == "fp8-int4": @@ -1668,7 +1669,7 @@ def _prepare_model_for_qat(model: torch.nn.Module, qat_scheme: str) -> torch.nn. base_config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) elif qat_scheme == "int8-int4": group_size = 32 - base_config = Int8DynamicActivationInt4WeightConfig(group_size=group_size) + base_config = Int8DynamicActivationIntxWeightConfig(weight_dtype=torch.int4, weight_granularity=PerGroup(group_size)) filter_fn = lambda m, _: isinstance(m, torch.nn.Linear) and m.in_features >= group_size elif qat_scheme == "int4": group_size = 128