diff --git a/docs/source/developer_guides/quantization.md b/docs/source/developer_guides/quantization.md index b14abecc3a..c03a8f50e7 100644 --- a/docs/source/developer_guides/quantization.md +++ b/docs/source/developer_guides/quantization.md @@ -272,12 +272,32 @@ model = convert(model) An example demonstrating how to load a PEFT LoRA adapter into an INC-quantized FLUX text-to-image model for HPU devices is provided [here](https://github.com/huggingface/peft/blob/main/examples/stable_diffusion/inc_flux_lora_hpu.py). - ### Caveats: - `merge()` and `unmerge()` methods are currently not supported for INC-quantized models. - Currently, only **Linear** INC-quantized layers are supported when loading PEFT adapters. +## Optimum-quanto + +PEFT supports models quantized with [optimum-quanto](https://github.com/huggingface/optimum-quanto). This has been tested with 2bit, 4bit, and 8bit int quantization. Optimum-quanto also works on CPU and MPS. + +```python +from transformers import AutoModelForCausalLM, QuantoConfig + +model_id = ... +quantization_config = QuantoConfig(weights="int4") # or qint2 or qint8 +base_model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config) +peft_config = LoraConfig(...) +model = get_peft_model(base_model, peft_config) +``` + +### Caveats: + +- Use optimum-quanto v0.2.5 or above, otherwise saving and loading won't work properly. +- If you want to use optimum-quanto via transformers, install transformers v4.46.0 or above. +- Float8 is discouraged as it can easily produce NaNs. +- There is explicit support for optimum-quanto when used with LoRA. However, when optimum-quanto quantizes a layer, it remains a subclass of the corresponding torch class (e.g., quanto's `QLinear` is a subclass of `nn.Linear`). For this reason, non-LoRA methods will generally also work with optimum-quanto, even if not explicitly supported. Be aware, however, that **merging only works correctly with LoRA**. If you use a method other than LoRA, merging may not raise an error but the results will be incorrect. + ## Other Supported PEFT Methods Besides LoRA, the following PEFT methods also support quantization: diff --git a/setup.py b/setup.py index 1ca8f6efa2..2fba127c99 100644 --- a/setup.py +++ b/setup.py @@ -38,6 +38,7 @@ "scipy", "protobuf", "sentencepiece", + "optimum-quanto", ] setup( diff --git a/src/peft/import_utils.py b/src/peft/import_utils.py index 6aa69a8519..735740baa4 100644 --- a/src/peft/import_utils.py +++ b/src/peft/import_utils.py @@ -170,3 +170,10 @@ def is_xpu_available(check_device=False): @lru_cache def is_diffusers_available(): return importlib.util.find_spec("diffusers") is not None + + +@lru_cache +def is_quanto_available(): + return (importlib.util.find_spec("optimum") is not None) and ( + importlib.util.find_spec("optimum.quanto") is not None + ) diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 7ab4c6558c..41b797d778 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -54,6 +54,7 @@ from .hqq import dispatch_hqq from .inc import dispatch_inc from .layer import Conv2d, LoraLayer, ParamWrapper, dispatch_default +from .quanto import dispatch_quanto from .torchao import dispatch_torchao from .tp_layer import dispatch_megatron @@ -355,6 +356,7 @@ def dynamic_dispatch_func(target, adapter_name, lora_config, **kwargs): dispatch_hqq, dispatch_inc, dispatch_torchao, + dispatch_quanto, dispatch_megatron, dispatch_default, ] diff --git a/src/peft/tuners/lora/quanto.py b/src/peft/tuners/lora/quanto.py new file mode 100644 index 0000000000..f7eb1c5047 --- /dev/null +++ b/src/peft/tuners/lora/quanto.py @@ -0,0 +1,422 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import math +import warnings +from typing import Any, Optional + +import torch +from torch import nn +from torch.nn import functional as F + +from peft.import_utils import is_quanto_available +from peft.tuners.lora.layer import LoraLayer +from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge +from peft.utils.other import transpose + + +if is_quanto_available(): + # ensure that there are no quanto imports unless optimum.quanto is installed + from optimum.quanto import QConv2d, QLinear +else: + QConv2d, QLinear = None, None + + +class QuantoLoraLinear(torch.nn.Module, LoraLayer): + """LoRA layer implementation for quanto QLinear""" + + def __init__( + self, + base_layer, + adapter_name, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + init_lora_weights: bool = True, + use_rslora: bool = False, + use_dora: bool = False, + **kwargs, + ): + if use_dora: + raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") + + super().__init__() + LoraLayer.__init__(self, base_layer) + self.fan_in_fan_out = fan_in_fan_out + + self._active_adapter = adapter_name + self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora) + + def _mixed_batch_forward( + self, x: torch.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any + ) -> torch.Tensor: + # This is a special method that handles the case when users pass the argument `adapter_names`. This is an + # extra argument that allows mixing different adapters in the same batch at inference time. + result = self.base_layer(x, *args, **kwargs) + # some quanto quantizations may require cloning or else will fail later when assigning the lora output in-place + result = result.clone() + torch_result_dtype = result.dtype + + unique_adapters = set(adapter_names) + sub_batch_indices_list = [] + for adapter in unique_adapters: + sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter]) + + for i, active_adapter in enumerate(unique_adapters): + if active_adapter == "__base__": + continue + if active_adapter not in self.lora_A.keys(): + continue + + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + x = self._cast_input_dtype(x, lora_A.weight.dtype) + + # getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear + # layer output + sub_batch = x[sub_batch_indices_list[i]] + lora_output = lora_B(lora_A(dropout(sub_batch))) * scaling + result[sub_batch_indices_list[i]] += lora_output.to(torch_result_dtype) + + return result + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + self._check_forward_args(x, *args, **kwargs) + adapter_names = kwargs.pop("adapter_names", None) + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif adapter_names is not None: + result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + x = x.to(lora_A.weight.dtype) + + output = lora_B(lora_A(dropout(x))) + if requires_conversion: + output = output.to(expected_dtype) + output = output * scaling + result = result + output + + return result + + def get_delta_weight(self, adapter): + return ( + transpose(self.lora_B[adapter].weight @ self.lora_A[adapter].weight, fan_in_fan_out=self.fan_in_fan_out) + * self.scaling[adapter] + ) + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + with torch.no_grad(): + new_module = torch.nn.Linear( + self.in_features, self.out_features, device=self.lora_A[adapter_names[0]].weight.device + ) + new_module.weight.zero_() + new_module.bias.zero_() + + base_layer = self.get_base_layer() + orig_weight = base_layer.qweight + new_module.weight.data += orig_weight + if getattr(base_layer, "bias", None) is not None: + new_module.bias.data += base_layer.bias + + for active_adapter in adapter_names: + new_module.weight.data += self.get_delta_weight(active_adapter) + + quantized = base_layer.from_module(new_module, weights=base_layer.weight_qtype).qweight + if safe_merge and not torch.isfinite(quantized).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + base_layer.qweight._data = quantized._data + base_layer.qweight._scale = quantized._scale + self.merged_adapters.extend(adapter_names) + + def unmerge(self) -> None: + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + + with torch.no_grad(): + new_module = torch.nn.Linear( + self.in_features, self.out_features, device=self.lora_A[self.active_adapters[0]].weight.device + ) + new_module.weight.zero_() + new_module.bias.zero_() + + base_layer = self.get_base_layer() + orig_weight = base_layer.qweight + new_module.weight.data += orig_weight + if getattr(base_layer, "bias", None) is not None: + new_module.bias.data += base_layer.bias + + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + new_module.weight.data -= self.get_delta_weight(active_adapter) + + quantized = base_layer.from_module(new_module, weights=base_layer.weight_qtype).qweight + base_layer.qweight._data = quantized._data + base_layer.qweight._scale = quantized._scale + + def __repr__(self) -> str: + rep = super().__repr__() + return "lora." + rep + + +class QuantoLoraConv2d(torch.nn.Module, LoraLayer): + """LoRA layer implementation for quanto QConv2d""" + + def __init__( + self, + base_layer, + adapter_name, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + init_lora_weights: bool = True, + use_rslora: bool = False, + use_dora: bool = False, + **kwargs, + ): + if use_dora: + raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") + + super().__init__() + LoraLayer.__init__(self, base_layer) + + self._active_adapter = adapter_name + self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora) + + def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora): + # same as lora.layer.Conv2d + if r <= 0: + raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") + + self.r[adapter_name] = r + self.lora_alpha[adapter_name] = lora_alpha + if lora_dropout > 0.0: + lora_dropout_layer = nn.Dropout(p=lora_dropout) + else: + lora_dropout_layer = nn.Identity() + + self.lora_dropout[adapter_name] = lora_dropout_layer + # Actual trainable parameters + base_layer = self.get_base_layer() + kernel_size = base_layer.kernel_size + stride = base_layer.stride + padding = base_layer.padding + self.lora_A[adapter_name] = nn.Conv2d(self.in_features, r, kernel_size, stride, padding, bias=False) + self.lora_B[adapter_name] = nn.Conv2d(r, self.out_features, (1, 1), (1, 1), bias=False) + if use_rslora: + self.scaling[adapter_name] = lora_alpha / math.sqrt(r) + else: + self.scaling[adapter_name] = lora_alpha / r + + if init_lora_weights == "loftq": + self.loftq_init(adapter_name) + elif init_lora_weights: + self.reset_lora_parameters(adapter_name, init_lora_weights) + + # call this before dora_init + self._move_adapter_to_device_of_base_layer(adapter_name) + + if use_dora: + # TODO: Implement DoRA + self.dora_init(adapter_name) + self.use_dora[adapter_name] = True + else: + self.use_dora[adapter_name] = False + + self.set_adapter(self.active_adapters) + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + result = self.base_layer(x) + adapter_names = kwargs.pop("adapter_names", None) + if adapter_names is not None: + raise ValueError(f"{self.__class__.__name__} does not support mixed_batch_forward yet.") + + if self.disable_adapters: + return result + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + x = x.to(lora_A.weight.dtype) + + output = lora_B(lora_A(dropout(x))) + if requires_conversion: + output = output.to(expected_dtype) + output = output * scaling + result = result + output + + return result + + def get_delta_weight(self, adapter): + # same as lora.layer.Conv2d + device = self.lora_B[adapter].weight.device + dtype = self.lora_A[adapter].weight.dtype + + # In case users wants to merge the adapter weights that are in + # (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to + # (b)float16 because some CPUs have slow bf16/fp16 matmuls. + cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) + + weight_A = self.lora_A[adapter].weight + weight_B = self.lora_B[adapter].weight + + if cast_to_fp32: + weight_A = weight_A.float() + weight_B = weight_B.float() + + # https://github.com/bmaltais/kohya_ss/blob/feb6728762a8f463d15ba936d189d4c3abfaa1ab/networks/lora.py#L117 + if self.get_base_layer().weight.size()[2:4] == (1, 1): + # conv2d 1x1 + output_tensor = (weight_B.squeeze(3).squeeze(2) @ weight_A.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze( + 3 + ) * self.scaling[adapter] + else: + # conv2d 3x3 + output_tensor = ( + F.conv2d( + weight_A.permute(1, 0, 2, 3), + weight_B, + ).permute(1, 0, 2, 3) + * self.scaling[adapter] + ) + + if cast_to_fp32: + output_tensor = output_tensor.to(dtype=dtype) + + # cast back the weights + self.lora_A[adapter].weight.data = weight_A.to(dtype) + self.lora_B[adapter].weight.data = weight_B.to(dtype) + + return output_tensor + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + # same as lora.quanto.QuantoLoraLinear + from optimum.quanto import quantize_weight + + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + base_layer = self.get_base_layer() + orig_weight = base_layer.qweight + + for active_adapter in adapter_names: + delta_weight = self.get_delta_weight(active_adapter) + # note: no in-place for safe_merge=False + new_weight_data = orig_weight + delta_weight + if safe_merge: + if torch.isfinite(new_weight_data).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + quantized = quantize_weight(new_weight_data, qtype=orig_weight.qtype, axis=orig_weight.axis) + base_layer.qweight._data = quantized._data + base_layer.qweight._scale = quantized._scale + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + # same as lora.quanto.QuantoLoraLinear + from optimum.quanto import quantize_weight + + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter not in self.lora_A.keys(): + continue + + base_layer = self.get_base_layer() + orig_weight = base_layer.weight + new_weight_data = orig_weight - self.get_delta_weight(active_adapter) + quantized = quantize_weight(new_weight_data, qtype=orig_weight.qtype, axis=orig_weight.axis) + base_layer.weight._data = quantized._data + base_layer.weight._scale = quantized._scale + + def __repr__(self) -> str: + rep = super().__repr__() + return "lora." + rep + + +def dispatch_quanto( + target: torch.nn.Module, + adapter_name: str, + **kwargs: Any, +) -> Optional[torch.nn.Module]: + new_module = None + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if is_quanto_available() and isinstance(target_base_layer, QLinear): + new_module = QuantoLoraLinear(target, adapter_name, **kwargs) + target.weight = target_base_layer.weight + + if hasattr(target, "bias"): + target.bias = target_base_layer.bias + elif is_quanto_available() and isinstance(target_base_layer, QConv2d): + new_module = QuantoLoraConv2d(target, adapter_name, **kwargs) + target.weight = target_base_layer.weight + + if hasattr(target, "bias"): + target.bias = target_base_layer.bias + + return new_module diff --git a/src/peft/utils/integrations.py b/src/peft/utils/integrations.py index a3d9a7b792..6dbf0d0356 100644 --- a/src/peft/utils/integrations.py +++ b/src/peft/utils/integrations.py @@ -139,15 +139,22 @@ def get_layer_device_map(model): """ Derive the device map for the layers of the model. """ - main_device = [d for d in model.hf_device_map.values() if d not in ["cpu", "disk"]][0] + if not hasattr(model, "hf_device_map"): + return None + + if (len(model.hf_device_map) == 1) and hasattr(model, "device"): + # E.g. with quanto, when the model is loaded as: + # `model = AutoModel.from_pretrained(model_id, quantization_config=quanto_config)` + # Then the model.hf_device_map is set to {'': 'cpu'}, even if model.to(0) is called later. Thus we can't fully + # rely on the hf_device_map. + main_device = model.device + else: + main_device = [d for d in model.hf_device_map.values() if d not in ["cpu", "disk"]][0] execution_device_map = { name: main_device if device in ["cpu", "disk"] else device for name, device in model.hf_device_map.items() } - if execution_device_map is None: - return None - if len(execution_device_map) == 1 and "" in execution_device_map: return {idx: execution_device_map[""] for idx in range(model.config.num_hidden_layers)} @@ -177,6 +184,9 @@ def map_cache_to_layer_device_map(model, cache) -> None: return layer_device_map = get_layer_device_map(model) + if layer_device_map is None: + return + for idx in range(model.config.num_hidden_layers): layer_device = layer_device_map[idx] cache.key_cache[idx] = cache.key_cache[idx].to(layer_device) diff --git a/tests/test_quanto.py b/tests/test_quanto.py new file mode 100644 index 0000000000..3eee78ecca --- /dev/null +++ b/tests/test_quanto.py @@ -0,0 +1,664 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import platform +import shutil +import tempfile +from unittest.mock import Mock, call, patch + +import pytest +import torch +from torch import nn +from transformers import AutoModelForCausalLM, AutoTokenizer + +from peft import ( + LoraConfig, + PrefixTuningConfig, + PromptTuningConfig, + PromptTuningInit, + TaskType, + get_peft_model, +) + +from .testing_common import PeftCommonTester +from .testing_utils import set_init_weights_false + + +# only test a small subset of models and PEFT methods, testing exhaustively would be slow for little benefit +MODELS_TO_TEST = [ + "trl-internal-testing/tiny-random-LlamaForCausalLM", +] + + +ALL_CONFIGS = [ + ( + LoraConfig, + { + "r": 8, + "lora_alpha": 32, + "target_modules": None, + "lora_dropout": 0.05, + "bias": "none", + "task_type": TaskType.CAUSAL_LM, + }, + ), + ( + PrefixTuningConfig, + { + "num_virtual_tokens": 10, + "task_type": TaskType.CAUSAL_LM, + }, + ), + ( + PromptTuningConfig, + { + "num_virtual_tokens": 10, + "task_type": TaskType.CAUSAL_LM, + }, + ), +] + + +def _skip_if_merging_not_supported(model_id, config_cls): + if config_cls in (PrefixTuningConfig, PromptTuningConfig): + pytest.skip("This PEFT method does not support merging") + + +def make_automodel_proxy(weights: str): + """Instantiate a quanto-quantized transformers model.""" + from transformers import QuantoConfig + + class QuantoModelProxy: + @classmethod + def from_pretrained(self, *args, **kwargs): + quantization_config = QuantoConfig(weights=weights) + model = AutoModelForCausalLM.from_pretrained(*args, quantization_config=quantization_config, **kwargs) + return model + + return QuantoModelProxy + + +# Seeing issues on CI with MacOS and Windows, so skipping them for now +@pytest.mark.skipif(platform.system() != "Linux", reason="Tests are skipped on macOS and Windows") +class BasePeftQuantoModelTester: + r"""Base class implementing tests for quanto-quantized models. + + This class is based on PeftDecoderModelTester with some quanto-specific edits, especially for the merging tests, + which are less precise due to the quantization. + + Subclasses should implement the attributes below. + """ + + # The weights argument for quanto, should be "int2", "int4", or "int8" + weights = "MISSING" + # transformers class should be make_automodel_proxy(weights=weights) + transformers_class = "MISSING" + # expected minimum correlation between logits before and after merging + # subclasses should override this with a float between 0 and 1 + min_correlation = "MISSING" + # the allowed tolerance for comparing the output tensors + tol = "MISSING" + + def skipTest(self, reason=""): + # for backwards compatibility with unittest style test classes + pytest.skip(reason) + + def _get_correlation_matrix(self, *tensors): + return torch.corrcoef(torch.stack([t.flatten() for t in tensors])) + + def check_tensors_approximately_equal(self, *tensors): + # Strict equality checks will fail due to the quantization, so we check: + # 1. The correlation between the tensors is high + # 2. Tensor equality after removing 1% of highest and lowest outliers + cc_matrix = self._get_correlation_matrix(*tensors) + assert cc_matrix.min() > self.min_correlation + + for tensor0, tensor1 in zip(tensors, tensors[1:]): + tensor0, tensor1 = tensor0.flatten(), tensor1.flatten() + diff = tensor0 - tensor1 + indices = torch.argsort(diff) + # remove 1% outliers on both ends + indices = indices[len(indices) // 100 : -len(indices) // 100] + tensor0, tensor1 = tensor0[indices], tensor1[indices] + assert torch.allclose(tensor0, tensor1, atol=self.tol, rtol=self.tol) + + def prepare_inputs_for_testing(self): + input_ids = torch.tensor([[1, 1, 1], [1, 2, 1]]).to(self.torch_device) + attention_mask = torch.tensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device) + + input_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + + return input_dict + + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_attributes_parametrized(self, model_id, config_cls, config_kwargs): + self._test_model_attr(model_id, config_cls, config_kwargs) + + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_adapter_name(self, model_id, config_cls, config_kwargs): + self._test_adapter_name(model_id, config_cls, config_kwargs) + + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_prepare_for_training_parametrized(self, model_id, config_cls, config_kwargs): + self._test_prepare_for_training(model_id, config_cls, config_kwargs) + + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_prompt_tuning_text_prepare_for_training(self, model_id, config_cls, config_kwargs): + # Test that prompt tuning works with text init + if config_cls != PromptTuningConfig: + return pytest.skip(f"This test does not apply to {config_cls}") + + config_kwargs = config_kwargs.copy() + config_kwargs["prompt_tuning_init"] = PromptTuningInit.TEXT + config_kwargs["prompt_tuning_init_text"] = "This is a test prompt." + config_kwargs["tokenizer_name_or_path"] = model_id + self._test_prepare_for_training(model_id, config_cls, config_kwargs) + + def test_prompt_tuning_text_tokenizer_kwargs(self): + # Allow users to pass additional arguments to Tokenizer.from_pretrained + # Fix for #1032 + mock = Mock() + orig_from_pretrained = AutoTokenizer.from_pretrained + + def mock_autotokenizer_from_pretrained(*args, **kwargs): + mock(*args, **kwargs) + return orig_from_pretrained(config.tokenizer_name_or_path) + + model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" + config = PromptTuningConfig( + base_model_name_or_path=model_id, + tokenizer_name_or_path=model_id, + num_virtual_tokens=10, + prompt_tuning_init=PromptTuningInit.TEXT, + task_type="CAUSAL_LM", + prompt_tuning_init_text="This is a test prompt.", + tokenizer_kwargs={"trust_remote_code": True, "foo": "bar"}, + ) + model = self.transformers_class.from_pretrained(model_id).to(self.torch_device) + with patch("transformers.AutoTokenizer.from_pretrained", mock_autotokenizer_from_pretrained): + model = get_peft_model(model, config) + + expected_call = call(model_id, trust_remote_code=True, foo="bar") + assert mock.call_args == expected_call + + def test_prompt_tuning_config_invalid_args(self): + # Raise an error when tokenizer_kwargs is used with prompt_tuning_init!='TEXT', because this argument has no + # function in that case + model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" + with pytest.raises(ValueError, match="tokenizer_kwargs only valid when using prompt_tuning_init='TEXT'."): + PromptTuningConfig( + base_model_name_or_path=model_id, + tokenizer_name_or_path=model_id, + num_virtual_tokens=10, + task_type="CAUSAL_LM", + prompt_tuning_init_text="This is a test prompt.", + prompt_tuning_init=PromptTuningInit.RANDOM, # <= should not be used together with tokenizer_kwargs + tokenizer_kwargs={"trust_remote_code": True, "foo": "bar"}, + ) + + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_save_pretrained(self, model_id, config_cls, config_kwargs): + self._test_save_pretrained(model_id, config_cls, config_kwargs) + + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_save_pretrained_pickle(self, model_id, config_cls, config_kwargs): + self._test_save_pretrained(model_id, config_cls, config_kwargs, safe_serialization=False) + + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_save_pretrained_selected_adapters(self, model_id, config_cls, config_kwargs): + self._test_save_pretrained_selected_adapters(model_id, config_cls, config_kwargs) + + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_save_pretrained_selected_adapters_pickle(self, model_id, config_cls, config_kwargs): + self._test_save_pretrained_selected_adapters(model_id, config_cls, config_kwargs, safe_serialization=False) + + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_from_pretrained_config_construction(self, model_id, config_cls, config_kwargs): + self._test_from_pretrained_config_construction(model_id, config_cls, config_kwargs) + + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_merge_layers(self, model_id, config_cls, config_kwargs): + # Not using PeftCommonTester for merging tests as merging is too imprecise. So instead of checking we use a + # custom check that relies on correlation and outlier removal + _skip_if_merging_not_supported(model_id, config_cls) + torch.manual_seed(0) + + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + if config.is_prompt_learning: + pytest.skip("Prompt learning models do not support merging.") + + model = self.transformers_class.from_pretrained(model_id) + model = get_peft_model(model, config) + model = model.to(self.torch_device) + + dummy_input = self.prepare_inputs_for_testing() + model.eval() + logits = model(**dummy_input)[0] + + model.merge_adapter() + logits_merged = model(**dummy_input)[0] + model.unmerge_adapter() + logits_unmerged = model(**dummy_input)[0] + + model = model.merge_and_unload() + logits_merged_unloaded = model(**dummy_input)[0] + + self.check_tensors_approximately_equal(logits, logits_merged, logits_unmerged, logits_merged_unloaded) + + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + # TODO: enable if/when deepcopy-ing is supported + @pytest.mark.skip("Quanto does not work (yet) with deepcopy-ing") + def test_merge_layers_multi(self, model_id, config_cls, config_kwargs): + # Not using PeftCommonTester for merging tests as merging is too imprecise. So instead of checking we use a + # custom check that relies on correlation and outlier removal + # NOTE: don't use with `torch.inference_mode()`, see: https://github.com/huggingface/optimum-quanto/issues/304 + _skip_if_merging_not_supported(model_id, config_cls) + torch.manual_seed(0) + + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + + model = self.transformers_class.from_pretrained(model_id) + model = get_peft_model(model, config) + + model = model.to(self.torch_device) + + dummy_input = self.prepare_inputs_for_testing() + model.eval() + + logits_adapter_1 = model(**dummy_input)[0] + + model.add_adapter("adapter-2", config) + model.set_adapter("adapter-2") + model.eval() + + logits_adapter_2 = model(**dummy_input)[0] + + assert not torch.allclose(logits_adapter_1, logits_adapter_2, atol=1e-3, rtol=1e-3) + + model.set_adapter("default") + + logits_adapter_1_after_set = model(**dummy_input)[0] + + self.check_tensors_approximately_equal(logits_adapter_1, logits_adapter_1_after_set) + + model_copy = copy.deepcopy(model) + model_copy_2 = copy.deepcopy(model) + model_merged_all = model.merge_and_unload(adapter_names=["adapter-2", "default"]) + + logits_merged_all = model_merged_all(**dummy_input)[0] + + assert not torch.allclose(logits_merged_all, logits_adapter_2, atol=1e-3, rtol=1e-3) + assert not torch.allclose(logits_merged_all, logits_adapter_1, atol=1e-3, rtol=1e-3) + + model_merged_adapter_2 = model_copy.merge_and_unload(adapter_names=["adapter-2"]) + + logits_merged_adapter_2 = model_merged_adapter_2(**dummy_input)[0] + + self.check_tensors_approximately_equal(logits_adapter_2, logits_merged_adapter_2) + + model_merged_adapter_default = model_copy_2.merge_and_unload(adapter_names=["default"]) + logits_merged_adapter_default = model_merged_adapter_default(**dummy_input)[0] + + self.check_tensors_approximately_equal(logits_adapter_1, logits_merged_adapter_default) + + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_merge_layers_nan(self, model_id, config_cls, config_kwargs): + # Not using PeftCommonTester for merging tests as merging is too imprecise. So instead of checking we use a + # custom check that relies on correlation and outlier removal + _skip_if_merging_not_supported(model_id, config_cls) + torch.manual_seed(0) + + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + + model = self.transformers_class.from_pretrained(model_id) + model = get_peft_model(model, config) + model = model.to(self.torch_device) + + dummy_input = self.prepare_inputs_for_testing() + + model.eval() + + # This should work + logits_unmerged = model(**dummy_input)[0] + + model = model.merge_and_unload() + logits_merged = model(**dummy_input)[0] + + self.check_tensors_approximately_equal(logits_unmerged, logits_merged) + + model = self.transformers_class.from_pretrained(model_id) + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + model = get_peft_model(model, config) + model = model.to(self.torch_device) + + prefixes = ["lora_A", "boft_R", "fourierft_spectrum", "hra_u", "hada_w1", "lokr_w1", "ia3_l", "oft_r"] + prefixes += ["vera_lambda_b"] + + for name, module in model.named_parameters(): + if any(prefix in name for prefix in prefixes): + module.data[0] = torch.nan + + with pytest.raises( + ValueError, match="NaNs detected in the merged weights. The adapter default seems to be broken" + ): + model = model.merge_and_unload(safe_merge=True) + + for name, module in model.named_parameters(): + if any(prefix in name for prefix in prefixes): + module.data[0] = torch.inf + + with pytest.raises( + ValueError, match="NaNs detected in the merged weights. The adapter default seems to be broken" + ): + model = model.merge_and_unload(safe_merge=True) + + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + @pytest.mark.xfail(strict=True) + def test_load_merge_and_unloaded_model(self, model_id, config_cls, config_kwargs): + # Saving and loading a quanto model that has been merged and unloaded does not work correctly. Here is the + # reason: Quanto requires its own save_pretrained method, which, among others, saves the quantization map. + # Without it, the model cannot be correctly loaded. To make use of this, we should thus use a quanto + # QuantizedModel instance instead of a PretrainedModel instance. However, the QuantizedModel instance cannot be + # used for anything else, e.g. it has no __call__ method. Therefore, we cannot use that in PEFT. Therefore, + # users need to pass the PretrainedModel instance to get_peft_model, thus we don't have the modified + # save_pretrained, thus loading the merged and unloaded model does not work. + from optimum.quanto import QuantizedModelForCausalLM + + _skip_if_merging_not_supported(model_id, config_cls) + torch.manual_seed(0) + + model = self.transformers_class.from_pretrained(model_id) + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + model = get_peft_model(model, config) + model = model.to(self.torch_device) + model = model.merge_and_unload() + model.eval() + + dummy_input = self.prepare_inputs_for_testing() + logits = model(**dummy_input)[0] + + # model is a transformers model + tmp_dirname = tempfile.mkdtemp() + # note: not using the context manager here because it fails on Windows CI for some reason + try: + model.save_pretrained(tmp_dirname) + # Carefuly: must use QuantizedModelForCausalLM.from_pretrained not AutoModelForCausalLM.from_pretrained + model_from_pretrained = QuantizedModelForCausalLM.from_pretrained(tmp_dirname).to(self.torch_device) + finally: + try: + shutil.rmtree(tmp_dirname) + except PermissionError: + # windows error + pass + + logits_merged_from_pretrained = model_from_pretrained(**dummy_input)[0] + self.check_tensors_approximately_equal(logits, logits_merged_from_pretrained) + + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_mixed_adapter_batches(self, model_id, config_cls, config_kwargs): + _skip_if_merging_not_supported(model_id, config_cls) + config_kwargs = set_init_weights_false(config_cls, config_kwargs) + self._test_mixed_adapter_batches(model_id, config_cls, config_kwargs) + + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_generate(self, model_id, config_cls, config_kwargs): + self._test_generate(model_id, config_cls, config_kwargs) + + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_generate_pos_args(self, model_id, config_cls, config_kwargs): + # positional args are supported for PeftModelForCausalLM + self._test_generate_pos_args(model_id, config_cls, config_kwargs, raises_err=False) + + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_merge_layers_fp16(self, model_id, config_cls, config_kwargs): + _skip_if_merging_not_supported(model_id, config_cls) + self._test_merge_layers_fp16(model_id, config_cls, config_kwargs) + + # this fails for a couple of methods (IA³, LoRA, prefix tuning) with segfault on GH CI + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_generate_half_prec(self, model_id, config_cls, config_kwargs): + self._test_generate_half_prec(model_id, config_cls, config_kwargs) + + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + @pytest.mark.skip("Quanto raises an error when trying to convert the dtype, skipping test.") + def test_prefix_tuning_half_prec_conversion(self, model_id, config_cls, config_kwargs): + self._test_prefix_tuning_half_prec_conversion(model_id, config_cls, config_kwargs) + + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_training_decoders(self, model_id, config_cls, config_kwargs): + self._test_training(model_id, config_cls, config_kwargs) + + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_training_decoders_layer_indexing(self, model_id, config_cls, config_kwargs): + self._test_training_layer_indexing(model_id, config_cls, config_kwargs) + + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_training_decoders_gradient_checkpointing(self, model_id, config_cls, config_kwargs): + self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs) + + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_inference_safetensors(self, model_id, config_cls, config_kwargs): + self._test_inference_safetensors(model_id, config_cls, config_kwargs) + + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_peft_model_device_map(self, model_id, config_cls, config_kwargs): + self._test_peft_model_device_map(model_id, config_cls, config_kwargs) + + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_delete_adapter(self, model_id, config_cls, config_kwargs): + self._test_delete_adapter(model_id, config_cls, config_kwargs) + + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_delete_inactive_adapter(self, model_id, config_cls, config_kwargs): + self._test_delete_inactive_adapter(model_id, config_cls, config_kwargs) + + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_adding_multiple_adapters_with_bias_raises(self, model_id, config_cls, config_kwargs): + self._test_adding_multiple_adapters_with_bias_raises(model_id, config_cls, config_kwargs) + + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_unload_adapter(self, model_id, config_cls, config_kwargs): + _skip_if_merging_not_supported(model_id, config_cls) + config_kwargs = set_init_weights_false(config_cls, config_kwargs) + self._test_unload_adapter(model_id, config_cls, config_kwargs) + + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_weighted_combination_of_adapters(self, model_id, config_cls, config_kwargs): + _skip_if_merging_not_supported(model_id, config_cls) + self._test_weighted_combination_of_adapters(model_id, config_cls, config_kwargs) + + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_training_prompt_learning_tasks(self, model_id, config_cls, config_kwargs): + self._test_training_prompt_learning_tasks(model_id, config_cls, config_kwargs) + + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_disable_adapter(self, model_id, config_cls, config_kwargs): + _skip_if_merging_not_supported(model_id, config_cls) + config_kwargs = set_init_weights_false(config_cls, config_kwargs) + self._test_disable_adapter(model_id, config_cls, config_kwargs) + + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_passing_input_embeds_works(self, model_id, config_cls, config_kwargs): + self._test_passing_input_embeds_works(self, model_id, config_cls, config_kwargs) + + # TODO: enable if/when deepcopy-ing is supported + @pytest.mark.skip("Quanto does not work (yet) with deepcopy-ing") + def test_lora_layer_replication(self): + model_id = "trl-internal-testing/tiny-random-LlamaForCausalLM" + config_kwargs = { + "target_modules": ["down_proj", "up_proj"], + "task_type": "CAUSAL_LM", + "lora_dropout": 0.0, + "layer_replication": [[0, 1], [0, 2], [1, 2]], + } + model = self.transformers_class.from_pretrained(model_id).to(self.torch_device) + config = LoraConfig( + base_model_name_or_path=model_id, + **config_kwargs, + ) + assert len(model.model.layers), "Expected 2 layers in original model." == 2 + model = get_peft_model(model, config) + layers = model.base_model.model.model.layers + assert len(layers) == 4, "Expected 4 layers in adapted model." + assert ( + layers[0].mlp.up_proj.base_layer.weight.data.storage().data_ptr() + == layers[1].mlp.up_proj.base_layer.weight.data.storage().data_ptr() + and layers[2].mlp.up_proj.base_layer.weight.data.storage().data_ptr() + == layers[3].mlp.up_proj.base_layer.weight.data.storage().data_ptr() + ), "Expected layers 0-1 and 2-3 to share weights" + assert ( + layers[0].mlp.up_proj.base_layer.weight.data.storage().data_ptr() + != layers[2].mlp.up_proj.base_layer.weight.data.storage().data_ptr() + ), "Expected layers 0 and 2 to have different weights" + assert ( + layers[0].mlp.up_proj.lora_A.default.weight.data.storage().data_ptr() + != layers[1].mlp.up_proj.lora_A.default.weight.data.storage().data_ptr() + and layers[2].mlp.up_proj.lora_A.default.weight.data.storage().data_ptr() + != layers[3].mlp.up_proj.lora_A.default.weight.data.storage().data_ptr() + ), "Expected all LoRA adapters to have distinct weights" + assert len([n for n, _ in model.named_parameters() if ".lora_A." in n]) == 8, ( + "Expected 8 LoRA adapters since we are adding one each for up and down." + ) + self._test_prepare_for_training(model_id, LoraConfig, config_kwargs) + self._test_generate(model_id, LoraConfig, config_kwargs) + + def test_prompt_learning_with_grouped_query_attention(self): + # See 1901, fixes a bug with handling GQA + model_id = "peft-internal-testing/tiny-dummy-qwen2" + base_model = AutoModelForCausalLM.from_pretrained(model_id) + peft_config = PrefixTuningConfig(num_virtual_tokens=10, task_type="CAUSAL_LM") + model = get_peft_model(base_model, peft_config) + x = torch.tensor([[1, 2, 3]]) + # does not raise + model(x) + + @pytest.mark.parametrize("model_id", MODELS_TO_TEST) + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) + def test_quanto_merge_conv2d(self, model_id, config_cls, config_kwargs): + _skip_if_merging_not_supported(model_id, config_cls) + torch.manual_seed(0) + + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + + config.target_modules = {"seq.0", "seq.2", "seq.4"} + config.task_type = None + + class ModelConv2D(nn.Module): + def __init__(self): + super().__init__() + self.seq = nn.Sequential( + nn.Conv2d(3, 8, 3), + nn.ReLU(), + nn.Conv2d(8, 8, 3), + nn.ReLU(), + nn.Conv2d(8, 8, 3), + nn.ReLU(), + nn.Flatten(), + nn.Linear(800, 64), + ) + + def forward(self, X): + return self.seq(X) + + model = ModelConv2D() + model = get_peft_model(model, config) + model = model.to(self.torch_device) + + dummy_input = torch.randn(5, 3, 16, 16).to(self.torch_device) + model.eval() + logits = model(dummy_input)[0] + + model.merge_adapter() + logits_merged = model(dummy_input)[0] + model.unmerge_adapter() + logits_unmerged = model(dummy_input)[0] + + model = model.merge_and_unload() + logits_merged_unloaded = model(dummy_input)[0] + + self.check_tensors_approximately_equal(logits, logits_merged, logits_unmerged, logits_merged_unloaded) + + +class TestPeftQuanto2bitModel(PeftCommonTester, BasePeftQuantoModelTester): + weights = "int2" + transformers_class = make_automodel_proxy(weights=weights) + min_correlation = 0.9 + tol = 0.3 + + +class TestPeftQuanto4bitModel(PeftCommonTester, BasePeftQuantoModelTester): + weights = "int4" + transformers_class = make_automodel_proxy(weights=weights) + min_correlation = 0.95 + tol = 1e-2 + + +class TestPeftQuanto8bitModel(PeftCommonTester, BasePeftQuantoModelTester): + weights = "int8" + transformers_class = make_automodel_proxy(weights=weights) + min_correlation = 0.95 + tol = 1e-2 diff --git a/tests/testing_common.py b/tests/testing_common.py index ce566db32f..3c4af3b54e 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -20,6 +20,7 @@ import shutil import tempfile import warnings +from contextlib import nullcontext from dataclasses import replace import pytest @@ -747,6 +748,10 @@ def _test_merge_layers(self, model_id, config_cls, config_kwargs): if model_id == "trl-internal-testing/tiny-Llama4ForCausalLM": # also getting larger errors here, not exactly sure why atol, rtol = 0.3, 0.01 + if quant_method := getattr(model, "quantization_method", None): + if quant_method.value == "quanto": + atol, rtol = 5e-3, 5e-3 + assert torch.allclose(logits, logits_merged, atol=atol, rtol=rtol) assert torch.allclose(logits, logits_unmerged, atol=atol, rtol=rtol) assert torch.allclose(logits, logits_merged_unloaded, atol=atol, rtol=rtol) @@ -936,6 +941,8 @@ def _test_mixed_adapter_batches(self, model_id, config_cls, config_kwargs): if config_cls not in (LoraConfig,): return pytest.skip(f"Mixed adapter batches not supported for {config_cls}") + from transformers.quantizers.quantizer_quanto import QuantoHfQuantizer + config = config_cls( base_model_name_or_path=model_id, **config_kwargs, @@ -955,18 +962,27 @@ def _test_mixed_adapter_batches(self, model_id, config_cls, config_kwargs): # ensure that we have at least 3 samples for this test dummy_input = {k: torch.cat([v for _ in range(3)]) for k, v in dummy_input.items()} - with torch.inference_mode(): + # Using quanto with inference model raises an error: + # > RuntimeError: Cannot set version_counter for inference tensor + # https://github.com/huggingface/optimum-quanto/issues/304 + # TODO: remove when/if this is fixed + if isinstance(getattr(model, "hf_quantizer", None), QuantoHfQuantizer): + inference_mode = nullcontext + else: + inference_mode = torch.inference_mode + + with inference_mode(): with model.disable_adapter(): output_base = model(**dummy_input)[0] logits_base = model.generate(**dummy_input, return_dict_in_generate=True, output_scores=True).scores[0] model.set_adapter("adapter0") - with torch.inference_mode(): + with inference_mode(): output_adapter0 = model(**dummy_input)[0] logits_adapter0 = model.generate(**dummy_input, return_dict_in_generate=True, output_scores=True).scores[0] model.set_adapter("adapter1") - with torch.inference_mode(): + with inference_mode(): output_adapter1 = model(**dummy_input)[0] logits_adapter1 = model.generate(**dummy_input, return_dict_in_generate=True, output_scores=True).scores[0] @@ -985,7 +1001,7 @@ def _test_mixed_adapter_batches(self, model_id, config_cls, config_kwargs): adapters = ["__base__", "adapter0", "adapter1"] dummy_input["adapter_names"] = [adapters[i % 3] for i in (range(len(dummy_input["input_ids"])))] - with torch.inference_mode(): + with inference_mode(): output_mixed = model(**dummy_input)[0] logits_mixed = model.generate(**dummy_input, return_dict_in_generate=True, output_scores=True).scores[0] @@ -1556,11 +1572,24 @@ def _test_delete_unknown_adapter_raises(self, model_id, config_cls, config_kwarg model.delete_adapter("unknown-adapter") def _test_unload_adapter(self, model_id, config_cls, config_kwargs): + from transformers.quantizers.quantizer_quanto import QuantoHfQuantizer + with hub_online_once(model_id): model = self.transformers_class.from_pretrained(model_id).to(self.torch_device) + model = model.to(self.torch_device) + + # Using quanto with inference model raises an error: + # > RuntimeError: Cannot set version_counter for inference tensor + # https://github.com/huggingface/optimum-quanto/issues/304 + # TODO: remove when/if this is fixed + if isinstance(getattr(model, "hf_quantizer", None), QuantoHfQuantizer): + inference_mode = nullcontext + else: + inference_mode = torch.inference_mode + num_params_base = len(model.state_dict()) dummy_input = self.prepare_inputs_for_testing() - with torch.inference_mode(): + with inference_mode(): logits_transformers = model(**dummy_input)[0] config = config_cls( @@ -1568,7 +1597,6 @@ def _test_unload_adapter(self, model_id, config_cls, config_kwargs): **config_kwargs, ) model = get_peft_model(model, config) - model = model.to(self.torch_device) if isinstance(config, PromptLearningConfig): # prompt learning does not support unloading @@ -1576,13 +1604,13 @@ def _test_unload_adapter(self, model_id, config_cls, config_kwargs): model = model.unload() else: self.perturb_trainable_token_weights_if_used(model, config_kwargs) - with torch.inference_mode(): + with inference_mode(): logits_with_adapter = model(**dummy_input)[0] model.eval() model = model.unload() num_params_unloaded = len(model.state_dict()) - with torch.inference_mode(): + with inference_mode(): logits_unload = model(**dummy_input)[0] # check that PEFT layers are completely removed