diff --git a/docs/source/conceptual_guides/lora.mdx b/docs/source/conceptual_guides/lora.mdx index ff028ca4bd..4f3027241c 100644 --- a/docs/source/conceptual_guides/lora.mdx +++ b/docs/source/conceptual_guides/lora.mdx @@ -77,6 +77,8 @@ As with other methods supported by PEFT, to fine-tune a model using LoRA, you ne - `modules_to_save`: List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. These typically include model's custom head that is randomly initialized for the fine-tuning task. - `layers_to_transform`: List of layers to be transformed by LoRA. If not specified, all layers in `target_modules` are transformed. - `layers_pattern`: Pattern to match layer names in `target_modules`, if `layers_to_transform` is specified. By default `PeftModel` will look at common layer pattern (`layers`, `h`, `blocks`, etc.), use it for exotic and custom models. +- `rank_pattern`: The mapping from layer names or regexp expression to ranks which are different from the default rank specified by `r`. +- `alpha_pattern`: The mapping from layer names or regexp expression to alphas which are different from the default alpha specified by `lora_alpha`. ## LoRA examples diff --git a/examples/lora_dreambooth/convert_kohya_ss_sd_lora_to_peft.py b/examples/lora_dreambooth/convert_kohya_ss_sd_lora_to_peft.py index cba8a6d6f8..5b3099fdff 100644 --- a/examples/lora_dreambooth/convert_kohya_ss_sd_lora_to_peft.py +++ b/examples/lora_dreambooth/convert_kohya_ss_sd_lora_to_peft.py @@ -1,11 +1,10 @@ import argparse import os -import re -from typing import Callable, List, Optional, Union +from dataclasses import dataclass +from typing import Dict, Optional import safetensors import torch -import torch.nn as nn from diffusers import UNet2DConditionModel from transformers import CLIPTextModel @@ -21,44 +20,66 @@ LORA_PREFIX_TEXT_ENCODER = "lora_te" -def get_modules_names( - root_module: nn.Module, - target_replace_modules_linear: Optional[List[str]] = [], - target_replace_modules_conv2d: Optional[List[str]] = [], -): - # Combine replacement modules - target_replace_modules = target_replace_modules_linear + target_replace_modules_conv2d - - # Store result - modules_names = set() - # https://github.com/kohya-ss/sd-scripts/blob/c924c47f374ac1b6e33e71f82948eb1853e2243f/networks/lora.py#L720 - for name, module in root_module.named_modules(): - if module.__class__.__name__ in target_replace_modules: - if len(name) == 0: - continue - for child_name, child_module in module.named_modules(): - if len(child_name) == 0: - continue - is_linear = isinstance(child_module, nn.Linear) - is_conv2d = isinstance(child_module, nn.Conv2d) - - if (is_linear and module.__class__.__name__ in target_replace_modules_linear) or ( - is_conv2d and module.__class__.__name__ in target_replace_modules_conv2d - ): - modules_names.add(f"{name}.{child_name}") - - return sorted(modules_names) - - -def get_rank_alpha( - layer_names: List[str], - value_getter: Callable[[str], Union[int, float]], - filter_string: str, -) -> Union[int, float]: - values = [value_getter(p) for p in filter(lambda x: bool(re.search(filter_string, x)), layer_names)] - value = values[0] - assert all(v == value for v in values), f"All LoRA ranks and alphas must be same, found: {values}" - return value +@dataclass +class LoRAInfo: + kohya_key: str + peft_key: str + alpha: Optional[float] = None + rank: Optional[int] = None + lora_A: Optional[torch.Tensor] = None + lora_B: Optional[torch.Tensor] = None + + def peft_state_dict(self) -> Dict[str, torch.Tensor]: + if self.lora_A is None or self.lora_B is None: + raise ValueError("At least one of lora_A or lora_B is None, they must both be provided") + return {f"{peft_key}.lora_A.weight": self.lora_A, f"{peft_key}.lora_B.weight": self.lora_A} + + +def construct_peft_loraconfig(info: Dict[str, LoRAInfo]) -> LoraConfig: + """Constructs LoraConfig from data extracted from kohya checkpoint + + Args: + info (Dict[str, LoRAInfo]): Information extracted from kohya checkpoint + + Returns: + LoraConfig: config for constructing LoRA + """ + + # Unpack all ranks and alphas + ranks = {x[0]: x[1].rank for x in info.items()} + alphas = {x[0]: x[1].alpha or x[1].rank for x in info.items()} + + # Determine which modules needs to be transformed + target_modules = list(info.keys()) + + # Determine most common rank and alpha + r = max(set(ranks.values()), key=list(ranks.values()).count) + lora_alpha = max(set(alphas.values()), key=list(alphas.values()).count) + + # Determine which modules have different rank and alpha + rank_pattern = dict(filter(lambda x: x[1] != r, ranks.items())) + alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, alphas.items())) + + config = LoraConfig( + r=r, + lora_alpha=lora_alpha, + target_modules=target_modules, + lora_dropout=0.0, + bias="none", + init_lora_weights=False, + rank_pattern=rank_pattern, + alpha_pattern=alpha_pattern, + ) + + return config + + +def combine_peft_state_dict(info: Dict[str, LoRAInfo]) -> Dict[str, torch.Tensor]: + result = {} + for key_name, key_info in info.items(): + result[f"base_model.model.{key_name}.lora_A.weight"] = key_info.lora_A + result[f"base_model.model.{key_name}.lora_B.weight"] = key_info.lora_B + return result if __name__ == "__main__": @@ -75,93 +96,79 @@ def get_rank_alpha( parser.add_argument("--half", action="store_true", help="Save weights in half precision.") args = parser.parse_args() - # Find text encoder modules to add LoRA to + # Load all models that we need to add adapter to text_encoder = CLIPTextModel.from_pretrained(args.sd_checkpoint, subfolder="text_encoder") - text_encoder_modules_names = get_modules_names( - text_encoder, target_replace_modules_linear=TEXT_ENCODER_TARGET_REPLACE_MODULE - ) - - # Find unet2d modules to add LoRA to unet = UNet2DConditionModel.from_pretrained(args.sd_checkpoint, subfolder="unet") - unet_modules_names = get_modules_names( - unet, - target_replace_modules_linear=UNET_TARGET_REPLACE_MODULE, - target_replace_modules_conv2d=UNET_TARGET_REPLACE_MODULE, - ) + + # Construct possible mapping from kohya keys to peft keys + models_keys = {} + for model, model_key, model_name in [ + (text_encoder, LORA_PREFIX_TEXT_ENCODER, "text_encoder"), + (unet, LORA_PREFIX_UNET, "unet"), + ]: + models_keys.update( + { + f"{model_key}.{peft_key}".replace(".", "_"): peft_key + for peft_key in (x[0] for x in model.named_modules()) + } + ) + + # Store conversion info (model_type -> peft_key -> LoRAInfo) + lora_info: Dict[str, Dict[str, LoRAInfo]] = { + "text_encoder": {}, + "unet": {}, + } # Open kohya_ss checkpoint with safetensors.safe_open(args.kohya_lora_path, framework="pt", device="cpu") as f: # Extract information about LoRA structure metadata = f.metadata() - if (metadata is not None) and ("ss_network_dim" in metadata) and ("ss_network_alpha" in metadata): - # LoRA rank and alpha are in safetensors metadata, just get it - lora_r = lora_text_encoder_r = int(metadata["ss_network_dim"]) - lora_alpha = lora_text_encoder_alpha = float(metadata["ss_network_alpha"]) - else: - # LoRA rank and alpha are not present, so infer them - lora_r = get_rank_alpha( - f.keys(), lambda n: f.get_tensor(n).size(0), f"^{LORA_PREFIX_UNET}\w+\.lora_down\.weight$" - ) - lora_text_encoder_r = get_rank_alpha( - f.keys(), lambda n: f.get_tensor(n).size(0), f"^{LORA_PREFIX_TEXT_ENCODER}\w+\.lora_down\.weight$" - ) - lora_alpha = get_rank_alpha(f.keys(), lambda n: f.get_tensor(n).item(), f"^{LORA_PREFIX_UNET}\w+\.alpha$") - lora_text_encoder_alpha = get_rank_alpha( - f.keys(), lambda n: f.get_tensor(n).item(), f"^{LORA_PREFIX_TEXT_ENCODER}\w+\.alpha$" - ) - - # Create LoRA for text encoder - text_encoder_config = LoraConfig( - r=lora_text_encoder_r, - lora_alpha=lora_text_encoder_alpha, - target_modules=text_encoder_modules_names, - lora_dropout=0.0, - bias="none", - ) - text_encoder = get_peft_model(text_encoder, text_encoder_config) - text_encoder_lora_state_dict = {x: None for x in get_peft_model_state_dict(text_encoder).keys()} - - # Load text encoder values from kohya_ss LoRA - for peft_te_key in text_encoder_lora_state_dict.keys(): - kohya_ss_te_key = peft_te_key.replace("base_model.model", LORA_PREFIX_TEXT_ENCODER) - kohya_ss_te_key = kohya_ss_te_key.replace("lora_A", "lora_down") - kohya_ss_te_key = kohya_ss_te_key.replace("lora_B", "lora_up") - kohya_ss_te_key = kohya_ss_te_key.replace(".", "_", kohya_ss_te_key.count(".") - 2) - text_encoder_lora_state_dict[peft_te_key] = f.get_tensor(kohya_ss_te_key).to(text_encoder.dtype) - # Load converted kohya_ss text encoder LoRA back to PEFT - set_peft_model_state_dict(text_encoder, text_encoder_lora_state_dict) + # Iterate through available info and unpack all the values + for key in f.keys(): + kohya_key, kohya_type = key.split(".")[:2] + + # Find which model this key belongs to + if kohya_key.startswith(LORA_PREFIX_TEXT_ENCODER): + model_type = "text_encoder" + elif kohya_key.startswith(LORA_PREFIX_UNET): + model_type = "unet" + else: + raise ValueError(f"Cannot determine model for key: {key}") + + # Find corresponding peft key + if kohya_key not in models_keys: + raise ValueError(f"Cannot find corresponding key for diffusers/transformers model: {kohya_key}") + peft_key = models_keys[kohya_key] + + if peft_key not in lora_info[model_type]: + lora_info[model_type][peft_key] = LoRAInfo(kohya_key=kohya_key, peft_key=peft_key) + + if kohya_type == "alpha": + lora_info[model_type][peft_key].alpha = f.get_tensor(key).item() + elif kohya_type == "lora_down": + tensor = f.get_tensor(key) + lora_info[model_type][peft_key].lora_A = tensor + lora_info[model_type][peft_key].rank = tensor.shape[0] + elif kohya_type == "lora_up": + tensor = f.get_tensor(key) + lora_info[model_type][peft_key].lora_B = f.get_tensor(key) + lora_info[model_type][peft_key].rank = tensor.shape[1] + else: + raise ValueError(f"Unknown weight name in key: {key} - {kohya_type}") + + # Process each model + for model, model_name in [(text_encoder, "text_encoder"), (unet, "unet")]: + config = construct_peft_loraconfig(lora_info[model_name]) + model = get_peft_model(model, config) + + keys_peft = list(get_peft_model_state_dict(model).keys()) + keys_new = list(combine_peft_state_dict(lora_info[model_name]).keys()) + + set_peft_model_state_dict(model, combine_peft_state_dict(lora_info[model_name])) if args.half: - text_encoder.to(torch.float16) - - # Save text encoder result - text_encoder.save_pretrained( - os.path.join(args.dump_path, "text_encoder"), - ) + model.to(torch.float16) - # Create LoRA for unet2d - unet_config = LoraConfig( - r=lora_r, lora_alpha=lora_alpha, target_modules=unet_modules_names, lora_dropout=0.0, bias="none" - ) - unet = get_peft_model(unet, unet_config) - unet_lora_state_dict = {x: None for x in get_peft_model_state_dict(unet).keys()} - - # Load unet2d values from kohya_ss LoRA - for peft_unet_key in unet_lora_state_dict.keys(): - kohya_ss_unet_key = peft_unet_key.replace("base_model.model", LORA_PREFIX_UNET) - kohya_ss_unet_key = kohya_ss_unet_key.replace("lora_A", "lora_down") - kohya_ss_unet_key = kohya_ss_unet_key.replace("lora_B", "lora_up") - kohya_ss_unet_key = kohya_ss_unet_key.replace(".", "_", kohya_ss_unet_key.count(".") - 2) - unet_lora_state_dict[peft_unet_key] = f.get_tensor(kohya_ss_unet_key).to(unet.dtype) - - # Load converted kohya_ss unet LoRA back to PEFT - set_peft_model_state_dict(unet, unet_lora_state_dict) - - if args.half: - unet.to(torch.float16) - - # Save text encoder result - unet.save_pretrained( - os.path.join(args.dump_path, "unet"), - ) + # Save model to disk + model.save_pretrained(os.path.join(args.dump_path, model_name)) diff --git a/src/peft/tuners/adalora/bnb.py b/src/peft/tuners/adalora/bnb.py index d57c1b0142..cd61ad75dd 100644 --- a/src/peft/tuners/adalora/bnb.py +++ b/src/peft/tuners/adalora/bnb.py @@ -56,31 +56,30 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: result = super().forward(x) - if ( - self.disable_adapters - or (self.active_adapter not in self.lora_A.keys()) - or (self.r[self.active_adapter] == 0) - ): + if self.disable_adapters: return result - requires_conversion = not torch.is_autocast_enabled() - if requires_conversion: - expected_dtype = result.dtype - if x.dtype != torch.float32: - x = x.float() - - lora_A = self.lora_A[self.active_adapter] - lora_B = self.lora_B[self.active_adapter] - lora_E = self.lora_E[self.active_adapter] - dropout = self.lora_dropout[self.active_adapter] - scaling = self.scaling[self.active_adapter] - ranknum = self.ranknum[self.active_adapter] + 1e-5 - - output = dropout(x) @ (lora_A * lora_E).T @ lora_B.T - if requires_conversion: - output = output.to(expected_dtype) - output = output * scaling / ranknum - result = result + output + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + if x.dtype != torch.float32: + x = x.float() + + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + lora_E = self.lora_E[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + ranknum = self.ranknum[active_adapter] + 1e-5 + + output = dropout(x) @ (lora_A * lora_E).T @ lora_B.T + if requires_conversion: + output = output.to(expected_dtype) + output = output * scaling / ranknum + result += output return result @@ -118,11 +117,7 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: result = super().forward(x) - if ( - self.disable_adapters - or (self.active_adapter not in self.lora_A.keys()) - or (self.r[self.active_adapter] == 0) - ): + if self.disable_adapters: return result # As per Tim Dettmers, for 4bit, we need to defensively clone here. @@ -132,23 +127,27 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # sure. result = result.clone() - lora_A = self.lora_A[self.active_adapter] - lora_B = self.lora_B[self.active_adapter] - lora_E = self.lora_E[self.active_adapter] - dropout = self.lora_dropout[self.active_adapter] - scaling = self.scaling[self.active_adapter] - ranknum = self.ranknum[self.active_adapter] + 1e-5 - - requires_conversion = not torch.is_autocast_enabled() - if requires_conversion: - expected_dtype = result.dtype - compute_dtype = lora_A.weight.dtype - if x.dtype != compute_dtype: - x = x.to(compute_dtype) - - output = dropout(x) @ (lora_A * lora_E).T @ lora_B.T - if requires_conversion: - output = output.to(expected_dtype) - output = output * scaling / ranknum - result = result + output + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + lora_E = self.lora_E[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + ranknum = self.ranknum[active_adapter] + 1e-5 + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + compute_dtype = lora_A.weight.dtype + if x.dtype != compute_dtype: + x = x.to(compute_dtype) + + output = dropout(x) @ (lora_A * lora_E).T @ lora_B.T + if requires_conversion: + output = output.to(expected_dtype) + output = output * scaling / ranknum + result += output return result diff --git a/src/peft/tuners/adalora/gptq.py b/src/peft/tuners/adalora/gptq.py index dc19d436a3..0ea9bc82fc 100644 --- a/src/peft/tuners/adalora/gptq.py +++ b/src/peft/tuners/adalora/gptq.py @@ -40,31 +40,30 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: result = self.quant_linear_module(x) - if ( - self.disable_adapters - or (self.active_adapter not in self.lora_A.keys()) - or (self.r[self.active_adapter] == 0) - ): + if self.disable_adapters: return result - requires_conversion = not torch.is_autocast_enabled() - if requires_conversion: - expected_dtype = result.dtype - if x.dtype != torch.float32: - x = x.float() + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + lora_E = self.lora_E[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + ranknum = self.ranknum[active_adapter] + 1e-5 - lora_A = self.lora_A[self.active_adapter] - lora_B = self.lora_B[self.active_adapter] - lora_E = self.lora_E[self.active_adapter] - dropout = self.lora_dropout[self.active_adapter] - scaling = self.scaling[self.active_adapter] - ranknum = self.ranknum[self.active_adapter] + 1e-5 + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + if x.dtype != torch.float32: + x = x.float() - output = (dropout(x) @ (lora_A * lora_E).T @ lora_B.T) * scaling / ranknum - # TODO: here, the dtype conversion is applied on the *whole expression*, - # not the intermediate result, unlike for SVDLinear8bitLT and - # SVDLinear4bit, is that correct? - if requires_conversion: - output = output.to(expected_dtype) - result = result + output + output = (dropout(x) @ (lora_A * lora_E).T @ lora_B.T) * scaling / ranknum + # TODO: here, the dtype conversion is applied on the *whole expression*, + # not the intermediate result, unlike for SVDLinear8bitLT and + # SVDLinear4bit, is that correct? + if requires_conversion: + output = output.to(expected_dtype) + result += output return result diff --git a/src/peft/tuners/adalora/layer.py b/src/peft/tuners/adalora/layer.py index 86c2d36f28..610d5403d8 100644 --- a/src/peft/tuners/adalora/layer.py +++ b/src/peft/tuners/adalora/layer.py @@ -95,64 +95,58 @@ def __init__( self.active_adapter = adapter_name def merge(self) -> None: - if self.active_adapter not in self.lora_A.keys(): - return if self.merged: - warnings.warn("Already merged. Nothing to do.") - return - if self.r[self.active_adapter] > 0: - self.weight.data += ( - transpose( - self.lora_B[self.active_adapter] - @ (self.lora_A[self.active_adapter] * self.lora_E[self.active_adapter]), - self.fan_in_fan_out, - ) - * self.scaling[self.active_adapter] - / (self.ranknum[self.active_adapter] + 1e-5) + warnings.warn( + f"Already following adapters were merged {','.join(self.merged_adapters)}. " + f"You are now additionally merging {','.join(self.active_adapters)}." ) - self.merged = True + for active_adapter in self.active_adapters: + if active_adapter in self.lora_embedding_A.keys(): + self.weight.data += self.get_delta_weight(active_adapter) + self.merged_adapters.append(active_adapter) + self.merged = True def unmerge(self) -> None: - if self.active_adapter not in self.lora_A.keys(): - return if not self.merged: warnings.warn("Already unmerged. Nothing to do.") return - if self.r[self.active_adapter] > 0: - self.weight.data -= ( - transpose( - self.lora_B[self.active_adapter] - @ (self.lora_A[self.active_adapter] * self.lora_E[self.active_adapter]) - ) - * self.scaling[self.active_adapter] - / (self.ranknum[self.active_adapter] + 1e-5) - ) - self.merged = False + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self.lora_embedding_A.keys(): + self.weight.data -= self.get_delta_weight(active_adapter) + self.merged = False + + def get_delta_weight(self, adapter) -> torch.Tensor: + return ( + transpose(self.lora_B[adapter] @ (self.lora_A[adapter] * self.lora_E[adapter])) + * self.scaling[adapter] + / (self.ranknum[adapter] + 1e-5) + ) def _linear(self, input: torch.Tensor) -> torch.Tensor: return F.linear(input, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.active_adapter not in self.lora_A.keys(): - return self._linear(x) - # TODO: SVDLinear does not convert dtype, unlike lora linear, is that correct? if self.disable_adapters: - if self.r[self.active_adapter] > 0 and self.merged: + if self.merged: self.unmerge() result = self._linear(x) - elif (self.r[self.active_adapter] == 0) or self.merged: + elif self.merged: result = self._linear(x) else: - lora_A = self.lora_A[self.active_adapter] - lora_B = self.lora_B[self.active_adapter] - lora_E = self.lora_E[self.active_adapter] - dropout = self.lora_dropout[self.active_adapter] - scaling = self.scaling[self.active_adapter] - ranknum = self.ranknum[self.active_adapter] + 1e-5 - result = self._linear(x) - result += (dropout(x) @ (lora_A * lora_E).T @ lora_B.T) * scaling / ranknum + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + lora_E = self.lora_E[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + ranknum = self.ranknum[active_adapter] + 1e-5 + + result += (dropout(x) @ (lora_A * lora_E).T @ lora_B.T) * scaling / ranknum return result diff --git a/src/peft/tuners/adalora/model.py b/src/peft/tuners/adalora/model.py index a0a003a73b..bc9932b381 100644 --- a/src/peft/tuners/adalora/model.py +++ b/src/peft/tuners/adalora/model.py @@ -115,10 +115,10 @@ def _create_and_replace( target, target_name, parent, - **optionnal_kwargs, + **optional_kwargs, ): - loaded_in_8bit = optionnal_kwargs.get("loaded_in_8bit", False) - loaded_in_4bit = optionnal_kwargs.get("loaded_in_4bit", False) + loaded_in_8bit = optional_kwargs.get("loaded_in_8bit", False) + loaded_in_4bit = optional_kwargs.get("loaded_in_4bit", False) if (loaded_in_8bit or loaded_in_4bit) and not is_bnb_available(): raise ImportError( "To use Lora with 8-bit quantization, please install the `bitsandbytes` package. " diff --git a/src/peft/tuners/ia3/bnb.py b/src/peft/tuners/ia3/bnb.py index c88829af0f..8f5868b8e9 100644 --- a/src/peft/tuners/ia3/bnb.py +++ b/src/peft/tuners/ia3/bnb.py @@ -50,14 +50,18 @@ def __init__( self.is_feedforward = is_feedforward def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.disable_adapters or (self.active_adapter not in self.ia3_l.keys()): + if self.disable_adapters: return super().forward(x) + ia3_scaling = 1 + for active_adapter in self.active_adapters: + if active_adapter not in self.ia3_l.keys(): + continue + ia3_scaling *= self.ia3_l[active_adapter].flatten() + requires_conversion = (not torch.is_autocast_enabled()) and (x.dtype != torch.float32) if requires_conversion: x = x.float() - - ia3_scaling = self.ia3_l[self.active_adapter].flatten() if self.is_feedforward: result = super().forward(x * ia3_scaling) expected_dtype = result.dtype diff --git a/src/peft/tuners/ia3/layer.py b/src/peft/tuners/ia3/layer.py index 3da8383cac..19527cd5d2 100644 --- a/src/peft/tuners/ia3/layer.py +++ b/src/peft/tuners/ia3/layer.py @@ -34,6 +34,7 @@ def __init__( self.ia3_l = nn.ParameterDict({}) # Mark the weight as unmerged self.merged = False + self.merged_adapters = [] self.disable_adapters = False self.in_features = in_features self.out_features = out_features @@ -85,40 +86,37 @@ def __init__( self.is_feedforward = is_feedforward def merge(self) -> None: - if self.active_adapter not in self.ia3_l.keys(): - return if self.merged: - warnings.warn("Already merged. Nothing to do.") - return - - self.weight = transpose(self.weight, self.fan_in_fan_out) - self.weight.data = torch.mul(self.weight.data, self.ia3_l[self.active_adapter].data) - self.weight = transpose(self.weight, self.fan_in_fan_out) - - self.merged = True + warnings.warn( + f"Already following adapters were merged {','.join(self.merged_adapters)}. " + f"You are now additionally merging {','.join(self.active_adapters)}." + ) + + for active_adapter in self.active_adapters: + if active_adapter in self.ia3_l.keys(): + self.weight = transpose(self.weight, self.fan_in_fan_out) + self.weight.data = torch.mul(self.weight.data, self.ia3_l[active_adapter].data) + self.weight = transpose(self.weight, self.fan_in_fan_out) + self.merged = True def unmerge(self) -> None: - if self.active_adapter not in self.ia3_l.keys(): - return if not self.merged: warnings.warn("Already unmerged. Nothing to do.") return warnings.warn("Unmerge result can be inaccurate for (IA)^3.") - self.weight = transpose(self.weight, self.fan_in_fan_out) - # divide by (IA)^3 vector. Add tolerace to avoid division by zero - self.weight.data = torch.div(self.weight.data, self.ia3_l[self.active_adapter].data + 1e-8) - self.weight = transpose(self.weight, self.fan_in_fan_out) - - self.merged = False + for active_adapter in self.active_adapters: + if active_adapter in self.ia3_l.keys(): + self.weight = transpose(self.weight, self.fan_in_fan_out) + # divide by (IA)^3 vector. Add tolerace to avoid division by zero + self.weight.data = torch.div(self.weight.data, self.ia3_l[active_adapter].data + 1e-8) + self.weight = transpose(self.weight, self.fan_in_fan_out) + self.merged = False def _linear(self, input: torch.Tensor) -> torch.Tensor: return F.linear(input, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.active_adapter not in self.ia3_l.keys(): - return self._linear(x) - previous_dtype = x.dtype if self.disable_adapters: @@ -128,11 +126,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: elif self.merged: result = self._linear(x) else: - dtype = self.ia3_l[self.active_adapter].dtype - ia3_scaling = self.ia3_l[self.active_adapter].flatten() + ia3_scaling = 1 + for active_adapter in self.active_adapters: + if active_adapter not in self.ia3_l.keys(): + continue + dtype = self.ia3_l[active_adapter].dtype + ia3_scaling *= self.ia3_l[active_adapter].flatten() + if self.is_feedforward: x = x.to(dtype) - # TODO: self.weight.dtype can be != self.ia3_l[self.active_adapter].dtype + # TODO: self.weight.dtype can be != self.ia3_l[self.active_adapters].dtype # e.g. bf16 vs fp32. Is that okay? interm = (x * ia3_scaling).to(self.weight.dtype) result = self._linear(interm) diff --git a/src/peft/tuners/ia3/model.py b/src/peft/tuners/ia3/model.py index 06a8d19684..106631058a 100644 --- a/src/peft/tuners/ia3/model.py +++ b/src/peft/tuners/ia3/model.py @@ -153,10 +153,10 @@ def _create_and_replace( target, target_name, parent, - **optionnal_kwargs, + **optional_kwargs, ): - loaded_in_8bit = optionnal_kwargs["loaded_in_8bit"] - current_key = optionnal_kwargs["current_key"] + loaded_in_8bit = optional_kwargs["loaded_in_8bit"] + current_key = optional_kwargs["current_key"] # check if target module is in feedforward_modules if isinstance(ia3_config.feedforward_modules, str): diff --git a/src/peft/tuners/lora/bnb.py b/src/peft/tuners/lora/bnb.py index 445f00ca21..7fdb5140a5 100644 --- a/src/peft/tuners/lora/bnb.py +++ b/src/peft/tuners/lora/bnb.py @@ -57,16 +57,18 @@ def __init__( self.active_adapter = adapter_name def merge(self): - if self.active_adapter not in self.lora_A.keys(): - return if self.merged: - warnings.warn("Already merged. Nothing to do.") - return - if self.r[self.active_adapter] > 0: + warnings.warn( + f"Already following adapters were merged {','.join(self.merged_adapters)}. " + f"You are now additionally merging {','.join(self.active_adapters)}." + ) + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue warnings.warn( "Merge lora module to 8-bit linear may get different generations due to rounding errors." ) - lora_data = self.get_delta_weight(self.active_adapter) + lora_data = self.get_delta_weight(active_adapter) if self.state.SCB is None: self.state.SCB = self.weight.SCB @@ -87,19 +89,21 @@ def merge(self): w_data.to("cpu"), requires_grad=False, has_fp16_weights=self.weight.has_fp16_weights ).to(self.weight.device) self.state.reset_grads() + self.merged_adapters.append(active_adapter) self.merged = True def unmerge(self): - if self.active_adapter not in self.lora_A.keys(): - return if not self.merged: warnings.warn("Already unmerged. Nothing to do.") return - if self.r[self.active_adapter] > 0: + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter not in self.lora_A.keys(): + continue warnings.warn( "Unmerge lora module to 8-bit linear may get different generations due to rounding errors." ) - lora_data = self.get_delta_weight(self.active_adapter) + lora_data = self.get_delta_weight(active_adapter) if self.state.SCB is None: self.state.SCB = self.weight.SCB @@ -130,35 +134,33 @@ def get_delta_weight(self, adapter): ) def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.active_adapter not in self.lora_A.keys(): - return super().forward(x) - if self.disable_adapters: - if (self.r[self.active_adapter] > 0) and self.merged: + if self.merged: self.unmerge() result = super().forward(x) - elif (self.r[self.active_adapter] == 0) or self.merged: + elif self.merged: result = super().forward(x) else: - lora_A = self.lora_A[self.active_adapter] - lora_B = self.lora_B[self.active_adapter] - dropout = self.lora_dropout[self.active_adapter] - scaling = self.scaling[self.active_adapter] - result = super().forward(x) - - requires_conversion = not torch.is_autocast_enabled() - if requires_conversion: - expected_dtype = result.dtype - compute_dtype = lora_A.weight.dtype - if x.dtype != compute_dtype: - x = x.to(compute_dtype) - - output = lora_B(lora_A(dropout(x))) - if requires_conversion: - output = output.to(expected_dtype) - output = output * scaling - result += output + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + compute_dtype = lora_A.weight.dtype + if x.dtype != compute_dtype: + x = x.to(compute_dtype) + output = lora_B(lora_A(dropout(x))) + if requires_conversion: + output = output.to(expected_dtype) + output = output * scaling + result += output return result @@ -196,34 +198,37 @@ def __init__( self.active_adapter = adapter_name def merge(self): - if self.active_adapter not in self.lora_A.keys(): - return if self.merged: - warnings.warn("Already merged. Nothing to do.") - return - if self.r[self.active_adapter] > 0: + warnings.warn( + f"Already following adapters were merged {','.join(self.merged_adapters)}. " + f"You are now additionally merging {','.join(self.active_adapters)}." + ) + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue warnings.warn( "Merge lora module to 4-bit linear may get different generations due to rounding errors." ) # Refer to https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930 kwargs = self.weight.__dict__ - lora_data = self.get_delta_weight(self.active_adapter) + lora_data = self.get_delta_weight(active_adapter) w_data = bnb.functional.dequantize_4bit(self.weight.data, self.weight.quant_state) + lora_data self.weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to(self.weight.device) self.merged = True def unmerge(self): - if self.active_adapter not in self.lora_A.keys(): - return if not self.merged: warnings.warn("Already unmerged. Nothing to do.") return - if self.r[self.active_adapter] > 0: + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter not in self.lora_A.keys(): + continue warnings.warn( "Unmerge lora module to 4-bit linear may get different generations due to rounding errors." ) kwargs = self.weight.__dict__ - lora_data = self.get_delta_weight(self.active_adapter) + lora_data = self.get_delta_weight(active_adapter) w_data = bnb.functional.dequantize_4bit(self.weight.data, self.weight.quant_state) - lora_data self.weight = bnb.nn.Params4bit(w_data.to("cpu"), requires_grad=False, **kwargs).to(self.weight.device) self.merged = False @@ -238,21 +243,13 @@ def get_delta_weight(self, adapter): ) def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.active_adapter not in self.lora_A.keys(): - return super().forward(x) - if self.disable_adapters: - if (self.r[self.active_adapter] > 0) and self.merged: + if self.merged: self.unmerge() result = super().forward(x) - elif (self.r[self.active_adapter] == 0) or self.merged: + elif self.merged: result = super().forward(x) else: - lora_A = self.lora_A[self.active_adapter] - lora_B = self.lora_B[self.active_adapter] - dropout = self.lora_dropout[self.active_adapter] - scaling = self.scaling[self.active_adapter] - result = super().forward(x) # As per Tim Dettmers, for 4bit, we need to defensively clone here. # The reason is that in some cases, an error can occur that backprop @@ -261,15 +258,23 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # sure. result = result.clone() - requires_conversion = not torch.is_autocast_enabled() - if requires_conversion: - expected_dtype = result.dtype - x = x.to(lora_A.weight.dtype) - - output = lora_B(lora_A(dropout(x))) - if requires_conversion: - output = output.to(expected_dtype) - output = output * scaling - result += output + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + x = x.to(lora_A.weight.dtype) + + output = lora_B(lora_A(dropout(x))) + if requires_conversion: + output = output.to(expected_dtype) + output = output * scaling + result += output return result diff --git a/src/peft/tuners/lora/config.py b/src/peft/tuners/lora/config.py index 30ccdd2111..d85aa79239 100644 --- a/src/peft/tuners/lora/config.py +++ b/src/peft/tuners/lora/config.py @@ -45,6 +45,12 @@ class LoraConfig(PeftConfig): layers_pattern (`str`): The layer pattern name, used only if `layers_to_transform` is different from `None` and if the layer pattern is not in the common layers pattern. + rank_pattern (`dict`): + The mapping from layer names or regexp expression to ranks which are different from the default rank + specified by `r`. + alpha_pattern (`dict`): + The mapping from layer names or regexp expression to alphas which are different from the default alpha + specified by `lora_alpha`. """ r: int = field(default=8, metadata={"help": "Lora attention dimension"}) @@ -91,6 +97,24 @@ class LoraConfig(PeftConfig): "help": "The layer pattern name, used only if `layers_to_transform` is different to None and if the layer pattern is not in the common layers pattern." }, ) + rank_pattern: Optional[dict] = field( + default_factory=dict, + metadata={ + "help": ( + "The mapping from layer names or regexp expression to ranks which are different from the default rank specified by `r`. " + "For example, `{model.decoder.layers.0.encoder_attn.k_proj: 8`}" + ) + }, + ) + alpha_pattern: Optional[dict] = field( + default_factory=dict, + metadata={ + "help": ( + "The mapping from layer names or regexp expression to alphas which are different from the default alpha specified by `lora_alpha`. " + "For example, `{model.decoder.layers.0.encoder_attn.k_proj: 32`}" + ) + }, + ) def __post_init__(self): self.peft_type = PeftType.LORA diff --git a/src/peft/tuners/lora/gptq.py b/src/peft/tuners/lora/gptq.py index c0963194e1..d0b9315fe1 100644 --- a/src/peft/tuners/lora/gptq.py +++ b/src/peft/tuners/lora/gptq.py @@ -42,28 +42,27 @@ def forward(self, x: torch.Tensor): # note: logic differs from default Linear because merging is not supported result = self.quant_linear_module(x) - if ( - self.disable_adapters - or (self.active_adapter not in self.lora_A.keys()) - or (self.r[self.active_adapter] == 0) - ): + if self.disable_adapters: return result - lora_A = self.lora_A[self.active_adapter] - lora_B = self.lora_B[self.active_adapter] - dropout = self.lora_dropout[self.active_adapter] - scaling = self.scaling[self.active_adapter] + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] - requires_conversion = not torch.is_autocast_enabled() - if requires_conversion: - expected_dtype = result.dtype - x = x.to(lora_A.weight.dtype) + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + x = x.to(lora_A.weight.dtype) - output = lora_B(lora_A(dropout(x))) - if requires_conversion: - output = output.to(expected_dtype) - output = output * scaling - result += output + output = lora_B(lora_A(dropout(x))) + if requires_conversion: + output = output.to(expected_dtype) + output = output * scaling + result += output return result # TODO: Check if it is better as suggested by users https://github.com/PanQiWei/AutoGPTQ/pull/102 diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index ca07055c71..a8682eb637 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -38,6 +38,7 @@ def __init__(self, in_features: int, out_features: int, **kwargs): self.lora_embedding_B = nn.ParameterDict({}) # Mark the weight as unmerged self.merged = False + self.merged_adapters = [] self.disable_adapters = False self.in_features = in_features self.out_features = out_features @@ -56,6 +57,8 @@ def _init_empty_weights(self, cls, *args, **kwargs) -> None: self.to_empty(device=final_device) def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights): + if r <= 0: + raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") self.r[adapter_name] = r self.lora_alpha[adapter_name] = lora_alpha if lora_dropout > 0.0: @@ -81,6 +84,8 @@ def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weig self.to(weight.device) def update_layer_conv2d(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights): + if r <= 0: + raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") self.r[adapter_name] = r self.lora_alpha[adapter_name] = lora_alpha if lora_dropout > 0.0: @@ -106,6 +111,8 @@ def update_layer_conv2d(self, adapter_name, r, lora_alpha, lora_dropout, init_lo self.to(self.weight.device, dtype=weight.dtype) def update_layer_embedding(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights): + if r <= 0: + raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") self.r[adapter_name] = r self.lora_alpha[adapter_name] = lora_alpha if lora_dropout > 0.0: @@ -139,6 +146,19 @@ def reset_lora_parameters(self, adapter_name): nn.init.zeros_(self.lora_embedding_A[adapter_name]) nn.init.normal_(self.lora_embedding_B[adapter_name]) + def scale_layer(self, scale_factor: float) -> None: + if scale_factor != 1: + for active_adapter in self.active_adapters: + alpha = self.lora_alpha[active_adapter] + r = self.r[active_adapter] + self.scaling[active_adapter] = (alpha / r) * scale_factor + + def unscale_layer(self) -> None: + for active_adapter in self.active_adapters: + alpha = self.lora_alpha[active_adapter] + r = self.r[active_adapter] + self.scaling[active_adapter] = alpha / r + # Below code is based on https://github.com/microsoft/LoRA/blob/main/loralib/layers.py # and modified to work with PyTorch FSDP @@ -180,33 +200,27 @@ def __init__( self.active_adapter = adapter_name self.is_target_conv_1d_layer = is_target_conv_1d_layer - def scale_layer(self, scale_factor: float) -> None: - if scale_factor != 0 and scale_factor != 1: - self.scaling[self.active_adapter] *= scale_factor - - def unscale_layer(self, scale_factor: float) -> None: - if scale_factor != 0 and scale_factor != 1: - self.scaling[self.active_adapter] /= scale_factor - def merge(self) -> None: - if self.active_adapter not in self.lora_A.keys(): - return if self.merged: - warnings.warn("Already merged. Nothing to do.") - return - if self.r[self.active_adapter] > 0: - self.weight.data += self.get_delta_weight(self.active_adapter) - self.merged = True + warnings.warn( + f"Already following adapters were merged {','.join(self.merged_adapters)}. " + f"You are now additionally merging {','.join(self.active_adapters)}." + ) + for active_adapter in self.active_adapters: + if active_adapter in self.lora_A.keys(): + self.weight.data += self.get_delta_weight(active_adapter) + self.merged_adapters.append(active_adapter) + self.merged = True def unmerge(self) -> None: - if self.active_adapter not in self.lora_A.keys(): - return if not self.merged: warnings.warn("Already unmerged. Nothing to do.") return - if self.r[self.active_adapter] > 0: - self.weight.data -= self.get_delta_weight(self.active_adapter) - self.merged = False + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self.lora_A.keys(): + self.weight.data -= self.get_delta_weight(active_adapter) + self.merged = False def get_delta_weight(self, adapter) -> torch.Tensor: return ( @@ -221,26 +235,25 @@ def _linear(self, input: torch.Tensor) -> torch.Tensor: return F.linear(input, transpose(self.weight, self.fan_in_fan_out), bias=self.bias) def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.active_adapter not in self.lora_A.keys(): - return self._linear(x) - previous_dtype = x.dtype if self.disable_adapters: - if (self.r[self.active_adapter] > 0) and self.merged: + if self.merged: self.unmerge() result = self._linear(x) - elif (self.r[self.active_adapter] == 0) or self.merged: + elif self.merged: result = self._linear(x) else: - lora_A = self.lora_A[self.active_adapter] - lora_B = self.lora_B[self.active_adapter] - dropout = self.lora_dropout[self.active_adapter] - scaling = self.scaling[self.active_adapter] - result = self._linear(x) - x = x.to(lora_A.weight.dtype) - result += lora_B(lora_A(dropout(x))) * scaling + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + x = x.to(lora_A.weight.dtype) + result += lora_B(lora_A(dropout(x))) * scaling result = result.to(previous_dtype) return result @@ -264,21 +277,27 @@ def __init__( self.update_layer_embedding(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) self.active_adapter = adapter_name + def merge(self) -> None: + if self.merged: + warnings.warn( + f"Already following adapters were merged {','.join(self.merged_adapters)}. " + f"You are now additionally merging {','.join(self.active_adapters)}." + ) + for active_adapter in self.active_adapters: + if active_adapter in self.lora_embedding_A.keys(): + self.weight.data += self.get_delta_weight(active_adapter) + self.merged_adapters.append(active_adapter) + self.merged = True + def unmerge(self) -> None: if not self.merged: warnings.warn("Already unmerged. Nothing to do.") return - if self.r[self.active_adapter] > 0: - self.weight.data -= self.get_delta_weight(self.active_adapter) - self.merged = False - - def merge(self) -> None: - if self.merged: - warnings.warn("Already merged. Nothing to do.") - return - if self.r[self.active_adapter] > 0: - self.weight.data += self.get_delta_weight(self.active_adapter) - self.merged = True + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self.lora_embedding_A.keys(): + self.weight.data -= self.get_delta_weight(active_adapter) + self.merged = False def get_delta_weight(self, adapter) -> torch.Tensor: return transpose(self.lora_embedding_B[adapter] @ self.lora_embedding_A[adapter], True) * self.scaling[adapter] @@ -296,24 +315,23 @@ def _embed(self, input: torch.Tensor, weight: Optional[torch.Tensor] = None) -> ) def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.active_adapter not in self.lora_embedding_A.keys(): - return self._embed(x) - # TODO: no dtype conversion here, unlike in Linear, is that correct? if self.disable_adapters: - if (self.r[self.active_adapter] > 0) and self.merged: + if self.merged: self.unmerge() result = self._embed(x) - elif (self.r[self.active_adapter] == 0) or self.merged: + elif self.merged: result = self._embed(x) else: - embedding_A = self.lora_embedding_A[self.active_adapter].T - embedding_B = self.lora_embedding_B[self.active_adapter].T - scaling = self.scaling[self.active_adapter] - result = self._embed(x) - after_A = self._embed(x, embedding_A) - result += (after_A @ embedding_B) * scaling + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_embedding_A: + continue + embedding_A = self.lora_embedding_A[active_adapter].T + embedding_B = self.lora_embedding_B[active_adapter].T + scaling = self.scaling[active_adapter] + after_A = self._embed(x, embedding_A) + result += (after_A @ embedding_B) * scaling return result @@ -349,24 +367,26 @@ def __init__( self.active_adapter = adapter_name def merge(self) -> None: - if self.active_adapter not in self.lora_A.keys(): - return if self.merged: - warnings.warn("Already merged. Nothing to do.") - return - if self.r[self.active_adapter] > 0: - self.weight.data += self.get_delta_weight(self.active_adapter) - self.merged = True + warnings.warn( + f"Already following adapters were merged {','.join(self.merged_adapters)}. " + f"You are now additionally merging {','.join(self.active_adapters)}." + ) + for active_adapter in self.active_adapters: + if active_adapter in self.lora_A.keys(): + self.weight.data += self.get_delta_weight(active_adapter) + self.merged_adapters.append(active_adapter) + self.merged = True def unmerge(self) -> None: - if self.active_adapter not in self.lora_A.keys(): - return if not self.merged: warnings.warn("Already unmerged. Nothing to do.") return - if self.r[self.active_adapter] > 0: - self.weight.data -= self.get_delta_weight(self.active_adapter) - self.merged = False + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter in self.lora_A.keys(): + self.weight.data -= self.get_delta_weight(active_adapter) + self.merged = False def get_delta_weight(self, adapter) -> torch.Tensor: # https://github.com/bmaltais/kohya_ss/blob/feb6728762a8f463d15ba936d189d4c3abfaa1ab/networks/lora.py#L117 @@ -397,26 +417,25 @@ def _conv2d(self, input: torch.Tensor) -> torch.Tensor: ) def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.active_adapter not in self.lora_A.keys(): - return self._conv2d(x) - previous_dtype = x.dtype if self.disable_adapters: - if self.r[self.active_adapter] > 0 and self.merged: + if self.merged: self.unmerge() result = self._conv2d(x) - elif (self.r[self.active_adapter] == 0) or self.merged: + elif self.merged: result = self._conv2d(x) else: - lora_A = self.lora_A[self.active_adapter] - lora_B = self.lora_B[self.active_adapter] - dropout = self.lora_dropout[self.active_adapter] - scaling = self.scaling[self.active_adapter] - result = self._conv2d(x) - x = x.to(lora_A.weight.dtype) - result += lora_B(lora_A(dropout(x))) * scaling + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + x = x.to(lora_A.weight.dtype) + result += lora_B(lora_A(dropout(x))) * scaling result = result.to(previous_dtype) return result diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index e5db47983f..8b2c40dffe 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -16,6 +16,8 @@ import warnings from dataclasses import asdict, replace from enum import Enum +from itertools import chain +from typing import List import torch from torch import nn @@ -159,18 +161,27 @@ def _create_and_replace( target, target_name, parent, - **optionnal_kwargs, + current_key, + **optional_kwargs, ): + if current_key is None: + raise ValueError("Current Key shouldn't be `None`") + # Regexp matching - Find key which matches current target_name in patterns provided + pattern_keys = list(chain(lora_config.rank_pattern.keys(), lora_config.alpha_pattern.keys())) + target_name_key = next(filter(lambda key: re.match(f".*\.{key}$", current_key), pattern_keys), target_name) + + r = lora_config.rank_pattern.get(target_name_key, lora_config.r) + alpha = lora_config.alpha_pattern.get(target_name_key, lora_config.lora_alpha) bias = hasattr(target, "bias") and target.bias is not None kwargs = { - "r": lora_config.r, - "lora_alpha": lora_config.lora_alpha, + "r": r, + "lora_alpha": alpha, "lora_dropout": lora_config.lora_dropout, "fan_in_fan_out": lora_config.fan_in_fan_out, "init_lora_weights": lora_config.init_lora_weights, } - kwargs["loaded_in_8bit"] = optionnal_kwargs.pop("loaded_in_8bit", False) - kwargs["loaded_in_4bit"] = optionnal_kwargs.pop("loaded_in_4bit", False) + kwargs["loaded_in_8bit"] = optional_kwargs.pop("loaded_in_8bit", False) + kwargs["loaded_in_4bit"] = optional_kwargs.pop("loaded_in_4bit", False) kwargs["bias"] = bias quantization_config = get_quantization_config(self.model, method="gptq") @@ -181,16 +192,16 @@ def _create_and_replace( if isinstance(target, LoraLayer) and isinstance(target, torch.nn.Conv2d): target.update_layer_conv2d( adapter_name, - lora_config.r, - lora_config.lora_alpha, + r, + alpha, lora_config.lora_dropout, lora_config.init_lora_weights, ) elif isinstance(target, LoraLayer) and isinstance(target, torch.nn.Embedding): target.update_layer_embedding( adapter_name, - lora_config.r, - lora_config.lora_alpha, + r, + alpha, lora_config.lora_dropout, lora_config.init_lora_weights, ) @@ -198,8 +209,8 @@ def _create_and_replace( elif isinstance(target, LoraLayer): target.update_layer( adapter_name, - lora_config.r, - lora_config.lora_alpha, + r, + alpha, lora_config.lora_dropout, lora_config.init_lora_weights, ) @@ -228,24 +239,24 @@ def _replace_module(parent, child_name, new_module, child): module.to(child.weight.device) def _mark_only_adapters_as_trainable(self) -> None: - active_adapter = self._get_active_adapter() - bias = self.peft_config[active_adapter].bias + for active_adapter in self._get_active_adapters(): + bias = self.peft_config[active_adapter].bias - for n, p in self.model.named_parameters(): - if "lora_" not in n: - p.requires_grad = False - if bias == "none": - return - elif bias == "all": for n, p in self.model.named_parameters(): - if "bias" in n: - p.requires_grad = True - elif bias == "lora_only": - for m in self.model.modules(): - if isinstance(m, LoraLayer) and hasattr(m, "bias") and m.bias is not None: - m.bias.requires_grad = True - else: - raise NotImplementedError(f"Requested bias: {bias}, is not implemented.") + if "lora_" not in n: + p.requires_grad = False + if bias == "none": + return + elif bias == "all": + for n, p in self.model.named_parameters(): + if "bias" in n: + p.requires_grad = True + elif bias == "lora_only": + for m in self.model.modules(): + if isinstance(m, LoraLayer) and hasattr(m, "bias") and m.bias is not None: + m.bias.requires_grad = True + else: + raise NotImplementedError(f"Requested bias: {bias}, is not implemented.") @staticmethod def _create_new_module(lora_config, adapter_name, target, **kwargs): @@ -348,27 +359,27 @@ def _set_adapter_layers(self, enabled=True): def enable_adapter_layers(self): self._set_adapter_layers(enabled=True) - def _get_active_adapter(self) -> str: - active_adapter = None + def _get_active_adapters(self) -> List[str]: + active_adapters = None for module in self.model.modules(): if isinstance(module, LoraLayer): - active_adapter = module.active_adapter + active_adapters = module.active_adapters - if active_adapter is None: + if active_adapters is None: raise ValueError( "Something went wrong, no active adapter could be found, please report the issue on GitHub" ) - return active_adapter + return active_adapters def disable_adapter_layers(self): - active_adapter = self._get_active_adapter() - val = self.peft_config[active_adapter].bias - if val != "none": - msg = ( - f"Careful, disabling adapter layers with bias configured to be '{val}' does not produce the same " - "output as the the base model would without adaption." - ) - warnings.warn(msg) + for active_adapter in self._get_active_adapters(): + val = self.peft_config[active_adapter].bias + if val != "none": + msg = ( + f"Careful, disabling adapter layers with bias configured to be '{val}' does not produce the same " + "output as the the base model would without adaption." + ) + warnings.warn(msg) self._set_adapter_layers(enabled=False) def set_adapter(self, adapter_name): @@ -641,8 +652,10 @@ def delete_adapter(self, adapter_name): ]: if adapter_name in getattr(target, attr): getattr(target, attr).pop(adapter_name) - if target.active_adapter == adapter_name: - resetting_active_adapter = list(self.peft_config.keys())[0] + if adapter_name in target.active_adapters: + resetting_active_adapter = ( + list(self.peft_config.keys())[0] if len(self.peft_config) > 0 else "default" + ) warnings.warn( f"Adapter {adapter_name} was active which is now deleted. Setting active adapter to {resetting_active_adapter}. " ) diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index 2a63f973eb..edabb7b6c6 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -133,7 +133,7 @@ def _create_and_replace( target: nn.Module, target_name: str, parent: nn.Module, - **optionnal_kwargs: Any, + **optional_kwargs: Any, ) -> None: r""" Inplace replacement of the target module with the adapter layer. This method needs to be overriden by all the @@ -152,7 +152,7 @@ def _create_and_replace( The target module's name. parent (`nn.Module`): The parent module. - **optionnal_kwargs (`dict`): + **optional_kwargs (`dict`): The optional keyword arguments to pass to deal with particular cases (e.g. 8bit, 4bit quantization) """ ... @@ -211,12 +211,12 @@ def inject_adapter(self, model: nn.Module, adapter_name: str): is_target_modules_in_base_model = True parent, target, target_name = _get_submodules(model, key) - optionnal_kwargs = { + optional_kwargs = { "loaded_in_8bit": getattr(model, "is_loaded_in_8bit", False), "loaded_in_4bit": getattr(model, "is_loaded_in_4bit", False), "current_key": key, } - self._create_and_replace(peft_config, adapter_name, target, target_name, parent, **optionnal_kwargs) + self._create_and_replace(peft_config, adapter_name, target, target_name, parent, **optional_kwargs) if not is_target_modules_in_base_model: raise ValueError( @@ -255,7 +255,7 @@ class BaseTunerLayer(ABC): Args: is_plugable (`bool`, *optional*): Whether the adapter layer can be plugged to any pytorch module - active_adapter (`str`, *optional*): + active_adapters (Union[List[`str`], `str`], *optional*): The name of the active adapter. """ active_adapter = None @@ -265,3 +265,10 @@ def merge(self) -> None: def unmerge(self) -> None: raise NotImplementedError + + @property + def active_adapters(self): + if isinstance(self.active_adapter, str): + return [self.active_adapter] + # is already a list of str + return self.active_adapter diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index cafae294bb..61062a9cdb 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -366,6 +366,38 @@ def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, co self._test_adding_multiple_adapters_with_bias_raises(model_id, config_cls, config_kwargs) +class TestMultiRankAdapter(unittest.TestCase): + """Tests related to multirank LoRA adapters""" + + def test_multirank(self): + config_1 = LoraConfig( + r=8, + lora_alpha=8, + init_lora_weights=False, + target_modules=["lin0", "lin1"], + ) + config_2 = LoraConfig( + r=8, + lora_alpha=8, + init_lora_weights=False, + target_modules=["lin0", "lin1"], + rank_pattern={"lin0": 4}, + alpha_pattern={"lin0": 4}, + ) + + # Add first adapter + model = get_peft_model(MLP(), config_1, adapter_name="first") + + # Add second adapter + model.add_adapter("second", config_2) + + # Extract current and expected ranks + rank_current = model.lin0.lora_A["second"].weight.shape[0] + rank_expected = config_2.rank_pattern["lin0"] + + self.assertTrue(rank_current == rank_expected, f"Rank {rank_current} is not equal to expected {rank_expected}") + + class TestRepr(unittest.TestCase): """Tests related to the repr of adapted models"""