diff --git a/docs/source/package_reference/config.mdx b/docs/source/package_reference/config.mdx index 6866685500..47f5a00707 100644 --- a/docs/source/package_reference/config.mdx +++ b/docs/source/package_reference/config.mdx @@ -4,7 +4,7 @@ The configuration classes stores the configuration of a [`PeftModel`], PEFT adap ## PeftConfigMixin -[[autodoc]] utils.config.PeftConfigMixin +[[autodoc]] config.PeftConfigMixin - all ## PeftConfig diff --git a/src/peft/__init__.py b/src/peft/__init__.py index 59d5024946..68092b49a6 100644 --- a/src/peft/__init__.py +++ b/src/peft/__init__.py @@ -28,7 +28,13 @@ AutoPeftModelForQuestionAnswering, AutoPeftModelForFeatureExtraction, ) -from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PEFT_TYPE_TO_CONFIG_MAPPING, get_peft_config, get_peft_model +from .mapping import ( + MODEL_TYPE_TO_PEFT_MODEL_MAPPING, + PEFT_TYPE_TO_CONFIG_MAPPING, + get_peft_config, + get_peft_model, + inject_adapter_in_model, +) from .peft_model import ( PeftModel, PeftModelForCausalLM, @@ -58,9 +64,7 @@ ) from .utils import ( TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, - PeftConfig, PeftType, - PromptLearningConfig, TaskType, bloom_model_postprocess_past_key_value, get_peft_model_state_dict, @@ -68,4 +72,6 @@ prepare_model_for_kbit_training, set_peft_model_state_dict, shift_tokens_right, + load_peft_weights, ) +from .config import PeftConfig, PromptLearningConfig diff --git a/src/peft/auto.py b/src/peft/auto.py index 26f69e5cf5..080d329e77 100644 --- a/src/peft/auto.py +++ b/src/peft/auto.py @@ -27,6 +27,7 @@ AutoModelForTokenClassification, ) +from .config import PeftConfig from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING from .peft_model import ( PeftModel, @@ -37,7 +38,6 @@ PeftModelForSequenceClassification, PeftModelForTokenClassification, ) -from .utils import PeftConfig class _BaseAutoPeftModel: diff --git a/src/peft/utils/config.py b/src/peft/config.py similarity index 81% rename from src/peft/utils/config.py rename to src/peft/config.py index f5191f702e..8cb72b5a80 100644 --- a/src/peft/utils/config.py +++ b/src/peft/config.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import enum import inspect import json import os @@ -22,26 +21,7 @@ from huggingface_hub import hf_hub_download from transformers.utils import PushToHubMixin -from .other import CONFIG_NAME - - -class PeftType(str, enum.Enum): - PROMPT_TUNING = "PROMPT_TUNING" - P_TUNING = "P_TUNING" - PREFIX_TUNING = "PREFIX_TUNING" - LORA = "LORA" - ADALORA = "ADALORA" - ADAPTION_PROMPT = "ADAPTION_PROMPT" - IA3 = "IA3" - - -class TaskType(str, enum.Enum): - SEQ_CLS = "SEQ_CLS" - SEQ_2_SEQ_LM = "SEQ_2_SEQ_LM" - CAUSAL_LM = "CAUSAL_LM" - TOKEN_CLS = "TOKEN_CLS" - QUESTION_ANS = "QUESTION_ANS" - FEATURE_EXTRACTION = "FEATURE_EXTRACTION" +from .utils import CONFIG_NAME, PeftType, TaskType @dataclass @@ -102,6 +82,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, subfolder=None, **kwargs kwargs (additional keyword arguments, *optional*): Additional keyword arguments passed along to the child class initialization. """ + # Avoid circular dependency .. TODO: fix this with a larger refactor + from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING + path = ( os.path.join(pretrained_model_name_or_path, subfolder) if subfolder is not None @@ -122,7 +105,27 @@ def from_pretrained(cls, pretrained_model_name_or_path, subfolder=None, **kwargs loaded_attributes = cls.from_json_file(config_file) - config = cls(**class_kwargs) + # TODO: this hack is needed to fix the following issue (on commit 702f937): + # if someone saves a default config and loads it back with `PeftConfig` class it yields to + # not loading the correct config class. + + # from peft import AdaLoraConfig, PeftConfig + # peft_config = AdaLoraConfig() + # print(peft_config) + # >>> AdaLoraConfig(peft_type=, auto_mapping=None, base_model_name_or_path=None, + # revision=None, task_type=None, inference_mode=False, r=8, target_modules=None, lora_alpha=8, lora_dropout=0.0, ... + # + # peft_config.save_pretrained("./test_config") + # peft_config = PeftConfig.from_pretrained("./test_config") + # print(peft_config) + # >>> PeftConfig(peft_type='ADALORA', auto_mapping=None, base_model_name_or_path=None, revision=None, task_type=None, inference_mode=False) + if "peft_type" in loaded_attributes: + peft_type = loaded_attributes["peft_type"] + config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_type] + else: + config_cls = cls + + config = config_cls(**class_kwargs) for key, value in loaded_attributes.items(): if hasattr(config, key): @@ -185,6 +188,18 @@ def _get_peft_type( loaded_attributes = cls.from_json_file(config_file) return loaded_attributes["peft_type"] + @property + def is_prompt_learning(self): + r""" + Utility method to check if the configuration is for prompt learning. + """ + return False + + @property + def is_adaption_prompt(self) -> bool: + """Return True if this is an adaption prompt config.""" + return False + @dataclass class PeftConfig(PeftConfigMixin): @@ -227,3 +242,10 @@ class PromptLearningConfig(PeftConfig): ) num_attention_heads: Optional[int] = field(default=None, metadata={"help": "Number of attention heads"}) num_layers: Optional[int] = field(default=None, metadata={"help": "Number of transformer layers"}) + + @property + def is_prompt_learning(self): + r""" + Utility method to check if the configuration is for prompt learning. + """ + return True diff --git a/src/peft/mapping.py b/src/peft/mapping.py index 50fea66653..b0732fff4c 100644 --- a/src/peft/mapping.py +++ b/src/peft/mapping.py @@ -17,6 +17,9 @@ from typing import TYPE_CHECKING, Any, Dict +import torch + +from .config import PeftConfig from .peft_model import ( PeftModel, PeftModelForCausalLM, @@ -28,21 +31,22 @@ ) from .tuners import ( AdaLoraConfig, + AdaLoraModel, AdaptionPromptConfig, IA3Config, + IA3Model, LoraConfig, + LoraModel, PrefixTuningConfig, PromptEncoderConfig, PromptTuningConfig, ) -from .utils import PromptLearningConfig, _prepare_prompt_learning_config +from .utils import _prepare_prompt_learning_config if TYPE_CHECKING: from transformers import PreTrainedModel - from .utils.config import PeftConfig - MODEL_TYPE_TO_PEFT_MODEL_MAPPING = { "SEQ_CLS": PeftModelForSequenceClassification, @@ -63,6 +67,12 @@ "IA3": IA3Config, } +PEFT_TYPE_TO_TUNER_MAPPING = { + "LORA": LoraModel, + "ADALORA": AdaLoraModel, + "IA3": IA3Model, +} + def get_peft_config(config_dict: Dict[str, Any]): """ @@ -89,10 +99,38 @@ def get_peft_model(model: PreTrainedModel, peft_config: PeftConfig, adapter_name peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None) - if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not isinstance( - peft_config, PromptLearningConfig - ): + if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not peft_config.is_prompt_learning: return PeftModel(model, peft_config, adapter_name=adapter_name) - if isinstance(peft_config, PromptLearningConfig): + if peft_config.is_prompt_learning: peft_config = _prepare_prompt_learning_config(peft_config, model_config) return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](model, peft_config, adapter_name=adapter_name) + + +def inject_adapter_in_model(peft_config: PeftConfig, model: torch.nn.Module, adapter_name: str): + r""" + A simple API to create and inject adapter in-place into a model. Currently the API does not support prompt learning + methods and adaption prompt. Make sure to have the correct `target_names` set in the `peft_config` object. The API + calls `get_peft_model` under the hood but would be restricted only to non-prompt learning methods. + + Args: + peft_config (`PeftConfig`): + Configuration object containing the parameters of the Peft model. + model (`torch.nn.Module`): + The input model where the adapter will be injected. + adapter_name (`str`): + The name of the adapter to be injected. + """ + if peft_config.is_prompt_learning or peft_config.is_adaption_prompt: + raise ValueError("`create_and_replace` does not support prompt learning and adaption prompt yet.") + + if peft_config.peft_type not in PEFT_TYPE_TO_TUNER_MAPPING.keys(): + raise ValueError( + f"`inject_adapter_in_model` does not support {peft_config.peft_type} yet. Please use `get_peft_model`." + ) + + tuner_cls = PEFT_TYPE_TO_TUNER_MAPPING[peft_config.peft_type] + + # By instantiating a peft model we are injecting randomly initialized LoRA layers into the model's modules. + peft_model = tuner_cls(model, peft_config, adapter_name=adapter_name) + + return peft_model.model diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index c606cfa41e..a71059fc81 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -27,8 +27,6 @@ from accelerate.hooks import AlignDevicesHook, add_hook_to_module, remove_hook_from_submodules from accelerate.utils import get_balanced_memory from huggingface_hub import hf_hub_download -from huggingface_hub.utils import EntryNotFoundError -from safetensors.torch import load_file as safe_load_file from safetensors.torch import save_file as safe_save_file from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers import PreTrainedModel @@ -36,6 +34,7 @@ from transformers.utils import PushToHubMixin from . import __version__ +from .config import PeftConfig from .tuners import ( AdaLoraModel, AdaptionPromptModel, @@ -49,9 +48,7 @@ SAFETENSORS_WEIGHTS_NAME, TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, WEIGHTS_NAME, - PeftConfig, PeftType, - PromptLearningConfig, TaskType, _get_batch_size, _prepare_prompt_learning_config, @@ -59,8 +56,8 @@ _set_trainable, add_library_to_model_card, get_peft_model_state_dict, - hub_file_exists, infer_device, + load_peft_weights, set_peft_model_state_dict, shift_tokens_right, ) @@ -109,7 +106,7 @@ def __init__(self, model: PreTrainedModel, peft_config: PeftConfig, adapter_name self.peft_config = {} self.active_adapter = adapter_name self.peft_type = peft_config.peft_type - if not isinstance(peft_config, PromptLearningConfig): + if not peft_config.is_prompt_learning: self.peft_config[adapter_name] = peft_config self.base_model = PEFT_TYPE_TO_MODEL_MAPPING[peft_config.peft_type]( self.base_model, self.peft_config, adapter_name @@ -186,7 +183,7 @@ def save_pretrained( if peft_config.base_model_name_or_path is None: peft_config.base_model_name_or_path = ( self.base_model.__dict__.get("name_or_path", None) - if isinstance(peft_config, PromptLearningConfig) + if peft_config.is_prompt_learning else self.base_model.model.__dict__.get("name_or_path", None) ) inference_mode = peft_config.inference_mode @@ -195,7 +192,7 @@ def save_pretrained( if peft_config.task_type is None: # deal with auto mapping base_model_class = self._get_base_model_class( - is_prompt_tuning=isinstance(peft_config, PromptLearningConfig) + is_prompt_tuning=peft_config.is_prompt_learning, ) parent_library = base_model_class.__module__ @@ -267,7 +264,7 @@ def from_pretrained( ) > 0: remove_hook_from_submodules(model) - if isinstance(config, PromptLearningConfig) and is_trainable: + if config.is_prompt_learning and is_trainable: raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.") else: config.inference_mode = not is_trainable @@ -452,7 +449,7 @@ def disable_adapter(self): Disables the adapter module. """ try: - if isinstance(self.peft_config[self.active_adapter], PromptLearningConfig): + if self.peft_config[self.active_adapter].is_prompt_learning: # TODO: consider replacing this patching of methods with a more robust mechanism: setting a flag and # letting the underyling methods deal with it, same as how LoRA does it. old_forward = self.forward @@ -463,7 +460,7 @@ def disable_adapter(self): self.base_model.disable_adapter_layers() yield finally: - if isinstance(self.peft_config[self.active_adapter], PromptLearningConfig): + if self.peft_config[self.active_adapter].is_prompt_learning: self.forward = old_forward self.old_prepare_inputs_for_generation = old_prepare_inputs_for_generation else: @@ -473,7 +470,7 @@ def get_base_model(self): """ Returns the base model. """ - return self.base_model if isinstance(self.active_peft_config, PromptLearningConfig) else self.base_model.model + return self.base_model if self.active_peft_config.is_prompt_learning else self.base_model.model def add_adapter(self, adapter_name: str, peft_config: PeftConfig): if peft_config.peft_type != self.peft_type: @@ -482,7 +479,7 @@ def add_adapter(self, adapter_name: str, peft_config: PeftConfig): f"Found {self.peft_type} and {peft_config.peft_type}." ) self.peft_config[adapter_name] = peft_config - if isinstance(peft_config, PromptLearningConfig): + if peft_config.is_prompt_learning: if hasattr(self.config, "to_dict"): dict_config = self.config.to_dict() else: @@ -490,8 +487,10 @@ def add_adapter(self, adapter_name: str, peft_config: PeftConfig): peft_config = _prepare_prompt_learning_config(peft_config, dict_config) self._setup_prompt_encoder(adapter_name) - else: + elif peft_config.is_adaption_prompt: self.base_model.add_adapter(adapter_name, peft_config) + else: + self.inject_adapter(self, adapter_name) self.set_additional_trainable_modules(peft_config, adapter_name) @@ -534,54 +533,13 @@ def load_adapter(self, model_id: str, adapter_name: str, is_trainable: bool = Fa model_id, **hf_hub_download_kwargs, ) - if isinstance(peft_config, PromptLearningConfig) and is_trainable: + if peft_config.is_prompt_learning and is_trainable: raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.") else: peft_config.inference_mode = not is_trainable self.add_adapter(adapter_name, peft_config) - # load weights if any - path = ( - os.path.join(model_id, hf_hub_download_kwargs["subfolder"]) - if hf_hub_download_kwargs.get("subfolder", None) is not None - else model_id - ) - - if os.path.exists(os.path.join(path, SAFETENSORS_WEIGHTS_NAME)): - filename = os.path.join(path, SAFETENSORS_WEIGHTS_NAME) - use_safetensors = True - elif os.path.exists(os.path.join(path, WEIGHTS_NAME)): - filename = os.path.join(path, WEIGHTS_NAME) - use_safetensors = False - else: - has_remote_safetensors_file = hub_file_exists( - model_id, - SAFETENSORS_WEIGHTS_NAME, - revision=hf_hub_download_kwargs.get("revision", None), - repo_type=hf_hub_download_kwargs.get("repo_type", None), - ) - use_safetensors = has_remote_safetensors_file - - if has_remote_safetensors_file: - # Priority 1: load safetensors weights - filename = hf_hub_download( - model_id, - SAFETENSORS_WEIGHTS_NAME, - **hf_hub_download_kwargs, - ) - else: - try: - filename = hf_hub_download(model_id, WEIGHTS_NAME, **hf_hub_download_kwargs) - except EntryNotFoundError: - raise ValueError( - f"Can't find weights for {model_id} in {model_id} or in the Hugging Face Hub. " - f"Please check that the file {WEIGHTS_NAME} or {SAFETENSORS_WEIGHTS_NAME} is present at {model_id}." - ) - - if use_safetensors: - adapters_weights = safe_load_file(filename, device=torch_device) - else: - adapters_weights = torch.load(filename, map_location=torch.device(torch_device)) + adapters_weights = load_peft_weights(model_id, device=torch_device, **hf_hub_download_kwargs) # load the weights into the model load_result = set_peft_model_state_dict(self, adapters_weights, adapter_name=adapter_name) @@ -621,7 +579,7 @@ def load_adapter(self, model_id: str, adapter_name: str, is_trainable: bool = Fa **dispatch_model_kwargs, ) hook = AlignDevicesHook(io_same_device=True) - if isinstance(self.peft_config[adapter_name], PromptLearningConfig): + if self.peft_config[adapter_name].is_prompt_learning: remove_hook_from_submodules(self.prompt_encoder) add_hook_to_module(self.get_base_model(), hook) @@ -637,7 +595,7 @@ def set_adapter(self, adapter_name: str): if adapter_name not in self.peft_config: raise ValueError(f"Adapter {adapter_name} not found.") self.active_adapter = adapter_name - if not isinstance(self.peft_config[adapter_name], PromptLearningConfig): + if not self.peft_config[adapter_name].is_prompt_learning: self.base_model.set_adapter(adapter_name) _set_adapter(self, adapter_name) @@ -758,7 +716,7 @@ def forward( ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict peft_config = self.active_peft_config - if not isinstance(peft_config, PromptLearningConfig): + if not peft_config.is_prompt_learning: return self.base_model( input_ids=input_ids, attention_mask=attention_mask, @@ -931,7 +889,7 @@ def forward( **kwargs, ): peft_config = self.active_peft_config - if not isinstance(peft_config, PromptLearningConfig): + if not peft_config.is_prompt_learning: if self.base_model.config.model_type == "mpt": if inputs_embeds is not None: raise AssertionError("forward in MPTForCausalLM does not support inputs_embeds") @@ -1013,7 +971,7 @@ def generate(self, **kwargs): def prepare_inputs_for_generation(self, *args, **kwargs): peft_config = self.active_peft_config model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs) - if isinstance(peft_config, PromptLearningConfig): + if peft_config.is_prompt_learning: if model_kwargs.get("attention_mask", None) is not None: prefix_attention_mask = torch.ones( model_kwargs["input_ids"].shape[0], peft_config.num_virtual_tokens @@ -1104,7 +1062,7 @@ def forward( **kwargs, ): peft_config = self.active_peft_config - if not isinstance(peft_config, PromptLearningConfig): + if not peft_config.is_prompt_learning: return self.base_model( input_ids=input_ids, attention_mask=attention_mask, @@ -1216,7 +1174,7 @@ def generate(self, **kwargs): self._prepare_encoder_decoder_kwargs_for_generation ) try: - if not isinstance(peft_config, PromptLearningConfig): + if not peft_config.is_prompt_learning: outputs = self.base_model.generate(**kwargs) else: if "input_ids" not in kwargs: @@ -1354,7 +1312,7 @@ def forward( peft_config = self.active_peft_config return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if not isinstance(peft_config, PromptLearningConfig): + if not peft_config.is_prompt_learning: return self.base_model( input_ids=input_ids, attention_mask=attention_mask, @@ -1527,7 +1485,7 @@ def forward( peft_config = self.active_peft_config return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if not isinstance(peft_config, PromptLearningConfig): + if not peft_config.is_prompt_learning: return self.base_model( input_ids=input_ids, attention_mask=attention_mask, @@ -1698,7 +1656,7 @@ def forward( **kwargs, ): peft_config = self.active_peft_config - if not isinstance(peft_config, PromptLearningConfig): + if not peft_config.is_prompt_learning: return self.base_model( input_ids=input_ids, attention_mask=attention_mask, diff --git a/src/peft/tuners/__init__.py b/src/peft/tuners/__init__.py index eeaa81d607..0c661feedd 100644 --- a/src/peft/tuners/__init__.py +++ b/src/peft/tuners/__init__.py @@ -24,3 +24,10 @@ from .p_tuning import PromptEncoder, PromptEncoderConfig, PromptEncoderReparameterizationType from .prefix_tuning import PrefixEncoder, PrefixTuningConfig from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit + +# Mapping of tuners that support direct plugging +TUNERS_MAPPING = { + "LORA": LoraModel, + "IA3": IA3Model, + "ADALORA": AdaLoraModel, +} diff --git a/src/peft/tuners/adalora.py b/src/peft/tuners/adalora.py index f31c33fc30..b494373030 100644 --- a/src/peft/tuners/adalora.py +++ b/src/peft/tuners/adalora.py @@ -1,4 +1,3 @@ -import re import warnings from dataclasses import dataclass, field from typing import Optional @@ -20,7 +19,6 @@ LoraConfig, LoraLayer, LoraModel, - mark_only_lora_as_trainable, ) @@ -61,6 +59,50 @@ def __post_init__(self): self.peft_type = PeftType.ADALORA +class AdaLoraLayer(LoraLayer): + def __init__( + self, + in_features: int, + out_features: int, + ): + super().__init__(in_features, out_features) + self.lora_E = nn.ParameterDict({}) + self.lora_A = nn.ParameterDict({}) + self.lora_B = nn.ParameterDict({}) + self.ranknum = nn.ParameterDict({}) + + def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights): + self.r[adapter_name] = r + self.lora_alpha[adapter_name] = lora_alpha + if lora_dropout > 0.0: + lora_dropout_layer = nn.Dropout(p=lora_dropout) + else: + lora_dropout_layer = nn.Identity() + + self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer})) + # Actual trainable parameters + # Right singular vectors + self.lora_A.update(nn.ParameterDict({adapter_name: nn.Parameter(torch.randn(r, self.in_features))})) + # Singular values + self.lora_E.update(nn.ParameterDict({adapter_name: nn.Parameter(torch.randn(r, 1))})) + # Left singular vectors + self.lora_B.update(nn.ParameterDict({adapter_name: nn.Parameter(torch.randn(self.out_features, r))})) + # The current rank + self.ranknum.update(nn.ParameterDict({adapter_name: nn.Parameter(torch.randn(1), requires_grad=False)})) + self.ranknum[adapter_name].data.fill_(float(r)) + self.ranknum[adapter_name].requires_grad = False + self.scaling[adapter_name] = lora_alpha if lora_alpha > 0 else float(r) + if init_lora_weights: + self.reset_lora_parameters(adapter_name) + self.to(self.weight.device) + + def reset_lora_parameters(self, adapter_name): + if adapter_name in self.lora_A.keys(): + nn.init.normal_(self.lora_E[adapter_name], mean=0.0, std=0.02) + nn.init.normal_(self.lora_A[adapter_name], mean=0.0, std=0.02) + nn.init.normal_(self.lora_B[adapter_name], mean=0.0, std=0.02) + + class AdaLoraModel(LoraModel): """ Creates AdaLoRA (Adaptive LoRA) model from a pretrained transformers model. Paper: @@ -88,17 +130,8 @@ class AdaLoraModel(LoraModel): """ def __init__(self, model, config, adapter_name): - nn.Module.__init__(self) - self.model = model - self.peft_config = config - self.add_adapter(adapter_name, self.peft_config[adapter_name]) - - def add_adapter(self, adapter_name, config=None): - if config is not None: - model_config = self.model.config.to_dict() if hasattr(self.model.config, "to_dict") else self.model.config - config = self._prepare_adalora_config(config, model_config) - self.peft_config[adapter_name] = config - self._find_and_replace(adapter_name) + super().__init__(model, config, adapter_name) + if len(self.peft_config) > 1 and self.peft_config[adapter_name].bias != "none": raise ValueError( "AdaLoraModel supports only 1 adapter with bias. When using multiple adapters, set bias to 'none' for all adapters." @@ -114,111 +147,118 @@ def add_adapter(self, adapter_name, config=None): "When using multiple adapters, set inference_mode to True for all adapters except the one you want to train." ) - mark_only_lora_as_trainable(self.model, self.peft_config[adapter_name].bias) if self.peft_config[adapter_name].inference_mode: _freeze_adapter(self.model, adapter_name) else: self.trainable_adapter_name = adapter_name self.rankallocator = RankAllocator(self.model, self.peft_config[adapter_name], self.trainable_adapter_name) - def _find_and_replace(self, adapter_name): - lora_config = self.peft_config[adapter_name] - loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False) - loaded_in_4bit = getattr(self.model, "is_loaded_in_4bit", False) + def _create_and_replace( + self, + lora_config, + adapter_name, + target, + target_name, + parent, + **optionnal_kwargs, + ): + loaded_in_8bit = optionnal_kwargs.get("loaded_in_8bit", False) + loaded_in_4bit = optionnal_kwargs.get("loaded_in_4bit", False) if (loaded_in_8bit or loaded_in_4bit) and not is_bnb_available(): raise ImportError( "To use Lora with 8-bit quantization, please install the `bitsandbytes` package. " "You can install it with `pip install bitsandbytes`." ) - is_target_modules_in_base_model = False kwargs = { "r": lora_config.init_r, "lora_alpha": lora_config.lora_alpha, "lora_dropout": lora_config.lora_dropout, "fan_in_fan_out": lora_config.fan_in_fan_out, "init_lora_weights": lora_config.init_lora_weights, + "loaded_in_8bit": loaded_in_8bit, + "loaded_in_4bit": loaded_in_4bit, } - key_list = [key for key, _ in self.model.named_modules()] - for key in key_list: - if isinstance(lora_config.target_modules, str): - target_module_found = re.fullmatch(lora_config.target_modules, key) - else: - target_module_found = any(key.endswith(target_key) for target_key in lora_config.target_modules) - if target_module_found: - if not is_target_modules_in_base_model: - is_target_modules_in_base_model = True - parent, target, target_name = _get_submodules(self.model, key) - bias = hasattr(target, "bias") and target.bias is not None - if isinstance(target, LoraLayer): - target.update_layer( - adapter_name, - lora_config.init_r, - lora_config.lora_alpha, - lora_config.lora_dropout, - lora_config.init_lora_weights, - ) - else: - if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt): - kwargs.update( - { - "has_fp16_weights": target.state.has_fp16_weights, - "memory_efficient_backward": target.state.memory_efficient_backward, - "threshold": target.state.threshold, - "index": target.index, - } - ) - new_module = SVDLinear8bitLt( - adapter_name, target.in_features, target.out_features, bias=bias, **kwargs - ) - elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target, bnb.nn.Linear4bit): - fourbit_kwargs = kwargs.copy() - fourbit_kwargs.update( - { - "compute_dtype": target.compute_dtype, - "compress_statistics": target.weight.compress_statistics, - "quant_type": target.weight.quant_type, - } - ) - new_module = SVDLinear4bit( - adapter_name, target.in_features, target.out_features, bias=bias, **fourbit_kwargs - ) - elif isinstance(target, (nn.ModuleList, nn.ModuleDict)): - # it's not applicable to replace whole module lists or module dicts - continue - else: - if isinstance(target, torch.nn.Linear): - in_features, out_features = target.in_features, target.out_features - if kwargs["fan_in_fan_out"]: - warnings.warn( - "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " - "Setting fan_in_fan_out to False." - ) - kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False - elif isinstance(target, Conv1D): - in_features, out_features = ( - target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape - ) - if not kwargs["fan_in_fan_out"]: - warnings.warn( - "fan_in_fan_out is set to False but the target module is `Conv1D`. " - "Setting fan_in_fan_out to True." - ) - kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True - else: - raise ValueError( - f"Target module {target} is not supported. " - f"Currently, only `torch.nn.Linear` and `Conv1D` are supported." - ) - new_module = SVDLinear(adapter_name, in_features, out_features, bias=bias, **kwargs) - - self._replace_module(parent, target_name, new_module, target) - if not is_target_modules_in_base_model: - raise ValueError( - f"Target modules {lora_config.target_modules} not found in the base model. " - f"Please check the target modules and try again." + + # If it is not a LoraLayer, create a new module, else update it with new adapters + if not isinstance(target, AdaLoraLayer): + new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs) + self._replace_module(parent, target_name, new_module, target) + else: + target.update_layer( + adapter_name, + lora_config.init_r, + lora_config.lora_alpha, + lora_config.lora_dropout, + lora_config.init_lora_weights, ) + @staticmethod + def _create_new_module(lora_config, adapter_name, target, **kwargs): + bias = target.bias is not None + loaded_in_8bit = kwargs.pop("loaded_in_8bit", False) + loaded_in_4bit = kwargs.pop("loaded_in_4bit", False) + + if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt): + kwargs.update( + { + "has_fp16_weights": target.state.has_fp16_weights, + "memory_efficient_backward": target.state.memory_efficient_backward, + "threshold": target.state.threshold, + "index": target.index, + } + ) + new_module = SVDLinear8bitLt(adapter_name, target.in_features, target.out_features, bias=bias, **kwargs) + elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target, bnb.nn.Linear4bit): + fourbit_kwargs = kwargs.copy() + fourbit_kwargs.update( + { + "compute_dtype": target.compute_dtype, + "compress_statistics": target.weight.compress_statistics, + "quant_type": target.weight.quant_type, + } + ) + new_module = SVDLinear4bit( + adapter_name, target.in_features, target.out_features, bias=bias, **fourbit_kwargs + ) + else: + if isinstance(target, torch.nn.Linear): + in_features, out_features = target.in_features, target.out_features + if kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " + "Setting fan_in_fan_out to False." + ) + kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False + elif isinstance(target, Conv1D): + in_features, out_features = ( + target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape + ) + if not kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to False but the target module is `Conv1D`. " + "Setting fan_in_fan_out to True." + ) + kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True + else: + raise ValueError( + f"Target module {target} is not supported. " + f"Currently, only `torch.nn.Linear` and `Conv1D` are supported." + ) + new_module = SVDLinear(adapter_name, in_features, out_features, bias=bias, **kwargs) + + return new_module + + @staticmethod + def _prepare_adapter_config(peft_config, model_config): + if peft_config.target_modules is None: + if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING: + raise ValueError("Please specify `target_modules` in `peft_config`") + peft_config.target_modules = TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING[ + model_config["model_type"] + ] + return peft_config + def __getattr__(self, name: str): """Forward missing attributes to the wrapped module.""" try: @@ -321,60 +361,6 @@ def update_and_allocate(self, global_step): else: return None - @staticmethod - def _prepare_adalora_config(peft_config, model_config): - if peft_config.target_modules is None: - if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING: - raise ValueError("Please specify `target_modules` in `peft_config`") - peft_config.target_modules = TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING[ - model_config["model_type"] - ] - return peft_config - - -class AdaLoraLayer(LoraLayer): - def __init__( - self, - in_features: int, - out_features: int, - ): - super().__init__(in_features, out_features) - self.lora_E = nn.ParameterDict({}) - self.lora_A = nn.ParameterDict({}) - self.lora_B = nn.ParameterDict({}) - self.ranknum = nn.ParameterDict({}) - - def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights): - self.r[adapter_name] = r - self.lora_alpha[adapter_name] = lora_alpha - if lora_dropout > 0.0: - lora_dropout_layer = nn.Dropout(p=lora_dropout) - else: - lora_dropout_layer = nn.Identity() - - self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer})) - # Actual trainable parameters - # Right singular vectors - self.lora_A.update(nn.ParameterDict({adapter_name: nn.Parameter(torch.randn(r, self.in_features))})) - # Singular values - self.lora_E.update(nn.ParameterDict({adapter_name: nn.Parameter(torch.randn(r, 1))})) - # Left singular vectors - self.lora_B.update(nn.ParameterDict({adapter_name: nn.Parameter(torch.randn(self.out_features, r))})) - # The current rank - self.ranknum.update(nn.ParameterDict({adapter_name: nn.Parameter(torch.zeros(1), requires_grad=False)})) - self.ranknum[adapter_name].data.fill_(float(r)) - self.ranknum[adapter_name].requires_grad = False - self.scaling[adapter_name] = lora_alpha if lora_alpha > 0 else float(r) - if init_lora_weights: - self.reset_lora_parameters(adapter_name) - self.to(self.weight.device) - - def reset_lora_parameters(self, adapter_name): - if adapter_name in self.lora_A.keys(): - nn.init.zeros_(self.lora_E[adapter_name]) - nn.init.normal_(self.lora_A[adapter_name], mean=0.0, std=0.02) - nn.init.normal_(self.lora_B[adapter_name], mean=0.0, std=0.02) - class SVDLinear(nn.Linear, AdaLoraLayer): # SVD-based adaptation by a dense layer diff --git a/src/peft/tuners/adaption_prompt.py b/src/peft/tuners/adaption_prompt.py index 3ce7d8b7fa..042f5e4c96 100644 --- a/src/peft/tuners/adaption_prompt.py +++ b/src/peft/tuners/adaption_prompt.py @@ -22,8 +22,8 @@ import torch.nn as nn import torch.nn.functional as F -from peft.utils.config import PeftConfig, PeftType -from peft.utils.other import _freeze_adapter, _get_submodules +from ..config import PeftConfig +from ..utils import PeftType, _freeze_adapter, _get_submodules def llama_rotate_half(x: torch.Tensor) -> torch.Tensor: @@ -114,6 +114,11 @@ class AdaptionPromptConfig(PeftConfig): def __post_init__(self): self.peft_type = PeftType.ADAPTION_PROMPT + @property + def is_adaption_prompt(self) -> bool: + """Return True if this is an adaption prompt config.""" + return True + def prepare_config( peft_config: AdaptionPromptConfig, diff --git a/src/peft/tuners/ia3.py b/src/peft/tuners/ia3.py index 245fb7f1c0..5843de3449 100644 --- a/src/peft/tuners/ia3.py +++ b/src/peft/tuners/ia3.py @@ -23,17 +23,18 @@ import torch.nn.functional as F from transformers.pytorch_utils import Conv1D +from ..config import PeftConfig from ..import_utils import is_bnb_available from ..utils import ( TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING, ModulesToSaveWrapper, - PeftConfig, PeftType, - _freeze_adapter, _get_submodules, + _is_valid_match, transpose, ) +from .tuners_utils import BaseTuner, BaseTunerLayer if is_bnb_available(): @@ -91,7 +92,40 @@ def __post_init__(self): self.peft_type = PeftType.IA3 -class IA3Model(torch.nn.Module): +class IA3Layer(BaseTunerLayer): + def __init__( + self, + in_features: int, + out_features: int, + is_feedforward: bool, + ): + self.scaling = {} + self.ia3_l = nn.ParameterDict({}) + # Mark the weight as unmerged + self.merged = False + self.disable_adapters = False + self.in_features = in_features + self.out_features = out_features + self.is_feedforward = is_feedforward + + def update_layer(self, adapter_name, init_ia3_weights): + # Actual trainable parameters + if self.is_feedforward: + weight = torch.randn((1, self.in_features)) + else: + weight = torch.randn((self.out_features, 1)) + self.ia3_l.update(nn.ParameterDict({adapter_name: nn.Parameter(weight)})) + if init_ia3_weights: + self.reset_ia3_parameters(adapter_name) + self.to(self.weight.device) + + def reset_ia3_parameters(self, adapter_name): + if adapter_name in self.ia3_l.keys(): + # initialize learned vector with torch.ones + nn.init.constant_(self.ia3_l[adapter_name], 1.0) + + +class IA3Model(BaseTuner): """ Creates a Infused Adapter by Inhibiting and Amplifying Inner Activations ((IA)^3) model from a pretrained transformers model. The method is described in detail in https://arxiv.org/abs/2205.05638 @@ -126,43 +160,13 @@ class IA3Model(torch.nn.Module): """ def __init__(self, model, config, adapter_name): - super().__init__() - self.model = model - self.forward = self.model.forward - self.peft_config = config - self.add_adapter(adapter_name, self.peft_config[adapter_name]) - - def add_adapter(self, adapter_name, config=None): - if config is not None: - model_config = self.model.config.to_dict() if hasattr(self.model.config, "to_dict") else self.model.config - config = self._prepare_ia3_config(config, model_config) - self.peft_config[adapter_name] = config - self._find_and_replace(adapter_name) - - mark_only_ia3_as_trainable(self.model) - if self.peft_config[adapter_name].inference_mode: - _freeze_adapter(self.model, adapter_name) - - def _check_quantization_dependency(self): - loaded_in_4bit = getattr(self.model, "is_loaded_in_4bit", False) - if loaded_in_4bit: - raise NotImplementedError( - "4-bit quantization is not supported for IA3 yet, 8-bit quantization can be used instead." - ) - loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False) - if loaded_in_8bit and not is_bnb_available(): - raise ImportError( - "To use (IA)^3 with 8-bit quantization, please install the `bitsandbytes` package. " - "You can install it with `pip install bitsandbytes`." - ) + super().__init__(model, config, adapter_name) - def _create_new_module(self, ia3_config, adapter_name, target, is_feedforward): - kwargs = { - "fan_in_fan_out": ia3_config.fan_in_fan_out, - "init_ia3_weights": ia3_config.init_ia3_weights, - } + @staticmethod + def _create_new_module(ia3_config, adapter_name, target, **kwargs): bias = hasattr(target, "bias") and target.bias is not None - loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False) + loaded_in_8bit = kwargs.pop("loaded_in_8bit", False) + is_feedforward = kwargs.pop("is_feedforward", False) if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt): eightbit_kwargs = kwargs.copy() @@ -213,75 +217,67 @@ def _create_new_module(self, ia3_config, adapter_name, target, is_feedforward): ) return new_module - def _check_target_module_exists(self, ia3_config, key): + @staticmethod + def _check_target_module_exists(ia3_config, key): if isinstance(ia3_config.target_modules, str): target_module_found = re.fullmatch(ia3_config.target_modules, key) else: - target_module_found = any( - self._is_valid_match(key, target_key) for target_key in ia3_config.target_modules - ) + target_module_found = any(_is_valid_match(key, target_key) for target_key in ia3_config.target_modules) return target_module_found - def _find_and_replace(self, adapter_name): - ia3_config = self.peft_config[adapter_name] - if not ia3_config.feedforward_modules: - ia3_config.feedforward_modules = [] # convert to list if None - self._check_quantization_dependency() - is_target_modules_in_base_model = False + def _mark_only_adapters_as_trainable(self) -> None: + for n, p in self.model.named_parameters(): + if "ia3_" not in n: + p.requires_grad = False - key_list = [key for key, _ in self.model.named_modules()] - for key in key_list: - if not self._check_target_module_exists(ia3_config, key): - continue - # check if target module is in feedforward_modules - if isinstance(ia3_config.feedforward_modules, str): - is_feedforward = re.fullmatch(ia3_config.feedforward_modules, key) - else: - is_feedforward = any(key.endswith(target_key) for target_key in ia3_config.feedforward_modules) + def _create_and_replace( + self, + ia3_config, + adapter_name, + target, + target_name, + parent, + **optionnal_kwargs, + ): + loaded_in_8bit = optionnal_kwargs["loaded_in_8bit"] + current_key = optionnal_kwargs["current_key"] - if not is_target_modules_in_base_model: - is_target_modules_in_base_model = True - parent, target, target_name = _get_submodules(self.model, key) + # check if target module is in feedforward_modules + if isinstance(ia3_config.feedforward_modules, str): + is_feedforward = re.fullmatch(ia3_config.feedforward_modules, current_key) + else: + is_feedforward = any(current_key.endswith(target_key) for target_key in ia3_config.feedforward_modules) - if isinstance(target, IA3Layer): - target.update_layer( - adapter_name, - ia3_config.init_ia3_weights, - ) - else: - new_module = self._create_new_module(ia3_config, adapter_name, target, is_feedforward) - self._replace_module(parent, target_name, new_module, target) - if not is_target_modules_in_base_model: - raise ValueError( - f"Target modules {ia3_config.target_modules} not found in the base model. " - f"Please check the target modules and try again." + kwargs = { + "fan_in_fan_out": ia3_config.fan_in_fan_out, + "init_ia3_weights": ia3_config.init_ia3_weights, + "loaded_in_8bit": loaded_in_8bit, + "is_feedforward": is_feedforward, + } + + if isinstance(target, IA3Layer): + target.update_layer( + adapter_name, + ia3_config.init_ia3_weights, ) + else: + new_module = self._create_new_module(ia3_config, adapter_name, target, **kwargs) + self._replace_module(parent, target_name, new_module, target) @staticmethod - def _is_valid_match(key: str, target_key: str): - """ - Helper function to match module names target_key and key. Makes sure that either the key is exactly the - target_key or the target_key is a submodule of key - """ - if key.endswith(target_key): - if len(key) > len(target_key): - return key.endswith("." + target_key) # must be a sub module - return True - return False - - def _replace_module(self, parent_module, child_name, new_module, old_module): - setattr(parent_module, child_name, new_module) - new_module.weight = old_module.weight - if old_module.bias is not None: - new_module.bias = old_module.bias - if getattr(old_module, "state", None) is not None: - new_module.state = old_module.state - new_module.to(old_module.weight.device) + def _replace_module(parent, child_name, new_module, child): + setattr(parent, child_name, new_module) + new_module.weight = child.weight + if child.bias is not None: + new_module.bias = child.bias + if getattr(child, "state", None) is not None: + new_module.state = child.state + new_module.to(child.weight.device) # dispatch to correct device for name, module in new_module.named_modules(): if "ia3_" in name: - module.to(old_module.weight.device) + module.to(child.weight.device) def __getattr__(self, name: str): """Forward missing attributes to the wrapped module.""" @@ -320,8 +316,7 @@ def set_adapter(self, adapter_name): module.unmerge() module.active_adapter = adapter_name - @staticmethod - def _prepare_ia3_config(peft_config, model_config): + def _prepare_adapter_config(self, peft_config, model_config): if peft_config.target_modules is None: if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING: raise ValueError("Please specify `target_modules` in `peft_config`") @@ -374,45 +369,6 @@ def merge_and_unload(self): # ------------------------------------------------------------------------------------------ -def mark_only_ia3_as_trainable(model: nn.Module) -> None: - for n, p in model.named_parameters(): - if "ia3_" not in n: - p.requires_grad = False - - -class IA3Layer: - def __init__( - self, - in_features: int, - out_features: int, - is_feedforward: bool, - ): - self.scaling = {} - self.ia3_l = nn.ParameterDict({}) - # Mark the weight as unmerged - self.merged = False - self.disable_adapters = False - self.in_features = in_features - self.out_features = out_features - self.is_feedforward = is_feedforward - - def update_layer(self, adapter_name, init_ia3_weights): - # Actual trainable parameters - if self.is_feedforward: - weight = torch.randn((1, self.in_features)) - else: - weight = torch.randn((self.out_features, 1)) - self.ia3_l.update(nn.ParameterDict({adapter_name: nn.Parameter(weight)})) - if init_ia3_weights: - self.reset_ia3_parameters(adapter_name) - self.to(self.weight.device) - - def reset_ia3_parameters(self, adapter_name): - if adapter_name in self.ia3_l.keys(): - # initialize learned vector with torch.ones - nn.init.constant_(self.ia3_l[adapter_name], 1.0) - - class Linear(nn.Linear, IA3Layer): # (IA)^3 implemented in a dense layer def __init__( diff --git a/src/peft/tuners/lora.py b/src/peft/tuners/lora.py index f83816cddb..7c78993e66 100644 --- a/src/peft/tuners/lora.py +++ b/src/peft/tuners/lora.py @@ -25,18 +25,19 @@ from tqdm import tqdm from transformers.pytorch_utils import Conv1D +from ..config import PeftConfig from ..import_utils import is_bnb_4bit_available, is_bnb_available from ..utils import ( CLAMP_QUANTILE, COMMON_LAYERS_PATTERN, TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, ModulesToSaveWrapper, - PeftConfig, PeftType, _freeze_adapter, _get_submodules, transpose, ) +from .tuners_utils import BaseTuner, BaseTunerLayer if is_bnb_available(): @@ -119,7 +120,99 @@ def __post_init__(self): self.peft_type = PeftType.LORA -class LoraModel(torch.nn.Module): +class LoraLayer(BaseTunerLayer): + def __init__(self, in_features: int, out_features: int, **kwargs): + self.r = {} + self.lora_alpha = {} + self.scaling = {} + self.lora_dropout = nn.ModuleDict({}) + self.lora_A = nn.ModuleDict({}) + self.lora_B = nn.ModuleDict({}) + # For Embedding layer + self.lora_embedding_A = nn.ParameterDict({}) + self.lora_embedding_B = nn.ParameterDict({}) + # Mark the weight as unmerged + self.merged = False + self.disable_adapters = False + self.in_features = in_features + self.out_features = out_features + self.kwargs = kwargs + + def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights): + self.r[adapter_name] = r + self.lora_alpha[adapter_name] = lora_alpha + if lora_dropout > 0.0: + lora_dropout_layer = nn.Dropout(p=lora_dropout) + else: + lora_dropout_layer = nn.Identity() + + self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer})) + # Actual trainable parameters + if r > 0: + self.lora_A.update(nn.ModuleDict({adapter_name: nn.Linear(self.in_features, r, bias=False)})) + self.lora_B.update(nn.ModuleDict({adapter_name: nn.Linear(r, self.out_features, bias=False)})) + self.scaling[adapter_name] = lora_alpha / r + if init_lora_weights: + self.reset_lora_parameters(adapter_name) + self.to(self.weight.device) + + def update_layer_conv2d(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights): + self.r[adapter_name] = r + self.lora_alpha[adapter_name] = lora_alpha + if lora_dropout > 0.0: + lora_dropout_layer = nn.Dropout(p=lora_dropout) + else: + lora_dropout_layer = nn.Identity() + + self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer})) + # Actual trainable parameters + if r > 0: + kernel_size = self.kwargs["kernel_size"] + stride = self.kwargs["stride"] + padding = self.kwargs["padding"] + self.lora_A.update( + nn.ModuleDict({adapter_name: nn.Conv2d(self.in_features, r, kernel_size, stride, padding, bias=False)}) + ) + self.lora_B.update( + nn.ModuleDict({adapter_name: nn.Conv2d(r, self.out_features, (1, 1), (1, 1), bias=False)}) + ) + self.scaling[adapter_name] = lora_alpha / r + if init_lora_weights: + self.reset_lora_parameters(adapter_name) + self.to(self.weight.device) + + def update_layer_embedding(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights): + self.r[adapter_name] = r + self.lora_alpha[adapter_name] = lora_alpha + if lora_dropout > 0.0: + lora_dropout_layer = nn.Dropout(p=lora_dropout) + else: + lora_dropout_layer = nn.Identity() + + self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer})) + # Actual trainable parameters + if r > 0: + weight_A = torch.randn((r, self.in_features), dtype=self.weight.dtype, device=self.weight.device) + weight_B = torch.randn((self.out_features, r), dtype=self.weight.dtype, device=self.weight.device) + self.lora_embedding_A.update(nn.ParameterDict({adapter_name: nn.Parameter(weight_A)})) + self.lora_embedding_B.update(nn.ParameterDict({adapter_name: nn.Parameter(weight_B)})) + self.scaling[adapter_name] = lora_alpha / r + if init_lora_weights: + self.reset_lora_parameters(adapter_name) + self.to(self.weight.device) + + def reset_lora_parameters(self, adapter_name): + if adapter_name in self.lora_A.keys(): + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.kaiming_uniform_(self.lora_A[adapter_name].weight, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B[adapter_name].weight) + if adapter_name in self.lora_embedding_A.keys(): + # initialize a the same way as the default for nn.linear and b to zero + nn.init.zeros_(self.lora_embedding_A[adapter_name]) + nn.init.normal_(self.lora_embedding_B[adapter_name]) + + +class LoraModel(BaseTuner): """ Creates Low Rank Adapter (Lora) model from a pretrained transformers model. @@ -176,47 +269,16 @@ class LoraModel(torch.nn.Module): """ def __init__(self, model, config, adapter_name): - super().__init__() - self.model = model - self.forward = self.model.forward - self.peft_config = config - self.add_adapter(adapter_name, self.peft_config[adapter_name]) - - # transformers models have a .config attribute, whose presence is assumed later on - if not hasattr(self, "config"): - self.config = {"model_type": "custom"} - - def add_adapter(self, adapter_name, config=None): - if config is not None: - model_config = getattr(self.model, "config", {"model_type": "custom"}) - if hasattr(model_config, "to_dict"): - model_config = model_config.to_dict() - - config = self._prepare_lora_config(config, model_config) - self.peft_config[adapter_name] = config - self._find_and_replace(adapter_name) - if len(self.peft_config) > 1 and self.peft_config[adapter_name].bias != "none": - raise ValueError( - "LoraModel supports only 1 adapter with bias. When using multiple adapters, set bias to 'none' for all adapters." - ) - mark_only_lora_as_trainable(self.model, self.peft_config[adapter_name].bias) - if self.peft_config[adapter_name].inference_mode: - _freeze_adapter(self.model, adapter_name) - - def _check_quantization_dependency(self): - loaded_in_4bit = getattr(self.model, "is_loaded_in_4bit", False) - loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False) - if (loaded_in_4bit or loaded_in_8bit) and not is_bnb_available(): - raise ImportError( - "To use Lora with 8-bit or 4-bit quantization, please install the `bitsandbytes` package. " - "You can install it with `pip install bitsandbytes`." - ) + super().__init__(model, config, adapter_name) - def _check_target_module_exists(self, lora_config, key): + @staticmethod + def _check_target_module_exists(lora_config, key): if isinstance(lora_config.target_modules, str): target_module_found = re.fullmatch(lora_config.target_modules, key) else: - target_module_found = any(key.endswith(target_key) for target_key in lora_config.target_modules) + target_module_found = any( + re.match(f".*\.{target_key}$", key) for target_key in lora_config.target_modules + ) or any(target_key == key for target_key in lora_config.target_modules) is_using_layer_indexes = getattr(lora_config, "layers_to_transform", None) is not None layer_indexing_pattern = getattr(lora_config, "layers_pattern", None) @@ -238,7 +300,15 @@ def _check_target_module_exists(self, lora_config, key): target_module_found = False return target_module_found - def _create_new_module(self, lora_config, adapter_name, target): + def _create_and_replace( + self, + lora_config, + adapter_name, + target, + target_name, + parent, + **optionnal_kwargs, + ): bias = hasattr(target, "bias") and target.bias is not None kwargs = { "r": lora_config.r, @@ -247,8 +317,85 @@ def _create_new_module(self, lora_config, adapter_name, target): "fan_in_fan_out": lora_config.fan_in_fan_out, "init_lora_weights": lora_config.init_lora_weights, } - loaded_in_4bit = getattr(self.model, "is_loaded_in_4bit", False) - loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False) + + kwargs["loaded_in_8bit"] = optionnal_kwargs.pop("loaded_in_8bit", False) + kwargs["loaded_in_4bit"] = optionnal_kwargs.pop("loaded_in_4bit", False) + kwargs["bias"] = bias + + # TODO: better deal with that + if isinstance(target, LoraLayer) and isinstance(target, torch.nn.Conv2d): + target.update_layer_conv2d( + adapter_name, + lora_config.r, + lora_config.lora_alpha, + lora_config.lora_dropout, + lora_config.init_lora_weights, + ) + elif isinstance(target, LoraLayer) and isinstance(target, torch.nn.Embedding): + target.update_layer_embedding( + adapter_name, + lora_config.r, + lora_config.lora_alpha, + lora_config.lora_dropout, + lora_config.init_lora_weights, + ) + + elif isinstance(target, LoraLayer): + target.update_layer( + adapter_name, + lora_config.r, + lora_config.lora_alpha, + lora_config.lora_dropout, + lora_config.init_lora_weights, + ) + else: + new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs) + self._replace_module(parent, target_name, new_module, target) + + @staticmethod + def _replace_module(parent, child_name, new_module, child): + setattr(parent, child_name, new_module) + new_module.weight = child.weight + if hasattr(child, "bias"): + if child.bias is not None: + new_module.bias = child.bias + + if getattr(child, "state", None) is not None: + new_module.state = child.state + new_module.to(child.weight.device) + + # dispatch to correct device + for name, module in new_module.named_modules(): + if "lora_" in name: + module.to(child.weight.device) + if "ranknum" in name: + module.to(child.weight.device) + + def _mark_only_adapters_as_trainable(self) -> None: + active_adapter = self._get_active_adapter() + bias = self.peft_config[active_adapter].bias + + for n, p in self.model.named_parameters(): + if "lora_" not in n: + p.requires_grad = False + if bias == "none": + return + elif bias == "all": + for n, p in self.model.named_parameters(): + if "bias" in n: + p.requires_grad = True + elif bias == "lora_only": + for m in self.model.modules(): + if isinstance(m, LoraLayer) and hasattr(m, "bias") and m.bias is not None: + m.bias.requires_grad = True + else: + raise NotImplementedError + + @staticmethod + def _create_new_module(lora_config, adapter_name, target, **kwargs): + loaded_in_8bit = kwargs.pop("loaded_in_8bit", False) + loaded_in_4bit = kwargs.pop("loaded_in_4bit", False) + bias = kwargs.pop("bias", False) if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt): eightbit_kwargs = kwargs.copy() @@ -313,72 +460,6 @@ def _create_new_module(self, lora_config, adapter_name, target): return new_module - def _find_and_replace(self, adapter_name): - lora_config = self.peft_config[adapter_name] - self._check_quantization_dependency() - is_target_modules_in_base_model = False - key_list = [key for key, _ in self.model.named_modules()] - - for key in key_list: - if not self._check_target_module_exists(lora_config, key): - continue - - is_target_modules_in_base_model = True - parent, target, target_name = _get_submodules(self.model, key) - - if isinstance(target, LoraLayer) and isinstance(target, torch.nn.Conv2d): - target.update_layer_conv2d( - adapter_name, - lora_config.r, - lora_config.lora_alpha, - lora_config.lora_dropout, - lora_config.init_lora_weights, - ) - elif isinstance(target, LoraLayer) and isinstance(target, torch.nn.Embedding): - target.update_layer_embedding( - adapter_name, - lora_config.r, - lora_config.lora_alpha, - lora_config.lora_dropout, - lora_config.init_lora_weights, - ) - - elif isinstance(target, LoraLayer): - target.update_layer( - adapter_name, - lora_config.r, - lora_config.lora_alpha, - lora_config.lora_dropout, - lora_config.init_lora_weights, - ) - else: - new_module = self._create_new_module(lora_config, adapter_name, target) - self._replace_module(parent, target_name, new_module, target) - - if not is_target_modules_in_base_model: - raise ValueError( - f"Target modules {lora_config.target_modules} not found in the base model. " - f"Please check the target modules and try again." - ) - - def _replace_module(self, parent_module, child_name, new_module, old_module): - setattr(parent_module, child_name, new_module) - new_module.weight = old_module.weight - if hasattr(old_module, "bias"): - if old_module.bias is not None: - new_module.bias = old_module.bias - - if getattr(old_module, "state", None) is not None: - new_module.state = old_module.state - new_module.to(old_module.weight.device) - - # dispatch to correct device - for name, module in new_module.named_modules(): - if "lora_" in name: - module.to(old_module.weight.device) - if "ranknum" in name: - module.to(old_module.weight.device) - def __getattr__(self, name: str): """Forward missing attributes to the wrapped module.""" try: @@ -436,24 +517,8 @@ def set_adapter(self, adapter_name): module.unmerge() module.active_adapter = adapter_name - def merge_adapter(self): - """ - This method merges the LoRa layers into the base model. - """ - for module in self.model.modules(): - if isinstance(module, LoraLayer): - module.merge() - - def unmerge_adapter(self): - """ - This method unmerges the LoRa layers from the base model. - """ - for module in self.model.modules(): - if isinstance(module, LoraLayer): - module.unmerge() - @staticmethod - def _prepare_lora_config(peft_config, model_config): + def _prepare_adapter_config(peft_config, model_config): if peft_config.target_modules is None: if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING: raise ValueError("Please specify `target_modules` in `peft_config`") @@ -509,6 +574,7 @@ def add_weighted_adapter(self, adapters, weights, adapter_name, combination_type adapter_name (str): Name of the new adapter. combination_type (str): Type of merging. Can be one of [`svd`, `linear`] """ + if adapter_name in list(self.peft_config.keys()): return for adapter in adapters: @@ -530,9 +596,11 @@ def add_weighted_adapter(self, adapters, weights, adapter_name, combination_type raise ValueError(f"Invalid combination_type: {combination_type}") self.peft_config[adapter_name] = replace(self.peft_config[adapters[0]], r=new_rank, lora_alpha=new_rank) - self._find_and_replace(adapter_name) - mark_only_lora_as_trainable(self.model, self.peft_config[adapter_name].bias) + self.inject_adapter(self.model, adapter_name) + + # Do we really need that? _freeze_adapter(self.model, adapter_name) + key_list = [key for key, _ in self.model.named_modules() if "lora" not in key] for key in key_list: _, target, _ = _get_submodules(self.model, key) @@ -664,117 +732,6 @@ def unload(self): # ------------------------------------------------------------------------------------------ -# had to adapt it for `lora_only` to work -def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None: - for n, p in model.named_parameters(): - if "lora_" not in n: - p.requires_grad = False - if bias == "none": - return - elif bias == "all": - for n, p in model.named_parameters(): - if "bias" in n: - p.requires_grad = True - elif bias == "lora_only": - for m in model.modules(): - if isinstance(m, LoraLayer) and hasattr(m, "bias") and m.bias is not None: - m.bias.requires_grad = True - else: - raise NotImplementedError - - -class LoraLayer: - def __init__(self, in_features: int, out_features: int, **kwargs): - self.r = {} - self.lora_alpha = {} - self.scaling = {} - self.lora_dropout = nn.ModuleDict({}) - self.lora_A = nn.ModuleDict({}) - self.lora_B = nn.ModuleDict({}) - # For Embedding layer - self.lora_embedding_A = nn.ParameterDict({}) - self.lora_embedding_B = nn.ParameterDict({}) - # Mark the weight as unmerged - self.merged = False - self.disable_adapters = False - self.in_features = in_features - self.out_features = out_features - self.kwargs = kwargs - - def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights): - self.r[adapter_name] = r - self.lora_alpha[adapter_name] = lora_alpha - if lora_dropout > 0.0: - lora_dropout_layer = nn.Dropout(p=lora_dropout) - else: - lora_dropout_layer = nn.Identity() - - self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer})) - # Actual trainable parameters - if r > 0: - self.lora_A.update(nn.ModuleDict({adapter_name: nn.Linear(self.in_features, r, bias=False)})) - self.lora_B.update(nn.ModuleDict({adapter_name: nn.Linear(r, self.out_features, bias=False)})) - self.scaling[adapter_name] = lora_alpha / r - if init_lora_weights: - self.reset_lora_parameters(adapter_name) - self.to(self.weight.device) - - def update_layer_conv2d(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights): - self.r[adapter_name] = r - self.lora_alpha[adapter_name] = lora_alpha - if lora_dropout > 0.0: - lora_dropout_layer = nn.Dropout(p=lora_dropout) - else: - lora_dropout_layer = nn.Identity() - - self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer})) - # Actual trainable parameters - if r > 0: - kernel_size = self.kwargs["kernel_size"] - stride = self.kwargs["stride"] - padding = self.kwargs["padding"] - self.lora_A.update( - nn.ModuleDict({adapter_name: nn.Conv2d(self.in_features, r, kernel_size, stride, padding, bias=False)}) - ) - self.lora_B.update( - nn.ModuleDict({adapter_name: nn.Conv2d(r, self.out_features, (1, 1), (1, 1), bias=False)}) - ) - self.scaling[adapter_name] = lora_alpha / r - if init_lora_weights: - self.reset_lora_parameters(adapter_name) - self.to(self.weight.device) - - def update_layer_embedding(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights): - self.r[adapter_name] = r - self.lora_alpha[adapter_name] = lora_alpha - if lora_dropout > 0.0: - lora_dropout_layer = nn.Dropout(p=lora_dropout) - else: - lora_dropout_layer = nn.Identity() - - self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer})) - # Actual trainable parameters - if r > 0: - weight_A = torch.randn((r, self.in_features), dtype=self.weight.dtype, device=self.weight.device) - weight_B = torch.randn((self.out_features, r), dtype=self.weight.dtype, device=self.weight.device) - self.lora_embedding_A.update(nn.ParameterDict({adapter_name: nn.Parameter(weight_A)})) - self.lora_embedding_B.update(nn.ParameterDict({adapter_name: nn.Parameter(weight_B)})) - self.scaling[adapter_name] = lora_alpha / r - if init_lora_weights: - self.reset_lora_parameters(adapter_name) - self.to(self.weight.device) - - def reset_lora_parameters(self, adapter_name): - if adapter_name in self.lora_A.keys(): - # initialize A the same way as the default for nn.Linear and B to zero - nn.init.kaiming_uniform_(self.lora_A[adapter_name].weight, a=math.sqrt(5)) - nn.init.zeros_(self.lora_B[adapter_name].weight) - if adapter_name in self.lora_embedding_A.keys(): - # initialize a the same way as the default for nn.linear and b to zero - nn.init.zeros_(self.lora_embedding_A[adapter_name]) - nn.init.normal_(self.lora_embedding_B[adapter_name]) - - class Linear(nn.Linear, LoraLayer): # Lora implemented in a dense layer def __init__( diff --git a/src/peft/tuners/p_tuning.py b/src/peft/tuners/p_tuning.py index de52c2fae9..b142a13f65 100644 --- a/src/peft/tuners/p_tuning.py +++ b/src/peft/tuners/p_tuning.py @@ -20,7 +20,8 @@ import torch -from ..utils import PeftType, PromptLearningConfig +from ..config import PromptLearningConfig +from ..utils import PeftType class PromptEncoderReparameterizationType(str, enum.Enum): diff --git a/src/peft/tuners/prefix_tuning.py b/src/peft/tuners/prefix_tuning.py index d18000e0fb..e5212dd59b 100644 --- a/src/peft/tuners/prefix_tuning.py +++ b/src/peft/tuners/prefix_tuning.py @@ -18,7 +18,8 @@ import torch -from ..utils import PeftType, PromptLearningConfig +from ..config import PromptLearningConfig +from ..utils import PeftType @dataclass diff --git a/src/peft/tuners/prompt_tuning.py b/src/peft/tuners/prompt_tuning.py index 6880ff7d0c..0b834c26bf 100644 --- a/src/peft/tuners/prompt_tuning.py +++ b/src/peft/tuners/prompt_tuning.py @@ -20,7 +20,8 @@ import torch -from ..utils import PeftType, PromptLearningConfig +from ..config import PromptLearningConfig +from ..utils import PeftType class PromptTuningInit(str, enum.Enum): diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py new file mode 100644 index 0000000000..7f6c3f9d39 --- /dev/null +++ b/src/peft/tuners/tuners_utils.py @@ -0,0 +1,246 @@ +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from abc import ABC, abstractmethod +from typing import Any + +from torch import nn + +from ..config import PeftConfig +from ..utils import _get_submodules + + +logger = logging.getLogger(__name__) + + +class BaseTuner(nn.Module, ABC): + r""" + A base tuner model that provides the common methods and attributes for all tuners that are injectable into a + torch.nn.Module + + For adding a new Tuner class, one needs to overwrite the following methods: + + - **_prepare_adapter_config**: + A private method to eventually prepare the adapter config, for example in case the field `target_modules` is + missing. + - **_check_target_module_exists**: + A helper private method to check if the passed module's key name matches any of the target modules in the + adatper_config. + - **_create_and_replace**: + A private method to create and replace the target module with the adapter module. + - **_check_target_module_exists**: + A private helper method to check if the passed module's key name matches any of the target modules in the + adatper_config. + + The easiest is to check what is done in the `peft.tuners.lora.LoraModel` class. + + Attributes: + model (`torch.nn.Module`): + The model to which the adapter tuner layers will be attached. + forward (`Callable`): + The forward method of the model. + peft_config (`Union[`PeftConfig`, dict[str, PeftConfig]]`): + The adapter configuration object, it should be a dictionary of `str` to `PeftConfig` objects. One can also + pass a PeftConfig object and a new adapter will be created with the default name `adapter` or create a new + dictionary with a key `adapter_name` and a value of that peft config. + config (`dict[str, Any]`): + The model configuration object, it should be a dictionary of `str` to `Any` objects. + """ + + def __init__(self, model, peft_config, adapter_name): + super().__init__() + + self.model = model + self.forward = self.model.forward + + # For advanced developpers, if you want to attach multiple adapters to your + # model, just add a `peft_config` dict attribute to your model. + if not hasattr(self, "peft_config"): + self.peft_config = {adapter_name: peft_config} if isinstance(peft_config, PeftConfig) else peft_config + else: + logger.info( + "Already found a `peft_config` attribute in the model. This will lead to having multiple adapters" + " in the model. Make sure to know what you are doing!" + ) + self.peft_config[adapter_name] = peft_config + + # transformers models have a .config attribute, whose presence is assumed later on + if not hasattr(self, "config"): + self.config = {"model_type": "custom"} + + self.inject_adapter(self.model, adapter_name) + + # Copy the peft_config in the injected model. + self.model.peft_config = self.peft_config + + @abstractmethod + def _prepare_adapter_config(self, peft_config: PeftConfig, model_config: dict) -> PeftConfig: + r""" + A private method to eventually prepare the adapter config. For transformers based models, if + `peft_config.target_modules` is None, we can automatically infer the target modules from the + `TRANSFORMERS_MODELS_TO_XXX_TARGET_MODULES_MAPPING`. This method can be further refactored in the future to + automatically infer it for all tuner models. + + Check out `peft.tuner.lora.LoraModel._prepare_adapter_config` for an example. + + Args: + peft_config (`str`): + The adapter config. + model_config (`str`): + The transformers model config, that config should contain the `model_type` key. + """ + ... + + @abstractmethod + def _check_target_module_exists(peft_config: PeftConfig, key: str) -> bool: + r""" + A helper private method to check if the passed module's key name matches any of the target modules in the + `peft_config.target_modules` list. If it does, return `True`, else return `False`. + + Args: + peft_config (`PeftConfig`): + The adapter config. + key (`str`): + The module's key name. + """ + ... + + @abstractmethod + def _create_and_replace( + self, + peft_config: PeftConfig, + adapter_name: str, + target: nn.Module, + target_name: str, + parent: nn.Module, + **optionnal_kwargs: Any, + ) -> None: + r""" + Inplace replacement of the target module with the adapter layer. This method needs to be overriden by all the + tuner classes. + + Check `peft.tuners.lora.LoraModel._create_and_replace` for an example. + + Args: + peft_config (`PeftConfig`): + The adapter config. + adapter_name (`str`): + The adapter name. + target (`nn.Module`): + The target module. + target_name (`str`): + The target module's name. + parent (`nn.Module`): + The parent module. + **optionnal_kwargs (`dict`): + The optional keyword arguments to pass to deal with particular cases (e.g. 8bit, 4bit quantization) + """ + ... + + @abstractmethod + def _mark_only_adapters_as_trainable(self): + r""" + A helper method to mark only the adapter layers as trainable (i.e. module.requires_grad = False) This needs to + be overriden for all tuner classes to match the correct key names. + + Check `peft.tuners.lora.LoraModel._mark_only_adapters_as_trainable` for an example. + """ + ... + + def inject_adapter(self, model: nn.Module, adapter_name: str): + r""" + Creates adapter layers and replaces the target modules with the adapter layers. This method is called under the + hood by `peft.mapping.get_peft_model` if a non-prompt tuning adapter class is passed. + + The corresponding PEFT config is directly retrieved from the `peft_config` attribute of the BaseTuner class. + + Args: + model (`nn.Module`): + The model to be tuned. + adapter_name (`str`): + The adapter name. + """ + peft_config = self.peft_config[adapter_name] + + is_target_modules_in_base_model = False + key_list = [key for key, _ in model.named_modules()] + + model_config = getattr(model, "config", {"model_type": "custom"}) + if hasattr(model_config, "to_dict"): + model_config = model_config.to_dict() + + peft_config = self._prepare_adapter_config(peft_config, model_config) + + for key in key_list: + if not self._check_target_module_exists(peft_config, key): + continue + + is_target_modules_in_base_model = True + parent, target, target_name = _get_submodules(model, key) + + optionnal_kwargs = { + "loaded_in_8bit": getattr(model, "is_loaded_in_8bit", False), + "loaded_in_4bit": getattr(model, "is_loaded_in_4bit", False), + "current_key": key, + } + self._create_and_replace(peft_config, adapter_name, target, target_name, parent, **optionnal_kwargs) + + if not is_target_modules_in_base_model: + raise ValueError( + f"Target modules {peft_config.target_modules} not found in the base model. " + f"Please check the target modules and try again." + ) + + self._mark_only_adapters_as_trainable() + + if self.peft_config[adapter_name].inference_mode: + for n, p in self.model.named_parameters(): + if adapter_name in n: + p.requires_grad = False + + def merge_adapter(self): + """ + This method merges the LoRa layers into the base model. + """ + for module in self.model.modules(): + if isinstance(module, BaseTunerLayer): + module.merge() + + def unmerge_adapter(self): + """ + This method unmerges the LoRa layers from the base model. + """ + for module in self.model.modules(): + if isinstance(module, BaseTunerLayer): + module.unmerge() + + +class BaseTunerLayer(ABC): + r""" + A tuner layer mixin that provides the common methods and attributes for all tuners. + + Args: + is_plugable (`bool`, *optional*): + Whether the adapter layer can be plugged to any pytorch module + active_adapter (`str`, *optional*): + The name of the active adapter. + """ + active_adapter = None + + def merge(self): + raise NotImplementedError + + def unmerge(self): + raise NotImplementedError diff --git a/src/peft/utils/__init__.py b/src/peft/utils/__init__.py index 19f74cf35f..83b2cd3c74 100644 --- a/src/peft/utils/__init__.py +++ b/src/peft/utils/__init__.py @@ -17,7 +17,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .config import PeftConfig, PeftType, PromptLearningConfig, TaskType +# from .config import PeftConfig, PeftType, PromptLearningConfig, TaskType +from .peft_types import PeftType, TaskType from .other import ( TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, @@ -42,7 +43,8 @@ _freeze_adapter, ModulesToSaveWrapper, _prepare_prompt_learning_config, + _is_valid_match, infer_device, ) from .hub_utils import hub_file_exists -from .save_and_load import get_peft_model_state_dict, set_peft_model_state_dict +from .save_and_load import get_peft_model_state_dict, set_peft_model_state_dict, load_peft_weights diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index 03e0e88147..c886ef9eb5 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -297,6 +297,18 @@ def transpose(weight, fan_in_fan_out): return weight.T if fan_in_fan_out else weight +def _is_valid_match(key: str, target_key: str): + """ + Helper function to match module names target_key and key. Makes sure that either the key is exactly the target_key + or the target_key is a submodule of key + """ + if key.endswith(target_key): + if len(key) > len(target_key): + return key.endswith("." + target_key) # must be a sub module + return True + return False + + def _get_batch_size(input_ids: Optional[torch.Tensor], inputs_embeds: Optional[torch.Tensor]) -> int: """Get the batch size based on either input_ids or input_embeds diff --git a/src/peft/utils/peft_types.py b/src/peft/utils/peft_types.py new file mode 100644 index 0000000000..acac24d999 --- /dev/null +++ b/src/peft/utils/peft_types.py @@ -0,0 +1,38 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all + +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import enum + + +class PeftType(str, enum.Enum): + PROMPT_TUNING = "PROMPT_TUNING" + P_TUNING = "P_TUNING" + PREFIX_TUNING = "PREFIX_TUNING" + LORA = "LORA" + ADALORA = "ADALORA" + ADAPTION_PROMPT = "ADAPTION_PROMPT" + IA3 = "IA3" + + +class TaskType(str, enum.Enum): + SEQ_CLS = "SEQ_CLS" + SEQ_2_SEQ_LM = "SEQ_2_SEQ_LM" + CAUSAL_LM = "CAUSAL_LM" + TOKEN_CLS = "TOKEN_CLS" + QUESTION_ANS = "QUESTION_ANS" + FEATURE_EXTRACTION = "FEATURE_EXTRACTION" diff --git a/src/peft/utils/save_and_load.py b/src/peft/utils/save_and_load.py index d67ff14627..617287e840 100644 --- a/src/peft/utils/save_and_load.py +++ b/src/peft/utils/save_and_load.py @@ -12,8 +12,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +from typing import Optional -from .config import PeftType, PromptLearningConfig +import torch +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import EntryNotFoundError +from safetensors.torch import load_file as safe_load_file + +from .hub_utils import hub_file_exists +from .other import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, infer_device +from .peft_types import PeftType def get_peft_model_state_dict(model, state_dict=None, adapter_name="default"): @@ -59,7 +68,7 @@ def get_peft_model_state_dict(model, state_dict=None, adapter_name="default"): elif config.peft_type == PeftType.ADAPTION_PROMPT: to_return = {k: state_dict[k] for k in state_dict if k.split(".")[-1].startswith("adaption_")} - elif isinstance(config, PromptLearningConfig): + elif config.is_prompt_learning: to_return = {} if config.inference_mode: prompt_embeddings = model.prompt_encoder[adapter_name].embedding.weight @@ -70,7 +79,7 @@ def get_peft_model_state_dict(model, state_dict=None, adapter_name="default"): to_return = {k: state_dict[k] for k in state_dict if "ia3_" in k} else: raise NotImplementedError - if model.modules_to_save is not None: + if getattr(model, "modules_to_save", None) is not None: for key, value in state_dict.items(): if any(f"{module_name}.modules_to_save.{adapter_name}" in key for module_name in model.modules_to_save): to_return[key.replace("modules_to_save.", "")] = value @@ -89,7 +98,7 @@ def set_peft_model_state_dict(model, peft_model_state_dict, adapter_name="defaul """ config = model.peft_config[adapter_name] state_dict = {} - if model.modules_to_save is not None: + if getattr(model, "modules_to_save", None) is not None: for key, value in peft_model_state_dict.items(): if any(module_name in key for module_name in model.modules_to_save): for module_name in model.modules_to_save: @@ -118,14 +127,74 @@ def set_peft_model_state_dict(model, peft_model_state_dict, adapter_name="defaul rank_pattern = config.rank_pattern if rank_pattern is not None: model.resize_modules_by_rank_pattern(rank_pattern, adapter_name) - elif isinstance(config, PromptLearningConfig) or config.peft_type == PeftType.ADAPTION_PROMPT: + elif config.is_prompt_learning or config.peft_type == PeftType.ADAPTION_PROMPT: peft_model_state_dict = state_dict else: raise NotImplementedError load_result = model.load_state_dict(peft_model_state_dict, strict=False) - if isinstance(config, PromptLearningConfig): + if config.is_prompt_learning: model.prompt_encoder[adapter_name].embedding.load_state_dict( {"weight": peft_model_state_dict["prompt_embeddings"]}, strict=True ) return load_result + + +def load_peft_weights(model_id: str, device: Optional[str] = None, **hf_hub_download_kwargs) -> dict: + r""" + A helper method to load the PEFT weights from the HuggingFace Hub or locally + + Args: + model_id (`str`): + The local path to the adapter weights or the name of the adapter to load from the HuggingFace Hub. + device (`str`): + The device to load the weights onto. + hf_hub_download_kwargs (`dict`): + Additional arguments to pass to the `hf_hub_download` method when loading from the HuggingFace Hub. + """ + path = ( + os.path.join(model_id, hf_hub_download_kwargs["subfolder"]) + if hf_hub_download_kwargs.get("subfolder", None) is not None + else model_id + ) + + if device is None: + device = infer_device() + + if os.path.exists(os.path.join(path, SAFETENSORS_WEIGHTS_NAME)): + filename = os.path.join(path, SAFETENSORS_WEIGHTS_NAME) + use_safetensors = True + elif os.path.exists(os.path.join(path, WEIGHTS_NAME)): + filename = os.path.join(path, WEIGHTS_NAME) + use_safetensors = False + else: + has_remote_safetensors_file = hub_file_exists( + model_id, + SAFETENSORS_WEIGHTS_NAME, + revision=hf_hub_download_kwargs.get("revision", None), + repo_type=hf_hub_download_kwargs.get("repo_type", None), + ) + use_safetensors = has_remote_safetensors_file + + if has_remote_safetensors_file: + # Priority 1: load safetensors weights + filename = hf_hub_download( + model_id, + SAFETENSORS_WEIGHTS_NAME, + **hf_hub_download_kwargs, + ) + else: + try: + filename = hf_hub_download(model_id, WEIGHTS_NAME, **hf_hub_download_kwargs) + except EntryNotFoundError: + raise ValueError( + f"Can't find weights for {model_id} in {model_id} or in the Hugging Face Hub. " + f"Please check that the file {WEIGHTS_NAME} or {SAFETENSORS_WEIGHTS_NAME} is present at {model_id}." + ) + + if use_safetensors: + adapters_weights = safe_load_file(filename, device=device) + else: + adapters_weights = torch.load(filename, map_location=torch.device(device)) + + return adapters_weights