Skip to content

Commit 64c3330

Browse files
authored
up (#3391)
1 parent b5214fd commit 64c3330

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

unsloth/models/_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1660,11 +1660,12 @@ def _prepare_model_for_qat(model: torch.nn.Module, qat_scheme: str) -> torch.nn.
16601660
from torchao.quantization import (
16611661
Float8DynamicActivationInt4WeightConfig,
16621662
Float8DynamicActivationFloat8WeightConfig,
1663-
Int8DynamicActivationInt4WeightConfig,
1663+
Int8DynamicActivationIntxWeightConfig,
16641664
Int4WeightOnlyConfig,
16651665
PerRow,
16661666
quantize_,
16671667
)
1668+
from torchao.quantization.granularity import PerGroup
16681669
from torchao.quantization.qat import QATConfig
16691670
filter_fn = None
16701671
if qat_scheme == "fp8-int4":
@@ -1675,7 +1676,7 @@ def _prepare_model_for_qat(model: torch.nn.Module, qat_scheme: str) -> torch.nn.
16751676
base_config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
16761677
elif qat_scheme == "int8-int4":
16771678
group_size = 32
1678-
base_config = Int8DynamicActivationInt4WeightConfig(group_size=group_size)
1679+
base_config = Int8DynamicActivationIntxWeightConfig(weight_dtype=torch.int4, weight_granularity=PerGroup(group_size))
16791680
filter_fn = lambda m, _: isinstance(m, torch.nn.Linear) and m.in_features >= group_size
16801681
elif qat_scheme == "int4":
16811682
group_size = 128

0 commit comments

Comments
 (0)