diff --git a/docker/transformers-all-latest-gpu/Dockerfile b/docker/transformers-all-latest-gpu/Dockerfile
index 85a3e63680cb..c96b4cc79b3b 100644
--- a/docker/transformers-all-latest-gpu/Dockerfile
+++ b/docker/transformers-all-latest-gpu/Dockerfile
@@ -44,6 +44,8 @@ RUN python3 -m pip install -U "itsdangerous<2.1.0"
RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/accelerate@main#egg=accelerate
+RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/peft@main#egg=peft
+
# Add bitsandbytes for mixed int8 testing
RUN python3 -m pip install --no-cache-dir bitsandbytes
diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 9d1c33900c10..adb0f475ee3c 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -19,6 +19,8 @@
title: Train with a script
- local: accelerate
title: Set up distributed training with 🤗 Accelerate
+ - local: peft
+ title: Load and train adapters with 🤗 PEFT
- local: model_sharing
title: Share your model
- local: transformers_agents
diff --git a/docs/source/en/peft.md b/docs/source/en/peft.md
new file mode 100644
index 000000000000..302b614e5f7b
--- /dev/null
+++ b/docs/source/en/peft.md
@@ -0,0 +1,216 @@
+
+
+# Load adapters with 🤗 PEFT
+
+[[open-in-colab]]
+
+[Parameter-Efficient Fine Tuning (PEFT)](https://huggingface.co/blog/peft) methods freeze the pretrained model parameters during fine-tuning and add a small number of trainable parameters (the adapters) on top of it. The adapters are trained to learn task-specific information. This approach has been shown to be very memory-efficient with lower compute usage while producing results comparable to a fully fine-tuned model.
+
+Adapters trained with PEFT are also usually an order of magnitude smaller than the full model, making it convenient to share, store, and load them.
+
+
+

+
The adapter weights for a OPTForCausalLM model stored on the Hub are only ~6MB compared to the full size of the model weights, which can be ~700MB.
+
+
+If you're interested in learning more about the 🤗 PEFT library, check out the [documentation](https://huggingface.co/docs/peft/index).
+
+## Setup
+
+Get started by installing 🤗 PEFT:
+
+```bash
+pip install peft
+```
+
+If you want to try out the brand new features, you might be interested in installing the library from source:
+
+```bash
+pip install git+https://github.com/huggingface/peft.git
+```
+
+## Supported PEFT models
+
+🤗 Transformers natively supports some PEFT methods, meaning you can load adapter weights stored locally or on the Hub and easily run or train them with a few lines of code. The following methods are supported:
+
+- [Low Rank Adapters](https://huggingface.co/docs/peft/conceptual_guides/lora)
+- [IA3](https://huggingface.co/docs/peft/conceptual_guides/ia3)
+- [AdaLoRA](https://arxiv.org/abs/2303.10512)
+
+If you want to use other PEFT methods, such as prompt learning or prompt tuning, or about the 🤗 PEFT library in general, please refer to the [documentation](https://huggingface.co/docs/peft/index).
+
+
+## Load a PEFT adapter
+
+To load and use a PEFT adapter model from 🤗 Transformers, make sure the Hub repository or local directory contains an `adapter_config.json` file and the adapter weights, as shown in the example image above. Then you can load the PEFT adapter model using the `AutoModelFor` class. For example, to load a PEFT adapter model for causal language modeling:
+
+1. specify the PEFT model id
+2. pass it to the [`AutoModelForCausalLM`] class
+
+```py
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+peft_model_id = "ybelkada/opt-350m-lora"
+model = AutoModelForCausalLM.from_pretrained(peft_model_id)
+```
+
+
+
+You can load a PEFT adapter with either an `AutoModelFor` class or the base model class like `OPTForCausalLM` or `LlamaForCausalLM`.
+
+
+
+You can also load a PEFT adapter by calling the `load_adapter` method:
+
+```py
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+model_id = "facebook/opt-350m"
+peft_model_id = "ybelkada/opt-350m-lora"
+
+model = AutoModelForCausalLM.from_pretrained(model_id)
+model.load_adapter(peft_model_id)
+```
+
+## Load in 8bit or 4bit
+
+The `bitsandbytes` integration supports 8bit and 4bit precision data types, which are useful for loading large models because it saves memory (see the `bitsandbytes` integration [guide](./quantization#bitsandbytes-integration) to learn more). Add the `load_in_8bit` or `load_in_4bit` parameters to [`~PreTrainedModel.from_pretrained`] and set `device_map="auto"` to effectively distribute the model to your hardware:
+
+```py
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+peft_model_id = "ybelkada/opt-350m-lora"
+model = AutoModelForCausalLM.from_pretrained(peft_model_id, device_map="auto", load_in_8bit=True)
+```
+
+## Add a new adapter
+
+You can use [`~peft.PeftModel.add_adapter`] to add a new adapter to a model with an existing adapter as long as the new adapter is the same type as the current one. For example, if you have an existing LoRA adapter attached to a model:
+
+```py
+from transformers import AutoModelForCausalLM, OPTForCausalLM, AutoTokenizer
+from peft import PeftConfig
+
+model_id = "facebook/opt-350m"
+model = AutoModelForCausalLM.from_pretrained(model_id)
+
+lora_config = LoraConfig(
+ target_modules=["q_proj", "k_proj"],
+ init_lora_weights=False
+)
+
+model.add_adapter(lora_config, adapter_name="adapter_1")
+```
+
+To add a new adapter:
+
+```py
+# attach new adapter with same config
+model.add_adapter(lora_config, adapter_name="adapter_2")
+```
+
+Now you can use [`~peft.PeftModel.set_adapter`] to set which adapter to use:
+
+```py
+# use adapter_1
+model.set_adapter("adapter_1")
+output = model.generate(**inputs)
+print(tokenizer.decode(output_disabled[0], skip_special_tokens=True))
+
+# use adapter_2
+model.set_adapter("adapter_2")
+output_enabled = model.generate(**inputs)
+print(tokenizer.decode(output_enabled[0], skip_special_tokens=True))
+```
+
+## Enable and disable adapters
+
+Once you've added an adapter to a model, you can enable or disable the adapter module. To enable the adapter module:
+
+```py
+from transformers import AutoModelForCausalLM, OPTForCausalLM, AutoTokenizer
+from peft import PeftConfig
+
+model_id = "facebook/opt-350m"
+adapter_model_id = "ybelkada/opt-350m-lora"
+tokenizer = AutoTokenizer.from_pretrained(model_id)
+text = "Hello"
+inputs = tokenizer(text, return_tensors="pt")
+
+model = AutoModelForCausalLM.from_pretrained(model_id)
+peft_config = PeftConfig.from_pretrained(adapter_model_id)
+
+# to initiate with random weights
+peft_config.init_lora_weights = False
+
+model.add_adapter(peft_config)
+model.enable_adapters()
+output = model.generate(**inputs)
+```
+
+To disable the adapter module:
+
+```py
+model.disable_adapters()
+output = model.generate(**inputs)
+```
+
+## Train a PEFT adapter
+
+PEFT adapters are supported by the [`Trainer`] class so that you can train an adapter for your specific use case. It only requires adding a few more lines of code. For example, to train a LoRA adapter:
+
+
+
+If you aren't familiar with fine-tuning a model with [`Trainer`], take a look at the [Fine-tune a pretrained model](training) tutorial.
+
+
+
+1. Define your adapter configuration with the task type and hyperparameters (see [`~peft.LoraConfig`] for more details about what the hyperparameters do).
+
+```py
+from peft import LoraConfig
+
+peft_config = LoraConfig(
+ lora_alpha=16,
+ lora_dropout=0.1,
+ r=64,
+ bias="none",
+ task_type="CAUSAL_LM",
+)
+```
+
+2. Add adapter to the model.
+
+```py
+model.add_adapter(peft_config)
+```
+
+3. Now you can pass the model to [`Trainer`]!
+
+```py
+trainer = Trainer(model=model, ...)
+trainer.train()
+```
+
+To save your trained adapter and load it back:
+
+```py
+model.save_pretrained(save_dir)
+model = AutoModelForCausalLM.from_pretrained(save_dir)
+```
+
+
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index b06254a1c0ad..0a9bc3257da1 100644
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -111,6 +111,8 @@
"is_tensorboard_available",
"is_wandb_available",
],
+ "lib_integrations": [],
+ "lib_integrations.peft": [],
"modelcard": ["ModelCard"],
"modeling_tf_pytorch_utils": [
"convert_tf_weight_name_to_pt_weight_name",
diff --git a/src/transformers/lib_integrations/__init__.py b/src/transformers/lib_integrations/__init__.py
new file mode 100644
index 000000000000..0a2b0329f696
--- /dev/null
+++ b/src/transformers/lib_integrations/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from .peft import PeftAdapterMixin
diff --git a/src/transformers/lib_integrations/peft/__init__.py b/src/transformers/lib_integrations/peft/__init__.py
new file mode 100644
index 000000000000..a6c1f0afd7e3
--- /dev/null
+++ b/src/transformers/lib_integrations/peft/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .peft_mixin import PeftAdapterMixin
diff --git a/src/transformers/lib_integrations/peft/peft_mixin.py b/src/transformers/lib_integrations/peft/peft_mixin.py
new file mode 100644
index 000000000000..82afe8db00e3
--- /dev/null
+++ b/src/transformers/lib_integrations/peft/peft_mixin.py
@@ -0,0 +1,390 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import inspect
+from typing import Optional
+
+from ...utils import (
+ check_peft_version,
+ find_adapter_config_file,
+ is_accelerate_available,
+ is_peft_available,
+ logging,
+)
+
+
+if is_accelerate_available():
+ from accelerate import dispatch_model
+ from accelerate.utils import get_balanced_memory, infer_auto_device_map
+
+
+logger = logging.get_logger(__name__)
+
+
+class PeftAdapterMixin:
+ """
+ A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
+ more details about adapters and injecting them on a transformer-based model, check out the documentation of PEFT
+ library: https://huggingface.co/docs/peft/index
+
+ Currently supported PEFT methods are all non-prefix tuning methods. Below is the list of supported PEFT methods
+ that anyone can load, train and run with this mixin class:
+ - Low Rank Adapters (LoRA): https://huggingface.co/docs/peft/conceptual_guides/lora
+ - IA3: https://huggingface.co/docs/peft/conceptual_guides/ia3
+ - AdaLora: https://arxiv.org/abs/2303.10512
+
+ Other PEFT models such as prompt tuning, prompt learning are out of scope as these adapters are not "injectable"
+ into a torch module. For using these methods, please refer to the usage guide of PEFT library.
+
+ With this mixin, if the correct PEFT version is installed, it is possible to:
+
+ - Load an adapter stored on a local path or in a remote Hub repository, and inject it in the model
+ - Attach new adapters in the model and train them with Trainer or by your own.
+ - Attach multiple adapters and iteratively activate / deactivate them
+ - Activate / deactivate all adapters from the model.
+ - Get the `state_dict` of the active adapter.
+ """
+
+ _hf_peft_config_loaded = False
+
+ def load_adapter(
+ self,
+ peft_model_id: str,
+ adapter_name: Optional[str] = None,
+ revision: Optional[str] = None,
+ token: Optional[str] = None,
+ device_map: Optional[str] = "auto",
+ max_memory: Optional[str] = None,
+ offload_folder: Optional[str] = None,
+ offload_index: Optional[int] = None,
+ ) -> None:
+ """
+ Load adapter weights from file or remote Hub folder. If you are not familiar with adapters and PEFT methods, we
+ invite you to read more about them on PEFT official documentation: https://huggingface.co/docs/peft
+
+ Requires peft as a backend to load the adapter weights.
+
+ Args:
+ peft_model_id (`str`):
+ The identifier of the model to look for on the Hub, or a local path to the saved adapter config file
+ and adapter weights.
+ adapter_name (`str`, *optional*):
+ The adapter name to use. If not set, will use the default adapter.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+
+
+
+ To test a pull request you made on the Hub, you can pass `revision="refs/pr/".
+
+
+
+ token (`str`, `optional`):
+ Whether to use authentication token to load the remote folder. Userful to load private repositories
+ that are on HuggingFace Hub. You might need to call `huggingface-cli login` and paste your tokens to
+ cache it.
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]` or `int` or `torch.device`, *optional*):
+ A map that specifies where each submodule should go. It doesn't need to be refined to each
+ parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
+ same device. If we only pass the device (*e.g.*, `"cpu"`, `"cuda:1"`, `"mps"`, or a GPU ordinal rank
+ like `1`) on which the model will be allocated, the device map will map the entire model to this
+ device. Passing `device_map = 0` means put the whole model on GPU 0.
+
+ To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
+ more information about each option see [designing a device
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
+ max_memory (`Dict`, *optional*):
+ A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
+ GPU and the available CPU RAM if unset.
+ offload_folder (`str` or `os.PathLike`, `optional`):
+ If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
+ offload_index (`int`, `optional`):
+ `offload_index` argument to be passed to `accelerate.dispatch_model` method.
+ """
+ check_peft_version(min_version="0.4.0")
+
+ adapter_name = adapter_name if adapter_name is not None else "default"
+
+ from peft import PeftConfig, inject_adapter_in_model, load_peft_weights
+ from peft.utils import set_peft_model_state_dict
+
+ if not self._hf_peft_config_loaded:
+ self._hf_peft_config_loaded = True
+ elif adapter_name in self.peft_config:
+ raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")
+
+ adapter_config_file = find_adapter_config_file(
+ peft_model_id,
+ revision=revision,
+ token=token,
+ )
+
+ if adapter_config_file is None:
+ raise ValueError(
+ f"adapter model file not found in {peft_model_id}. Make sure you are passing the correct path to the "
+ "adapter model."
+ )
+
+ loaded_peft_config = PeftConfig.from_pretrained(
+ peft_model_id,
+ revision=revision,
+ use_auth_token=token,
+ )
+
+ # Create and add fresh new adapters into the model.
+ inject_adapter_in_model(loaded_peft_config, self, adapter_name)
+
+ adapter_state_dict = load_peft_weights(peft_model_id, revision=revision, use_auth_token=token)
+
+ # We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility
+ processed_adapter_state_dict = {}
+ prefix = "base_model.model."
+ for key, value in adapter_state_dict.items():
+ if key.startswith(prefix):
+ new_key = key[len(prefix) :]
+ else:
+ new_key = key
+ processed_adapter_state_dict[new_key] = value
+
+ # Load state dict
+ incompatible_keys = set_peft_model_state_dict(self, processed_adapter_state_dict, adapter_name)
+
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ if hasattr(incompatible_keys, "unexpected_keys") and len(incompatible_keys.unexpected_keys) > 0:
+ logger.warning(
+ f"Loading adapter weights from {peft_model_id} led to unexpected keys not found in the model: "
+ f" {incompatible_keys.unexpected_keys}. "
+ )
+
+ # Re-dispatch model and hooks in case the model is offloaded to CPU / Disk.
+ if (
+ (getattr(self, "hf_device_map", None) is not None)
+ and (len(set(self.hf_device_map.values()).intersection({"cpu", "disk"})) > 0)
+ and len(self.peft_config) == 1
+ ):
+ self._dispatch_accelerate_model(
+ device_map=device_map,
+ max_memory=max_memory,
+ offload_folder=offload_folder,
+ offload_index=offload_index,
+ )
+
+ def add_adapter(self, adapter_config, adapter_name: Optional[str] = None) -> None:
+ r"""
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
+ official documentation: https://huggingface.co/docs/peft
+
+ Adds a fresh new adapter to the current model for training purpose. If no adapter name is passed, a default
+ name is assigned to the adapter to follow the convention of PEFT library (in PEFT we use "default" as the
+ default adapter name).
+
+ Args:
+ adapter_config (`~peft.PeftConfig`):
+ The configuration of the adapter to add, supported adapters are non-prefix tuning and adaption prompts
+ methods
+ adapter_name (`str`, *optional*, defaults to `"default"`):
+ The name of the adapter to add. If no name is passed, a default name is assigned to the adapter.
+ """
+ check_peft_version(min_version="0.4.0")
+
+ from peft import PeftConfig, inject_adapter_in_model
+
+ adapter_name = adapter_name or "default"
+
+ if not self._hf_peft_config_loaded:
+ self._hf_peft_config_loaded = True
+ elif adapter_name in self.peft_config:
+ raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")
+
+ if not isinstance(adapter_config, PeftConfig):
+ raise ValueError(
+ f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead."
+ )
+
+ inject_adapter_in_model(adapter_config, self, adapter_name)
+
+ self.set_adapter(adapter_name)
+
+ def set_adapter(self, adapter_name: str) -> None:
+ """
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
+ official documentation: https://huggingface.co/docs/peft
+
+ Sets a specific adapter by forcing the model to use a that adapter and disable the other adapters.
+
+ Args:
+ adapter_name (`str`):
+ The name of the adapter to set.
+ """
+ check_peft_version(min_version="0.4.0")
+ if not self._hf_peft_config_loaded:
+ raise ValueError("No adapter loaded. Please load an adapter first.")
+ elif adapter_name not in self.peft_config:
+ raise ValueError(
+ f"Adapter with name {adapter_name} not found. Please pass the correct adapter name among {list(self.peft_config.keys())}"
+ )
+
+ from peft.tuners.tuners_utils import BaseTunerLayer
+
+ _adapters_has_been_set = False
+
+ for _, module in self.named_modules():
+ if isinstance(module, BaseTunerLayer):
+ module.active_adapter = adapter_name
+ _adapters_has_been_set = True
+
+ if not _adapters_has_been_set:
+ raise ValueError(
+ "Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters."
+ )
+
+ def disable_adapters(self) -> None:
+ r"""
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
+ official documentation: https://huggingface.co/docs/peft
+
+ Disable all adapters that are attached to the model. This leads to inferring with the base model only.
+ """
+ check_peft_version(min_version="0.4.0")
+
+ if not self._hf_peft_config_loaded:
+ raise ValueError("No adapter loaded. Please load an adapter first.")
+
+ from peft.tuners.tuners_utils import BaseTunerLayer
+
+ for _, module in self.named_modules():
+ if isinstance(module, BaseTunerLayer):
+ module.disable_adapters = True
+
+ def enable_adapters(self) -> None:
+ """
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
+ official documentation: https://huggingface.co/docs/peft
+
+ Enable adapters that are attached to the model. The model will use `self.active_adapter()`
+ """
+ check_peft_version(min_version="0.4.0")
+
+ if not self._hf_peft_config_loaded:
+ raise ValueError("No adapter loaded. Please load an adapter first.")
+
+ from peft.tuners.tuners_utils import BaseTunerLayer
+
+ for _, module in self.named_modules():
+ if isinstance(module, BaseTunerLayer):
+ module.disable_adapters = False
+
+ def active_adapter(self) -> str:
+ """
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
+ official documentation: https://huggingface.co/docs/peft
+
+ Gets the current active adapter of the model.
+ """
+ check_peft_version(min_version="0.4.0")
+
+ if not is_peft_available():
+ raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.")
+
+ if not self._hf_peft_config_loaded:
+ raise ValueError("No adapter loaded. Please load an adapter first.")
+
+ from peft.tuners.tuners_utils import BaseTunerLayer
+
+ for _, module in self.named_modules():
+ if isinstance(module, BaseTunerLayer):
+ return module.active_adapter
+
+ def get_adapter_state_dict(self, adapter_name: Optional[str] = None) -> dict:
+ """
+ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
+ official documentation: https://huggingface.co/docs/peft
+
+ Gets the adapter state dict that should only contain the weights tensors of the specified adapter_name adapter.
+ If no adapter_name is passed, the active adapter is used.
+
+ Args:
+ adapter_name (`str`, *optional*):
+ The name of the adapter to get the state dict from. If no name is passed, the active adapter is used.
+ """
+ check_peft_version(min_version="0.4.0")
+
+ if not self._hf_peft_config_loaded:
+ raise ValueError("No adapter loaded. Please load an adapter first.")
+
+ from peft import get_peft_model_state_dict
+
+ if adapter_name is None:
+ adapter_name = self.active_adapter()
+
+ adapter_state_dict = get_peft_model_state_dict(self, adapter_name=adapter_name)
+ return adapter_state_dict
+
+ def _dispatch_accelerate_model(
+ self,
+ device_map: str,
+ max_memory: Optional[int] = None,
+ offload_folder: Optional[str] = None,
+ offload_index: Optional[int] = None,
+ ) -> None:
+ """
+ Optionnal re-dispatch the model and attach new hooks to the model in case the model has been loaded with
+ accelerate (i.e. with `device_map=xxx`)
+
+ Args:
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]` or `int` or `torch.device`, *optional*):
+ A map that specifies where each submodule should go. It doesn't need to be refined to each
+ parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
+ same device. If we only pass the device (*e.g.*, `"cpu"`, `"cuda:1"`, `"mps"`, or a GPU ordinal rank
+ like `1`) on which the model will be allocated, the device map will map the entire model to this
+ device. Passing `device_map = 0` means put the whole model on GPU 0.
+
+ To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
+ more information about each option see [designing a device
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
+ max_memory (`Dict`, *optional*):
+ A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
+ GPU and the available CPU RAM if unset.
+ offload_folder (`str` or `os.PathLike`, *optional*):
+ If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
+ offload_index (`int`, *optional*):
+ The offload_index argument to be passed to `accelerate.dispatch_model` method.
+ """
+ dispatch_model_kwargs = {}
+ # Safety checker for previous `accelerate` versions
+ # `offload_index` was introduced in https://github.com/huggingface/accelerate/pull/873/
+ if "offload_index" in inspect.signature(dispatch_model).parameters:
+ dispatch_model_kwargs["offload_index"] = offload_index
+
+ no_split_module_classes = self._no_split_modules
+
+ if device_map != "sequential":
+ max_memory = get_balanced_memory(
+ self,
+ max_memory=max_memory,
+ no_split_module_classes=no_split_module_classes,
+ low_zero=(device_map == "balanced_low_0"),
+ )
+ if isinstance(device_map, str):
+ device_map = infer_auto_device_map(
+ self, max_memory=max_memory, no_split_module_classes=no_split_module_classes
+ )
+ dispatch_model(
+ self,
+ device_map=device_map,
+ offload_dir=offload_folder,
+ **dispatch_model_kwargs,
+ )
diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py
index 43f9a434fa8f..ea2d70ec4a8b 100644
--- a/src/transformers/modeling_utils.py
+++ b/src/transformers/modeling_utils.py
@@ -38,6 +38,7 @@
from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
from .dynamic_module_utils import custom_object_save
from .generation import GenerationConfig, GenerationMixin
+from .lib_integrations import PeftAdapterMixin
from .pytorch_utils import ( # noqa: F401
Conv1D,
apply_chunking_to_forward,
@@ -48,6 +49,8 @@
prune_linear_layer,
)
from .utils import (
+ ADAPTER_SAFE_WEIGHTS_NAME,
+ ADAPTER_WEIGHTS_NAME,
DUMMY_INPUTS,
FLAX_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME,
@@ -68,6 +71,7 @@
is_bitsandbytes_available,
is_offline_mode,
is_optimum_available,
+ is_peft_available,
is_remote_url,
is_safetensors_available,
is_torch_tpu_available,
@@ -123,6 +127,9 @@ def is_fsdp_enabled_and_dist_rank_0():
else:
IS_SAGEMAKER_MP_POST_1_10 = False
+if is_peft_available():
+ from .utils import find_adapter_config_file
+
@contextmanager
def no_init_weights(_enable=True):
@@ -1039,7 +1046,7 @@ def floating_point_ops(
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)
-class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin):
+class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin):
r"""
Base class for all models.
@@ -1738,6 +1745,7 @@ def save_pretrained(
safe_serialization: bool = False,
variant: Optional[str] = None,
token: Optional[Union[str, bool]] = None,
+ save_peft_format: bool = True,
**kwargs,
):
"""
@@ -1780,6 +1788,10 @@ def save_pretrained(
token (`str` or `bool`, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
+ save_peft_format (`bool`, *optional*, defaults to `True`):
+ For backward compatibility with PEFT library, in case adapter weights are attached to the model, all
+ keys of the state dict of adapters needs to be pre-pended with `base_model.model`. Advanced users can
+ disable this behaviours by setting `save_peft_format` to `False`.
kwargs (`Dict[str, Any]`, *optional*):
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
@@ -1847,12 +1859,33 @@ def save_pretrained(
if self._auto_class is not None:
custom_object_save(self, save_directory, config=self.config)
+ _hf_peft_config_loaded = getattr(model_to_save, "_hf_peft_config_loaded", False)
+
# Save the config
if is_main_process:
- model_to_save.config.save_pretrained(save_directory)
+ if not _hf_peft_config_loaded:
+ model_to_save.config.save_pretrained(save_directory)
if self.can_generate():
model_to_save.generation_config.save_pretrained(save_directory)
+ if _hf_peft_config_loaded:
+ logger.info(
+ "Detected adapters on the model, saving the model in the PEFT format, only adapter weights will be saved."
+ )
+ state_dict = model_to_save.get_adapter_state_dict()
+
+ if save_peft_format:
+ logger.info(
+ "To match the expected format of the PEFT library, all keys of the state dict of adapters will be pre-pended with `base_model.model`."
+ )
+ peft_state_dict = {}
+ for key, value in state_dict.items():
+ peft_state_dict[f"base_model.model.{key}"] = value
+ state_dict = peft_state_dict
+
+ current_peft_config = self.peft_config[self.active_adapter()]
+ current_peft_config.save_pretrained(save_directory)
+
# Save the model
if state_dict is None:
state_dict = model_to_save.state_dict()
@@ -1907,8 +1940,11 @@ def save_pretrained(
)
# Shard the model if it is too big.
- weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
- weights_name = _add_variant(weights_name, variant)
+ if not _hf_peft_config_loaded:
+ weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
+ weights_name = _add_variant(weights_name, variant)
+ else:
+ weights_name = ADAPTER_SAFE_WEIGHTS_NAME if safe_serialization else ADAPTER_WEIGHTS_NAME
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name)
@@ -2295,6 +2331,8 @@ def from_pretrained(
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)
variant = kwargs.pop("variant", None)
+ _adapter_model_path = kwargs.pop("_adapter_model_path", None)
+ adapter_name = kwargs.pop("adapter_name", "default")
if is_fsdp_enabled():
low_cpu_mem_usage = True
@@ -2323,6 +2361,29 @@ def from_pretrained(
" ignored."
)
+ if is_peft_available() and _adapter_model_path is None:
+ maybe_adapter_model_path = find_adapter_config_file(
+ pretrained_model_name_or_path,
+ revision=revision,
+ subfolder=subfolder,
+ token=token,
+ commit_hash=commit_hash,
+ )
+ elif is_peft_available() and _adapter_model_path is not None:
+ maybe_adapter_model_path = _adapter_model_path
+ else:
+ maybe_adapter_model_path = None
+
+ has_adapter_config = maybe_adapter_model_path is not None
+
+ if has_adapter_config:
+ if _adapter_model_path is not None:
+ adapter_model_id = _adapter_model_path
+ else:
+ with open(maybe_adapter_model_path, "r", encoding="utf-8") as f:
+ adapter_model_id = pretrained_model_name_or_path
+ pretrained_model_name_or_path = json.load(f)["base_model_name_or_path"]
+
# change device_map into a map if we passed an int, a str or a torch.device
if isinstance(device_map, torch.device):
device_map = {"": device_map}
@@ -3153,6 +3214,14 @@ def from_pretrained(
if quantization_method_from_config == QuantizationMethod.GPTQ:
model = quantizer.post_init_model(model)
+ if has_adapter_config:
+ model.load_adapter(
+ adapter_model_id,
+ adapter_name=adapter_name,
+ revision=revision,
+ token=token,
+ )
+
if output_loading_info:
if loading_info is None:
loading_info = {
diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py
index a5e4fcbb90ee..fc58170b75ed 100644
--- a/src/transformers/models/auto/auto_factory.py
+++ b/src/transformers/models/auto/auto_factory.py
@@ -15,13 +15,14 @@
"""Factory function to build auto-model classes."""
import copy
import importlib
+import json
import os
import warnings
from collections import OrderedDict
from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
-from ...utils import copy_func, logging, requires_backends
+from ...utils import copy_func, find_adapter_config_file, is_peft_available, logging, requires_backends
from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings
@@ -469,6 +470,24 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
if token is not None:
hub_kwargs["token"] = token
+ if is_peft_available():
+ revision = kwargs.get("revision", None)
+ subfolder = kwargs.get("subfolder", None)
+
+ maybe_adapter_path = find_adapter_config_file(
+ pretrained_model_name_or_path,
+ revision=revision,
+ token=use_auth_token,
+ subfolder=subfolder,
+ )
+
+ if maybe_adapter_path is not None:
+ with open(maybe_adapter_path, "r", encoding="utf-8") as f:
+ adapter_config = json.load(f)
+
+ kwargs["_adapter_model_path"] = pretrained_model_name_or_path
+ pretrained_model_name_or_path = adapter_config["base_model_name_or_path"]
+
if not isinstance(config, PretrainedConfig):
kwargs_orig = copy.deepcopy(kwargs)
# ensure not to pollute the config object with torch_dtype="auto" - since it's
diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py
index b1d6ec4bdac3..746089b4e5cf 100755
--- a/src/transformers/pipelines/__init__.py
+++ b/src/transformers/pipelines/__init__.py
@@ -34,8 +34,10 @@
from ..tokenization_utils import PreTrainedTokenizer
from ..utils import (
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
+ find_adapter_config_file,
is_kenlm_available,
is_offline_mode,
+ is_peft_available,
is_pyctcdecode_available,
is_tf_available,
is_torch_available,
@@ -721,6 +723,21 @@ def pipeline(
config = AutoConfig.from_pretrained(config, _from_pipeline=task, **hub_kwargs, **model_kwargs)
hub_kwargs["_commit_hash"] = config._commit_hash
elif config is None and isinstance(model, str):
+ # Check for an adapter file in the model path if PEFT is available
+ if is_peft_available():
+ subfolder = hub_kwargs.get("subfolder", None)
+ maybe_adapter_path = find_adapter_config_file(
+ model,
+ revision=revision,
+ token=use_auth_token,
+ subfolder=subfolder,
+ )
+
+ if maybe_adapter_path is not None:
+ with open(maybe_adapter_path, "r", encoding="utf-8") as f:
+ adapter_config = json.load(f)
+ model = adapter_config["base_model_name_or_path"]
+
config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs)
hub_kwargs["_commit_hash"] = config._commit_hash
diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py
index 18d5880a1729..9a59327af61f 100644
--- a/src/transformers/testing_utils.py
+++ b/src/transformers/testing_utils.py
@@ -69,6 +69,7 @@
is_onnx_available,
is_optimum_available,
is_pandas_available,
+ is_peft_available,
is_phonemizer_available,
is_pyctcdecode_available,
is_pytesseract_available,
@@ -369,6 +370,16 @@ def require_torch(test_case):
return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
+def require_peft(test_case):
+ """
+ Decorator marking a test that requires PEFT.
+
+ These tests are skipped when PEFT isn't installed.
+
+ """
+ return unittest.skipUnless(is_peft_available(), "test requires PEFT")(test_case)
+
+
def require_torchvision(test_case):
"""
Decorator marking a test that requires Torchvision.
diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py
index 4f4982553687..85d4fd5a5252 100755
--- a/src/transformers/trainer.py
+++ b/src/transformers/trainer.py
@@ -396,7 +396,7 @@ def __init__(
)
# At this stage the model is already loaded
- if getattr(model, "is_quantized", False):
+ if getattr(model, "is_quantized", False) and not getattr(model, "_hf_peft_config_loaded", False):
if getattr(model, "_is_quantized_training_enabled", False):
logger.info(
"The model is quantized. To train this model you need to add additional modules"
diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py
index 83b0128fbc58..422464ce5bd0 100644
--- a/src/transformers/utils/__init__.py
+++ b/src/transformers/utils/__init__.py
@@ -179,13 +179,17 @@
requires_backends,
torch_only_method,
)
+from .peft_utils import (
+ ADAPTER_CONFIG_NAME,
+ ADAPTER_SAFE_WEIGHTS_NAME,
+ ADAPTER_WEIGHTS_NAME,
+ check_peft_version,
+ find_adapter_config_file,
+)
WEIGHTS_NAME = "pytorch_model.bin"
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
-ADAPTER_CONFIG_NAME = "adapter_config.json"
-ADAPTER_WEIGHTS_NAME = "adapter_model.bin"
-ADAPTER_SAFE_WEIGHTS_NAME = "adapter_model.safetensors"
TF2_WEIGHTS_NAME = "tf_model.h5"
TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json"
TF_WEIGHTS_NAME = "model.ckpt"
diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py
index 54ed4030a2b6..8172ba23401f 100644
--- a/src/transformers/utils/import_utils.py
+++ b/src/transformers/utils/import_utils.py
@@ -1001,6 +1001,11 @@ def is_jieba_available():
jieba`. Please note that you may need to restart your runtime after installation.
"""
+PEFT_IMPORT_ERROR = """
+{0} requires the peft library but it was not found in your environment. You can install it with pip: `pip install
+peft`. Please note that you may need to restart your runtime after installation.
+"""
+
BACKENDS_MAPPING = OrderedDict(
[
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
@@ -1034,6 +1039,7 @@ def is_jieba_available():
("decord", (is_decord_available, DECORD_IMPORT_ERROR)),
("cython", (is_cython_available, CYTHON_IMPORT_ERROR)),
("jieba", (is_jieba_available, JIEBA_IMPORT_ERROR)),
+ ("peft", (is_peft_available, PEFT_IMPORT_ERROR)),
]
)
diff --git a/src/transformers/utils/peft_utils.py b/src/transformers/utils/peft_utils.py
new file mode 100644
index 000000000000..1dd8e14dbb86
--- /dev/null
+++ b/src/transformers/utils/peft_utils.py
@@ -0,0 +1,98 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import importlib
+import os
+from typing import Optional
+
+from packaging import version
+
+from .hub import cached_file
+from .import_utils import is_peft_available
+
+
+ADAPTER_CONFIG_NAME = "adapter_config.json"
+ADAPTER_WEIGHTS_NAME = "adapter_model.bin"
+ADAPTER_SAFE_WEIGHTS_NAME = "adapter_model.safetensors"
+
+
+def find_adapter_config_file(
+ model_id: str,
+ revision: str = None,
+ subfolder: str = None,
+ token: Optional[str] = None,
+ commit_hash: Optional[str] = None,
+) -> Optional[str]:
+ r"""
+ Simply checks if the model stored on the Hub or locally is an adapter model or not, return the path the the adapter
+ config file if it is, None otherwise.
+
+ Args:
+ model_id (`str`):
+ The identifier of the model to look for, can be either a local path or an id to the repository on the Hub.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+
+
+
+ To test a pull request you made on the Hub, you can pass `revision="refs/pr/".
+
+
+
+ subfolder (`str`, *optional*, defaults to `""`):
+ In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
+ specify the folder name here.
+ token (`str`, `optional`):
+ Whether to use authentication token to load the remote folder. Userful to load private repositories that
+ are on HuggingFace Hub. You might need to call `huggingface-cli login` and paste your tokens to cache it.
+ """
+ adapter_cached_filename = None
+ if os.path.isdir(model_id):
+ list_remote_files = os.listdir(model_id)
+ if ADAPTER_CONFIG_NAME in list_remote_files:
+ adapter_cached_filename = os.path.join(model_id, ADAPTER_CONFIG_NAME)
+ else:
+ adapter_cached_filename = cached_file(
+ model_id,
+ ADAPTER_CONFIG_NAME,
+ revision=revision,
+ token=token,
+ _commit_hash=commit_hash,
+ subfolder=subfolder,
+ _raise_exceptions_for_missing_entries=False,
+ _raise_exceptions_for_connection_errors=False,
+ )
+
+ return adapter_cached_filename
+
+
+def check_peft_version(min_version: str) -> None:
+ r"""
+ Checks if the version of PEFT is compatible.
+
+ Args:
+ version (`str`):
+ The version of PEFT to check against.
+ """
+ if not is_peft_available():
+ raise ValueError("PEFT is not installed. Please install it with `pip install peft`")
+
+ is_peft_version_compatible = version.parse(importlib.metadata.version("peft")) <= version.parse(min_version)
+
+ if not is_peft_version_compatible:
+ raise ValueError(
+ f"The version of PEFT you are using is not compatible, please use a version that is greater"
+ f" than {min_version}"
+ )
diff --git a/tests/peft_integration/test_peft_integration.py b/tests/peft_integration/test_peft_integration.py
new file mode 100644
index 000000000000..b80912607c02
--- /dev/null
+++ b/tests/peft_integration/test_peft_integration.py
@@ -0,0 +1,236 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import tempfile
+import unittest
+
+from transformers import AutoModelForCausalLM, OPTForCausalLM
+from transformers.testing_utils import require_peft, require_torch, require_torch_gpu, slow, torch_device
+from transformers.utils import is_torch_available
+
+
+if is_torch_available():
+ import torch
+
+
+@require_peft
+@require_torch
+class PeftTesterMixin:
+ peft_test_model_ids = ("peft-internal-testing/tiny-OPTForCausalLM-lora",)
+ transformers_test_model_ids = ("hf-internal-testing/tiny-random-OPTForCausalLM",)
+ transformers_test_model_classes = (AutoModelForCausalLM, OPTForCausalLM)
+
+
+# TODO: run it with CI after PEFT release.
+@slow
+class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
+ """
+ A testing suite that makes sure that the PeftModel class is correctly integrated into the transformers library.
+ """
+
+ def _check_lora_correctly_converted(self, model):
+ """
+ Utility method to check if the model has correctly adapters injected on it.
+ """
+ from peft.tuners.tuners_utils import BaseTunerLayer
+
+ is_peft_loaded = False
+
+ for _, m in model.named_modules():
+ if isinstance(m, BaseTunerLayer):
+ is_peft_loaded = True
+ break
+
+ return is_peft_loaded
+
+ def test_peft_from_pretrained(self):
+ """
+ Simple test that tests the basic usage of PEFT model through `from_pretrained`.
+ This checks if we pass a remote folder that contains an adapter config and adapter weights, it
+ should correctly load a model that has adapters injected on it.
+ """
+ for model_id in self.peft_test_model_ids:
+ for transformers_class in self.transformers_test_model_classes:
+ peft_model = transformers_class.from_pretrained(model_id).to(torch_device)
+
+ self.assertTrue(self._check_lora_correctly_converted(peft_model))
+ self.assertTrue(peft_model._hf_peft_config_loaded)
+ # dummy generation
+ _ = peft_model.generate(input_ids=torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device))
+
+ def test_peft_state_dict(self):
+ """
+ Simple test that checks if the returned state dict of `get_adapter_state_dict()` method contains
+ the expected keys.
+ """
+ for model_id in self.peft_test_model_ids:
+ for transformers_class in self.transformers_test_model_classes:
+ peft_model = transformers_class.from_pretrained(model_id).to(torch_device)
+
+ state_dict = peft_model.get_adapter_state_dict()
+
+ for key in state_dict.keys():
+ self.assertTrue("lora" in key)
+
+ def test_peft_save_pretrained(self):
+ """
+ Test that checks various combinations of `save_pretrained` with a model that has adapters loaded
+ on it. This checks if the saved model contains the expected files (adapter weights and adapter config).
+ """
+ for model_id in self.peft_test_model_ids:
+ for transformers_class in self.transformers_test_model_classes:
+ peft_model = transformers_class.from_pretrained(model_id).to(torch_device)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ peft_model.save_pretrained(tmpdirname)
+
+ self.assertTrue("adapter_model.bin" in os.listdir(tmpdirname))
+ self.assertTrue("adapter_config.json" in os.listdir(tmpdirname))
+
+ self.assertTrue("config.json" not in os.listdir(tmpdirname))
+ self.assertTrue("pytorch_model.bin" not in os.listdir(tmpdirname))
+
+ peft_model = transformers_class.from_pretrained(tmpdirname).to(torch_device)
+ self.assertTrue(self._check_lora_correctly_converted(peft_model))
+
+ peft_model.save_pretrained(tmpdirname, safe_serialization=True)
+ self.assertTrue("adapter_model.safetensors" in os.listdir(tmpdirname))
+ self.assertTrue("adapter_config.json" in os.listdir(tmpdirname))
+
+ peft_model = transformers_class.from_pretrained(tmpdirname).to(torch_device)
+ self.assertTrue(self._check_lora_correctly_converted(peft_model))
+
+ def test_peft_enable_disable_adapters(self):
+ """
+ A test that checks if `enable_adapters` and `disable_adapters` methods work as expected.
+ """
+ from peft import LoraConfig
+
+ dummy_input = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device)
+
+ for model_id in self.transformers_test_model_ids:
+ for transformers_class in self.transformers_test_model_classes:
+ peft_model = transformers_class.from_pretrained(model_id).to(torch_device)
+
+ peft_config = LoraConfig(init_lora_weights=False)
+
+ peft_model.add_adapter(peft_config)
+
+ peft_logits = peft_model(dummy_input).logits
+
+ peft_model.disable_adapters()
+
+ peft_logits_disabled = peft_model(dummy_input).logits
+
+ peft_model.enable_adapters()
+
+ peft_logits_enabled = peft_model(dummy_input).logits
+
+ self.assertTrue(torch.allclose(peft_logits, peft_logits_enabled, atol=1e-12, rtol=1e-12))
+ self.assertFalse(torch.allclose(peft_logits_enabled, peft_logits_disabled, atol=1e-12, rtol=1e-12))
+
+ def test_peft_add_adapter(self):
+ """
+ Simple test that tests if `add_adapter` works as expected
+ """
+ from peft import LoraConfig
+
+ for model_id in self.transformers_test_model_ids:
+ for transformers_class in self.transformers_test_model_classes:
+ model = transformers_class.from_pretrained(model_id).to(torch_device)
+
+ peft_config = LoraConfig(init_lora_weights=False)
+
+ model.add_adapter(peft_config)
+
+ self.assertTrue(self._check_lora_correctly_converted(model))
+ # dummy generation
+ _ = model.generate(input_ids=torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device))
+
+ def test_peft_add_multi_adapter(self):
+ """
+ Simple test that tests the basic usage of PEFT model through `from_pretrained`. This test tests if
+ add_adapter works as expected in multi-adapter setting.
+ """
+ from peft import LoraConfig
+ from peft.tuners.tuners_utils import BaseTunerLayer
+
+ dummy_input = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device)
+
+ for model_id in self.transformers_test_model_ids:
+ for transformers_class in self.transformers_test_model_classes:
+ is_peft_loaded = False
+ model = transformers_class.from_pretrained(model_id).to(torch_device)
+
+ logits_original_model = model(dummy_input).logits
+
+ peft_config = LoraConfig(init_lora_weights=False)
+
+ model.add_adapter(peft_config)
+
+ logits_adapter_1 = model(dummy_input)
+
+ model.add_adapter(peft_config, adapter_name="adapter-2")
+
+ logits_adapter_2 = model(dummy_input)
+
+ for _, m in model.named_modules():
+ if isinstance(m, BaseTunerLayer):
+ is_peft_loaded = True
+ break
+
+ self.assertTrue(is_peft_loaded)
+
+ # dummy generation
+ _ = model.generate(input_ids=dummy_input)
+
+ model.set_adapter("default")
+ self.assertTrue(model.active_adapter() == "default")
+
+ model.set_adapter("adapter-2")
+ self.assertTrue(model.active_adapter() == "adapter-2")
+
+ # Logits comparison
+ self.assertFalse(
+ torch.allclose(logits_adapter_1.logits, logits_adapter_2.logits, atol=1e-6, rtol=1e-6)
+ )
+ self.assertFalse(torch.allclose(logits_original_model, logits_adapter_2.logits, atol=1e-6, rtol=1e-6))
+
+ @require_torch_gpu
+ def test_peft_from_pretrained_kwargs(self):
+ """
+ Simple test that tests the basic usage of PEFT model through `from_pretrained` + additional kwargs
+ and see if the integraiton behaves as expected.
+ """
+ for model_id in self.peft_test_model_ids:
+ for transformers_class in self.transformers_test_model_classes:
+ peft_model = transformers_class.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
+
+ module = peft_model.model.decoder.layers[0].self_attn.v_proj
+ self.assertTrue(module.__class__.__name__ == "Linear8bitLt")
+ self.assertTrue(peft_model.hf_device_map is not None)
+
+ # dummy generation
+ _ = peft_model.generate(input_ids=torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device))
+
+ def test_peft_pipeline(self):
+ """
+ Simple test that tests the basic usage of PEFT model + pipeline
+ """
+ from transformers import pipeline
+
+ for model_id in self.peft_test_model_ids:
+ pipe = pipeline("text-generation", model_id)
+ _ = pipe("Hello")