diff --git a/src/peft/tuners/adalora/bnb.py b/src/peft/tuners/adalora/bnb.py index cd61ad75dd..3ccfd91b2b 100644 --- a/src/peft/tuners/adalora/bnb.py +++ b/src/peft/tuners/adalora/bnb.py @@ -51,7 +51,7 @@ def __init__( init_lora_weights = kwargs.pop("init_lora_weights", True) self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) - self.active_adapter = adapter_name + self.set_adapter(adapter_name) def forward(self, x: torch.Tensor) -> torch.Tensor: result = super().forward(x) @@ -112,7 +112,7 @@ def __init__( init_lora_weights = kwargs.pop("init_lora_weights", True) self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) - self.active_adapter = adapter_name + self.set_adapter(adapter_name) def forward(self, x: torch.Tensor) -> torch.Tensor: result = super().forward(x) diff --git a/src/peft/tuners/adalora/gptq.py b/src/peft/tuners/adalora/gptq.py index 0ea9bc82fc..92de32ac15 100644 --- a/src/peft/tuners/adalora/gptq.py +++ b/src/peft/tuners/adalora/gptq.py @@ -35,7 +35,7 @@ def __init__( self.weight = quant_linear_module.qweight init_lora_weights = kwargs.pop("init_lora_weights", True) self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) - self.active_adapter = adapter_name + self.set_adapter(adapter_name) def forward(self, x: torch.Tensor) -> torch.Tensor: result = self.quant_linear_module(x) diff --git a/src/peft/tuners/adalora/layer.py b/src/peft/tuners/adalora/layer.py index 610d5403d8..a6e95369ad 100644 --- a/src/peft/tuners/adalora/layer.py +++ b/src/peft/tuners/adalora/layer.py @@ -24,6 +24,10 @@ class AdaLoraLayer(LoraLayer): + # List all names of layers that may contain adapter weights + # Note: ranknum doesn't need to be included as it is not an nn.Module + adapter_layer_names = ["lora_A", "lora_B", "lora_E", "lora_embedding_A", "lora_embedding_B"] + def __init__( self, in_features: int, @@ -59,6 +63,7 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig if init_lora_weights: self.reset_lora_parameters(adapter_name) self.to(self.weight.device) + self.set_adapter(self.active_adapters) def reset_lora_parameters(self, adapter_name): if adapter_name in self.lora_A.keys(): @@ -92,7 +97,7 @@ def __init__( nn.Linear.reset_parameters(self) self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) - self.active_adapter = adapter_name + self.set_adapter(adapter_name) def merge(self) -> None: if self.merged: diff --git a/src/peft/tuners/adalora/model.py b/src/peft/tuners/adalora/model.py index bc9932b381..a863acce31 100644 --- a/src/peft/tuners/adalora/model.py +++ b/src/peft/tuners/adalora/model.py @@ -141,6 +141,9 @@ def _create_and_replace( # If it is not a LoraLayer, create a new module, else update it with new adapters if not isinstance(target, AdaLoraLayer): new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs) + if adapter_name != self.active_adapter: + # adding an additional adapter: it is not automatically trainable + new_module.requires_grad_(False) self._replace_module(parent, target_name, new_module, target) else: target.update_layer( diff --git a/src/peft/tuners/ia3/bnb.py b/src/peft/tuners/ia3/bnb.py index 8f5868b8e9..842588befb 100644 --- a/src/peft/tuners/ia3/bnb.py +++ b/src/peft/tuners/ia3/bnb.py @@ -40,14 +40,14 @@ def __init__( index=kwargs.get("index", None), ) IA3Layer.__init__(self, in_features=in_features, out_features=out_features, is_feedforward=is_feedforward) + self.is_feedforward = is_feedforward # Freezing the pre-trained weight matrix self.weight.requires_grad = False init_ia3_weights = kwargs.pop("init_ia3_weights", True) self.update_layer(adapter_name, init_ia3_weights) - self.active_adapter = adapter_name - self.is_feedforward = is_feedforward + self.set_adapter(adapter_name) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.disable_adapters: diff --git a/src/peft/tuners/ia3/layer.py b/src/peft/tuners/ia3/layer.py index 19527cd5d2..270159935c 100644 --- a/src/peft/tuners/ia3/layer.py +++ b/src/peft/tuners/ia3/layer.py @@ -24,6 +24,9 @@ class IA3Layer(BaseTunerLayer): + # List all names of layers that may contain adapter weights + adapter_layer_names = ["ia3_l"] + def __init__( self, in_features: int, @@ -34,8 +37,8 @@ def __init__( self.ia3_l = nn.ParameterDict({}) # Mark the weight as unmerged self.merged = False + self._disable_adapters = False self.merged_adapters = [] - self.disable_adapters = False self.in_features = in_features self.out_features = out_features self.is_feedforward = is_feedforward @@ -50,6 +53,7 @@ def update_layer(self, adapter_name, init_ia3_weights): if init_ia3_weights: self.reset_ia3_parameters(adapter_name) self.to(self.weight.device) + self.set_adapter(self.active_adapters) def reset_ia3_parameters(self, adapter_name): if adapter_name in self.ia3_l.keys(): @@ -72,6 +76,7 @@ def __init__( nn.Linear.__init__(self, in_features, out_features, **kwargs) IA3Layer.__init__(self, in_features=in_features, out_features=out_features, is_feedforward=is_feedforward) + self.is_feedforward = is_feedforward # Freezing the pre-trained weight matrix self.weight.requires_grad = False @@ -81,9 +86,7 @@ def __init__( nn.Linear.reset_parameters(self) self.update_layer(adapter_name, init_ia3_weights) - self.active_adapter = adapter_name - - self.is_feedforward = is_feedforward + self.set_adapter(adapter_name) def merge(self) -> None: if self.merged: diff --git a/src/peft/tuners/ia3/model.py b/src/peft/tuners/ia3/model.py index 106631058a..299fff4b89 100644 --- a/src/peft/tuners/ia3/model.py +++ b/src/peft/tuners/ia3/model.py @@ -178,6 +178,9 @@ def _create_and_replace( ) else: new_module = self._create_new_module(ia3_config, adapter_name, target, **kwargs) + if adapter_name != self.active_adapter: + # adding an additional adapter: it is not automatically trainable + new_module.requires_grad_(False) self._replace_module(parent, target_name, new_module, target) @staticmethod @@ -213,10 +216,8 @@ def get_peft_config_as_dict(self, inference: bool = False): def _set_adapter_layers(self, enabled=True): for module in self.model.modules(): - if isinstance(module, IA3Layer): - module.disable_adapters = False if enabled else True - elif isinstance(module, ModulesToSaveWrapper): - module.disable_adapters = False if enabled else True + if isinstance(module, (IA3Layer, ModulesToSaveWrapper)): + module.enable_adapters(enabled) def enable_adapter_layers(self): self._set_adapter_layers(enabled=True) @@ -230,7 +231,7 @@ def set_adapter(self, adapter_name): if module.merged: warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") module.unmerge() - module.active_adapter = adapter_name + module.set_adapter(adapter_name) def _prepare_adapter_config(self, peft_config, model_config): if peft_config.target_modules is None: diff --git a/src/peft/tuners/lora/bnb.py b/src/peft/tuners/lora/bnb.py index 7fdb5140a5..3007564ad5 100644 --- a/src/peft/tuners/lora/bnb.py +++ b/src/peft/tuners/lora/bnb.py @@ -54,7 +54,7 @@ def __init__( self.weight.requires_grad = False init_lora_weights = kwargs.pop("init_lora_weights", True) self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) - self.active_adapter = adapter_name + self.set_adapter(adapter_name) def merge(self): if self.merged: @@ -195,7 +195,7 @@ def __init__( init_lora_weights = kwargs.pop("init_lora_weights", True) self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) - self.active_adapter = adapter_name + self.set_adapter(adapter_name) def merge(self): if self.merged: diff --git a/src/peft/tuners/lora/gptq.py b/src/peft/tuners/lora/gptq.py index d0b9315fe1..1505045a3e 100644 --- a/src/peft/tuners/lora/gptq.py +++ b/src/peft/tuners/lora/gptq.py @@ -36,7 +36,7 @@ def __init__( self.weight = quant_linear_module.qweight init_lora_weights = kwargs.pop("init_lora_weights", True) self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) - self.active_adapter = adapter_name + self.set_adapter(adapter_name) def forward(self, x: torch.Tensor): # note: logic differs from default Linear because merging is not supported diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index a8682eb637..ccfa3cc1b7 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -26,6 +26,9 @@ class LoraLayer(BaseTunerLayer): + # List all names of layers that may contain adapter weights + adapter_layer_names = ["lora_A", "lora_B", "lora_embedding_A", "lora_embedding_B"] + def __init__(self, in_features: int, out_features: int, **kwargs): self.r = {} self.lora_alpha = {} @@ -38,8 +41,8 @@ def __init__(self, in_features: int, out_features: int, **kwargs): self.lora_embedding_B = nn.ParameterDict({}) # Mark the weight as unmerged self.merged = False + self._disable_adapters = False self.merged_adapters = [] - self.disable_adapters = False self.in_features = in_features self.out_features = out_features self.kwargs = kwargs @@ -82,6 +85,7 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig self.to(weight.device, dtype=weight.dtype) else: self.to(weight.device) + self.set_adapter(self.active_adapters) def update_layer_conv2d(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights): if r <= 0: @@ -197,8 +201,8 @@ def __init__( self.fan_in_fan_out = fan_in_fan_out self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) - self.active_adapter = adapter_name self.is_target_conv_1d_layer = is_target_conv_1d_layer + self.set_adapter(adapter_name) def merge(self) -> None: if self.merged: @@ -275,7 +279,7 @@ def __init__( self._init_empty_weights(nn.Embedding, num_embeddings, embedding_dim, **kwargs) LoraLayer.__init__(self, in_features=num_embeddings, out_features=embedding_dim) self.update_layer_embedding(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) - self.active_adapter = adapter_name + self.set_adapter(adapter_name) def merge(self) -> None: if self.merged: @@ -364,7 +368,7 @@ def __init__( ) self.update_layer_conv2d(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) - self.active_adapter = adapter_name + self.set_adapter(adapter_name) def merge(self) -> None: if self.merged: diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 8b2c40dffe..027e45b7ff 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -17,7 +17,6 @@ from dataclasses import asdict, replace from enum import Enum from itertools import chain -from typing import List import torch from torch import nn @@ -25,7 +24,7 @@ from transformers.pytorch_utils import Conv1D from peft.import_utils import is_bnb_4bit_available, is_bnb_available -from peft.tuners.tuners_utils import BaseTuner +from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer from peft.utils import ( COMMON_LAYERS_PATTERN, TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, @@ -216,6 +215,9 @@ def _create_and_replace( ) else: new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs) + if adapter_name != self.active_adapter: + # adding an additional adapter: it is not automatically trainable + new_module.requires_grad_(False) self._replace_module(parent, target_name, new_module, target) @staticmethod @@ -239,15 +241,16 @@ def _replace_module(parent, child_name, new_module, child): module.to(child.weight.device) def _mark_only_adapters_as_trainable(self) -> None: - for active_adapter in self._get_active_adapters(): - bias = self.peft_config[active_adapter].bias + for n, p in self.model.named_parameters(): + if "lora_" not in n: + p.requires_grad = False - for n, p in self.model.named_parameters(): - if "lora_" not in n: - p.requires_grad = False + for active_adapter in self.active_adapters: + bias = self.peft_config[active_adapter].bias if bias == "none": - return - elif bias == "all": + continue + + if bias == "all": for n, p in self.model.named_parameters(): if "bias" in n: p.requires_grad = True @@ -351,28 +354,14 @@ def get_peft_config_as_dict(self, inference: bool = False): def _set_adapter_layers(self, enabled=True): for module in self.model.modules(): - if isinstance(module, LoraLayer): - module.disable_adapters = False if enabled else True - elif isinstance(module, ModulesToSaveWrapper): - module.disable_adapters = False if enabled else True + if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)): + module.enable_adapters(enabled) def enable_adapter_layers(self): self._set_adapter_layers(enabled=True) - def _get_active_adapters(self) -> List[str]: - active_adapters = None - for module in self.model.modules(): - if isinstance(module, LoraLayer): - active_adapters = module.active_adapters - - if active_adapters is None: - raise ValueError( - "Something went wrong, no active adapter could be found, please report the issue on GitHub" - ) - return active_adapters - def disable_adapter_layers(self): - for active_adapter in self._get_active_adapters(): + for active_adapter in self.active_adapters: val = self.peft_config[active_adapter].bias if val != "none": msg = ( @@ -388,7 +377,7 @@ def set_adapter(self, adapter_name): if module.merged: warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") module.unmerge() - module.active_adapter = adapter_name + module.set_adapter(adapter_name) @staticmethod def _prepare_adapter_config(peft_config, model_config): @@ -626,7 +615,7 @@ def _svd_weighted_adapter( Vh = Vh.reshape(target_lora_A.data.shape) return Vh, U - def delete_adapter(self, adapter_name): + def delete_adapter(self, adapter_name: str): """ Deletes an existing adapter. @@ -636,6 +625,7 @@ def delete_adapter(self, adapter_name): if adapter_name not in list(self.peft_config.keys()): raise ValueError(f"Adapter {adapter_name} does not exist") del self.peft_config[adapter_name] + key_list = [key for key, _ in self.model.named_modules() if "lora" not in key] for key in key_list: _, target, _ = _get_submodules(self.model, key) @@ -659,7 +649,7 @@ def delete_adapter(self, adapter_name): warnings.warn( f"Adapter {adapter_name} was active which is now deleted. Setting active adapter to {resetting_active_adapter}. " ) - target.active_adapter = resetting_active_adapter + target.set_adapter(resetting_active_adapter) def merge_and_unload(self, progressbar: bool = False): r""" diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index edabb7b6c6..82bac04c2f 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -81,6 +81,8 @@ def __init__(self, model, peft_config: Union[PeftConfig, dict[str, PeftConfig]], # user is adding a dict of PeftConfigs self.peft_config.update(peft_config) + self.active_adapter = adapter_name + # transformers models have a .config attribute, whose presence is assumed later on if not hasattr(self, "config"): self.config = {"model_type": "custom"} @@ -90,6 +92,13 @@ def __init__(self, model, peft_config: Union[PeftConfig, dict[str, PeftConfig]], # Copy the peft_config in the injected model. self.model.peft_config = self.peft_config + @property + def active_adapters(self) -> list[str]: + if isinstance(self.active_adapter, str): + return [self.active_adapter] + # is already a list of str + return self.active_adapter + def forward(self, *args: Any, **kwargs: Any): return self.model.forward(*args, **kwargs) @@ -260,15 +269,74 @@ class BaseTunerLayer(ABC): """ active_adapter = None + # List all names of layers that may contain adapter weights + adapter_layer_names: list[str] = [] + + # indicates whether all adapters should be disabled + _disable_adapters: bool = False + + # the currently active adapter(s) + _active_adapter: str | list[str] = "default" + def merge(self) -> None: raise NotImplementedError def unmerge(self) -> None: raise NotImplementedError + @property + def disable_adapters(self) -> bool: + # use a property to ensure that disable_adapters is not set directly, instead use the enable_adapters method + return self._disable_adapters + + @property + def active_adapter(self) -> str: + # use a property to ensure that active_adapter is not set directly, instead use the set_adapter method + return self._active_adapter + @property def active_adapters(self): if isinstance(self.active_adapter, str): return [self.active_adapter] # is already a list of str return self.active_adapter + + def enable_adapters(self, enabled: bool): + """Toggle the enabling and disabling of adapters + + Takes care of setting the requires_grad flag for the adapter weights. + + Args: + enabled (bool): True to enable adapters, False to disable adapters + """ + if enabled: + self.set_adapter(self.active_adapters) + self._disable_adapters = False + else: + # disable grads on all adapter layers + for layer_name in self.adapter_layer_names: + layer = getattr(self, layer_name) + layer.requires_grad_(False) + self._disable_adapters = True + + def set_adapter(self, adapter_names: str | list[str]): + """Set the active adapter + + Args: + adapter_name (str): The name of the adapter to set as active + """ + if isinstance(adapter_names, str): + adapter_names = [adapter_names] + + # Deactivate grads on the inactive adapter and activate grads on the active adapter + for layer_name in self.adapter_layer_names: + module_dict = getattr(self, layer_name) + for key, layer in module_dict.items(): + if key in adapter_names: + # Note: It is possible that not a single layer is called with requires_grad_(True) here. This may + # happen if a completely different adapter layer is being activated. + layer.requires_grad_(True) + else: + layer.requires_grad_(False) + + self._active_adapter = adapter_names diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index 956c2c91c8..0723a00ccc 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -164,9 +164,19 @@ def __init__(self, module_to_save, adapter_name): super().__init__() self.original_module = module_to_save self.modules_to_save = torch.nn.ModuleDict({}) + self._active_adapter = adapter_name + self._disable_adapters = False self.update(adapter_name) - self.active_adapter = adapter_name - self.disable_adapters = False + + @property + def disable_adapters(self) -> bool: + # use a property to ensure that disable_adapters is not set directly, instead use the enable_adapters method + return self._disable_adapters + + @property + def active_adapter(self) -> str: + # use a property to ensure that active_adapter is not set directly, instead use the set_adapter method + return self._active_adapter def update(self, adapter_name): self.modules_to_save.update(torch.nn.ModuleDict({adapter_name: copy.deepcopy(self.original_module)})) @@ -177,6 +187,10 @@ def update(self, adapter_name): remove_hook_from_module(self.modules_to_save[adapter_name]) add_hook_to_module(self.modules_to_save[adapter_name], new_hook) + self.original_module.requires_grad_(False) + if adapter_name == self.active_adapter: + self.modules_to_save[adapter_name].requires_grad_(True) + def _create_new_hook(self, old_hook): r""" Creates a new hook based on the old hook. Use it only if you know what you are doing ! @@ -196,6 +210,40 @@ def forward(self, *args, **kwargs): return self.original_module(*args, **kwargs) return self.modules_to_save[self.active_adapter](*args, **kwargs) + def enable_adapters(self, enabled: bool): + """Toggle the enabling and disabling of adapters + + Takes care of setting the requires_grad flag for the adapter weights. + + Args: + enabled (bool): True to enable adapters, False to disable adapters + """ + if self._disable_adapters is not enabled: + # already in the desired state, do nothing + return + + if enabled: + self.original_module.requires_grad_(False) + self.modules_to_save[self.active_adapter].requires_grad_(True) + self._disable_adapters = False + else: + self.original_module.requires_grad_(True) + self.modules_to_save.requires_grad_(False) + self._disable_adapters = True + + def set_adapter(self, adapter_name: str): + """Set the active adapter + + Args: + adapter_name (str): The name of the adapter to set as active + """ + if adapter_name not in self.modules_to_save: + raise ValueError(f"Adapter {adapter_name} not found in {self.modules_to_save.keys()}") + + self.modules_to_save[self.active_adapter].requires_grad_(False) + self.modules_to_save[adapter_name].requires_grad_(True) + self._active_adapter = adapter_name + def _get_submodules(model, key): parent = model.get_submodule(".".join(key.split(".")[:-1])) @@ -218,16 +266,17 @@ def _set_trainable(model, adapter_name): parent, target, target_name = _get_submodules(model, key) if isinstance(target, ModulesToSaveWrapper): target.update(adapter_name) + target.set_adapter(target.active_adapter) else: - for param in target.parameters(): - param.requires_grad = True - setattr(parent, target_name, ModulesToSaveWrapper(target, adapter_name)) + new_module = ModulesToSaveWrapper(target, adapter_name) + new_module.set_adapter(adapter_name) + setattr(parent, target_name, new_module) def _set_adapter(model, adapter_name): for module in model.modules(): if isinstance(module, ModulesToSaveWrapper): - module.active_adapter = adapter_name + module.set_adapter(adapter_name) def _prepare_prompt_learning_config(peft_config, model_config): diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 61062a9cdb..18251a2458 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -23,7 +23,7 @@ from torch import nn from transformers.pytorch_utils import Conv1D -from peft import LoraConfig, PeftModel, get_peft_model +from peft import AdaLoraConfig, IA3Config, LoraConfig, PeftModel, get_peft_model from .testing_common import PeftCommonTester from .testing_utils import get_state_dict @@ -442,3 +442,333 @@ def test_repr_lora_conv2d(self): self.assertTrue("lora_A" in print_output) self.assertTrue("lora_B" in print_output) self.assertTrue("default" in print_output) + + +class RequiresGradTester(unittest.TestCase): + """Test that requires_grad is set correctly in specific circumstances + + # See issue #899. + + This is not specifically tied to custom models, it's just easy to test here and testing it on all types of models + would be overkill. + + """ + + def test_requires_grad_modules_to_save_default(self): + config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"]) + peft_model = get_peft_model(MLP(), config) + + self.assertTrue(peft_model.model.lin1.modules_to_save.default.weight.requires_grad) + self.assertTrue(peft_model.model.lin1.modules_to_save.default.bias.requires_grad) + self.assertFalse(peft_model.model.lin1.original_module.weight.requires_grad) + self.assertFalse(peft_model.model.lin1.original_module.bias.requires_grad) + + def test_requires_grad_modules_to_save_disabling(self): + config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"]) + peft_model = get_peft_model(MLP(), config) + + # when disabling the adapter, the original module's grad should be enabled and vice versa + peft_model.disable_adapter_layers() + self.assertFalse(peft_model.model.lin1.modules_to_save.default.weight.requires_grad) + self.assertFalse(peft_model.model.lin1.modules_to_save.default.bias.requires_grad) + self.assertTrue(peft_model.model.lin1.original_module.weight.requires_grad) + self.assertTrue(peft_model.model.lin1.original_module.bias.requires_grad) + + # when re-enabling the adapter, the original module's grad should be disabled and vice versa + peft_model.enable_adapter_layers() + self.assertTrue(peft_model.model.lin1.modules_to_save.default.weight.requires_grad) + self.assertTrue(peft_model.model.lin1.modules_to_save.default.bias.requires_grad) + self.assertFalse(peft_model.model.lin1.original_module.weight.requires_grad) + self.assertFalse(peft_model.model.lin1.original_module.bias.requires_grad) + + # when using the disable_adapter context, the original module's grad should be enabled and vice versa + with peft_model.disable_adapter(): + self.assertFalse(peft_model.model.lin1.modules_to_save.default.weight.requires_grad) + self.assertFalse(peft_model.model.lin1.modules_to_save.default.bias.requires_grad) + self.assertTrue(peft_model.model.lin1.original_module.weight.requires_grad) + self.assertTrue(peft_model.model.lin1.original_module.bias.requires_grad) + + # after context is exited, return to the previous state + self.assertTrue(peft_model.model.lin1.modules_to_save.default.weight.requires_grad) + self.assertTrue(peft_model.model.lin1.modules_to_save.default.bias.requires_grad) + self.assertFalse(peft_model.model.lin1.original_module.weight.requires_grad) + self.assertFalse(peft_model.model.lin1.original_module.bias.requires_grad) + + def test_requires_grad_modules_to_save_multiple_adapters(self): + config0 = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"]) + peft_model = get_peft_model(MLP(), config0) + + config1 = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"]) + peft_model.add_adapter("adapter1", config1) + + # active adapter is still "default" + self.assertTrue(peft_model.model.lin1.modules_to_save.default.weight.requires_grad) + self.assertTrue(peft_model.model.lin1.modules_to_save.default.bias.requires_grad) + self.assertFalse(peft_model.model.lin1.modules_to_save.adapter1.weight.requires_grad) + self.assertFalse(peft_model.model.lin1.modules_to_save.adapter1.bias.requires_grad) + + # set config0 as active, should not change anything + peft_model.set_adapter("default") + self.assertTrue(peft_model.model.lin1.modules_to_save.default.weight.requires_grad) + self.assertTrue(peft_model.model.lin1.modules_to_save.default.bias.requires_grad) + self.assertFalse(peft_model.model.lin1.modules_to_save.adapter1.weight.requires_grad) + self.assertFalse(peft_model.model.lin1.modules_to_save.adapter1.bias.requires_grad) + + # set config1 as active, should lead to adapter1 requiring grad + peft_model.set_adapter("adapter1") + self.assertFalse(peft_model.model.lin1.modules_to_save.default.weight.requires_grad) + self.assertFalse(peft_model.model.lin1.modules_to_save.default.bias.requires_grad) + self.assertTrue(peft_model.model.lin1.modules_to_save.adapter1.weight.requires_grad) + self.assertTrue(peft_model.model.lin1.modules_to_save.adapter1.bias.requires_grad) + + def test_requires_grad_lora_different_targets(self): + # test two different LoRA adapters that target different modules + config0 = LoraConfig(target_modules=["lin0"]) + peft_model = get_peft_model(MLP(), config0) + + config1 = LoraConfig(target_modules=["lin1"]) + peft_model.add_adapter("adapter1", config1) + + # active adapter is still "default" + self.assertTrue(peft_model.model.lin0.lora_A.default.weight.requires_grad) + self.assertTrue(peft_model.model.lin0.lora_B.default.weight.requires_grad) + self.assertFalse(peft_model.model.lin1.lora_A.adapter1.weight.requires_grad) + self.assertFalse(peft_model.model.lin1.lora_B.adapter1.weight.requires_grad) + + # set config0 as active, should not change anything + peft_model.set_adapter("default") + self.assertTrue(peft_model.model.lin0.lora_A.default.weight.requires_grad) + self.assertTrue(peft_model.model.lin0.lora_B.default.weight.requires_grad) + self.assertFalse(peft_model.model.lin1.lora_A.adapter1.weight.requires_grad) + self.assertFalse(peft_model.model.lin1.lora_B.adapter1.weight.requires_grad) + + # change activate adapter to adapter1 + peft_model.set_adapter("adapter1") + self.assertFalse(peft_model.model.lin0.lora_A.default.weight.requires_grad) + self.assertFalse(peft_model.model.lin0.lora_B.default.weight.requires_grad) + self.assertTrue(peft_model.model.lin1.lora_A.adapter1.weight.requires_grad) + self.assertTrue(peft_model.model.lin1.lora_B.adapter1.weight.requires_grad) + + # disable all adapters + with peft_model.disable_adapter(): + self.assertFalse(peft_model.model.lin0.lora_A.default.weight.requires_grad) + self.assertFalse(peft_model.model.lin0.lora_B.default.weight.requires_grad) + self.assertFalse(peft_model.model.lin1.lora_A.adapter1.weight.requires_grad) + self.assertFalse(peft_model.model.lin1.lora_B.adapter1.weight.requires_grad) + + # after context is exited, return to the previous state + peft_model.set_adapter("adapter1") + self.assertFalse(peft_model.model.lin0.lora_A.default.weight.requires_grad) + self.assertFalse(peft_model.model.lin0.lora_B.default.weight.requires_grad) + self.assertTrue(peft_model.model.lin1.lora_A.adapter1.weight.requires_grad) + self.assertTrue(peft_model.model.lin1.lora_B.adapter1.weight.requires_grad) + + def test_requires_grad_lora_same_targets(self): + # same as previous test, except that LoRA adapters target the same layer + config0 = LoraConfig(target_modules=["lin0"]) + peft_model = get_peft_model(MLP(), config0) + + config1 = LoraConfig(target_modules=["lin0"]) + peft_model.add_adapter("adapter1", config1) + + # active adapter is still "default" + self.assertTrue(peft_model.model.lin0.lora_A.default.weight.requires_grad) + self.assertTrue(peft_model.model.lin0.lora_B.default.weight.requires_grad) + self.assertFalse(peft_model.model.lin0.lora_A.adapter1.weight.requires_grad) + self.assertFalse(peft_model.model.lin0.lora_B.adapter1.weight.requires_grad) + + # set config0 as active, should not change anything + peft_model.set_adapter("default") + self.assertTrue(peft_model.model.lin0.lora_A.default.weight.requires_grad) + self.assertTrue(peft_model.model.lin0.lora_B.default.weight.requires_grad) + self.assertFalse(peft_model.model.lin0.lora_A.adapter1.weight.requires_grad) + self.assertFalse(peft_model.model.lin0.lora_B.adapter1.weight.requires_grad) + + # change activate adapter to adapter1 + peft_model.set_adapter("adapter1") + self.assertFalse(peft_model.model.lin0.lora_A.default.weight.requires_grad) + self.assertFalse(peft_model.model.lin0.lora_B.default.weight.requires_grad) + self.assertTrue(peft_model.model.lin0.lora_A.adapter1.weight.requires_grad) + self.assertTrue(peft_model.model.lin0.lora_B.adapter1.weight.requires_grad) + + # disable all adapters + with peft_model.disable_adapter(): + self.assertFalse(peft_model.model.lin0.lora_A.default.weight.requires_grad) + self.assertFalse(peft_model.model.lin0.lora_B.default.weight.requires_grad) + self.assertFalse(peft_model.model.lin0.lora_A.adapter1.weight.requires_grad) + self.assertFalse(peft_model.model.lin0.lora_B.adapter1.weight.requires_grad) + + # after context is exited, return to the previous state + peft_model.set_adapter("adapter1") + self.assertFalse(peft_model.model.lin0.lora_A.default.weight.requires_grad) + self.assertFalse(peft_model.model.lin0.lora_B.default.weight.requires_grad) + self.assertTrue(peft_model.model.lin0.lora_A.adapter1.weight.requires_grad) + self.assertTrue(peft_model.model.lin0.lora_B.adapter1.weight.requires_grad) + + def test_requires_grad_ia3_different_targets(self): + # test two different IA3 adapters that target different modules + config0 = IA3Config(target_modules=["lin0"], feedforward_modules=["lin0"]) + peft_model = get_peft_model(MLP(), config0) + + config1 = IA3Config(target_modules=["lin1"], feedforward_modules=["lin1"]) + peft_model.add_adapter("adapter1", config1) + + # active adapter is still "default" + self.assertTrue(peft_model.model.lin0.ia3_l.default.requires_grad) + self.assertFalse(peft_model.model.lin1.ia3_l.adapter1.requires_grad) + + # set config0 as active, should not change anything + peft_model.set_adapter("default") + self.assertTrue(peft_model.model.lin0.ia3_l.default.requires_grad) + self.assertFalse(peft_model.model.lin1.ia3_l.adapter1.requires_grad) + + # change activate adapter to adapter1 + peft_model.set_adapter("adapter1") + self.assertFalse(peft_model.model.lin0.ia3_l.default.requires_grad) + self.assertTrue(peft_model.model.lin1.ia3_l.adapter1.requires_grad) + + # disable all adapters + with peft_model.disable_adapter(): + self.assertFalse(peft_model.model.lin0.ia3_l.default.requires_grad) + self.assertFalse(peft_model.model.lin1.ia3_l.adapter1.requires_grad) + + # after context is exited, return to the previous state + peft_model.set_adapter("adapter1") + self.assertFalse(peft_model.model.lin0.ia3_l.default.requires_grad) + self.assertTrue(peft_model.model.lin1.ia3_l.adapter1.requires_grad) + + def test_requires_grad_ia3_same_targets(self): + # same as previous test, except that IA3 adapters target the same layer + config0 = IA3Config(target_modules=["lin0"], feedforward_modules=["lin0"]) + peft_model = get_peft_model(MLP(), config0) + + config1 = IA3Config(target_modules=["lin0"], feedforward_modules=["lin1"]) + peft_model.add_adapter("adapter1", config1) + + # active adapter is still "default" + self.assertTrue(peft_model.model.lin0.ia3_l.default.requires_grad) + self.assertFalse(peft_model.model.lin0.ia3_l.adapter1.requires_grad) + + # set config0 as active, should not change anything + peft_model.set_adapter("default") + self.assertTrue(peft_model.model.lin0.ia3_l.default.requires_grad) + self.assertFalse(peft_model.model.lin0.ia3_l.adapter1.requires_grad) + + # change activate adapter to adapter1 + peft_model.set_adapter("adapter1") + self.assertFalse(peft_model.model.lin0.ia3_l.default.requires_grad) + self.assertTrue(peft_model.model.lin0.ia3_l.adapter1.requires_grad) + + # disable all adapters + with peft_model.disable_adapter(): + self.assertFalse(peft_model.model.lin0.ia3_l.default.requires_grad) + self.assertFalse(peft_model.model.lin0.ia3_l.adapter1.requires_grad) + + # after context is exited, return to the previous state + peft_model.set_adapter("adapter1") + self.assertFalse(peft_model.model.lin0.ia3_l.default.requires_grad) + self.assertTrue(peft_model.model.lin0.ia3_l.adapter1.requires_grad) + + def test_requires_grad_adalora_different_targets(self): + # test two different AdaLora adapters that target different modules + config0 = AdaLoraConfig(target_modules=["lin0"]) + peft_model = get_peft_model(MLP(), config0) + + config1 = AdaLoraConfig(target_modules=["lin1"], inference_mode=True) + peft_model.add_adapter("adapter1", config1) + + # active adapter is still "default" + self.assertTrue(peft_model.model.lin0.lora_A.default.requires_grad) + self.assertTrue(peft_model.model.lin0.lora_B.default.requires_grad) + self.assertTrue(peft_model.model.lin0.lora_E.default.requires_grad) + self.assertFalse(peft_model.model.lin1.lora_A.adapter1.requires_grad) + self.assertFalse(peft_model.model.lin1.lora_B.adapter1.requires_grad) + self.assertFalse(peft_model.model.lin1.lora_E.adapter1.requires_grad) + + # set config0 as active, should not change anything + peft_model.set_adapter("default") + self.assertTrue(peft_model.model.lin0.lora_A.default.requires_grad) + self.assertTrue(peft_model.model.lin0.lora_B.default.requires_grad) + self.assertTrue(peft_model.model.lin0.lora_E.default.requires_grad) + self.assertFalse(peft_model.model.lin1.lora_A.adapter1.requires_grad) + self.assertFalse(peft_model.model.lin1.lora_B.adapter1.requires_grad) + self.assertFalse(peft_model.model.lin1.lora_E.adapter1.requires_grad) + + # change activate adapter to adapter1 + peft_model.set_adapter("adapter1") + self.assertFalse(peft_model.model.lin0.lora_A.default.requires_grad) + self.assertFalse(peft_model.model.lin0.lora_B.default.requires_grad) + self.assertFalse(peft_model.model.lin0.lora_E.default.requires_grad) + self.assertTrue(peft_model.model.lin1.lora_A.adapter1.requires_grad) + self.assertTrue(peft_model.model.lin1.lora_B.adapter1.requires_grad) + self.assertTrue(peft_model.model.lin1.lora_E.adapter1.requires_grad) + + # disable all adapters + with peft_model.disable_adapter(): + self.assertFalse(peft_model.model.lin0.lora_A.default.requires_grad) + self.assertFalse(peft_model.model.lin0.lora_B.default.requires_grad) + self.assertFalse(peft_model.model.lin0.lora_E.default.requires_grad) + self.assertFalse(peft_model.model.lin1.lora_A.adapter1.requires_grad) + self.assertFalse(peft_model.model.lin1.lora_B.adapter1.requires_grad) + self.assertFalse(peft_model.model.lin1.lora_E.adapter1.requires_grad) + + # after context is exited, return to the previous state + peft_model.set_adapter("adapter1") + self.assertFalse(peft_model.model.lin0.lora_A.default.requires_grad) + self.assertFalse(peft_model.model.lin0.lora_B.default.requires_grad) + self.assertFalse(peft_model.model.lin0.lora_E.default.requires_grad) + self.assertTrue(peft_model.model.lin1.lora_A.adapter1.requires_grad) + self.assertTrue(peft_model.model.lin1.lora_B.adapter1.requires_grad) + self.assertTrue(peft_model.model.lin1.lora_E.adapter1.requires_grad) + + def test_requires_grad_adalora_same_targets(self): + # same as previous test, except that AdaLora adapters target the same layer + config0 = AdaLoraConfig(target_modules=["lin0"]) + peft_model = get_peft_model(MLP(), config0) + + config1 = AdaLoraConfig(target_modules=["lin0"], inference_mode=True) + peft_model.add_adapter("adapter1", config1) + + # active adapter is still "default" + self.assertTrue(peft_model.model.lin0.lora_A.default.requires_grad) + self.assertTrue(peft_model.model.lin0.lora_B.default.requires_grad) + self.assertFalse(peft_model.model.lin0.lora_E.adapter1.requires_grad) + self.assertFalse(peft_model.model.lin0.lora_A.adapter1.requires_grad) + self.assertFalse(peft_model.model.lin0.lora_B.adapter1.requires_grad) + self.assertFalse(peft_model.model.lin0.lora_E.adapter1.requires_grad) + + # set config0 as active, should not change anything + peft_model.set_adapter("default") + self.assertTrue(peft_model.model.lin0.lora_A.default.requires_grad) + self.assertTrue(peft_model.model.lin0.lora_B.default.requires_grad) + self.assertFalse(peft_model.model.lin0.lora_E.adapter1.requires_grad) + self.assertFalse(peft_model.model.lin0.lora_A.adapter1.requires_grad) + self.assertFalse(peft_model.model.lin0.lora_B.adapter1.requires_grad) + self.assertFalse(peft_model.model.lin0.lora_E.adapter1.requires_grad) + + # change activate adapter to adapter1 + peft_model.set_adapter("adapter1") + self.assertFalse(peft_model.model.lin0.lora_A.default.requires_grad) + self.assertFalse(peft_model.model.lin0.lora_B.default.requires_grad) + self.assertFalse(peft_model.model.lin0.lora_E.default.requires_grad) + self.assertTrue(peft_model.model.lin0.lora_A.adapter1.requires_grad) + self.assertTrue(peft_model.model.lin0.lora_B.adapter1.requires_grad) + self.assertTrue(peft_model.model.lin0.lora_E.adapter1.requires_grad) + + # disable all adapters + with peft_model.disable_adapter(): + self.assertFalse(peft_model.model.lin0.lora_A.default.requires_grad) + self.assertFalse(peft_model.model.lin0.lora_B.default.requires_grad) + self.assertFalse(peft_model.model.lin0.lora_E.default.requires_grad) + self.assertFalse(peft_model.model.lin0.lora_A.adapter1.requires_grad) + self.assertFalse(peft_model.model.lin0.lora_B.adapter1.requires_grad) + self.assertFalse(peft_model.model.lin0.lora_E.adapter1.requires_grad) + + # after context is exited, return to the previous state + peft_model.set_adapter("adapter1") + self.assertFalse(peft_model.model.lin0.lora_A.default.requires_grad) + self.assertFalse(peft_model.model.lin0.lora_B.default.requires_grad) + self.assertFalse(peft_model.model.lin0.lora_E.default.requires_grad) + self.assertTrue(peft_model.model.lin0.lora_A.adapter1.requires_grad) + self.assertTrue(peft_model.model.lin0.lora_B.adapter1.requires_grad) + self.assertTrue(peft_model.model.lin0.lora_E.adapter1.requires_grad) diff --git a/tests/testing_common.py b/tests/testing_common.py index e3741d9505..1dcb1a8fe8 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -207,14 +207,14 @@ def _test_prepare_for_training(self, model_id, config_cls, config_kwargs): dummy_input = self.prepare_inputs_for_testing() dummy_output = model.get_input_embeddings()(dummy_input["input_ids"]) - self.assertTrue(not dummy_output.requires_grad) + self.assertFalse(dummy_output.requires_grad) # load with `prepare_model_for_int8_training` model = self.transformers_class.from_pretrained(model_id).to(self.torch_device) model = prepare_model_for_int8_training(model) for param in model.parameters(): - self.assertTrue(not param.requires_grad) + self.assertFalse(param.requires_grad) config = config_cls( base_model_name_or_path=model_id,