diff --git a/docker/transformers-all-latest-gpu/Dockerfile b/docker/transformers-all-latest-gpu/Dockerfile index 85a3e63680cb..c96b4cc79b3b 100644 --- a/docker/transformers-all-latest-gpu/Dockerfile +++ b/docker/transformers-all-latest-gpu/Dockerfile @@ -44,6 +44,8 @@ RUN python3 -m pip install -U "itsdangerous<2.1.0" RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/accelerate@main#egg=accelerate +RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/peft@main#egg=peft + # Add bitsandbytes for mixed int8 testing RUN python3 -m pip install --no-cache-dir bitsandbytes diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 9d1c33900c10..adb0f475ee3c 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -19,6 +19,8 @@ title: Train with a script - local: accelerate title: Set up distributed training with 🤗 Accelerate + - local: peft + title: Load and train adapters with 🤗 PEFT - local: model_sharing title: Share your model - local: transformers_agents diff --git a/docs/source/en/peft.md b/docs/source/en/peft.md new file mode 100644 index 000000000000..302b614e5f7b --- /dev/null +++ b/docs/source/en/peft.md @@ -0,0 +1,216 @@ + + +# Load adapters with 🤗 PEFT + +[[open-in-colab]] + +[Parameter-Efficient Fine Tuning (PEFT)](https://huggingface.co/blog/peft) methods freeze the pretrained model parameters during fine-tuning and add a small number of trainable parameters (the adapters) on top of it. The adapters are trained to learn task-specific information. This approach has been shown to be very memory-efficient with lower compute usage while producing results comparable to a fully fine-tuned model. + +Adapters trained with PEFT are also usually an order of magnitude smaller than the full model, making it convenient to share, store, and load them. + +
+ +
The adapter weights for a OPTForCausalLM model stored on the Hub are only ~6MB compared to the full size of the model weights, which can be ~700MB.
+
+ +If you're interested in learning more about the 🤗 PEFT library, check out the [documentation](https://huggingface.co/docs/peft/index). + +## Setup + +Get started by installing 🤗 PEFT: + +```bash +pip install peft +``` + +If you want to try out the brand new features, you might be interested in installing the library from source: + +```bash +pip install git+https://github.com/huggingface/peft.git +``` + +## Supported PEFT models + +🤗 Transformers natively supports some PEFT methods, meaning you can load adapter weights stored locally or on the Hub and easily run or train them with a few lines of code. The following methods are supported: + +- [Low Rank Adapters](https://huggingface.co/docs/peft/conceptual_guides/lora) +- [IA3](https://huggingface.co/docs/peft/conceptual_guides/ia3) +- [AdaLoRA](https://arxiv.org/abs/2303.10512) + +If you want to use other PEFT methods, such as prompt learning or prompt tuning, or about the 🤗 PEFT library in general, please refer to the [documentation](https://huggingface.co/docs/peft/index). + + +## Load a PEFT adapter + +To load and use a PEFT adapter model from 🤗 Transformers, make sure the Hub repository or local directory contains an `adapter_config.json` file and the adapter weights, as shown in the example image above. Then you can load the PEFT adapter model using the `AutoModelFor` class. For example, to load a PEFT adapter model for causal language modeling: + +1. specify the PEFT model id +2. pass it to the [`AutoModelForCausalLM`] class + +```py +from transformers import AutoModelForCausalLM, AutoTokenizer + +peft_model_id = "ybelkada/opt-350m-lora" +model = AutoModelForCausalLM.from_pretrained(peft_model_id) +``` + + + +You can load a PEFT adapter with either an `AutoModelFor` class or the base model class like `OPTForCausalLM` or `LlamaForCausalLM`. + + + +You can also load a PEFT adapter by calling the `load_adapter` method: + +```py +from transformers import AutoModelForCausalLM, AutoTokenizer + +model_id = "facebook/opt-350m" +peft_model_id = "ybelkada/opt-350m-lora" + +model = AutoModelForCausalLM.from_pretrained(model_id) +model.load_adapter(peft_model_id) +``` + +## Load in 8bit or 4bit + +The `bitsandbytes` integration supports 8bit and 4bit precision data types, which are useful for loading large models because it saves memory (see the `bitsandbytes` integration [guide](./quantization#bitsandbytes-integration) to learn more). Add the `load_in_8bit` or `load_in_4bit` parameters to [`~PreTrainedModel.from_pretrained`] and set `device_map="auto"` to effectively distribute the model to your hardware: + +```py +from transformers import AutoModelForCausalLM, AutoTokenizer + +peft_model_id = "ybelkada/opt-350m-lora" +model = AutoModelForCausalLM.from_pretrained(peft_model_id, device_map="auto", load_in_8bit=True) +``` + +## Add a new adapter + +You can use [`~peft.PeftModel.add_adapter`] to add a new adapter to a model with an existing adapter as long as the new adapter is the same type as the current one. For example, if you have an existing LoRA adapter attached to a model: + +```py +from transformers import AutoModelForCausalLM, OPTForCausalLM, AutoTokenizer +from peft import PeftConfig + +model_id = "facebook/opt-350m" +model = AutoModelForCausalLM.from_pretrained(model_id) + +lora_config = LoraConfig( + target_modules=["q_proj", "k_proj"], + init_lora_weights=False +) + +model.add_adapter(lora_config, adapter_name="adapter_1") +``` + +To add a new adapter: + +```py +# attach new adapter with same config +model.add_adapter(lora_config, adapter_name="adapter_2") +``` + +Now you can use [`~peft.PeftModel.set_adapter`] to set which adapter to use: + +```py +# use adapter_1 +model.set_adapter("adapter_1") +output = model.generate(**inputs) +print(tokenizer.decode(output_disabled[0], skip_special_tokens=True)) + +# use adapter_2 +model.set_adapter("adapter_2") +output_enabled = model.generate(**inputs) +print(tokenizer.decode(output_enabled[0], skip_special_tokens=True)) +``` + +## Enable and disable adapters + +Once you've added an adapter to a model, you can enable or disable the adapter module. To enable the adapter module: + +```py +from transformers import AutoModelForCausalLM, OPTForCausalLM, AutoTokenizer +from peft import PeftConfig + +model_id = "facebook/opt-350m" +adapter_model_id = "ybelkada/opt-350m-lora" +tokenizer = AutoTokenizer.from_pretrained(model_id) +text = "Hello" +inputs = tokenizer(text, return_tensors="pt") + +model = AutoModelForCausalLM.from_pretrained(model_id) +peft_config = PeftConfig.from_pretrained(adapter_model_id) + +# to initiate with random weights +peft_config.init_lora_weights = False + +model.add_adapter(peft_config) +model.enable_adapters() +output = model.generate(**inputs) +``` + +To disable the adapter module: + +```py +model.disable_adapters() +output = model.generate(**inputs) +``` + +## Train a PEFT adapter + +PEFT adapters are supported by the [`Trainer`] class so that you can train an adapter for your specific use case. It only requires adding a few more lines of code. For example, to train a LoRA adapter: + + + +If you aren't familiar with fine-tuning a model with [`Trainer`], take a look at the [Fine-tune a pretrained model](training) tutorial. + + + +1. Define your adapter configuration with the task type and hyperparameters (see [`~peft.LoraConfig`] for more details about what the hyperparameters do). + +```py +from peft import LoraConfig + +peft_config = LoraConfig( + lora_alpha=16, + lora_dropout=0.1, + r=64, + bias="none", + task_type="CAUSAL_LM", +) +``` + +2. Add adapter to the model. + +```py +model.add_adapter(peft_config) +``` + +3. Now you can pass the model to [`Trainer`]! + +```py +trainer = Trainer(model=model, ...) +trainer.train() +``` + +To save your trained adapter and load it back: + +```py +model.save_pretrained(save_dir) +model = AutoModelForCausalLM.from_pretrained(save_dir) +``` + + diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index b06254a1c0ad..0a9bc3257da1 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -111,6 +111,8 @@ "is_tensorboard_available", "is_wandb_available", ], + "lib_integrations": [], + "lib_integrations.peft": [], "modelcard": ["ModelCard"], "modeling_tf_pytorch_utils": [ "convert_tf_weight_name_to_pt_weight_name", diff --git a/src/transformers/lib_integrations/__init__.py b/src/transformers/lib_integrations/__init__.py new file mode 100644 index 000000000000..0a2b0329f696 --- /dev/null +++ b/src/transformers/lib_integrations/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .peft import PeftAdapterMixin diff --git a/src/transformers/lib_integrations/peft/__init__.py b/src/transformers/lib_integrations/peft/__init__.py new file mode 100644 index 000000000000..a6c1f0afd7e3 --- /dev/null +++ b/src/transformers/lib_integrations/peft/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .peft_mixin import PeftAdapterMixin diff --git a/src/transformers/lib_integrations/peft/peft_mixin.py b/src/transformers/lib_integrations/peft/peft_mixin.py new file mode 100644 index 000000000000..82afe8db00e3 --- /dev/null +++ b/src/transformers/lib_integrations/peft/peft_mixin.py @@ -0,0 +1,390 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import Optional + +from ...utils import ( + check_peft_version, + find_adapter_config_file, + is_accelerate_available, + is_peft_available, + logging, +) + + +if is_accelerate_available(): + from accelerate import dispatch_model + from accelerate.utils import get_balanced_memory, infer_auto_device_map + + +logger = logging.get_logger(__name__) + + +class PeftAdapterMixin: + """ + A class containing all functions for loading and using adapters weights that are supported in PEFT library. For + more details about adapters and injecting them on a transformer-based model, check out the documentation of PEFT + library: https://huggingface.co/docs/peft/index + + Currently supported PEFT methods are all non-prefix tuning methods. Below is the list of supported PEFT methods + that anyone can load, train and run with this mixin class: + - Low Rank Adapters (LoRA): https://huggingface.co/docs/peft/conceptual_guides/lora + - IA3: https://huggingface.co/docs/peft/conceptual_guides/ia3 + - AdaLora: https://arxiv.org/abs/2303.10512 + + Other PEFT models such as prompt tuning, prompt learning are out of scope as these adapters are not "injectable" + into a torch module. For using these methods, please refer to the usage guide of PEFT library. + + With this mixin, if the correct PEFT version is installed, it is possible to: + + - Load an adapter stored on a local path or in a remote Hub repository, and inject it in the model + - Attach new adapters in the model and train them with Trainer or by your own. + - Attach multiple adapters and iteratively activate / deactivate them + - Activate / deactivate all adapters from the model. + - Get the `state_dict` of the active adapter. + """ + + _hf_peft_config_loaded = False + + def load_adapter( + self, + peft_model_id: str, + adapter_name: Optional[str] = None, + revision: Optional[str] = None, + token: Optional[str] = None, + device_map: Optional[str] = "auto", + max_memory: Optional[str] = None, + offload_folder: Optional[str] = None, + offload_index: Optional[int] = None, + ) -> None: + """ + Load adapter weights from file or remote Hub folder. If you are not familiar with adapters and PEFT methods, we + invite you to read more about them on PEFT official documentation: https://huggingface.co/docs/peft + + Requires peft as a backend to load the adapter weights. + + Args: + peft_model_id (`str`): + The identifier of the model to look for on the Hub, or a local path to the saved adapter config file + and adapter weights. + adapter_name (`str`, *optional*): + The adapter name to use. If not set, will use the default adapter. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + + + + To test a pull request you made on the Hub, you can pass `revision="refs/pr/". + + + + token (`str`, `optional`): + Whether to use authentication token to load the remote folder. Userful to load private repositories + that are on HuggingFace Hub. You might need to call `huggingface-cli login` and paste your tokens to + cache it. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]` or `int` or `torch.device`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be refined to each + parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the + same device. If we only pass the device (*e.g.*, `"cpu"`, `"cuda:1"`, `"mps"`, or a GPU ordinal rank + like `1`) on which the model will be allocated, the device map will map the entire model to this + device. Passing `device_map = 0` means put the whole model on GPU 0. + + To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier to maximum memory. Will default to the maximum memory available for each + GPU and the available CPU RAM if unset. + offload_folder (`str` or `os.PathLike`, `optional`): + If the `device_map` contains any value `"disk"`, the folder where we will offload weights. + offload_index (`int`, `optional`): + `offload_index` argument to be passed to `accelerate.dispatch_model` method. + """ + check_peft_version(min_version="0.4.0") + + adapter_name = adapter_name if adapter_name is not None else "default" + + from peft import PeftConfig, inject_adapter_in_model, load_peft_weights + from peft.utils import set_peft_model_state_dict + + if not self._hf_peft_config_loaded: + self._hf_peft_config_loaded = True + elif adapter_name in self.peft_config: + raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.") + + adapter_config_file = find_adapter_config_file( + peft_model_id, + revision=revision, + token=token, + ) + + if adapter_config_file is None: + raise ValueError( + f"adapter model file not found in {peft_model_id}. Make sure you are passing the correct path to the " + "adapter model." + ) + + loaded_peft_config = PeftConfig.from_pretrained( + peft_model_id, + revision=revision, + use_auth_token=token, + ) + + # Create and add fresh new adapters into the model. + inject_adapter_in_model(loaded_peft_config, self, adapter_name) + + adapter_state_dict = load_peft_weights(peft_model_id, revision=revision, use_auth_token=token) + + # We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility + processed_adapter_state_dict = {} + prefix = "base_model.model." + for key, value in adapter_state_dict.items(): + if key.startswith(prefix): + new_key = key[len(prefix) :] + else: + new_key = key + processed_adapter_state_dict[new_key] = value + + # Load state dict + incompatible_keys = set_peft_model_state_dict(self, processed_adapter_state_dict, adapter_name) + + if incompatible_keys is not None: + # check only for unexpected keys + if hasattr(incompatible_keys, "unexpected_keys") and len(incompatible_keys.unexpected_keys) > 0: + logger.warning( + f"Loading adapter weights from {peft_model_id} led to unexpected keys not found in the model: " + f" {incompatible_keys.unexpected_keys}. " + ) + + # Re-dispatch model and hooks in case the model is offloaded to CPU / Disk. + if ( + (getattr(self, "hf_device_map", None) is not None) + and (len(set(self.hf_device_map.values()).intersection({"cpu", "disk"})) > 0) + and len(self.peft_config) == 1 + ): + self._dispatch_accelerate_model( + device_map=device_map, + max_memory=max_memory, + offload_folder=offload_folder, + offload_index=offload_index, + ) + + def add_adapter(self, adapter_config, adapter_name: Optional[str] = None) -> None: + r""" + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT + official documentation: https://huggingface.co/docs/peft + + Adds a fresh new adapter to the current model for training purpose. If no adapter name is passed, a default + name is assigned to the adapter to follow the convention of PEFT library (in PEFT we use "default" as the + default adapter name). + + Args: + adapter_config (`~peft.PeftConfig`): + The configuration of the adapter to add, supported adapters are non-prefix tuning and adaption prompts + methods + adapter_name (`str`, *optional*, defaults to `"default"`): + The name of the adapter to add. If no name is passed, a default name is assigned to the adapter. + """ + check_peft_version(min_version="0.4.0") + + from peft import PeftConfig, inject_adapter_in_model + + adapter_name = adapter_name or "default" + + if not self._hf_peft_config_loaded: + self._hf_peft_config_loaded = True + elif adapter_name in self.peft_config: + raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.") + + if not isinstance(adapter_config, PeftConfig): + raise ValueError( + f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead." + ) + + inject_adapter_in_model(adapter_config, self, adapter_name) + + self.set_adapter(adapter_name) + + def set_adapter(self, adapter_name: str) -> None: + """ + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT + official documentation: https://huggingface.co/docs/peft + + Sets a specific adapter by forcing the model to use a that adapter and disable the other adapters. + + Args: + adapter_name (`str`): + The name of the adapter to set. + """ + check_peft_version(min_version="0.4.0") + if not self._hf_peft_config_loaded: + raise ValueError("No adapter loaded. Please load an adapter first.") + elif adapter_name not in self.peft_config: + raise ValueError( + f"Adapter with name {adapter_name} not found. Please pass the correct adapter name among {list(self.peft_config.keys())}" + ) + + from peft.tuners.tuners_utils import BaseTunerLayer + + _adapters_has_been_set = False + + for _, module in self.named_modules(): + if isinstance(module, BaseTunerLayer): + module.active_adapter = adapter_name + _adapters_has_been_set = True + + if not _adapters_has_been_set: + raise ValueError( + "Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters." + ) + + def disable_adapters(self) -> None: + r""" + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT + official documentation: https://huggingface.co/docs/peft + + Disable all adapters that are attached to the model. This leads to inferring with the base model only. + """ + check_peft_version(min_version="0.4.0") + + if not self._hf_peft_config_loaded: + raise ValueError("No adapter loaded. Please load an adapter first.") + + from peft.tuners.tuners_utils import BaseTunerLayer + + for _, module in self.named_modules(): + if isinstance(module, BaseTunerLayer): + module.disable_adapters = True + + def enable_adapters(self) -> None: + """ + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT + official documentation: https://huggingface.co/docs/peft + + Enable adapters that are attached to the model. The model will use `self.active_adapter()` + """ + check_peft_version(min_version="0.4.0") + + if not self._hf_peft_config_loaded: + raise ValueError("No adapter loaded. Please load an adapter first.") + + from peft.tuners.tuners_utils import BaseTunerLayer + + for _, module in self.named_modules(): + if isinstance(module, BaseTunerLayer): + module.disable_adapters = False + + def active_adapter(self) -> str: + """ + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT + official documentation: https://huggingface.co/docs/peft + + Gets the current active adapter of the model. + """ + check_peft_version(min_version="0.4.0") + + if not is_peft_available(): + raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.") + + if not self._hf_peft_config_loaded: + raise ValueError("No adapter loaded. Please load an adapter first.") + + from peft.tuners.tuners_utils import BaseTunerLayer + + for _, module in self.named_modules(): + if isinstance(module, BaseTunerLayer): + return module.active_adapter + + def get_adapter_state_dict(self, adapter_name: Optional[str] = None) -> dict: + """ + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT + official documentation: https://huggingface.co/docs/peft + + Gets the adapter state dict that should only contain the weights tensors of the specified adapter_name adapter. + If no adapter_name is passed, the active adapter is used. + + Args: + adapter_name (`str`, *optional*): + The name of the adapter to get the state dict from. If no name is passed, the active adapter is used. + """ + check_peft_version(min_version="0.4.0") + + if not self._hf_peft_config_loaded: + raise ValueError("No adapter loaded. Please load an adapter first.") + + from peft import get_peft_model_state_dict + + if adapter_name is None: + adapter_name = self.active_adapter() + + adapter_state_dict = get_peft_model_state_dict(self, adapter_name=adapter_name) + return adapter_state_dict + + def _dispatch_accelerate_model( + self, + device_map: str, + max_memory: Optional[int] = None, + offload_folder: Optional[str] = None, + offload_index: Optional[int] = None, + ) -> None: + """ + Optionnal re-dispatch the model and attach new hooks to the model in case the model has been loaded with + accelerate (i.e. with `device_map=xxx`) + + Args: + device_map (`str` or `Dict[str, Union[int, str, torch.device]]` or `int` or `torch.device`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be refined to each + parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the + same device. If we only pass the device (*e.g.*, `"cpu"`, `"cuda:1"`, `"mps"`, or a GPU ordinal rank + like `1`) on which the model will be allocated, the device map will map the entire model to this + device. Passing `device_map = 0` means put the whole model on GPU 0. + + To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier to maximum memory. Will default to the maximum memory available for each + GPU and the available CPU RAM if unset. + offload_folder (`str` or `os.PathLike`, *optional*): + If the `device_map` contains any value `"disk"`, the folder where we will offload weights. + offload_index (`int`, *optional*): + The offload_index argument to be passed to `accelerate.dispatch_model` method. + """ + dispatch_model_kwargs = {} + # Safety checker for previous `accelerate` versions + # `offload_index` was introduced in https://github.com/huggingface/accelerate/pull/873/ + if "offload_index" in inspect.signature(dispatch_model).parameters: + dispatch_model_kwargs["offload_index"] = offload_index + + no_split_module_classes = self._no_split_modules + + if device_map != "sequential": + max_memory = get_balanced_memory( + self, + max_memory=max_memory, + no_split_module_classes=no_split_module_classes, + low_zero=(device_map == "balanced_low_0"), + ) + if isinstance(device_map, str): + device_map = infer_auto_device_map( + self, max_memory=max_memory, no_split_module_classes=no_split_module_classes + ) + dispatch_model( + self, + device_map=device_map, + offload_dir=offload_folder, + **dispatch_model_kwargs, + ) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 43f9a434fa8f..ea2d70ec4a8b 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -38,6 +38,7 @@ from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled from .dynamic_module_utils import custom_object_save from .generation import GenerationConfig, GenerationMixin +from .lib_integrations import PeftAdapterMixin from .pytorch_utils import ( # noqa: F401 Conv1D, apply_chunking_to_forward, @@ -48,6 +49,8 @@ prune_linear_layer, ) from .utils import ( + ADAPTER_SAFE_WEIGHTS_NAME, + ADAPTER_WEIGHTS_NAME, DUMMY_INPUTS, FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, @@ -68,6 +71,7 @@ is_bitsandbytes_available, is_offline_mode, is_optimum_available, + is_peft_available, is_remote_url, is_safetensors_available, is_torch_tpu_available, @@ -123,6 +127,9 @@ def is_fsdp_enabled_and_dist_rank_0(): else: IS_SAGEMAKER_MP_POST_1_10 = False +if is_peft_available(): + from .utils import find_adapter_config_file + @contextmanager def no_init_weights(_enable=True): @@ -1039,7 +1046,7 @@ def floating_point_ops( return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings) -class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin): +class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin): r""" Base class for all models. @@ -1738,6 +1745,7 @@ def save_pretrained( safe_serialization: bool = False, variant: Optional[str] = None, token: Optional[Union[str, bool]] = None, + save_peft_format: bool = True, **kwargs, ): """ @@ -1780,6 +1788,10 @@ def save_pretrained( token (`str` or `bool`, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). + save_peft_format (`bool`, *optional*, defaults to `True`): + For backward compatibility with PEFT library, in case adapter weights are attached to the model, all + keys of the state dict of adapters needs to be pre-pended with `base_model.model`. Advanced users can + disable this behaviours by setting `save_peft_format` to `False`. kwargs (`Dict[str, Any]`, *optional*): Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ @@ -1847,12 +1859,33 @@ def save_pretrained( if self._auto_class is not None: custom_object_save(self, save_directory, config=self.config) + _hf_peft_config_loaded = getattr(model_to_save, "_hf_peft_config_loaded", False) + # Save the config if is_main_process: - model_to_save.config.save_pretrained(save_directory) + if not _hf_peft_config_loaded: + model_to_save.config.save_pretrained(save_directory) if self.can_generate(): model_to_save.generation_config.save_pretrained(save_directory) + if _hf_peft_config_loaded: + logger.info( + "Detected adapters on the model, saving the model in the PEFT format, only adapter weights will be saved." + ) + state_dict = model_to_save.get_adapter_state_dict() + + if save_peft_format: + logger.info( + "To match the expected format of the PEFT library, all keys of the state dict of adapters will be pre-pended with `base_model.model`." + ) + peft_state_dict = {} + for key, value in state_dict.items(): + peft_state_dict[f"base_model.model.{key}"] = value + state_dict = peft_state_dict + + current_peft_config = self.peft_config[self.active_adapter()] + current_peft_config.save_pretrained(save_directory) + # Save the model if state_dict is None: state_dict = model_to_save.state_dict() @@ -1907,8 +1940,11 @@ def save_pretrained( ) # Shard the model if it is too big. - weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME - weights_name = _add_variant(weights_name, variant) + if not _hf_peft_config_loaded: + weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME + weights_name = _add_variant(weights_name, variant) + else: + weights_name = ADAPTER_SAFE_WEIGHTS_NAME if safe_serialization else ADAPTER_WEIGHTS_NAME shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name) @@ -2295,6 +2331,8 @@ def from_pretrained( subfolder = kwargs.pop("subfolder", "") commit_hash = kwargs.pop("_commit_hash", None) variant = kwargs.pop("variant", None) + _adapter_model_path = kwargs.pop("_adapter_model_path", None) + adapter_name = kwargs.pop("adapter_name", "default") if is_fsdp_enabled(): low_cpu_mem_usage = True @@ -2323,6 +2361,29 @@ def from_pretrained( " ignored." ) + if is_peft_available() and _adapter_model_path is None: + maybe_adapter_model_path = find_adapter_config_file( + pretrained_model_name_or_path, + revision=revision, + subfolder=subfolder, + token=token, + commit_hash=commit_hash, + ) + elif is_peft_available() and _adapter_model_path is not None: + maybe_adapter_model_path = _adapter_model_path + else: + maybe_adapter_model_path = None + + has_adapter_config = maybe_adapter_model_path is not None + + if has_adapter_config: + if _adapter_model_path is not None: + adapter_model_id = _adapter_model_path + else: + with open(maybe_adapter_model_path, "r", encoding="utf-8") as f: + adapter_model_id = pretrained_model_name_or_path + pretrained_model_name_or_path = json.load(f)["base_model_name_or_path"] + # change device_map into a map if we passed an int, a str or a torch.device if isinstance(device_map, torch.device): device_map = {"": device_map} @@ -3153,6 +3214,14 @@ def from_pretrained( if quantization_method_from_config == QuantizationMethod.GPTQ: model = quantizer.post_init_model(model) + if has_adapter_config: + model.load_adapter( + adapter_model_id, + adapter_name=adapter_name, + revision=revision, + token=token, + ) + if output_loading_info: if loading_info is None: loading_info = { diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index a5e4fcbb90ee..fc58170b75ed 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -15,13 +15,14 @@ """Factory function to build auto-model classes.""" import copy import importlib +import json import os import warnings from collections import OrderedDict from ...configuration_utils import PretrainedConfig from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code -from ...utils import copy_func, logging, requires_backends +from ...utils import copy_func, find_adapter_config_file, is_peft_available, logging, requires_backends from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings @@ -469,6 +470,24 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): if token is not None: hub_kwargs["token"] = token + if is_peft_available(): + revision = kwargs.get("revision", None) + subfolder = kwargs.get("subfolder", None) + + maybe_adapter_path = find_adapter_config_file( + pretrained_model_name_or_path, + revision=revision, + token=use_auth_token, + subfolder=subfolder, + ) + + if maybe_adapter_path is not None: + with open(maybe_adapter_path, "r", encoding="utf-8") as f: + adapter_config = json.load(f) + + kwargs["_adapter_model_path"] = pretrained_model_name_or_path + pretrained_model_name_or_path = adapter_config["base_model_name_or_path"] + if not isinstance(config, PretrainedConfig): kwargs_orig = copy.deepcopy(kwargs) # ensure not to pollute the config object with torch_dtype="auto" - since it's diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index b1d6ec4bdac3..746089b4e5cf 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -34,8 +34,10 @@ from ..tokenization_utils import PreTrainedTokenizer from ..utils import ( HUGGINGFACE_CO_RESOLVE_ENDPOINT, + find_adapter_config_file, is_kenlm_available, is_offline_mode, + is_peft_available, is_pyctcdecode_available, is_tf_available, is_torch_available, @@ -721,6 +723,21 @@ def pipeline( config = AutoConfig.from_pretrained(config, _from_pipeline=task, **hub_kwargs, **model_kwargs) hub_kwargs["_commit_hash"] = config._commit_hash elif config is None and isinstance(model, str): + # Check for an adapter file in the model path if PEFT is available + if is_peft_available(): + subfolder = hub_kwargs.get("subfolder", None) + maybe_adapter_path = find_adapter_config_file( + model, + revision=revision, + token=use_auth_token, + subfolder=subfolder, + ) + + if maybe_adapter_path is not None: + with open(maybe_adapter_path, "r", encoding="utf-8") as f: + adapter_config = json.load(f) + model = adapter_config["base_model_name_or_path"] + config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs) hub_kwargs["_commit_hash"] = config._commit_hash diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 18d5880a1729..9a59327af61f 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -69,6 +69,7 @@ is_onnx_available, is_optimum_available, is_pandas_available, + is_peft_available, is_phonemizer_available, is_pyctcdecode_available, is_pytesseract_available, @@ -369,6 +370,16 @@ def require_torch(test_case): return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case) +def require_peft(test_case): + """ + Decorator marking a test that requires PEFT. + + These tests are skipped when PEFT isn't installed. + + """ + return unittest.skipUnless(is_peft_available(), "test requires PEFT")(test_case) + + def require_torchvision(test_case): """ Decorator marking a test that requires Torchvision. diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 4f4982553687..85d4fd5a5252 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -396,7 +396,7 @@ def __init__( ) # At this stage the model is already loaded - if getattr(model, "is_quantized", False): + if getattr(model, "is_quantized", False) and not getattr(model, "_hf_peft_config_loaded", False): if getattr(model, "_is_quantized_training_enabled", False): logger.info( "The model is quantized. To train this model you need to add additional modules" diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 83b0128fbc58..422464ce5bd0 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -179,13 +179,17 @@ requires_backends, torch_only_method, ) +from .peft_utils import ( + ADAPTER_CONFIG_NAME, + ADAPTER_SAFE_WEIGHTS_NAME, + ADAPTER_WEIGHTS_NAME, + check_peft_version, + find_adapter_config_file, +) WEIGHTS_NAME = "pytorch_model.bin" WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" -ADAPTER_CONFIG_NAME = "adapter_config.json" -ADAPTER_WEIGHTS_NAME = "adapter_model.bin" -ADAPTER_SAFE_WEIGHTS_NAME = "adapter_model.safetensors" TF2_WEIGHTS_NAME = "tf_model.h5" TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json" TF_WEIGHTS_NAME = "model.ckpt" diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 54ed4030a2b6..8172ba23401f 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -1001,6 +1001,11 @@ def is_jieba_available(): jieba`. Please note that you may need to restart your runtime after installation. """ +PEFT_IMPORT_ERROR = """ +{0} requires the peft library but it was not found in your environment. You can install it with pip: `pip install +peft`. Please note that you may need to restart your runtime after installation. +""" + BACKENDS_MAPPING = OrderedDict( [ ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), @@ -1034,6 +1039,7 @@ def is_jieba_available(): ("decord", (is_decord_available, DECORD_IMPORT_ERROR)), ("cython", (is_cython_available, CYTHON_IMPORT_ERROR)), ("jieba", (is_jieba_available, JIEBA_IMPORT_ERROR)), + ("peft", (is_peft_available, PEFT_IMPORT_ERROR)), ] ) diff --git a/src/transformers/utils/peft_utils.py b/src/transformers/utils/peft_utils.py new file mode 100644 index 000000000000..1dd8e14dbb86 --- /dev/null +++ b/src/transformers/utils/peft_utils.py @@ -0,0 +1,98 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import os +from typing import Optional + +from packaging import version + +from .hub import cached_file +from .import_utils import is_peft_available + + +ADAPTER_CONFIG_NAME = "adapter_config.json" +ADAPTER_WEIGHTS_NAME = "adapter_model.bin" +ADAPTER_SAFE_WEIGHTS_NAME = "adapter_model.safetensors" + + +def find_adapter_config_file( + model_id: str, + revision: str = None, + subfolder: str = None, + token: Optional[str] = None, + commit_hash: Optional[str] = None, +) -> Optional[str]: + r""" + Simply checks if the model stored on the Hub or locally is an adapter model or not, return the path the the adapter + config file if it is, None otherwise. + + Args: + model_id (`str`): + The identifier of the model to look for, can be either a local path or an id to the repository on the Hub. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + + + + To test a pull request you made on the Hub, you can pass `revision="refs/pr/". + + + + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can + specify the folder name here. + token (`str`, `optional`): + Whether to use authentication token to load the remote folder. Userful to load private repositories that + are on HuggingFace Hub. You might need to call `huggingface-cli login` and paste your tokens to cache it. + """ + adapter_cached_filename = None + if os.path.isdir(model_id): + list_remote_files = os.listdir(model_id) + if ADAPTER_CONFIG_NAME in list_remote_files: + adapter_cached_filename = os.path.join(model_id, ADAPTER_CONFIG_NAME) + else: + adapter_cached_filename = cached_file( + model_id, + ADAPTER_CONFIG_NAME, + revision=revision, + token=token, + _commit_hash=commit_hash, + subfolder=subfolder, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, + ) + + return adapter_cached_filename + + +def check_peft_version(min_version: str) -> None: + r""" + Checks if the version of PEFT is compatible. + + Args: + version (`str`): + The version of PEFT to check against. + """ + if not is_peft_available(): + raise ValueError("PEFT is not installed. Please install it with `pip install peft`") + + is_peft_version_compatible = version.parse(importlib.metadata.version("peft")) <= version.parse(min_version) + + if not is_peft_version_compatible: + raise ValueError( + f"The version of PEFT you are using is not compatible, please use a version that is greater" + f" than {min_version}" + ) diff --git a/tests/peft_integration/test_peft_integration.py b/tests/peft_integration/test_peft_integration.py new file mode 100644 index 000000000000..b80912607c02 --- /dev/null +++ b/tests/peft_integration/test_peft_integration.py @@ -0,0 +1,236 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import tempfile +import unittest + +from transformers import AutoModelForCausalLM, OPTForCausalLM +from transformers.testing_utils import require_peft, require_torch, require_torch_gpu, slow, torch_device +from transformers.utils import is_torch_available + + +if is_torch_available(): + import torch + + +@require_peft +@require_torch +class PeftTesterMixin: + peft_test_model_ids = ("peft-internal-testing/tiny-OPTForCausalLM-lora",) + transformers_test_model_ids = ("hf-internal-testing/tiny-random-OPTForCausalLM",) + transformers_test_model_classes = (AutoModelForCausalLM, OPTForCausalLM) + + +# TODO: run it with CI after PEFT release. +@slow +class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin): + """ + A testing suite that makes sure that the PeftModel class is correctly integrated into the transformers library. + """ + + def _check_lora_correctly_converted(self, model): + """ + Utility method to check if the model has correctly adapters injected on it. + """ + from peft.tuners.tuners_utils import BaseTunerLayer + + is_peft_loaded = False + + for _, m in model.named_modules(): + if isinstance(m, BaseTunerLayer): + is_peft_loaded = True + break + + return is_peft_loaded + + def test_peft_from_pretrained(self): + """ + Simple test that tests the basic usage of PEFT model through `from_pretrained`. + This checks if we pass a remote folder that contains an adapter config and adapter weights, it + should correctly load a model that has adapters injected on it. + """ + for model_id in self.peft_test_model_ids: + for transformers_class in self.transformers_test_model_classes: + peft_model = transformers_class.from_pretrained(model_id).to(torch_device) + + self.assertTrue(self._check_lora_correctly_converted(peft_model)) + self.assertTrue(peft_model._hf_peft_config_loaded) + # dummy generation + _ = peft_model.generate(input_ids=torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device)) + + def test_peft_state_dict(self): + """ + Simple test that checks if the returned state dict of `get_adapter_state_dict()` method contains + the expected keys. + """ + for model_id in self.peft_test_model_ids: + for transformers_class in self.transformers_test_model_classes: + peft_model = transformers_class.from_pretrained(model_id).to(torch_device) + + state_dict = peft_model.get_adapter_state_dict() + + for key in state_dict.keys(): + self.assertTrue("lora" in key) + + def test_peft_save_pretrained(self): + """ + Test that checks various combinations of `save_pretrained` with a model that has adapters loaded + on it. This checks if the saved model contains the expected files (adapter weights and adapter config). + """ + for model_id in self.peft_test_model_ids: + for transformers_class in self.transformers_test_model_classes: + peft_model = transformers_class.from_pretrained(model_id).to(torch_device) + + with tempfile.TemporaryDirectory() as tmpdirname: + peft_model.save_pretrained(tmpdirname) + + self.assertTrue("adapter_model.bin" in os.listdir(tmpdirname)) + self.assertTrue("adapter_config.json" in os.listdir(tmpdirname)) + + self.assertTrue("config.json" not in os.listdir(tmpdirname)) + self.assertTrue("pytorch_model.bin" not in os.listdir(tmpdirname)) + + peft_model = transformers_class.from_pretrained(tmpdirname).to(torch_device) + self.assertTrue(self._check_lora_correctly_converted(peft_model)) + + peft_model.save_pretrained(tmpdirname, safe_serialization=True) + self.assertTrue("adapter_model.safetensors" in os.listdir(tmpdirname)) + self.assertTrue("adapter_config.json" in os.listdir(tmpdirname)) + + peft_model = transformers_class.from_pretrained(tmpdirname).to(torch_device) + self.assertTrue(self._check_lora_correctly_converted(peft_model)) + + def test_peft_enable_disable_adapters(self): + """ + A test that checks if `enable_adapters` and `disable_adapters` methods work as expected. + """ + from peft import LoraConfig + + dummy_input = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device) + + for model_id in self.transformers_test_model_ids: + for transformers_class in self.transformers_test_model_classes: + peft_model = transformers_class.from_pretrained(model_id).to(torch_device) + + peft_config = LoraConfig(init_lora_weights=False) + + peft_model.add_adapter(peft_config) + + peft_logits = peft_model(dummy_input).logits + + peft_model.disable_adapters() + + peft_logits_disabled = peft_model(dummy_input).logits + + peft_model.enable_adapters() + + peft_logits_enabled = peft_model(dummy_input).logits + + self.assertTrue(torch.allclose(peft_logits, peft_logits_enabled, atol=1e-12, rtol=1e-12)) + self.assertFalse(torch.allclose(peft_logits_enabled, peft_logits_disabled, atol=1e-12, rtol=1e-12)) + + def test_peft_add_adapter(self): + """ + Simple test that tests if `add_adapter` works as expected + """ + from peft import LoraConfig + + for model_id in self.transformers_test_model_ids: + for transformers_class in self.transformers_test_model_classes: + model = transformers_class.from_pretrained(model_id).to(torch_device) + + peft_config = LoraConfig(init_lora_weights=False) + + model.add_adapter(peft_config) + + self.assertTrue(self._check_lora_correctly_converted(model)) + # dummy generation + _ = model.generate(input_ids=torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device)) + + def test_peft_add_multi_adapter(self): + """ + Simple test that tests the basic usage of PEFT model through `from_pretrained`. This test tests if + add_adapter works as expected in multi-adapter setting. + """ + from peft import LoraConfig + from peft.tuners.tuners_utils import BaseTunerLayer + + dummy_input = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device) + + for model_id in self.transformers_test_model_ids: + for transformers_class in self.transformers_test_model_classes: + is_peft_loaded = False + model = transformers_class.from_pretrained(model_id).to(torch_device) + + logits_original_model = model(dummy_input).logits + + peft_config = LoraConfig(init_lora_weights=False) + + model.add_adapter(peft_config) + + logits_adapter_1 = model(dummy_input) + + model.add_adapter(peft_config, adapter_name="adapter-2") + + logits_adapter_2 = model(dummy_input) + + for _, m in model.named_modules(): + if isinstance(m, BaseTunerLayer): + is_peft_loaded = True + break + + self.assertTrue(is_peft_loaded) + + # dummy generation + _ = model.generate(input_ids=dummy_input) + + model.set_adapter("default") + self.assertTrue(model.active_adapter() == "default") + + model.set_adapter("adapter-2") + self.assertTrue(model.active_adapter() == "adapter-2") + + # Logits comparison + self.assertFalse( + torch.allclose(logits_adapter_1.logits, logits_adapter_2.logits, atol=1e-6, rtol=1e-6) + ) + self.assertFalse(torch.allclose(logits_original_model, logits_adapter_2.logits, atol=1e-6, rtol=1e-6)) + + @require_torch_gpu + def test_peft_from_pretrained_kwargs(self): + """ + Simple test that tests the basic usage of PEFT model through `from_pretrained` + additional kwargs + and see if the integraiton behaves as expected. + """ + for model_id in self.peft_test_model_ids: + for transformers_class in self.transformers_test_model_classes: + peft_model = transformers_class.from_pretrained(model_id, load_in_8bit=True, device_map="auto") + + module = peft_model.model.decoder.layers[0].self_attn.v_proj + self.assertTrue(module.__class__.__name__ == "Linear8bitLt") + self.assertTrue(peft_model.hf_device_map is not None) + + # dummy generation + _ = peft_model.generate(input_ids=torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device)) + + def test_peft_pipeline(self): + """ + Simple test that tests the basic usage of PEFT model + pipeline + """ + from transformers import pipeline + + for model_id in self.peft_test_model_ids: + pipe = pipeline("text-generation", model_id) + _ = pipe("Hello")