Skip to content

Commit 1368978

Browse files
authored
improve qat (#3446)
* Update save.py * Update vision.py * Update save.py
1 parent bfad39b commit 1368978

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

unsloth/models/vision.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from ..kernels import (
2929
post_patch_loss_function,
3030
)
31-
from ._utils import __version__, importlib_version
31+
from ._utils import __version__, importlib_version, _prepare_model_for_qat
3232
from ._utils import *
3333
from ..save import patch_saving_functions
3434
from peft import LoraConfig, TaskType, get_peft_model as _get_peft_model
@@ -796,6 +796,7 @@ def get_peft_model(
796796
loftq_config = {},
797797
task_type = TaskType.CAUSAL_LM,
798798
temporary_location = "_unsloth_temporary_saved_buffers",
799+
qat_scheme = None,
799800
**kwargs
800801
):
801802
if os.environ.get("UNSLOTH_ENABLE_FULL_FINETUNING", "0") == "1":
@@ -871,6 +872,11 @@ def get_peft_model(
871872
use_gradient_checkpointing = use_gradient_checkpointing,
872873
)
873874
model = _get_peft_model(model, lora_config)
875+
# Apply QAT + LoRA if specified
876+
if qat_scheme is not None:
877+
print("Unsloth: Applying QAT to mitigate quantization degradation")
878+
model = _prepare_model_for_qat(model, qat_scheme)
879+
pass
874880
# Fix LoraConfig.auto_mapping is None
875881
fix_lora_auto_mapping(model)
876882
# Enable gradients on modules which are trainable

unsloth/save.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
pass
4545
pass
4646
from pathlib import Path
47+
from peft import PeftModelForCausalLM, PeftModel
4748

4849
__all__ = [
4950
"print_quantization_methods",
@@ -2522,7 +2523,12 @@ def unsloth_save_pretrained_torchao(
25222523
arguments["save_method"] = "merged_16bit" # Must be 16bit
25232524
del arguments["self"]
25242525
del arguments["torchao_config"]
2525-
unsloth_generic_save(**arguments)
2526+
2527+
if not isinstance(self, PeftModelForCausalLM) and not isinstance(self, PeftModel):
2528+
self.save_pretrained(save_directory)
2529+
tokenizer.save_pretrained(save_directory)
2530+
else:
2531+
unsloth_generic_save(**arguments)
25262532
for _ in range(3):
25272533
gc.collect()
25282534

0 commit comments

Comments
 (0)