|
28 | 28 | from ..kernels import ( |
29 | 29 | post_patch_loss_function, |
30 | 30 | ) |
31 | | -from ._utils import __version__, importlib_version |
| 31 | +from ._utils import __version__, importlib_version, _prepare_model_for_qat |
32 | 32 | from ._utils import * |
33 | 33 | from ..save import patch_saving_functions |
34 | 34 | from peft import LoraConfig, TaskType, get_peft_model as _get_peft_model |
@@ -796,6 +796,7 @@ def get_peft_model( |
796 | 796 | loftq_config = {}, |
797 | 797 | task_type = TaskType.CAUSAL_LM, |
798 | 798 | temporary_location = "_unsloth_temporary_saved_buffers", |
| 799 | + qat_scheme = None, |
799 | 800 | **kwargs |
800 | 801 | ): |
801 | 802 | if os.environ.get("UNSLOTH_ENABLE_FULL_FINETUNING", "0") == "1": |
@@ -871,6 +872,11 @@ def get_peft_model( |
871 | 872 | use_gradient_checkpointing = use_gradient_checkpointing, |
872 | 873 | ) |
873 | 874 | model = _get_peft_model(model, lora_config) |
| 875 | + # Apply QAT + LoRA if specified |
| 876 | + if qat_scheme is not None: |
| 877 | + print("Unsloth: Applying QAT to mitigate quantization degradation") |
| 878 | + model = _prepare_model_for_qat(model, qat_scheme) |
| 879 | + pass |
874 | 880 | # Fix LoraConfig.auto_mapping is None |
875 | 881 | fix_lora_auto_mapping(model) |
876 | 882 | # Enable gradients on modules which are trainable |
|
0 commit comments