Skip to content
4 changes: 2 additions & 2 deletions src/peft/tuners/adalora/bnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(

init_lora_weights = kwargs.pop("init_lora_weights", True)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.active_adapter = adapter_name
self.set_adapter(adapter_name)

def forward(self, x: torch.Tensor) -> torch.Tensor:
result = super().forward(x)
Expand Down Expand Up @@ -112,7 +112,7 @@ def __init__(

init_lora_weights = kwargs.pop("init_lora_weights", True)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.active_adapter = adapter_name
self.set_adapter(adapter_name)

def forward(self, x: torch.Tensor) -> torch.Tensor:
result = super().forward(x)
Expand Down
2 changes: 1 addition & 1 deletion src/peft/tuners/adalora/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
self.weight = quant_linear_module.qweight
init_lora_weights = kwargs.pop("init_lora_weights", True)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.active_adapter = adapter_name
self.set_adapter(adapter_name)

def forward(self, x: torch.Tensor) -> torch.Tensor:
result = self.quant_linear_module(x)
Expand Down
7 changes: 6 additions & 1 deletion src/peft/tuners/adalora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@


class AdaLoraLayer(LoraLayer):
# List all names of layers that may contain adapter weights
# Note: ranknum doesn't need to be included as it is not an nn.Module
adapter_layer_names = ["lora_A", "lora_B", "lora_E", "lora_embedding_A", "lora_embedding_B"]

def __init__(
self,
in_features: int,
Expand Down Expand Up @@ -59,6 +63,7 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig
if init_lora_weights:
self.reset_lora_parameters(adapter_name)
self.to(self.weight.device)
self.set_adapter(self.active_adapters)

def reset_lora_parameters(self, adapter_name):
if adapter_name in self.lora_A.keys():
Expand Down Expand Up @@ -92,7 +97,7 @@ def __init__(

nn.Linear.reset_parameters(self)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.active_adapter = adapter_name
self.set_adapter(adapter_name)

def merge(self) -> None:
if self.merged:
Expand Down
3 changes: 3 additions & 0 deletions src/peft/tuners/adalora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ def _create_and_replace(
# If it is not a LoraLayer, create a new module, else update it with new adapters
if not isinstance(target, AdaLoraLayer):
new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)
if adapter_name != self.active_adapter:
# adding an additional adapter: it is not automatically trainable
new_module.requires_grad_(False)
self._replace_module(parent, target_name, new_module, target)
else:
target.update_layer(
Expand Down
4 changes: 2 additions & 2 deletions src/peft/tuners/ia3/bnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ def __init__(
index=kwargs.get("index", None),
)
IA3Layer.__init__(self, in_features=in_features, out_features=out_features, is_feedforward=is_feedforward)
self.is_feedforward = is_feedforward

# Freezing the pre-trained weight matrix
self.weight.requires_grad = False

init_ia3_weights = kwargs.pop("init_ia3_weights", True)
self.update_layer(adapter_name, init_ia3_weights)
self.active_adapter = adapter_name
self.is_feedforward = is_feedforward
self.set_adapter(adapter_name)

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.disable_adapters:
Expand Down
11 changes: 7 additions & 4 deletions src/peft/tuners/ia3/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@


class IA3Layer(BaseTunerLayer):
# List all names of layers that may contain adapter weights
adapter_layer_names = ["ia3_l"]

def __init__(
self,
in_features: int,
Expand All @@ -34,8 +37,8 @@ def __init__(
self.ia3_l = nn.ParameterDict({})
# Mark the weight as unmerged
self.merged = False
self._disable_adapters = False
self.merged_adapters = []
self.disable_adapters = False
self.in_features = in_features
self.out_features = out_features
self.is_feedforward = is_feedforward
Expand All @@ -50,6 +53,7 @@ def update_layer(self, adapter_name, init_ia3_weights):
if init_ia3_weights:
self.reset_ia3_parameters(adapter_name)
self.to(self.weight.device)
self.set_adapter(self.active_adapters)

def reset_ia3_parameters(self, adapter_name):
if adapter_name in self.ia3_l.keys():
Expand All @@ -72,6 +76,7 @@ def __init__(

nn.Linear.__init__(self, in_features, out_features, **kwargs)
IA3Layer.__init__(self, in_features=in_features, out_features=out_features, is_feedforward=is_feedforward)
self.is_feedforward = is_feedforward
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False

Expand All @@ -81,9 +86,7 @@ def __init__(

nn.Linear.reset_parameters(self)
self.update_layer(adapter_name, init_ia3_weights)
self.active_adapter = adapter_name

self.is_feedforward = is_feedforward
self.set_adapter(adapter_name)

def merge(self) -> None:
if self.merged:
Expand Down
11 changes: 6 additions & 5 deletions src/peft/tuners/ia3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ def _create_and_replace(
)
else:
new_module = self._create_new_module(ia3_config, adapter_name, target, **kwargs)
if adapter_name != self.active_adapter:
# adding an additional adapter: it is not automatically trainable
new_module.requires_grad_(False)
self._replace_module(parent, target_name, new_module, target)

@staticmethod
Expand Down Expand Up @@ -213,10 +216,8 @@ def get_peft_config_as_dict(self, inference: bool = False):

def _set_adapter_layers(self, enabled=True):
for module in self.model.modules():
if isinstance(module, IA3Layer):
module.disable_adapters = False if enabled else True
elif isinstance(module, ModulesToSaveWrapper):
module.disable_adapters = False if enabled else True
if isinstance(module, (IA3Layer, ModulesToSaveWrapper)):
module.enable_adapters(enabled)

def enable_adapter_layers(self):
self._set_adapter_layers(enabled=True)
Expand All @@ -230,7 +231,7 @@ def set_adapter(self, adapter_name):
if module.merged:
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
module.unmerge()
module.active_adapter = adapter_name
module.set_adapter(adapter_name)

def _prepare_adapter_config(self, peft_config, model_config):
if peft_config.target_modules is None:
Expand Down
4 changes: 2 additions & 2 deletions src/peft/tuners/lora/bnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(
self.weight.requires_grad = False
init_lora_weights = kwargs.pop("init_lora_weights", True)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.active_adapter = adapter_name
self.set_adapter(adapter_name)

def merge(self):
if self.merged:
Expand Down Expand Up @@ -195,7 +195,7 @@ def __init__(

init_lora_weights = kwargs.pop("init_lora_weights", True)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.active_adapter = adapter_name
self.set_adapter(adapter_name)

def merge(self):
if self.merged:
Expand Down
2 changes: 1 addition & 1 deletion src/peft/tuners/lora/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
self.weight = quant_linear_module.qweight
init_lora_weights = kwargs.pop("init_lora_weights", True)
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.active_adapter = adapter_name
self.set_adapter(adapter_name)

def forward(self, x: torch.Tensor):
# note: logic differs from default Linear because merging is not supported
Expand Down
12 changes: 8 additions & 4 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@


class LoraLayer(BaseTunerLayer):
# List all names of layers that may contain adapter weights
adapter_layer_names = ["lora_A", "lora_B", "lora_embedding_A", "lora_embedding_B"]

def __init__(self, in_features: int, out_features: int, **kwargs):
self.r = {}
self.lora_alpha = {}
Expand All @@ -38,8 +41,8 @@ def __init__(self, in_features: int, out_features: int, **kwargs):
self.lora_embedding_B = nn.ParameterDict({})
# Mark the weight as unmerged
self.merged = False
self._disable_adapters = False
self.merged_adapters = []
self.disable_adapters = False
self.in_features = in_features
self.out_features = out_features
self.kwargs = kwargs
Expand Down Expand Up @@ -82,6 +85,7 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig
self.to(weight.device, dtype=weight.dtype)
else:
self.to(weight.device)
self.set_adapter(self.active_adapters)

def update_layer_conv2d(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights):
if r <= 0:
Expand Down Expand Up @@ -197,8 +201,8 @@ def __init__(
self.fan_in_fan_out = fan_in_fan_out

self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.active_adapter = adapter_name
self.is_target_conv_1d_layer = is_target_conv_1d_layer
self.set_adapter(adapter_name)

def merge(self) -> None:
if self.merged:
Expand Down Expand Up @@ -275,7 +279,7 @@ def __init__(
self._init_empty_weights(nn.Embedding, num_embeddings, embedding_dim, **kwargs)
LoraLayer.__init__(self, in_features=num_embeddings, out_features=embedding_dim)
self.update_layer_embedding(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.active_adapter = adapter_name
self.set_adapter(adapter_name)

def merge(self) -> None:
if self.merged:
Expand Down Expand Up @@ -364,7 +368,7 @@ def __init__(
)

self.update_layer_conv2d(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.active_adapter = adapter_name
self.set_adapter(adapter_name)

def merge(self) -> None:
if self.merged:
Expand Down
48 changes: 19 additions & 29 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@
from dataclasses import asdict, replace
from enum import Enum
from itertools import chain
from typing import List

import torch
from torch import nn
from tqdm import tqdm
from transformers.pytorch_utils import Conv1D

from peft.import_utils import is_bnb_4bit_available, is_bnb_available
from peft.tuners.tuners_utils import BaseTuner
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer
from peft.utils import (
COMMON_LAYERS_PATTERN,
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING,
Expand Down Expand Up @@ -216,6 +215,9 @@ def _create_and_replace(
)
else:
new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)
if adapter_name != self.active_adapter:
# adding an additional adapter: it is not automatically trainable
new_module.requires_grad_(False)
self._replace_module(parent, target_name, new_module, target)

@staticmethod
Expand All @@ -239,15 +241,16 @@ def _replace_module(parent, child_name, new_module, child):
module.to(child.weight.device)

def _mark_only_adapters_as_trainable(self) -> None:
for active_adapter in self._get_active_adapters():
bias = self.peft_config[active_adapter].bias
for n, p in self.model.named_parameters():
if "lora_" not in n:
p.requires_grad = False

for n, p in self.model.named_parameters():
if "lora_" not in n:
p.requires_grad = False
for active_adapter in self.active_adapters:
bias = self.peft_config[active_adapter].bias
if bias == "none":
return
elif bias == "all":
continue

if bias == "all":
for n, p in self.model.named_parameters():
if "bias" in n:
p.requires_grad = True
Expand Down Expand Up @@ -351,28 +354,14 @@ def get_peft_config_as_dict(self, inference: bool = False):

def _set_adapter_layers(self, enabled=True):
for module in self.model.modules():
if isinstance(module, LoraLayer):
module.disable_adapters = False if enabled else True
elif isinstance(module, ModulesToSaveWrapper):
module.disable_adapters = False if enabled else True
if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)):
module.enable_adapters(enabled)

def enable_adapter_layers(self):
self._set_adapter_layers(enabled=True)

def _get_active_adapters(self) -> List[str]:
active_adapters = None
for module in self.model.modules():
if isinstance(module, LoraLayer):
active_adapters = module.active_adapters

if active_adapters is None:
raise ValueError(
"Something went wrong, no active adapter could be found, please report the issue on GitHub"
)
return active_adapters

def disable_adapter_layers(self):
for active_adapter in self._get_active_adapters():
for active_adapter in self.active_adapters:
val = self.peft_config[active_adapter].bias
if val != "none":
msg = (
Expand All @@ -388,7 +377,7 @@ def set_adapter(self, adapter_name):
if module.merged:
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
module.unmerge()
module.active_adapter = adapter_name
module.set_adapter(adapter_name)

@staticmethod
def _prepare_adapter_config(peft_config, model_config):
Expand Down Expand Up @@ -626,7 +615,7 @@ def _svd_weighted_adapter(
Vh = Vh.reshape(target_lora_A.data.shape)
return Vh, U

def delete_adapter(self, adapter_name):
def delete_adapter(self, adapter_name: str):
"""
Deletes an existing adapter.

Expand All @@ -636,6 +625,7 @@ def delete_adapter(self, adapter_name):
if adapter_name not in list(self.peft_config.keys()):
raise ValueError(f"Adapter {adapter_name} does not exist")
del self.peft_config[adapter_name]

key_list = [key for key, _ in self.model.named_modules() if "lora" not in key]
for key in key_list:
_, target, _ = _get_submodules(self.model, key)
Expand All @@ -659,7 +649,7 @@ def delete_adapter(self, adapter_name):
warnings.warn(
f"Adapter {adapter_name} was active which is now deleted. Setting active adapter to {resetting_active_adapter}. "
)
target.active_adapter = resetting_active_adapter
target.set_adapter(resetting_active_adapter)

def merge_and_unload(self, progressbar: bool = False):
r"""
Expand Down
Loading