From ba24f2a5ce7c7b5035f77f64e36ad495b113931e Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 14 Sep 2023 11:32:10 +0000 Subject: [PATCH 01/52] more fixes --- src/diffusers/loaders.py | 3 + src/diffusers/utils/__init__.py | 5 + src/diffusers/utils/peft_utils.py | 147 ++++++++++++++++++++++++++++++ 3 files changed, 155 insertions(+) create mode 100644 src/diffusers/utils/peft_utils.py diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 45c866c1aa16..105c8b8fb4d5 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -37,6 +37,9 @@ is_omegaconf_available, is_transformers_available, logging, + convert_old_state_dict_to_peft, + convert_diffusers_state_dict_to_peft, + convert_unet_state_dict_to_peft, ) from .utils.import_utils import BACKENDS_MAPPING diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 7390a2f69d23..ee2ff7f4ba37 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -84,6 +84,11 @@ from .outputs import BaseOutput from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil +from .torch_utils import is_compiled_module, randn_tensor +from .peft_utils import convert_old_state_dict_to_peft, convert_peft_state_dict_to_diffusers, convert_diffusers_state_dict_to_peft, convert_unet_state_dict_to_peft + +from .testing_utils import export_to_gif, export_to_obj, export_to_ply, export_to_video + logger = get_logger(__name__) diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py new file mode 100644 index 000000000000..995e4b9357ef --- /dev/null +++ b/src/diffusers/utils/peft_utils.py @@ -0,0 +1,147 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +PEFT utilities: Utilities related to peft library +""" + +def convert_old_state_dict_to_peft(attention_modules, state_dict): + # Convert from the old naming convention to the new naming convention. + # + # Previously, the old LoRA layers were stored on the state dict at the + # same level as the attention block i.e. + # `text_model.encoder.layers.11.self_attn.to_out_lora.lora_A.weight`. + # + # This is no actual module at that point, they were monkey patched on to the + # existing module. We want to be able to load them via their actual state dict. + # They're in `PatchedLoraProjection.lora_linear_layer` now. + converted_state_dict = {} + + for name, _ in attention_modules: + converted_state_dict[ + f"{name}.q_proj.lora_B.weight" + ] = state_dict.pop(f"{name}.to_q_lora.up.weight") + converted_state_dict[ + f"{name}.k_proj.lora_B.weight" + ] = state_dict.pop(f"{name}.to_k_lora.up.weight") + converted_state_dict[ + f"{name}.v_proj.lora_B.weight" + ] = state_dict.pop(f"{name}.to_v_lora.up.weight") + converted_state_dict[ + f"{name}.out_proj.lora_B.weight" + ] = state_dict.pop(f"{name}.to_out_lora.up.weight") + + converted_state_dict[ + f"{name}.q_proj.lora_A.weight" + ] = state_dict.pop(f"{name}.to_q_lora.down.weight") + converted_state_dict[ + f"{name}.k_proj.lora_A.weight" + ] = state_dict.pop(f"{name}.to_k_lora.down.weight") + converted_state_dict[ + f"{name}.v_proj.lora_A.weight" + ] = state_dict.pop(f"{name}.to_v_lora.down.weight") + converted_state_dict[ + f"{name}.out_proj.lora_A.weight" + ] = state_dict.pop(f"{name}.to_out_lora.down.weight") + + return converted_state_dict + + +def convert_peft_state_dict_to_diffusers(attention_modules, state_dict, adapter_name): + # Convert from the new naming convention to the diffusers naming convention. + converted_state_dict = {} + + for name, _ in attention_modules: + converted_state_dict[ + f"{name}.q_proj.lora_linear_layer.up.weight" + ] = state_dict.pop(f"{name}.q_proj.lora_B.{adapter_name}.weight") + converted_state_dict[ + f"{name}.k_proj.lora_linear_layer.up.weight" + ] = state_dict.pop(f"{name}.k_proj.lora_B.{adapter_name}.weight") + converted_state_dict[ + f"{name}.v_proj.lora_linear_layer.up.weight" + ] = state_dict.pop(f"{name}.v_proj.lora_B.{adapter_name}.weight") + converted_state_dict[ + f"{name}.out_proj.lora_linear_layer.up.weight" + ] = state_dict.pop(f"{name}.out_proj.lora_B.{adapter_name}.weight") + + converted_state_dict[ + f"{name}.q_proj.lora_linear_layer.down.weight" + ] = state_dict.pop(f"{name}.q_proj.lora_A.{adapter_name}.weight") + converted_state_dict[ + f"{name}.k_proj.lora_linear_layer.down.weight" + ] = state_dict.pop(f"{name}.k_proj.lora_A.{adapter_name}.weight") + converted_state_dict[ + f"{name}.v_proj.lora_linear_layer.down.weight" + ] = state_dict.pop(f"{name}.v_proj.lora_A.{adapter_name}.weight") + converted_state_dict[ + f"{name}.out_proj.lora_linear_layer.down.weight" + ] = state_dict.pop(f"{name}.out_proj.lora_A.{adapter_name}.weight") + + return converted_state_dict + + +def convert_diffusers_state_dict_to_peft(attention_modules, state_dict): + # Convert from the diffusers naming convention to the new naming convention. + converted_state_dict = {} + + for name, _ in attention_modules: + converted_state_dict[ + f"{name}.q_proj.lora_B.weight" + ] = state_dict.pop(f"{name}.q_proj.lora_linear_layer.up.weight") + converted_state_dict[ + f"{name}.k_proj.lora_B.weight" + ] = state_dict.pop(f"{name}.k_proj.lora_linear_layer.up.weight") + converted_state_dict[ + f"{name}.v_proj.lora_B.weight" + ] = state_dict.pop(f"{name}.v_proj.lora_linear_layer.up.weight") + converted_state_dict[ + f"{name}.out_proj.lora_B.weight" + ] = state_dict.pop(f"{name}.out_proj.lora_linear_layer.up.weight") + + converted_state_dict[ + f"{name}.q_proj.lora_A.weight" + ] = state_dict.pop(f"{name}.q_proj.lora_linear_layer.down.weight") + converted_state_dict[ + f"{name}.k_proj.lora_A.weight" + ] = state_dict.pop(f"{name}.k_proj.lora_linear_layer.down.weight") + converted_state_dict[ + f"{name}.v_proj.lora_A.weight" + ] = state_dict.pop(f"{name}.v_proj.lora_linear_layer.down.weight") + converted_state_dict[ + f"{name}.out_proj.lora_A.weight" + ] = state_dict.pop(f"{name}.out_proj.lora_linear_layer.down.weight") + + return converted_state_dict + + +def convert_unet_state_dict_to_peft(state_dict): + converted_state_dict = {} + + patterns = { + ".to_out_lora": ".to_o", + ".down": ".lora_A", + ".up": ".lora_B", + ".to_q_lora": ".to_q", + ".to_k_lora": ".to_k", + ".to_v_lora": ".to_v", + } + + for k, v in state_dict.items(): + if any(pattern in k for pattern in patterns.keys()): + for old, new in patterns.items(): + k = k.replace(old, new) + + converted_state_dict[k] = v + + return converted_state_dict \ No newline at end of file From c17634c39e5e39017042a36469a57ad14e2fd06c Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 15 Sep 2023 13:03:17 +0000 Subject: [PATCH 02/52] up --- src/diffusers/loaders.py | 94 +++++++++-------------------- src/diffusers/utils/import_utils.py | 11 ++++ 2 files changed, 40 insertions(+), 65 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 105c8b8fb4d5..b0be3cd42f18 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -36,6 +36,7 @@ is_accelerate_version, is_omegaconf_available, is_transformers_available, + is_peft_available, logging, convert_old_state_dict_to_peft, convert_diffusers_state_dict_to_peft, @@ -51,6 +52,9 @@ from accelerate import init_empty_weights from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module +if is_peft_available(): + from peft import LoraConfig + logger = logging.get_logger(__name__) TEXT_ENCODER_NAME = "text_encoder" @@ -1417,7 +1421,6 @@ def load_lora_into_text_encoder( argument to `True` will raise an error. """ low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT - # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # their prefixes. @@ -1436,55 +1439,33 @@ def load_lora_into_text_encoder( logger.info(f"Loading {prefix}.") rank = {} + # Old diffusers to PEFT if any("to_out_lora" in k for k in text_encoder_lora_state_dict.keys()): - # Convert from the old naming convention to the new naming convention. - # - # Previously, the old LoRA layers were stored on the state dict at the - # same level as the attention block i.e. - # `text_model.encoder.layers.11.self_attn.to_out_lora.up.weight`. - # - # This is no actual module at that point, they were monkey patched on to the - # existing module. We want to be able to load them via their actual state dict. - # They're in `PatchedLoraProjection.lora_linear_layer` now. - for name, _ in text_encoder_attn_modules(text_encoder): - text_encoder_lora_state_dict[ - f"{name}.q_proj.lora_linear_layer.up.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.up.weight") - text_encoder_lora_state_dict[ - f"{name}.k_proj.lora_linear_layer.up.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.up.weight") - text_encoder_lora_state_dict[ - f"{name}.v_proj.lora_linear_layer.up.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.up.weight") - text_encoder_lora_state_dict[ - f"{name}.out_proj.lora_linear_layer.up.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.up.weight") - - text_encoder_lora_state_dict[ - f"{name}.q_proj.lora_linear_layer.down.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.down.weight") - text_encoder_lora_state_dict[ - f"{name}.k_proj.lora_linear_layer.down.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.down.weight") - text_encoder_lora_state_dict[ - f"{name}.v_proj.lora_linear_layer.down.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.down.weight") - text_encoder_lora_state_dict[ - f"{name}.out_proj.lora_linear_layer.down.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.down.weight") + attention_modules = text_encoder_attn_modules(text_encoder) + text_encoder_lora_state_dict = convert_old_state_dict_to_peft( + attention_modules, text_encoder_lora_state_dict + ) + # New diffusers format to PEFT + elif any("lora_linear_layer" in k for k in text_encoder_lora_state_dict.keys()): + attention_modules = text_encoder_attn_modules(text_encoder) + text_encoder_lora_state_dict = convert_diffusers_state_dict_to_peft( + attention_modules, text_encoder_lora_state_dict + ) for name, _ in text_encoder_attn_modules(text_encoder): - rank_key = f"{name}.out_proj.lora_linear_layer.up.weight" + rank_key = f"{name}.out_proj.lora_B.weight" rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]}) patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys()) if patch_mlp: for name, _ in text_encoder_mlp_modules(text_encoder): - rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight" - rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight" + rank_key_fc1 = f"{name}.fc1.lora_B.weight" + rank_key_fc2 = f"{name}.fc2.lora_B.weight" rank.update({rank_key_fc1: text_encoder_lora_state_dict[rank_key_fc1].shape[1]}) rank.update({rank_key_fc2: text_encoder_lora_state_dict[rank_key_fc2].shape[1]}) + # for diffusers format you always get the same rank everywhere + # is it possible to load with PEFT if network_alphas is not None: alpha_keys = [ k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix @@ -1493,37 +1474,20 @@ def load_lora_into_text_encoder( k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys } - cls._modify_text_encoder( - text_encoder, - lora_scale, - network_alphas, - rank=rank, - patch_mlp=patch_mlp, - low_cpu_mem_usage=low_cpu_mem_usage, - ) + lora_rank = list(rank.values())[0] + alpha = lora_scale * lora_rank - # set correct dtype & device - text_encoder_lora_state_dict = { - k: v.to(device=text_encoder.device, dtype=text_encoder.dtype) - for k, v in text_encoder_lora_state_dict.items() - } - if low_cpu_mem_usage: - device = next(iter(text_encoder_lora_state_dict.values())).device - dtype = next(iter(text_encoder_lora_state_dict.values())).dtype - unexpected_keys = load_model_dict_into_meta( - text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype - ) - else: - load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False) - unexpected_keys = load_state_dict_results.unexpected_keys + target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] + if patch_mlp: + target_modules += ["fc1", "fc2"] - if len(unexpected_keys) != 0: - raise ValueError( - f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}" - ) + lora_config = LoraConfig(r=lora_rank, target_modules=target_modules, lora_alpha=alpha) + + text_encoder.load_adapter(text_encoder_lora_state_dict, peft_config=lora_config) text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) + @property def lora_scale(self) -> float: # property function that returns the lora scale which can be set at run time by the pipeline. diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 587949ab0c52..883a5b0adb2a 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -267,6 +267,13 @@ _invisible_watermark_available = False +_peft_available = importlib.util.find_spec("peft") is not None +try: + _accelerate_version = importlib_metadata.version("peft") + logger.debug(f"Successfully imported accelerate version {_accelerate_version}") +except importlib_metadata.PackageNotFoundError: + _peft_available = False + def is_torch_available(): return _torch_available @@ -351,6 +358,10 @@ def is_invisible_watermark_available(): return _invisible_watermark_available +def is_peft_available(): + return _peft_available + + # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the From 2a6e5358a0bd0018273b33a9377688e23ffd6432 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 15 Sep 2023 13:11:30 +0000 Subject: [PATCH 03/52] up --- src/diffusers/dependency_versions_table.py | 1 + src/diffusers/loaders.py | 18 ++- src/diffusers/utils/__init__.py | 12 +- src/diffusers/utils/import_utils.py | 1 + src/diffusers/utils/peft_utils.py | 147 +++++++++------------ 5 files changed, 89 insertions(+), 90 deletions(-) diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index d4b94ba6d4ed..42adc6444f53 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -41,4 +41,5 @@ "torchvision": "torchvision", "transformers": "transformers>=4.25.1", "urllib3": "urllib3<=2.0.0", + "peft": "peft>=0.5.0", } diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index b0be3cd42f18..e215103d6085 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib import os import re import warnings @@ -24,6 +25,7 @@ import safetensors import torch from huggingface_hub import hf_hub_download, model_info +from packaging import version from torch import nn from .models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta @@ -31,16 +33,15 @@ DIFFUSERS_CACHE, HF_HUB_OFFLINE, _get_model_file, + convert_diffusers_state_dict_to_peft, + convert_old_state_dict_to_peft, deprecate, is_accelerate_available, is_accelerate_version, is_omegaconf_available, - is_transformers_available, is_peft_available, + is_transformers_available, logging, - convert_old_state_dict_to_peft, - convert_diffusers_state_dict_to_peft, - convert_unet_state_dict_to_peft, ) from .utils.import_utils import BACKENDS_MAPPING @@ -1427,6 +1428,14 @@ def load_lora_into_text_encoder( keys = list(state_dict.keys()) prefix = cls.text_encoder_name if prefix is None else prefix + is_peft_lora_compatible = version.parse(importlib.metadata.version("transformers")) > version.parse("4.33.1") + + if not is_peft_lora_compatible: + raise ValueError( + "You are using an older version of transformers. Please upgrade to transformers>=4.33.1 to load LoRA weights into" + " text encoder." + ) + # Safe prefix to check with. if any(cls.text_encoder_name in key for key in keys): # Load the layers corresponding to text encoder and make necessary adjustments. @@ -1487,7 +1496,6 @@ def load_lora_into_text_encoder( text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) - @property def lora_scale(self) -> float: # property function that returns the lora scale which can be set at run time by the pipeline. diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index ee2ff7f4ba37..31de66ab315b 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -77,17 +77,21 @@ is_unidecode_available, is_wandb_available, is_xformers_available, + is_peft_available, requires_backends, ) from .loading_utils import load_image from .logging import get_logger from .outputs import BaseOutput +from .peft_utils import ( + convert_diffusers_state_dict_to_peft, + convert_old_state_dict_to_peft, + convert_peft_state_dict_to_diffusers, + convert_unet_state_dict_to_peft, +) from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil - -from .torch_utils import is_compiled_module, randn_tensor -from .peft_utils import convert_old_state_dict_to_peft, convert_peft_state_dict_to_diffusers, convert_diffusers_state_dict_to_peft, convert_unet_state_dict_to_peft - from .testing_utils import export_to_gif, export_to_obj, export_to_ply, export_to_video +from .torch_utils import is_compiled_module, randn_tensor logger = get_logger(__name__) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 883a5b0adb2a..0ffa2727e54d 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -274,6 +274,7 @@ except importlib_metadata.PackageNotFoundError: _peft_available = False + def is_torch_available(): return _torch_available diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 995e4b9357ef..fbeac9bad642 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -15,6 +15,7 @@ PEFT utilities: Utilities related to peft library """ + def convert_old_state_dict_to_peft(attention_modules, state_dict): # Convert from the old naming convention to the new naming convention. # @@ -26,102 +27,86 @@ def convert_old_state_dict_to_peft(attention_modules, state_dict): # existing module. We want to be able to load them via their actual state dict. # They're in `PatchedLoraProjection.lora_linear_layer` now. converted_state_dict = {} - + for name, _ in attention_modules: - converted_state_dict[ - f"{name}.q_proj.lora_B.weight" - ] = state_dict.pop(f"{name}.to_q_lora.up.weight") - converted_state_dict[ - f"{name}.k_proj.lora_B.weight" - ] = state_dict.pop(f"{name}.to_k_lora.up.weight") - converted_state_dict[ - f"{name}.v_proj.lora_B.weight" - ] = state_dict.pop(f"{name}.to_v_lora.up.weight") - converted_state_dict[ - f"{name}.out_proj.lora_B.weight" - ] = state_dict.pop(f"{name}.to_out_lora.up.weight") - - converted_state_dict[ - f"{name}.q_proj.lora_A.weight" - ] = state_dict.pop(f"{name}.to_q_lora.down.weight") - converted_state_dict[ - f"{name}.k_proj.lora_A.weight" - ] = state_dict.pop(f"{name}.to_k_lora.down.weight") - converted_state_dict[ - f"{name}.v_proj.lora_A.weight" - ] = state_dict.pop(f"{name}.to_v_lora.down.weight") - converted_state_dict[ - f"{name}.out_proj.lora_A.weight" - ] = state_dict.pop(f"{name}.to_out_lora.down.weight") - + converted_state_dict[f"{name}.q_proj.lora_B.weight"] = state_dict.pop(f"{name}.to_q_lora.up.weight") + converted_state_dict[f"{name}.k_proj.lora_B.weight"] = state_dict.pop(f"{name}.to_k_lora.up.weight") + converted_state_dict[f"{name}.v_proj.lora_B.weight"] = state_dict.pop(f"{name}.to_v_lora.up.weight") + converted_state_dict[f"{name}.out_proj.lora_B.weight"] = state_dict.pop(f"{name}.to_out_lora.up.weight") + + converted_state_dict[f"{name}.q_proj.lora_A.weight"] = state_dict.pop(f"{name}.to_q_lora.down.weight") + converted_state_dict[f"{name}.k_proj.lora_A.weight"] = state_dict.pop(f"{name}.to_k_lora.down.weight") + converted_state_dict[f"{name}.v_proj.lora_A.weight"] = state_dict.pop(f"{name}.to_v_lora.down.weight") + converted_state_dict[f"{name}.out_proj.lora_A.weight"] = state_dict.pop(f"{name}.to_out_lora.down.weight") + return converted_state_dict def convert_peft_state_dict_to_diffusers(attention_modules, state_dict, adapter_name): # Convert from the new naming convention to the diffusers naming convention. converted_state_dict = {} - + + for name, _ in attention_modules: + converted_state_dict[f"{name}.q_proj.lora_linear_layer.up.weight"] = state_dict.pop( + f"{name}.q_proj.lora_B.{adapter_name}.weight" + ) + converted_state_dict[f"{name}.k_proj.lora_linear_layer.up.weight"] = state_dict.pop( + f"{name}.k_proj.lora_B.{adapter_name}.weight" + ) + converted_state_dict[f"{name}.v_proj.lora_linear_layer.up.weight"] = state_dict.pop( + f"{name}.v_proj.lora_B.{adapter_name}.weight" + ) + converted_state_dict[f"{name}.out_proj.lora_linear_layer.up.weight"] = state_dict.pop( + f"{name}.out_proj.lora_B.{adapter_name}.weight" + ) + + converted_state_dict[f"{name}.q_proj.lora_linear_layer.down.weight"] = state_dict.pop( + f"{name}.q_proj.lora_A.{adapter_name}.weight" + ) + converted_state_dict[f"{name}.k_proj.lora_linear_layer.down.weight"] = state_dict.pop( + f"{name}.k_proj.lora_A.{adapter_name}.weight" + ) + converted_state_dict[f"{name}.v_proj.lora_linear_layer.down.weight"] = state_dict.pop( + f"{name}.v_proj.lora_A.{adapter_name}.weight" + ) + converted_state_dict[f"{name}.out_proj.lora_linear_layer.down.weight"] = state_dict.pop( + f"{name}.out_proj.lora_A.{adapter_name}.weight" + ) + + return converted_state_dict + + +def convert_diffusers_state_dict_to_peft(attention_modules, state_dict): + # Convert from the diffusers naming convention to the new naming convention. + converted_state_dict = {} + for name, _ in attention_modules: - converted_state_dict[ + converted_state_dict[f"{name}.q_proj.lora_B.weight"] = state_dict.pop( f"{name}.q_proj.lora_linear_layer.up.weight" - ] = state_dict.pop(f"{name}.q_proj.lora_B.{adapter_name}.weight") - converted_state_dict[ + ) + converted_state_dict[f"{name}.k_proj.lora_B.weight"] = state_dict.pop( f"{name}.k_proj.lora_linear_layer.up.weight" - ] = state_dict.pop(f"{name}.k_proj.lora_B.{adapter_name}.weight") - converted_state_dict[ + ) + converted_state_dict[f"{name}.v_proj.lora_B.weight"] = state_dict.pop( f"{name}.v_proj.lora_linear_layer.up.weight" - ] = state_dict.pop(f"{name}.v_proj.lora_B.{adapter_name}.weight") - converted_state_dict[ + ) + converted_state_dict[f"{name}.out_proj.lora_B.weight"] = state_dict.pop( f"{name}.out_proj.lora_linear_layer.up.weight" - ] = state_dict.pop(f"{name}.out_proj.lora_B.{adapter_name}.weight") + ) - converted_state_dict[ + converted_state_dict[f"{name}.q_proj.lora_A.weight"] = state_dict.pop( f"{name}.q_proj.lora_linear_layer.down.weight" - ] = state_dict.pop(f"{name}.q_proj.lora_A.{adapter_name}.weight") - converted_state_dict[ + ) + converted_state_dict[f"{name}.k_proj.lora_A.weight"] = state_dict.pop( f"{name}.k_proj.lora_linear_layer.down.weight" - ] = state_dict.pop(f"{name}.k_proj.lora_A.{adapter_name}.weight") - converted_state_dict[ + ) + converted_state_dict[f"{name}.v_proj.lora_A.weight"] = state_dict.pop( f"{name}.v_proj.lora_linear_layer.down.weight" - ] = state_dict.pop(f"{name}.v_proj.lora_A.{adapter_name}.weight") - converted_state_dict[ + ) + converted_state_dict[f"{name}.out_proj.lora_A.weight"] = state_dict.pop( f"{name}.out_proj.lora_linear_layer.down.weight" - ] = state_dict.pop(f"{name}.out_proj.lora_A.{adapter_name}.weight") - - return converted_state_dict - + ) -def convert_diffusers_state_dict_to_peft(attention_modules, state_dict): - # Convert from the diffusers naming convention to the new naming convention. - converted_state_dict = {} - - for name, _ in attention_modules: - converted_state_dict[ - f"{name}.q_proj.lora_B.weight" - ] = state_dict.pop(f"{name}.q_proj.lora_linear_layer.up.weight") - converted_state_dict[ - f"{name}.k_proj.lora_B.weight" - ] = state_dict.pop(f"{name}.k_proj.lora_linear_layer.up.weight") - converted_state_dict[ - f"{name}.v_proj.lora_B.weight" - ] = state_dict.pop(f"{name}.v_proj.lora_linear_layer.up.weight") - converted_state_dict[ - f"{name}.out_proj.lora_B.weight" - ] = state_dict.pop(f"{name}.out_proj.lora_linear_layer.up.weight") - - converted_state_dict[ - f"{name}.q_proj.lora_A.weight" - ] = state_dict.pop(f"{name}.q_proj.lora_linear_layer.down.weight") - converted_state_dict[ - f"{name}.k_proj.lora_A.weight" - ] = state_dict.pop(f"{name}.k_proj.lora_linear_layer.down.weight") - converted_state_dict[ - f"{name}.v_proj.lora_A.weight" - ] = state_dict.pop(f"{name}.v_proj.lora_linear_layer.down.weight") - converted_state_dict[ - f"{name}.out_proj.lora_A.weight" - ] = state_dict.pop(f"{name}.out_proj.lora_linear_layer.down.weight") - return converted_state_dict @@ -141,7 +126,7 @@ def convert_unet_state_dict_to_peft(state_dict): if any(pattern in k for pattern in patterns.keys()): for old, new in patterns.items(): k = k.replace(old, new) - + converted_state_dict[k] = v - - return converted_state_dict \ No newline at end of file + + return converted_state_dict From 01f6d1d88cfb5452a7eb15f9c7673abe88b46a96 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 15 Sep 2023 13:13:59 +0000 Subject: [PATCH 04/52] style --- src/diffusers/utils/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 31de66ab315b..061904d3d54d 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -67,6 +67,7 @@ is_note_seq_available, is_omegaconf_available, is_onnx_available, + is_peft_available, is_scipy_available, is_tensorboard_available, is_torch_available, @@ -77,7 +78,6 @@ is_unidecode_available, is_wandb_available, is_xformers_available, - is_peft_available, requires_backends, ) from .loading_utils import load_image From 5a150b205917c0c32f179cde7a5b779f6caf5ef2 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 15 Sep 2023 13:14:12 +0000 Subject: [PATCH 05/52] add in setup --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index ca8928b3223c..83ecff1b98a1 100644 --- a/setup.py +++ b/setup.py @@ -128,6 +128,7 @@ "torchvision", "transformers>=4.25.1", "urllib3<=2.0.0", + "peft>=0.5.0" ] # this is a lookup table with items like: From 961e77629847d3e312f30d45280c44f54495a696 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 15 Sep 2023 13:16:14 +0000 Subject: [PATCH 06/52] oops --- src/diffusers/utils/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 061904d3d54d..1a0500b1b0bd 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -90,8 +90,6 @@ convert_unet_state_dict_to_peft, ) from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil -from .testing_utils import export_to_gif, export_to_obj, export_to_ply, export_to_video -from .torch_utils import is_compiled_module, randn_tensor logger = get_logger(__name__) From cdbe7391a85f4b763a6f892902f2a18d07cd2699 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 15 Sep 2023 14:48:49 +0000 Subject: [PATCH 07/52] more changes --- src/diffusers/loaders.py | 157 ++++++------------------------ src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/peft_utils.py | 29 +++++- tests/models/test_lora_layers.py | 5 + 4 files changed, 62 insertions(+), 130 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index e215103d6085..c7616a677e9d 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -35,9 +35,9 @@ _get_model_file, convert_diffusers_state_dict_to_peft, convert_old_state_dict_to_peft, + recurse_replace_peft_layers, deprecate, is_accelerate_available, - is_accelerate_version, is_omegaconf_available, is_peft_available, is_transformers_available, @@ -616,19 +616,20 @@ def save_function(weights, filename): logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") def fuse_lora(self, lora_scale=1.0): + from peft.tuners.tuners_utils import BaseTunerLayer + self.lora_scale = lora_scale - self.apply(self._fuse_lora_apply) - def _fuse_lora_apply(self, module): - if hasattr(module, "_fuse_lora"): - module._fuse_lora(self.lora_scale) + for module in self.modules(): + if isinstance(module, BaseTunerLayer): + module.merge() def unfuse_lora(self): - self.apply(self._unfuse_lora_apply) + from peft.tuners.tuners_utils import BaseTunerLayer - def _unfuse_lora_apply(self, module): - if hasattr(module, "_unfuse_lora"): - module._unfuse_lora() + for module in self.modules(): + if isinstance(module, BaseTunerLayer): + module.unmerge() class TextualInversionLoaderMixin: @@ -1492,7 +1493,7 @@ def load_lora_into_text_encoder( lora_config = LoraConfig(r=lora_rank, target_modules=target_modules, lora_alpha=alpha) - text_encoder.load_adapter(text_encoder_lora_state_dict, peft_config=lora_config) + text_encoder.load_adapter(adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config) text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) @@ -1502,99 +1503,6 @@ def lora_scale(self) -> float: # if _lora_scale has not been set, return 1 return self._lora_scale if hasattr(self, "_lora_scale") else 1.0 - def _remove_text_encoder_monkey_patch(self): - self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder) - - @classmethod - def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder): - for _, attn_module in text_encoder_attn_modules(text_encoder): - if isinstance(attn_module.q_proj, PatchedLoraProjection): - attn_module.q_proj.lora_linear_layer = None - attn_module.k_proj.lora_linear_layer = None - attn_module.v_proj.lora_linear_layer = None - attn_module.out_proj.lora_linear_layer = None - - for _, mlp_module in text_encoder_mlp_modules(text_encoder): - if isinstance(mlp_module.fc1, PatchedLoraProjection): - mlp_module.fc1.lora_linear_layer = None - mlp_module.fc2.lora_linear_layer = None - - @classmethod - def _modify_text_encoder( - cls, - text_encoder, - lora_scale=1, - network_alphas=None, - rank: Union[Dict[str, int], int] = 4, - dtype=None, - patch_mlp=False, - low_cpu_mem_usage=False, - ): - r""" - Monkey-patches the forward passes of attention modules of the text encoder. - """ - - def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters): - linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model - ctx = init_empty_weights if low_cpu_mem_usage else nullcontext - with ctx(): - model = PatchedLoraProjection(linear_layer, lora_scale, network_alpha, rank, dtype=dtype) - - lora_parameters.extend(model.lora_linear_layer.parameters()) - return model - - # First, remove any monkey-patch that might have been applied before - cls._remove_text_encoder_monkey_patch_classmethod(text_encoder) - - lora_parameters = [] - network_alphas = {} if network_alphas is None else network_alphas - is_network_alphas_populated = len(network_alphas) > 0 - - for name, attn_module in text_encoder_attn_modules(text_encoder): - query_alpha = network_alphas.pop(name + ".to_q_lora.down.weight.alpha", None) - key_alpha = network_alphas.pop(name + ".to_k_lora.down.weight.alpha", None) - value_alpha = network_alphas.pop(name + ".to_v_lora.down.weight.alpha", None) - out_alpha = network_alphas.pop(name + ".to_out_lora.down.weight.alpha", None) - - if isinstance(rank, dict): - current_rank = rank.pop(f"{name}.out_proj.lora_linear_layer.up.weight") - else: - current_rank = rank - - attn_module.q_proj = create_patched_linear_lora( - attn_module.q_proj, query_alpha, current_rank, dtype, lora_parameters - ) - attn_module.k_proj = create_patched_linear_lora( - attn_module.k_proj, key_alpha, current_rank, dtype, lora_parameters - ) - attn_module.v_proj = create_patched_linear_lora( - attn_module.v_proj, value_alpha, current_rank, dtype, lora_parameters - ) - attn_module.out_proj = create_patched_linear_lora( - attn_module.out_proj, out_alpha, current_rank, dtype, lora_parameters - ) - - if patch_mlp: - for name, mlp_module in text_encoder_mlp_modules(text_encoder): - fc1_alpha = network_alphas.pop(name + ".fc1.lora_linear_layer.down.weight.alpha", None) - fc2_alpha = network_alphas.pop(name + ".fc2.lora_linear_layer.down.weight.alpha", None) - - current_rank_fc1 = rank.pop(f"{name}.fc1.lora_linear_layer.up.weight") - current_rank_fc2 = rank.pop(f"{name}.fc2.lora_linear_layer.up.weight") - - mlp_module.fc1 = create_patched_linear_lora( - mlp_module.fc1, fc1_alpha, current_rank_fc1, dtype, lora_parameters - ) - mlp_module.fc2 = create_patched_linear_lora( - mlp_module.fc2, fc2_alpha, current_rank_fc2, dtype, lora_parameters - ) - - if is_network_alphas_populated and len(network_alphas) > 0: - raise ValueError( - f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}" - ) - - return lora_parameters @classmethod def save_lora_weights( @@ -1870,7 +1778,10 @@ def unload_lora_weights(self): module.set_lora_layer(None) # Safe to call the following regardless of LoRA. - self._remove_text_encoder_monkey_patch() + recurse_replace_peft_layers(self.text_encoder) + + def _remove_text_encoder_monkey_patch(self): + recurse_replace_peft_layers(self.text_encoder) def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora_scale: float = 1.0): r""" @@ -1890,6 +1801,8 @@ def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora lora_scale (`float`, defaults to 1.0): Controls how much to influence the outputs with the LoRA parameters. """ + from peft.tuners.tuners_utils import BaseTunerLayer + if fuse_unet or fuse_text_encoder: self.num_fused_loras += 1 if self.num_fused_loras > 1: @@ -1901,17 +1814,9 @@ def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora self.unet.fuse_lora(lora_scale) def fuse_text_encoder_lora(text_encoder): - for _, attn_module in text_encoder_attn_modules(text_encoder): - if isinstance(attn_module.q_proj, PatchedLoraProjection): - attn_module.q_proj._fuse_lora(lora_scale) - attn_module.k_proj._fuse_lora(lora_scale) - attn_module.v_proj._fuse_lora(lora_scale) - attn_module.out_proj._fuse_lora(lora_scale) - - for _, mlp_module in text_encoder_mlp_modules(text_encoder): - if isinstance(mlp_module.fc1, PatchedLoraProjection): - mlp_module.fc1._fuse_lora(lora_scale) - mlp_module.fc2._fuse_lora(lora_scale) + for module in text_encoder.modules(): + if isinstance(module, BaseTunerLayer): + module.merge() if fuse_text_encoder: if hasattr(self, "text_encoder"): @@ -1936,21 +1841,15 @@ def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the LoRA parameters then it won't have any effect. """ + from peft.tuners.tuners_utils import BaseTunerLayer + if unfuse_unet: self.unet.unfuse_lora() def unfuse_text_encoder_lora(text_encoder): - for _, attn_module in text_encoder_attn_modules(text_encoder): - if isinstance(attn_module.q_proj, PatchedLoraProjection): - attn_module.q_proj._unfuse_lora() - attn_module.k_proj._unfuse_lora() - attn_module.v_proj._unfuse_lora() - attn_module.out_proj._unfuse_lora() - - for _, mlp_module in text_encoder_mlp_modules(text_encoder): - if isinstance(mlp_module.fc1, PatchedLoraProjection): - mlp_module.fc1._unfuse_lora() - mlp_module.fc2._unfuse_lora() + for module in text_encoder.modules(): + if isinstance(module, BaseTunerLayer): + module.unmerge() if unfuse_text_encoder: if hasattr(self, "text_encoder"): @@ -2566,7 +2465,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di # pipeline. # Remove any existing hooks. - if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + if is_accelerate_available() and version.parse(importlib.metadata.version("accelerate")) >= version.parse("0.17.0"): from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module else: raise ImportError("Offloading requires `accelerate v0.17.0` or higher.") @@ -2680,5 +2579,5 @@ def pack_weights(layers, prefix): ) def _remove_text_encoder_monkey_patch(self): - self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder) - self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2) + recurse_replace_peft_layers(self.text_encoder) + recurse_replace_peft_layers(self.text_encoder_2) \ No newline at end of file diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 1a0500b1b0bd..f8018c0c3a43 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -88,6 +88,7 @@ convert_old_state_dict_to_peft, convert_peft_state_dict_to_diffusers, convert_unet_state_dict_to_peft, + recurse_replace_peft_layers, ) from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index fbeac9bad642..63bfe5a0da66 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -14,7 +14,34 @@ """ PEFT utilities: Utilities related to peft library """ - +import torch + + +def recurse_replace_peft_layers(model): + r""" + Recursively replace all instances of `LoraLayer` with corresponding new layers in `model`. + """ + from peft.tuners.lora import LoraLayer + + for name, module in model.named_children(): + if len(list(module.children())) > 0: + ## compound module, go inside it + recurse_replace_peft_layers(module) + + if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear): + new_module = torch.nn.Linear(module.in_features, module.out_features, bias=module.bias is not None).to(module.weight.device) + new_module.weight = module.weight + if module.bias is not None: + new_module.bias = module.bias + + setattr(model, name, new_module) + del module + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + # TODO: do it for Conv2d + + return model def convert_old_state_dict_to_peft(attention_modules, state_dict): # Convert from the old naming convention to the new naming convention. diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index ef6ade9af5c1..622cfc2afd3e 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -266,6 +266,7 @@ def test_lora_save_load_no_safe_serialization(self): # Outputs shouldn't match. self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) + @unittest.skip("this is an old test") def test_text_encoder_lora_monkey_patch(self): pipeline_components, _ = self.get_dummy_components() pipe = StableDiffusionPipeline(**pipeline_components) @@ -305,6 +306,10 @@ def test_text_encoder_lora_monkey_patch(self): outputs_without_lora, outputs_with_lora ), "lora_up_weight are not zero, so the lora outputs should be different to without lora outputs" + def test_text_encoder_lora_unload(self): + # TODO @younesbelkada ... + pass + def test_text_encoder_lora_remove_monkey_patch(self): pipeline_components, _ = self.get_dummy_components() pipe = StableDiffusionPipeline(**pipeline_components) From 691368b060daf87aa1784414879fc27e6b900646 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 18 Sep 2023 13:51:40 +0000 Subject: [PATCH 08/52] v1 rzfactor CI --- src/diffusers/loaders.py | 37 ++++--- src/diffusers/models/lora.py | 16 +-- src/diffusers/utils/peft_utils.py | 1 + tests/models/test_lora_layers.py | 163 +++++++++++++++++------------- 4 files changed, 123 insertions(+), 94 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index c7616a677e9d..5736635294af 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -616,20 +616,20 @@ def save_function(weights, filename): logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") def fuse_lora(self, lora_scale=1.0): - from peft.tuners.tuners_utils import BaseTunerLayer - self.lora_scale = lora_scale + self.apply(self._fuse_lora_apply) - for module in self.modules(): - if isinstance(module, BaseTunerLayer): - module.merge() + def _fuse_lora_apply(self, module): + if hasattr(module, "_fuse_lora"): + module._fuse_lora(self.lora_scale) def unfuse_lora(self): - from peft.tuners.tuners_utils import BaseTunerLayer + self.apply(self._unfuse_lora_apply) + + def _unfuse_lora_apply(self, module): + if hasattr(module, "_unfuse_lora"): + module._unfuse_lora() - for module in self.modules(): - if isinstance(module, BaseTunerLayer): - module.unmerge() class TextualInversionLoaderMixin: @@ -1398,7 +1398,7 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage @classmethod def load_lora_into_text_encoder( - cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0, low_cpu_mem_usage=None + cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0, low_cpu_mem_usage=None, adapter_name="default" ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -1421,6 +1421,9 @@ def load_lora_into_text_encoder( tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this argument to `True` will raise an error. + adapter_name (`str`, *optional*, defaults to `"default"`): + The name of the adapter to load the LoRA layers into, useful in the case of using multiple adapters + with the same model. Default to the default name used in PEFT library - `"default"`. """ low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), @@ -1778,10 +1781,18 @@ def unload_lora_weights(self): module.set_lora_layer(None) # Safe to call the following regardless of LoRA. - recurse_replace_peft_layers(self.text_encoder) + if hasattr(self, "text_encoder"): + recurse_replace_peft_layers(self.text_encoder) + if hasattr(self, "text_encoder_2"): + recurse_replace_peft_layers(self.text_encoder_2) + + # import pdb; pdb.set_trace() def _remove_text_encoder_monkey_patch(self): - recurse_replace_peft_layers(self.text_encoder) + if hasattr(self, "text_encoder"): + recurse_replace_peft_layers(self.text_encoder) + if hasattr(self, "text_encoder_2"): + recurse_replace_peft_layers(self.text_encoder_2) def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora_scale: float = 1.0): r""" @@ -1819,6 +1830,8 @@ def fuse_text_encoder_lora(text_encoder): module.merge() if fuse_text_encoder: + # import pdb; pdb.set_trace() + if hasattr(self, "text_encoder"): fuse_text_encoder_lora(self.text_encoder) if hasattr(self, "text_encoder_2"): diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index cc8e3e231e2b..c2bee253bfcf 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -26,17 +26,11 @@ def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0): - for _, attn_module in text_encoder_attn_modules(text_encoder): - if isinstance(attn_module.q_proj, PatchedLoraProjection): - attn_module.q_proj.lora_scale = lora_scale - attn_module.k_proj.lora_scale = lora_scale - attn_module.v_proj.lora_scale = lora_scale - attn_module.out_proj.lora_scale = lora_scale - - for _, mlp_module in text_encoder_mlp_modules(text_encoder): - if isinstance(mlp_module.fc1, PatchedLoraProjection): - mlp_module.fc1.lora_scale = lora_scale - mlp_module.fc2.lora_scale = lora_scale + from peft.tuners.lora import LoraLayer + + for module in text_encoder.modules(): + if isinstance(module, LoraLayer): + module.scaling[module.active_adapter] = lora_scale class LoRALinearLayer(nn.Module): diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 63bfe5a0da66..524096aa21bd 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -17,6 +17,7 @@ import torch + def recurse_replace_peft_layers(model): r""" Recursively replace all instances of `LoraLayer` with corresponding new layers in `model`. diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 622cfc2afd3e..976fa8d54d3a 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -266,8 +266,10 @@ def test_lora_save_load_no_safe_serialization(self): # Outputs shouldn't match. self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) - @unittest.skip("this is an old test") - def test_text_encoder_lora_monkey_patch(self): + + def test_text_encoder_lora(self): + from peft import LoraConfig + pipeline_components, _ = self.get_dummy_components() pipe = StableDiffusionPipeline(**pipeline_components) @@ -275,42 +277,43 @@ def test_text_encoder_lora_monkey_patch(self): # inference without lora outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0] - assert outputs_without_lora.shape == (1, 77, 32) + self.assertTrue(outputs_without_lora.shape == (1, 77, 32)) - # monkey patch - params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale) + lora_config = LoraConfig( + r=4, + target_modules=["k_proj", "q_proj", "v_proj"] + ) - set_lora_weights(params, randn_weight=False) + # 0-init the lora weights + pipe.text_encoder.add_adapter(lora_config, adapter_name="default_O_init") # inference with lora outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] - assert outputs_with_lora.shape == (1, 77, 32) + self.assertTrue(outputs_with_lora.shape == (1, 77, 32)) - assert torch.allclose( - outputs_without_lora, outputs_with_lora - ), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs" - - # create lora_attn_procs with randn up.weights - create_text_encoder_lora_attn_procs(pipe.text_encoder) + self.assertTrue(torch.allclose(outputs_without_lora, outputs_with_lora), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs") - # monkey patch - params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale) + lora_config = LoraConfig( + r=4, + target_modules=["k_proj", "q_proj", "v_proj"], + init_lora_weights=False + ) - set_lora_weights(params, randn_weight=True) + # LoRA with no init + pipe.text_encoder.add_adapter(lora_config, adapter_name="default_no_init") + # Make it use that adapter + pipe.text_encoder.set_adapter("default_no_init") # inference with lora outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] - assert outputs_with_lora.shape == (1, 77, 32) + self.assertTrue(outputs_with_lora.shape == (1, 77, 32)) - assert not torch.allclose( - outputs_without_lora, outputs_with_lora - ), "lora_up_weight are not zero, so the lora outputs should be different to without lora outputs" + self.assertFalse(torch.allclose(outputs_without_lora, outputs_with_lora), "lora_up_weight are not zero, so the lora outputs should be different to without lora outputs") - def test_text_encoder_lora_unload(self): - # TODO @younesbelkada ... - pass def test_text_encoder_lora_remove_monkey_patch(self): + from peft import LoraConfig + pipeline_components, _ = self.get_dummy_components() pipe = StableDiffusionPipeline(**pipeline_components) @@ -320,10 +323,15 @@ def test_text_encoder_lora_remove_monkey_patch(self): outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0] assert outputs_without_lora.shape == (1, 77, 32) - # monkey patch - params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale) + lora_config = LoraConfig( + r=4, + target_modules=["k_proj", "q_proj", "v_proj"], + # To randomly init LoRA weights + init_lora_weights=False + ) - set_lora_weights(params, randn_weight=True) + # Inject adapters + pipe.text_encoder.add_adapter(lora_config) # inference with lora outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] @@ -783,6 +791,8 @@ def test_lora_fusion(self): self.assertFalse(np.allclose(orig_image_slice, lora_image_slice, atol=1e-3)) def test_unfuse_lora(self): + from peft import LoraConfig + pipeline_components, lora_components = self.get_dummy_components() sd_pipe = StableDiffusionXLPipeline(**pipeline_components) sd_pipe = sd_pipe.to(torch_device) @@ -795,15 +805,20 @@ def test_unfuse_lora(self): # Emulate training. set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True) - set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True) - set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True) + + lora_config = LoraConfig( + r=8, + target_modules=["q_proj", "k_proj", "v_proj"], + init_lora_weights=False + ) + + sd_pipe.text_encoder.add_adapter(lora_config) + sd_pipe.text_encoder_2.add_adapter(lora_config) with tempfile.TemporaryDirectory() as tmpdirname: StableDiffusionXLPipeline.save_lora_weights( save_directory=tmpdirname, unet_lora_layers=lora_components["unet_lora_layers"], - text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], - text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], safe_serialization=True, ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) @@ -825,10 +840,12 @@ def test_unfuse_lora(self): orig_image_slice_two, lora_image_slice ), "Fusion of LoRAs should lead to a different image slice." assert np.allclose( - orig_image_slice, orig_image_slice_two, atol=1e-3 + orig_image_slice, orig_image_slice_two, atol=4e-2 ), "Reversing LoRA fusion should lead to results similar to what was obtained with the pipeline without any LoRA parameters." def test_lora_fusion_is_not_affected_by_unloading(self): + from peft import LoraConfig + pipeline_components, lora_components = self.get_dummy_components() sd_pipe = StableDiffusionXLPipeline(**pipeline_components) sd_pipe = sd_pipe.to(torch_device) @@ -840,19 +857,15 @@ def test_lora_fusion_is_not_affected_by_unloading(self): # Emulate training. set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True) - set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True) - set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True) - with tempfile.TemporaryDirectory() as tmpdirname: - StableDiffusionXLPipeline.save_lora_weights( - save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_layers"], - text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], - text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], - safe_serialization=True, - ) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + lora_config = LoraConfig( + r=8, + target_modules=["q_proj", "k_proj", "v_proj"], + init_lora_weights=False + ) + + sd_pipe.text_encoder.add_adapter(lora_config) + sd_pipe.text_encoder_2.add_adapter(lora_config) sd_pipe.fuse_lora() lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images @@ -863,11 +876,11 @@ def test_lora_fusion_is_not_affected_by_unloading(self): images_with_unloaded_lora = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images images_with_unloaded_lora_slice = images_with_unloaded_lora[0, -3:, -3:, -1] - assert np.allclose( - lora_image_slice, images_with_unloaded_lora_slice - ), "`unload_lora_weights()` should have not effect on the semantics of the results as the LoRA parameters were fused." + self.assertTrue(np.allclose(lora_image_slice, images_with_unloaded_lora_slice), "`unload_lora_weights()` should have not effect on the semantics of the results as the LoRA parameters were fused.") def test_fuse_lora_with_different_scales(self): + from peft import LoraConfig + pipeline_components, lora_components = self.get_dummy_components() sd_pipe = StableDiffusionXLPipeline(**pipeline_components) sd_pipe = sd_pipe.to(torch_device) @@ -879,15 +892,21 @@ def test_fuse_lora_with_different_scales(self): # Emulate training. set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True) - set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True) - set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True) + lora_config = LoraConfig( + r=8, + target_modules=["q_proj", "k_proj", "v_proj"], + init_lora_weights=False + ) + + sd_pipe.text_encoder.add_adapter(lora_config) + sd_pipe.text_encoder_2.add_adapter(lora_config) with tempfile.TemporaryDirectory() as tmpdirname: StableDiffusionXLPipeline.save_lora_weights( save_directory=tmpdirname, unet_lora_layers=lora_components["unet_lora_layers"], - text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], - text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + # text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], + # text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], safe_serialization=True, ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) @@ -900,17 +919,6 @@ def test_fuse_lora_with_different_scales(self): # Reverse LoRA fusion. sd_pipe.unfuse_lora() - with tempfile.TemporaryDirectory() as tmpdirname: - StableDiffusionXLPipeline.save_lora_weights( - save_directory=tmpdirname, - unet_lora_layers=lora_components["unet_lora_layers"], - text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], - text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], - safe_serialization=True, - ) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) - sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - sd_pipe.fuse_lora(lora_scale=0.5) lora_images_scale_0_5 = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images lora_image_slice_scale_0_5 = lora_images_scale_0_5[0, -3:, -3:, -1] @@ -920,6 +928,8 @@ def test_fuse_lora_with_different_scales(self): ), "Different LoRA scales should influence the outputs accordingly." def test_with_different_scales(self): + from peft import LoraConfig + pipeline_components, lora_components = self.get_dummy_components() sd_pipe = StableDiffusionXLPipeline(**pipeline_components) sd_pipe = sd_pipe.to(torch_device) @@ -931,15 +941,19 @@ def test_with_different_scales(self): # Emulate training. set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True) - set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True) - set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True) + lora_config = LoraConfig( + r=8, + target_modules=["q_proj", "k_proj", "v_proj"], + init_lora_weights=False + ) + + sd_pipe.text_encoder.add_adapter(lora_config) + sd_pipe.text_encoder_2.add_adapter(lora_config) with tempfile.TemporaryDirectory() as tmpdirname: StableDiffusionXLPipeline.save_lora_weights( save_directory=tmpdirname, unet_lora_layers=lora_components["unet_lora_layers"], - text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], - text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], safe_serialization=True, ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) @@ -959,14 +973,16 @@ def test_with_different_scales(self): lora_image_slice_scale_0_0 = lora_images_scale_0_0[0, -3:, -3:, -1] assert not np.allclose( - lora_image_slice_scale_one, lora_image_slice_scale_0_5, atol=1e-03 + lora_image_slice_scale_one, lora_image_slice_scale_0_5, atol=1e-3 ), "Different LoRA scales should influence the outputs accordingly." assert np.allclose( - original_imagee_slice, lora_image_slice_scale_0_0, atol=1e-03 + original_imagee_slice, lora_image_slice_scale_0_0, atol=1e-3 ), "LoRA scale of 0.0 shouldn't be different from the results without LoRA." def test_with_different_scales_fusion_equivalence(self): + from peft import LoraConfig + pipeline_components, lora_components = self.get_dummy_components() sd_pipe = StableDiffusionXLPipeline(**pipeline_components) sd_pipe = sd_pipe.to(torch_device) @@ -979,15 +995,20 @@ def test_with_different_scales_fusion_equivalence(self): # Emulate training. set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True, var=0.1) - set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True, var=0.1) - set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True, var=0.1) + + lora_config = LoraConfig( + r=8, + target_modules=["q_proj", "k_proj", "v_proj"], + init_lora_weights=False + ) + + sd_pipe.text_encoder.add_adapter(lora_config) + sd_pipe.text_encoder_2.add_adapter(lora_config) with tempfile.TemporaryDirectory() as tmpdirname: StableDiffusionXLPipeline.save_lora_weights( save_directory=tmpdirname, unet_lora_layers=lora_components["unet_lora_layers"], - text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], - text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], safe_serialization=True, ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) From 79188516402d61570efef3332fe7c6857cc4c18e Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon, 18 Sep 2023 16:00:24 +0200 Subject: [PATCH 09/52] Apply suggestions from code review Co-authored-by: Patrick von Platen --- src/diffusers/loaders.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 5736635294af..13fa45982178 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1786,8 +1786,6 @@ def unload_lora_weights(self): if hasattr(self, "text_encoder_2"): recurse_replace_peft_layers(self.text_encoder_2) - # import pdb; pdb.set_trace() - def _remove_text_encoder_monkey_patch(self): if hasattr(self, "text_encoder"): recurse_replace_peft_layers(self.text_encoder) @@ -1830,8 +1828,6 @@ def fuse_text_encoder_lora(text_encoder): module.merge() if fuse_text_encoder: - # import pdb; pdb.set_trace() - if hasattr(self, "text_encoder"): fuse_text_encoder_lora(self.text_encoder) if hasattr(self, "text_encoder_2"): From 14db139116098268ae465ffaa0bf3a169ce241ae Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 18 Sep 2023 14:03:26 +0000 Subject: [PATCH 10/52] few todos --- tests/models/test_lora_layers.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 976fa8d54d3a..7dfd243f4f47 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -927,6 +927,9 @@ def test_fuse_lora_with_different_scales(self): lora_image_slice_scale_one, lora_image_slice_scale_0_5, atol=1e-03 ), "Different LoRA scales should influence the outputs accordingly." + # TODO: @younesbelkada add save / load tests with text encoder + # TODO: @younesbelkada add public method to attach adapters in text encoder + def test_with_different_scales(self): from peft import LoraConfig From d56a14db7b55d5a410259f4ce9ab6f7cb8ea5f3b Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 18 Sep 2023 14:07:31 +0000 Subject: [PATCH 11/52] protect torch import --- src/diffusers/utils/peft_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 524096aa21bd..9479056f906f 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -14,7 +14,10 @@ """ PEFT utilities: Utilities related to peft library """ -import torch +from .import_utils import is_torch_available + +if is_torch_available(): + import torch From ec87c196f32b916e3da2bc3877fc1ed616eafd93 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 18 Sep 2023 14:29:08 +0000 Subject: [PATCH 12/52] style --- src/diffusers/loaders.py | 20 +++++++--- src/diffusers/models/lora.py | 1 - src/diffusers/utils/peft_utils.py | 11 ++++-- tests/models/test_lora_layers.py | 63 +++++++++++-------------------- 4 files changed, 44 insertions(+), 51 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index a7ecc82de318..047feb8902a9 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -35,13 +35,13 @@ _get_model_file, convert_diffusers_state_dict_to_peft, convert_old_state_dict_to_peft, - recurse_replace_peft_layers, deprecate, is_accelerate_available, is_omegaconf_available, is_peft_available, is_transformers_available, logging, + recurse_replace_peft_layers, ) from .utils.import_utils import BACKENDS_MAPPING @@ -641,6 +641,7 @@ def _unfuse_lora_apply(self, module): if hasattr(module, "_unfuse_lora"): module._unfuse_lora() + def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs): cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) force_download = kwargs.pop("force_download", False) @@ -715,6 +716,7 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs) return state_dicts + class TextualInversionLoaderMixin: r""" Load textual inversion tokens and embeddings to the tokenizer and text encoder. @@ -1457,7 +1459,14 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage @classmethod def load_lora_into_text_encoder( - cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0, low_cpu_mem_usage=None, adapter_name="default" + cls, + state_dict, + network_alphas, + text_encoder, + prefix=None, + lora_scale=1.0, + low_cpu_mem_usage=None, + adapter_name="default", ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -1565,7 +1574,6 @@ def lora_scale(self) -> float: # if _lora_scale has not been set, return 1 return self._lora_scale if hasattr(self, "_lora_scale") else 1.0 - @classmethod def save_lora_weights( self, @@ -2533,7 +2541,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di # pipeline. # Remove any existing hooks. - if is_accelerate_available() and version.parse(importlib.metadata.version("accelerate")) >= version.parse("0.17.0"): + if is_accelerate_available() and version.parse(importlib.metadata.version("accelerate")) >= version.parse( + "0.17.0" + ): from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module else: raise ImportError("Offloading requires `accelerate v0.17.0` or higher.") @@ -2647,4 +2657,4 @@ def pack_weights(layers, prefix): def _remove_text_encoder_monkey_patch(self): recurse_replace_peft_layers(self.text_encoder) - recurse_replace_peft_layers(self.text_encoder_2) \ No newline at end of file + recurse_replace_peft_layers(self.text_encoder_2) diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index c2bee253bfcf..2677ae8dbb6d 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -18,7 +18,6 @@ import torch.nn.functional as F from torch import nn -from ..loaders import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules from ..utils import logging diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 9479056f906f..8629037e0afc 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -16,11 +16,11 @@ """ from .import_utils import is_torch_available + if is_torch_available(): import torch - def recurse_replace_peft_layers(model): r""" Recursively replace all instances of `LoraLayer` with corresponding new layers in `model`. @@ -31,9 +31,11 @@ def recurse_replace_peft_layers(model): if len(list(module.children())) > 0: ## compound module, go inside it recurse_replace_peft_layers(module) - + if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear): - new_module = torch.nn.Linear(module.in_features, module.out_features, bias=module.bias is not None).to(module.weight.device) + new_module = torch.nn.Linear(module.in_features, module.out_features, bias=module.bias is not None).to( + module.weight.device + ) new_module.weight = module.weight if module.bias is not None: new_module.bias = module.bias @@ -44,9 +46,10 @@ def recurse_replace_peft_layers(model): if torch.cuda.is_available(): torch.cuda.empty_cache() # TODO: do it for Conv2d - + return model + def convert_old_state_dict_to_peft(attention_modules, state_dict): # Convert from the old naming convention to the new naming convention. # diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 7dfd243f4f47..2acc693e77b9 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -266,7 +266,6 @@ def test_lora_save_load_no_safe_serialization(self): # Outputs shouldn't match. self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) - def test_text_encoder_lora(self): from peft import LoraConfig @@ -279,10 +278,7 @@ def test_text_encoder_lora(self): outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0] self.assertTrue(outputs_without_lora.shape == (1, 77, 32)) - lora_config = LoraConfig( - r=4, - target_modules=["k_proj", "q_proj", "v_proj"] - ) + lora_config = LoraConfig(r=4, target_modules=["k_proj", "q_proj", "v_proj"]) # 0-init the lora weights pipe.text_encoder.add_adapter(lora_config, adapter_name="default_O_init") @@ -291,14 +287,13 @@ def test_text_encoder_lora(self): outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] self.assertTrue(outputs_with_lora.shape == (1, 77, 32)) - self.assertTrue(torch.allclose(outputs_without_lora, outputs_with_lora), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs") - - lora_config = LoraConfig( - r=4, - target_modules=["k_proj", "q_proj", "v_proj"], - init_lora_weights=False + self.assertTrue( + torch.allclose(outputs_without_lora, outputs_with_lora), + "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs", ) + lora_config = LoraConfig(r=4, target_modules=["k_proj", "q_proj", "v_proj"], init_lora_weights=False) + # LoRA with no init pipe.text_encoder.add_adapter(lora_config, adapter_name="default_no_init") # Make it use that adapter @@ -308,9 +303,11 @@ def test_text_encoder_lora(self): outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] self.assertTrue(outputs_with_lora.shape == (1, 77, 32)) - self.assertFalse(torch.allclose(outputs_without_lora, outputs_with_lora), "lora_up_weight are not zero, so the lora outputs should be different to without lora outputs") + self.assertFalse( + torch.allclose(outputs_without_lora, outputs_with_lora), + "lora_up_weight are not zero, so the lora outputs should be different to without lora outputs", + ) - def test_text_encoder_lora_remove_monkey_patch(self): from peft import LoraConfig @@ -327,7 +324,7 @@ def test_text_encoder_lora_remove_monkey_patch(self): r=4, target_modules=["k_proj", "q_proj", "v_proj"], # To randomly init LoRA weights - init_lora_weights=False + init_lora_weights=False, ) # Inject adapters @@ -806,11 +803,7 @@ def test_unfuse_lora(self): # Emulate training. set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True) - lora_config = LoraConfig( - r=8, - target_modules=["q_proj", "k_proj", "v_proj"], - init_lora_weights=False - ) + lora_config = LoraConfig(r=8, target_modules=["q_proj", "k_proj", "v_proj"], init_lora_weights=False) sd_pipe.text_encoder.add_adapter(lora_config) sd_pipe.text_encoder_2.add_adapter(lora_config) @@ -844,7 +837,7 @@ def test_unfuse_lora(self): ), "Reversing LoRA fusion should lead to results similar to what was obtained with the pipeline without any LoRA parameters." def test_lora_fusion_is_not_affected_by_unloading(self): - from peft import LoraConfig + from peft import LoraConfig pipeline_components, lora_components = self.get_dummy_components() sd_pipe = StableDiffusionXLPipeline(**pipeline_components) @@ -858,11 +851,7 @@ def test_lora_fusion_is_not_affected_by_unloading(self): # Emulate training. set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True) - lora_config = LoraConfig( - r=8, - target_modules=["q_proj", "k_proj", "v_proj"], - init_lora_weights=False - ) + lora_config = LoraConfig(r=8, target_modules=["q_proj", "k_proj", "v_proj"], init_lora_weights=False) sd_pipe.text_encoder.add_adapter(lora_config) sd_pipe.text_encoder_2.add_adapter(lora_config) @@ -876,7 +865,10 @@ def test_lora_fusion_is_not_affected_by_unloading(self): images_with_unloaded_lora = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images images_with_unloaded_lora_slice = images_with_unloaded_lora[0, -3:, -3:, -1] - self.assertTrue(np.allclose(lora_image_slice, images_with_unloaded_lora_slice), "`unload_lora_weights()` should have not effect on the semantics of the results as the LoRA parameters were fused.") + self.assertTrue( + np.allclose(lora_image_slice, images_with_unloaded_lora_slice), + "`unload_lora_weights()` should have not effect on the semantics of the results as the LoRA parameters were fused.", + ) def test_fuse_lora_with_different_scales(self): from peft import LoraConfig @@ -892,11 +884,7 @@ def test_fuse_lora_with_different_scales(self): # Emulate training. set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True) - lora_config = LoraConfig( - r=8, - target_modules=["q_proj", "k_proj", "v_proj"], - init_lora_weights=False - ) + lora_config = LoraConfig(r=8, target_modules=["q_proj", "k_proj", "v_proj"], init_lora_weights=False) sd_pipe.text_encoder.add_adapter(lora_config) sd_pipe.text_encoder_2.add_adapter(lora_config) @@ -944,11 +932,7 @@ def test_with_different_scales(self): # Emulate training. set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True) - lora_config = LoraConfig( - r=8, - target_modules=["q_proj", "k_proj", "v_proj"], - init_lora_weights=False - ) + lora_config = LoraConfig(r=8, target_modules=["q_proj", "k_proj", "v_proj"], init_lora_weights=False) sd_pipe.text_encoder.add_adapter(lora_config) sd_pipe.text_encoder_2.add_adapter(lora_config) @@ -999,11 +983,7 @@ def test_with_different_scales_fusion_equivalence(self): # Emulate training. set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True, var=0.1) - lora_config = LoraConfig( - r=8, - target_modules=["q_proj", "k_proj", "v_proj"], - init_lora_weights=False - ) + lora_config = LoraConfig(r=8, target_modules=["q_proj", "k_proj", "v_proj"], init_lora_weights=False) sd_pipe.text_encoder.add_adapter(lora_config) sd_pipe.text_encoder_2.add_adapter(lora_config) @@ -1494,6 +1474,7 @@ def test_sdxl_1_0_last_ben(self): def test_sdxl_1_0_fuse_unfuse_all(self): pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16) + text_encoder_1_sd = copy.deepcopy(pipe.text_encoder.state_dict()) text_encoder_2_sd = copy.deepcopy(pipe.text_encoder_2.state_dict()) unet_sd = copy.deepcopy(pipe.unet.state_dict()) From 40a60286b48ac95fbff5de91b17681cb9565fc53 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 18 Sep 2023 15:19:05 +0000 Subject: [PATCH 13/52] fix fuse text encoder --- src/diffusers/loaders.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 047feb8902a9..7d204ecd5344 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1889,16 +1889,19 @@ def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora if fuse_unet: self.unet.fuse_lora(lora_scale) - def fuse_text_encoder_lora(text_encoder): + def fuse_text_encoder_lora(text_encoder, lora_scale=1.0): for module in text_encoder.modules(): if isinstance(module, BaseTunerLayer): + if lora_scale != 1.0: + module.scale_layer(lora_scale) + module.merge() if fuse_text_encoder: if hasattr(self, "text_encoder"): - fuse_text_encoder_lora(self.text_encoder) + fuse_text_encoder_lora(self.text_encoder, lora_scale) if hasattr(self, "text_encoder_2"): - fuse_text_encoder_lora(self.text_encoder_2) + fuse_text_encoder_lora(self.text_encoder_2, lora_scale) def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True): r""" From c4295c9432cea81a43f106498aab39d129cdcaa8 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 19 Sep 2023 13:10:28 +0200 Subject: [PATCH 14/52] Update src/diffusers/loaders.py Co-authored-by: Sayak Paul --- src/diffusers/loaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 7d204ecd5344..032554cbd97e 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1491,7 +1491,7 @@ def load_lora_into_text_encoder( argument to `True` will raise an error. adapter_name (`str`, *optional*, defaults to `"default"`): The name of the adapter to load the LoRA layers into, useful in the case of using multiple adapters - with the same model. Default to the default name used in PEFT library - `"default"`. + with the same model. Defaults to the default name used in PEFT library - `"default"`. """ low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), From 4162ddfdba97861d5a43161c20e81c79541bffa0 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 19 Sep 2023 11:12:43 +0000 Subject: [PATCH 15/52] replace with `recurse_replace_peft_layers` --- src/diffusers/loaders.py | 14 +++++++------- src/diffusers/utils/__init__.py | 2 +- src/diffusers/utils/peft_utils.py | 4 ++-- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 032554cbd97e..3e3197ac76e7 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -41,7 +41,7 @@ is_peft_available, is_transformers_available, logging, - recurse_replace_peft_layers, + recurse_remove_peft_layers, ) from .utils.import_utils import BACKENDS_MAPPING @@ -1849,15 +1849,15 @@ def unload_lora_weights(self): # Safe to call the following regardless of LoRA. if hasattr(self, "text_encoder"): - recurse_replace_peft_layers(self.text_encoder) + recurse_remove_peft_layers(self.text_encoder) if hasattr(self, "text_encoder_2"): - recurse_replace_peft_layers(self.text_encoder_2) + recurse_remove_peft_layers(self.text_encoder_2) def _remove_text_encoder_monkey_patch(self): if hasattr(self, "text_encoder"): - recurse_replace_peft_layers(self.text_encoder) + recurse_remove_peft_layers(self.text_encoder) if hasattr(self, "text_encoder_2"): - recurse_replace_peft_layers(self.text_encoder_2) + recurse_remove_peft_layers(self.text_encoder_2) def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora_scale: float = 1.0): r""" @@ -2659,5 +2659,5 @@ def pack_weights(layers, prefix): ) def _remove_text_encoder_monkey_patch(self): - recurse_replace_peft_layers(self.text_encoder) - recurse_replace_peft_layers(self.text_encoder_2) + recurse_remove_peft_layers(self.text_encoder) + recurse_remove_peft_layers(self.text_encoder_2) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index f8018c0c3a43..6f7d8a413925 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -88,7 +88,7 @@ convert_old_state_dict_to_peft, convert_peft_state_dict_to_diffusers, convert_unet_state_dict_to_peft, - recurse_replace_peft_layers, + recurse_remove_peft_layers, ) from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 8629037e0afc..fcdf7eb8e37c 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -21,7 +21,7 @@ import torch -def recurse_replace_peft_layers(model): +def recurse_remove_peft_layers(model): r""" Recursively replace all instances of `LoraLayer` with corresponding new layers in `model`. """ @@ -30,7 +30,7 @@ def recurse_replace_peft_layers(model): for name, module in model.named_children(): if len(list(module.children())) > 0: ## compound module, go inside it - recurse_replace_peft_layers(module) + recurse_remove_peft_layers(module) if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear): new_module = torch.nn.Linear(module.in_features, module.out_features, bias=module.bias is not None).to( From 1d13f4054857b698b6bdf1cd92ca833e83bd9cef Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 19 Sep 2023 12:04:33 +0000 Subject: [PATCH 16/52] keep old modules for BC --- src/diffusers/loaders.py | 263 +++++++++++++++++++++++++++++++++++---- 1 file changed, 242 insertions(+), 21 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 3e3197ac76e7..2900d2ca81d4 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1115,6 +1115,18 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di elif is_sequential_cpu_offload: self.enable_sequential_cpu_offload() + @property + def use_peft_backend(self): + """ + A property method that returns `True` if the current version of `peft` and `transformers` are compatible with + PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are available. + + For PEFT is has to be greater than 0.6.0 and for transformers it has to be greater than 4.33.1. + """ + correct_peft_version = is_peft_available() and version.parse(importlib.metadata.version("peft")) > version.parse("0.6.0") + correct_transformers_version = version.parse(importlib.metadata.version("transformers")) > version.parse("4.33.1") + return correct_peft_version and correct_transformers_version + @classmethod def lora_state_dict( cls, @@ -1467,6 +1479,7 @@ def load_lora_into_text_encoder( lora_scale=1.0, low_cpu_mem_usage=None, adapter_name="default", + _pipeline=None, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -1555,25 +1568,164 @@ def load_lora_into_text_encoder( k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys } - lora_rank = list(rank.values())[0] - alpha = lora_scale * lora_rank + if cls.use_peft_backend: + lora_rank = list(rank.values())[0] + alpha = lora_scale * lora_rank - target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] - if patch_mlp: - target_modules += ["fc1", "fc2"] + target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] + if patch_mlp: + target_modules += ["fc1", "fc2"] + + lora_config = LoraConfig(r=lora_rank, target_modules=target_modules, lora_alpha=alpha) + + text_encoder.load_adapter(adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config) + + text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) + else: + # raise deprecation warning + warnings.warn( + "You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT" + " make sure to install the latest PEFT and transformers packages in the future.", + FutureWarning, + ) + + cls._modify_text_encoder( + text_encoder, + lora_scale, + network_alphas, + rank=rank, + patch_mlp=patch_mlp, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + + is_pipeline_offloaded = _pipeline is not None and any( + isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook") for c in _pipeline.components.values() + ) + if is_pipeline_offloaded and low_cpu_mem_usage: + low_cpu_mem_usage = True + logger.info( + f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced." + ) - lora_config = LoraConfig(r=lora_rank, target_modules=target_modules, lora_alpha=alpha) + if low_cpu_mem_usage: + device = next(iter(text_encoder_lora_state_dict.values())).device + dtype = next(iter(text_encoder_lora_state_dict.values())).dtype + unexpected_keys = load_model_dict_into_meta( + text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype + ) + else: + load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False) + unexpected_keys = load_state_dict_results.unexpected_keys + + if len(unexpected_keys) != 0: + raise ValueError( + f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}" + ) - text_encoder.load_adapter(adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config) + # float: # property function that returns the lora scale which can be set at run time by the pipeline. # if _lora_scale has not been set, return 1 return self._lora_scale if hasattr(self, "_lora_scale") else 1.0 + @classmethod + def _modify_text_encoder( + cls, + text_encoder, + lora_scale=1, + network_alphas=None, + rank: Union[Dict[str, int], int] = 4, + dtype=None, + patch_mlp=False, + low_cpu_mem_usage=False, + ): + r""" + Monkey-patches the forward passes of attention modules of the text encoder. + """ + + def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters): + linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model + ctx = init_empty_weights if low_cpu_mem_usage else nullcontext + with ctx(): + model = PatchedLoraProjection(linear_layer, lora_scale, network_alpha, rank, dtype=dtype) + + lora_parameters.extend(model.lora_linear_layer.parameters()) + return model + + # First, remove any monkey-patch that might have been applied before + cls._remove_text_encoder_monkey_patch_classmethod(text_encoder) + + lora_parameters = [] + network_alphas = {} if network_alphas is None else network_alphas + is_network_alphas_populated = len(network_alphas) > 0 + + for name, attn_module in text_encoder_attn_modules(text_encoder): + query_alpha = network_alphas.pop(name + ".to_q_lora.down.weight.alpha", None) + key_alpha = network_alphas.pop(name + ".to_k_lora.down.weight.alpha", None) + value_alpha = network_alphas.pop(name + ".to_v_lora.down.weight.alpha", None) + out_alpha = network_alphas.pop(name + ".to_out_lora.down.weight.alpha", None) + + if isinstance(rank, dict): + current_rank = rank.pop(f"{name}.out_proj.lora_linear_layer.up.weight") + else: + current_rank = rank + + attn_module.q_proj = create_patched_linear_lora( + attn_module.q_proj, query_alpha, current_rank, dtype, lora_parameters + ) + attn_module.k_proj = create_patched_linear_lora( + attn_module.k_proj, key_alpha, current_rank, dtype, lora_parameters + ) + attn_module.v_proj = create_patched_linear_lora( + attn_module.v_proj, value_alpha, current_rank, dtype, lora_parameters + ) + attn_module.out_proj = create_patched_linear_lora( + attn_module.out_proj, out_alpha, current_rank, dtype, lora_parameters + ) + + if patch_mlp: + for name, mlp_module in text_encoder_mlp_modules(text_encoder): + fc1_alpha = network_alphas.pop(name + ".fc1.lora_linear_layer.down.weight.alpha", None) + fc2_alpha = network_alphas.pop(name + ".fc2.lora_linear_layer.down.weight.alpha", None) + + current_rank_fc1 = rank.pop(f"{name}.fc1.lora_linear_layer.up.weight") + current_rank_fc2 = rank.pop(f"{name}.fc2.lora_linear_layer.up.weight") + + mlp_module.fc1 = create_patched_linear_lora( + mlp_module.fc1, fc1_alpha, current_rank_fc1, dtype, lora_parameters + ) + mlp_module.fc2 = create_patched_linear_lora( + mlp_module.fc2, fc2_alpha, current_rank_fc2, dtype, lora_parameters + ) + + if is_network_alphas_populated and len(network_alphas) > 0: + raise ValueError( + f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}" + ) + + return lora_parameters + @classmethod def save_lora_weights( self, @@ -1854,10 +2006,34 @@ def unload_lora_weights(self): recurse_remove_peft_layers(self.text_encoder_2) def _remove_text_encoder_monkey_patch(self): + if self.use_peft_backend: + remove_method = recurse_remove_peft_layers + else: + warnings.warn( + "You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT" + " make sure to install the latest PEFT and transformers packages in the future.", + FutureWarning, + ) + remove_method = self._remove_text_encoder_monkey_patch_classmethod + if hasattr(self, "text_encoder"): - recurse_remove_peft_layers(self.text_encoder) + remove_method(self.text_encoder) if hasattr(self, "text_encoder_2"): - recurse_remove_peft_layers(self.text_encoder_2) + remove_method(self.text_encoder_2) + + @classmethod + def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder): + for _, attn_module in text_encoder_attn_modules(text_encoder): + if isinstance(attn_module.q_proj, PatchedLoraProjection): + attn_module.q_proj.lora_linear_layer = None + attn_module.k_proj.lora_linear_layer = None + attn_module.v_proj.lora_linear_layer = None + attn_module.out_proj.lora_linear_layer = None + + for _, mlp_module in text_encoder_mlp_modules(text_encoder): + if isinstance(mlp_module.fc1, PatchedLoraProjection): + mlp_module.fc1.lora_linear_layer = None + mlp_module.fc2.lora_linear_layer = None def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora_scale: float = 1.0): r""" @@ -1889,13 +2065,34 @@ def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora if fuse_unet: self.unet.fuse_lora(lora_scale) - def fuse_text_encoder_lora(text_encoder, lora_scale=1.0): - for module in text_encoder.modules(): - if isinstance(module, BaseTunerLayer): - if lora_scale != 1.0: - module.scale_layer(lora_scale) + if self.use_peft_backend: + def fuse_text_encoder_lora(text_encoder, lora_scale=1.0): + for module in text_encoder.modules(): + if isinstance(module, BaseTunerLayer): + if lora_scale != 1.0: + module.scale_layer(lora_scale) + + module.merge() + else: + warnings.warn( + "You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT" + " make sure to install the latest PEFT and transformers packages in the future.", + FutureWarning, + ) + + def fuse_text_encoder_lora(text_encoder, lora_scale=1.0): + for _, attn_module in text_encoder_attn_modules(text_encoder): + if isinstance(attn_module.q_proj, PatchedLoraProjection): + attn_module.q_proj._fuse_lora(lora_scale) + attn_module.k_proj._fuse_lora(lora_scale) + attn_module.v_proj._fuse_lora(lora_scale) + attn_module.out_proj._fuse_lora(lora_scale) + + for _, mlp_module in text_encoder_mlp_modules(text_encoder): + if isinstance(mlp_module.fc1, PatchedLoraProjection): + mlp_module.fc1._fuse_lora(lora_scale) + mlp_module.fc2._fuse_lora(lora_scale) - module.merge() if fuse_text_encoder: if hasattr(self, "text_encoder"): @@ -1925,10 +2122,30 @@ def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True if unfuse_unet: self.unet.unfuse_lora() - def unfuse_text_encoder_lora(text_encoder): - for module in text_encoder.modules(): - if isinstance(module, BaseTunerLayer): - module.unmerge() + if self.use_peft_backend: + def unfuse_text_encoder_lora(text_encoder): + for module in text_encoder.modules(): + if isinstance(module, BaseTunerLayer): + module.unmerge() + else: + warnings.warn( + "You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT" + " make sure to install the latest PEFT and transformers packages in the future.", + FutureWarning, + ) + def unfuse_text_encoder_lora(text_encoder): + for _, attn_module in text_encoder_attn_modules(text_encoder): + if isinstance(attn_module.q_proj, PatchedLoraProjection): + attn_module.q_proj._unfuse_lora() + attn_module.k_proj._unfuse_lora() + attn_module.v_proj._unfuse_lora() + attn_module.out_proj._unfuse_lora() + + for _, mlp_module in text_encoder_mlp_modules(text_encoder): + if isinstance(mlp_module.fc1, PatchedLoraProjection): + mlp_module.fc1._unfuse_lora() + mlp_module.fc2._unfuse_lora() + if unfuse_text_encoder: if hasattr(self, "text_encoder"): @@ -2659,5 +2876,9 @@ def pack_weights(layers, prefix): ) def _remove_text_encoder_monkey_patch(self): - recurse_remove_peft_layers(self.text_encoder) - recurse_remove_peft_layers(self.text_encoder_2) + if self.use_peft_backend: + recurse_remove_peft_layers(self.text_encoder) + recurse_remove_peft_layers(self.text_encoder_2) + else: + self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder) + self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2) From 78a860d276fd854f8303c087373407a0704d8c0c Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 19 Sep 2023 12:10:23 +0000 Subject: [PATCH 17/52] adjustments on `adjust_lora_scale_text_encoder` --- src/diffusers/loaders.py | 2 +- src/diffusers/models/lora.py | 25 +++++++++++++++---- .../alt_diffusion/pipeline_alt_diffusion.py | 2 +- .../pipeline_alt_diffusion_img2img.py | 2 +- .../controlnet/pipeline_controlnet.py | 2 +- .../controlnet/pipeline_controlnet_img2img.py | 2 +- .../controlnet/pipeline_controlnet_inpaint.py | 2 +- .../pipeline_controlnet_inpaint_sd_xl.py | 2 +- .../controlnet/pipeline_controlnet_sd_xl.py | 2 +- .../pipeline_controlnet_sd_xl_img2img.py | 2 +- .../pipeline_cycle_diffusion.py | 2 +- .../pipeline_stable_diffusion.py | 2 +- ...line_stable_diffusion_attend_and_excite.py | 2 +- .../pipeline_stable_diffusion_depth2img.py | 2 +- .../pipeline_stable_diffusion_diffedit.py | 2 +- .../pipeline_stable_diffusion_gligen.py | 2 +- ...line_stable_diffusion_gligen_text_image.py | 2 +- .../pipeline_stable_diffusion_img2img.py | 2 +- .../pipeline_stable_diffusion_inpaint.py | 2 +- ...ipeline_stable_diffusion_inpaint_legacy.py | 2 +- .../pipeline_stable_diffusion_k_diffusion.py | 2 +- .../pipeline_stable_diffusion_ldm3d.py | 2 +- ...pipeline_stable_diffusion_model_editing.py | 2 +- .../pipeline_stable_diffusion_panorama.py | 2 +- .../pipeline_stable_diffusion_paradigms.py | 2 +- .../pipeline_stable_diffusion_pix2pix_zero.py | 2 +- .../pipeline_stable_diffusion_sag.py | 2 +- .../pipeline_stable_diffusion_upscale.py | 2 +- .../pipeline_stable_unclip.py | 2 +- .../pipeline_stable_unclip_img2img.py | 2 +- .../pipeline_stable_diffusion_xl.py | 2 +- .../pipeline_stable_diffusion_xl_img2img.py | 2 +- .../pipeline_stable_diffusion_xl_inpaint.py | 2 +- ...ne_stable_diffusion_xl_instruct_pix2pix.py | 2 +- .../pipeline_stable_diffusion_adapter.py | 2 +- .../pipeline_stable_diffusion_xl_adapter.py | 2 +- .../pipeline_text_to_video_synth.py | 2 +- .../pipeline_text_to_video_synth_img2img.py | 2 +- 38 files changed, 57 insertions(+), 42 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 2900d2ca81d4..6608c9f51a26 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1116,7 +1116,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di self.enable_sequential_cpu_offload() @property - def use_peft_backend(self): + def use_peft_backend(self) -> bool: """ A property method that returns `True` if the current version of `peft` and `transformers` are compatible with PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are available. diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index 2677ae8dbb6d..f6c81559cbf7 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -18,18 +18,33 @@ import torch.nn.functional as F from torch import nn +from ..loaders import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules from ..utils import logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0): - from peft.tuners.lora import LoraLayer +def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0, use_peft_backend: bool = False): + if use_peft_backend: + from peft.tuners.lora import LoraLayer + + for module in text_encoder.modules(): + if isinstance(module, LoraLayer): + module.scaling[module.active_adapter] = lora_scale + else: + for _, attn_module in text_encoder_attn_modules(text_encoder): + if isinstance(attn_module.q_proj, PatchedLoraProjection): + attn_module.q_proj.lora_scale = lora_scale + attn_module.k_proj.lora_scale = lora_scale + attn_module.v_proj.lora_scale = lora_scale + attn_module.out_proj.lora_scale = lora_scale + + for _, mlp_module in text_encoder_mlp_modules(text_encoder): + if isinstance(mlp_module.fc1, PatchedLoraProjection): + mlp_module.fc1.lora_scale = lora_scale + mlp_module.fc2.lora_scale = lora_scale - for module in text_encoder.modules(): - if isinstance(module, LoraLayer): - module.scaling[module.active_adapter] = lora_scale class LoRALinearLayer(nn.Module): diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 3e57c3a0fdab..1cddfe1a027f 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -297,7 +297,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 00a5cb452a90..3e1993137654 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -295,7 +295,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index c498f606d3d7..7f719dc917e2 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -285,7 +285,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index 1126fa8b139e..1658b557e25e 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -309,7 +309,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 7fdaf80d7bbe..1602afa5f44c 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -436,7 +436,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index a89a13f70830..9cc2afec3031 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -311,7 +311,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) if prompt is not None and isinstance(prompt, str): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index cbb78e509b84..5d00baa6ca8b 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -284,7 +284,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) if prompt is not None and isinstance(prompt, str): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 6fe3d0c641e5..934ce1ff04b7 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -322,7 +322,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) if prompt is not None and isinstance(prompt, str): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py index 7ed335ea8f7d..87b6ae37f214 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py @@ -302,7 +302,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index cb87bffd8f3e..ab263db6ddac 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -295,7 +295,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py index 8bfcc7decb34..701fe131dda1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py @@ -326,7 +326,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index e8fcd39202dc..ebb601f4420c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -207,7 +207,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py index 40d53d384bfd..975e6789a3e2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -475,7 +475,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py index d3d15cf1e543..06d295deb400 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py @@ -272,7 +272,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py index 65cb9d284552..fbf6e85c9670 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py @@ -305,7 +305,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 6d01f2285af2..03abd11d9d80 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -296,7 +296,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 72dfa4289959..545321473069 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -369,7 +369,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index c25bf1c7be33..e929cad5d3aa 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -291,7 +291,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py index dc00f9fd4378..f1f5912094b4 100755 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -205,7 +205,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py index 217b2bb43032..4f929bd808e4 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py @@ -266,7 +266,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py index f8d6296ea943..8a11b237dc62 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py @@ -238,7 +238,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py index aec93b56b6f8..963749a928ff 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py @@ -215,7 +215,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py index 5152209f21aa..271e4580fcc4 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py @@ -250,7 +250,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 47a16d6663cf..4bbca9fdfb16 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -440,7 +440,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py index 84094f69b78c..942b187df062 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py @@ -238,7 +238,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 31982891cd01..6bf8713f4562 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -234,7 +234,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py index 917b9fef0ead..3dfe879ca4e2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -340,7 +340,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py index be837564fddc..ae528e9d02e1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py @@ -290,7 +290,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 25e95e0b3454..797fcfdf3fbf 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -260,7 +260,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) if prompt is not None and isinstance(prompt, str): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 86f337fa2d51..b3f2eb9bc1b6 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -267,7 +267,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) if prompt is not None and isinstance(prompt, str): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 9363d7b2a3d3..af1e4c3a51e2 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -416,7 +416,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) if prompt is not None and isinstance(prompt, str): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index fdf02c536f78..187bd938edb8 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -272,7 +272,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) if prompt is not None and isinstance(prompt, str): diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py index 4120d5f9dfe6..9b6136b44f84 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py @@ -291,7 +291,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index ca876440166e..55d7a4f45039 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -284,7 +284,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) if prompt is not None and isinstance(prompt, str): diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index a8395a5e86c8..0f445e78f14e 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -222,7 +222,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py index 1e1b30e18fcb..25daf461d276 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py @@ -284,7 +284,7 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend) if prompt is not None and isinstance(prompt, str): batch_size = 1 From 6f1adcd65d58f323311d7d86d0d136aef52fe44d Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 19 Sep 2023 12:52:56 +0000 Subject: [PATCH 18/52] nit --- src/diffusers/loaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 222f0dc246c0..70f6fd27e10f 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1142,7 +1142,7 @@ def use_peft_backend(self) -> bool: """ correct_peft_version = is_peft_available() and version.parse( importlib.metadata.version("peft") - ) > version.parse("0.6.0") + ) > version.parse("0.5.0") correct_transformers_version = version.parse(importlib.metadata.version("transformers")) > version.parse( "4.33.1" ) From f8909061ee92d86c0846c21763b8bd766740f7b7 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 19 Sep 2023 13:00:25 +0000 Subject: [PATCH 19/52] move tests --- ...ers.py => test_lora_layers_old_backend.py} | 144 ++++++++--------- tests/lora/test_lora_layers_peft.py | 147 ++++++++++++++++++ 2 files changed, 214 insertions(+), 77 deletions(-) rename tests/lora/{test_lora_layers.py => test_lora_layers_old_backend.py} (95%) create mode 100644 tests/lora/test_lora_layers_peft.py diff --git a/tests/lora/test_lora_layers.py b/tests/lora/test_lora_layers_old_backend.py similarity index 95% rename from tests/lora/test_lora_layers.py rename to tests/lora/test_lora_layers_old_backend.py index 0b8db9c1a9e0..e54caeb9f0c2 100644 --- a/tests/lora/test_lora_layers.py +++ b/tests/lora/test_lora_layers_old_backend.py @@ -420,9 +420,7 @@ def test_lora_save_load_no_safe_serialization(self): # Outputs shouldn't match. self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) - def test_text_encoder_lora(self): - from peft import LoraConfig - + def test_text_encoder_lora_monkey_patch(self): pipeline_components, _ = self.get_dummy_components() pipe = StableDiffusionPipeline(**pipeline_components) @@ -430,41 +428,38 @@ def test_text_encoder_lora(self): # inference without lora outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0] - self.assertTrue(outputs_without_lora.shape == (1, 77, 32)) + assert outputs_without_lora.shape == (1, 77, 32) - lora_config = LoraConfig(r=4, target_modules=["k_proj", "q_proj", "v_proj"]) + # monkey patch + params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale) - # 0-init the lora weights - pipe.text_encoder.add_adapter(lora_config, adapter_name="default_O_init") + set_lora_weights(params, randn_weight=False) # inference with lora outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] - self.assertTrue(outputs_with_lora.shape == (1, 77, 32)) + assert outputs_with_lora.shape == (1, 77, 32) - self.assertTrue( - torch.allclose(outputs_without_lora, outputs_with_lora), - "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs", - ) + assert torch.allclose( + outputs_without_lora, outputs_with_lora + ), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs" + + # create lora_attn_procs with randn up.weights + create_text_encoder_lora_attn_procs(pipe.text_encoder) - lora_config = LoraConfig(r=4, target_modules=["k_proj", "q_proj", "v_proj"], init_lora_weights=False) + # monkey patch + params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale) - # LoRA with no init - pipe.text_encoder.add_adapter(lora_config, adapter_name="default_no_init") - # Make it use that adapter - pipe.text_encoder.set_adapter("default_no_init") + set_lora_weights(params, randn_weight=True) # inference with lora outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] - self.assertTrue(outputs_with_lora.shape == (1, 77, 32)) + assert outputs_with_lora.shape == (1, 77, 32) - self.assertFalse( - torch.allclose(outputs_without_lora, outputs_with_lora), - "lora_up_weight are not zero, so the lora outputs should be different to without lora outputs", - ) + assert not torch.allclose( + outputs_without_lora, outputs_with_lora + ), "lora_up_weight are not zero, so the lora outputs should be different to without lora outputs" def test_text_encoder_lora_remove_monkey_patch(self): - from peft import LoraConfig - pipeline_components, _ = self.get_dummy_components() pipe = StableDiffusionPipeline(**pipeline_components) @@ -474,15 +469,10 @@ def test_text_encoder_lora_remove_monkey_patch(self): outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0] assert outputs_without_lora.shape == (1, 77, 32) - lora_config = LoraConfig( - r=4, - target_modules=["k_proj", "q_proj", "v_proj"], - # To randomly init LoRA weights - init_lora_weights=False, - ) + # monkey patch + params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale) - # Inject adapters - pipe.text_encoder.add_adapter(lora_config) + set_lora_weights(params, randn_weight=True) # inference with lora outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] @@ -1062,8 +1052,6 @@ def test_lora_fusion(self): self.assertFalse(np.allclose(orig_image_slice, lora_image_slice, atol=1e-3)) def test_unfuse_lora(self): - from peft import LoraConfig - pipeline_components, lora_components = self.get_dummy_components() sd_pipe = StableDiffusionXLPipeline(**pipeline_components) sd_pipe = sd_pipe.to(torch_device) @@ -1076,16 +1064,15 @@ def test_unfuse_lora(self): # Emulate training. set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True) - - lora_config = LoraConfig(r=8, target_modules=["q_proj", "k_proj", "v_proj"], init_lora_weights=False) - - sd_pipe.text_encoder.add_adapter(lora_config) - sd_pipe.text_encoder_2.add_adapter(lora_config) + set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True) + set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True) with tempfile.TemporaryDirectory() as tmpdirname: StableDiffusionXLPipeline.save_lora_weights( save_directory=tmpdirname, unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], safe_serialization=True, ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) @@ -1107,12 +1094,10 @@ def test_unfuse_lora(self): orig_image_slice_two, lora_image_slice ), "Fusion of LoRAs should lead to a different image slice." assert np.allclose( - orig_image_slice, orig_image_slice_two, atol=4e-2 + orig_image_slice, orig_image_slice_two, atol=1e-3 ), "Reversing LoRA fusion should lead to results similar to what was obtained with the pipeline without any LoRA parameters." def test_lora_fusion_is_not_affected_by_unloading(self): - from peft import LoraConfig - pipeline_components, lora_components = self.get_dummy_components() sd_pipe = StableDiffusionXLPipeline(**pipeline_components) sd_pipe = sd_pipe.to(torch_device) @@ -1124,11 +1109,19 @@ def test_lora_fusion_is_not_affected_by_unloading(self): # Emulate training. set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True) + set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True) + set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True) - lora_config = LoraConfig(r=8, target_modules=["q_proj", "k_proj", "v_proj"], init_lora_weights=False) - - sd_pipe.text_encoder.add_adapter(lora_config) - sd_pipe.text_encoder_2.add_adapter(lora_config) + with tempfile.TemporaryDirectory() as tmpdirname: + StableDiffusionXLPipeline.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + safe_serialization=True, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) sd_pipe.fuse_lora() lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images @@ -1139,14 +1132,11 @@ def test_lora_fusion_is_not_affected_by_unloading(self): images_with_unloaded_lora = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images images_with_unloaded_lora_slice = images_with_unloaded_lora[0, -3:, -3:, -1] - self.assertTrue( - np.allclose(lora_image_slice, images_with_unloaded_lora_slice), - "`unload_lora_weights()` should have not effect on the semantics of the results as the LoRA parameters were fused.", - ) + assert np.allclose( + lora_image_slice, images_with_unloaded_lora_slice + ), "`unload_lora_weights()` should have not effect on the semantics of the results as the LoRA parameters were fused." def test_fuse_lora_with_different_scales(self): - from peft import LoraConfig - pipeline_components, lora_components = self.get_dummy_components() sd_pipe = StableDiffusionXLPipeline(**pipeline_components) sd_pipe = sd_pipe.to(torch_device) @@ -1158,17 +1148,15 @@ def test_fuse_lora_with_different_scales(self): # Emulate training. set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True) - lora_config = LoraConfig(r=8, target_modules=["q_proj", "k_proj", "v_proj"], init_lora_weights=False) - - sd_pipe.text_encoder.add_adapter(lora_config) - sd_pipe.text_encoder_2.add_adapter(lora_config) + set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True) + set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True) with tempfile.TemporaryDirectory() as tmpdirname: StableDiffusionXLPipeline.save_lora_weights( save_directory=tmpdirname, unet_lora_layers=lora_components["unet_lora_layers"], - # text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], - # text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], safe_serialization=True, ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) @@ -1181,6 +1169,17 @@ def test_fuse_lora_with_different_scales(self): # Reverse LoRA fusion. sd_pipe.unfuse_lora() + with tempfile.TemporaryDirectory() as tmpdirname: + StableDiffusionXLPipeline.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + safe_serialization=True, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + sd_pipe.fuse_lora(lora_scale=0.5) lora_images_scale_0_5 = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images lora_image_slice_scale_0_5 = lora_images_scale_0_5[0, -3:, -3:, -1] @@ -1189,12 +1188,7 @@ def test_fuse_lora_with_different_scales(self): lora_image_slice_scale_one, lora_image_slice_scale_0_5, atol=1e-03 ), "Different LoRA scales should influence the outputs accordingly." - # TODO: @younesbelkada add save / load tests with text encoder - # TODO: @younesbelkada add public method to attach adapters in text encoder - def test_with_different_scales(self): - from peft import LoraConfig - pipeline_components, lora_components = self.get_dummy_components() sd_pipe = StableDiffusionXLPipeline(**pipeline_components) sd_pipe = sd_pipe.to(torch_device) @@ -1206,15 +1200,15 @@ def test_with_different_scales(self): # Emulate training. set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True) - lora_config = LoraConfig(r=8, target_modules=["q_proj", "k_proj", "v_proj"], init_lora_weights=False) - - sd_pipe.text_encoder.add_adapter(lora_config) - sd_pipe.text_encoder_2.add_adapter(lora_config) + set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True) + set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True) with tempfile.TemporaryDirectory() as tmpdirname: StableDiffusionXLPipeline.save_lora_weights( save_directory=tmpdirname, unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], safe_serialization=True, ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) @@ -1234,16 +1228,14 @@ def test_with_different_scales(self): lora_image_slice_scale_0_0 = lora_images_scale_0_0[0, -3:, -3:, -1] assert not np.allclose( - lora_image_slice_scale_one, lora_image_slice_scale_0_5, atol=1e-3 + lora_image_slice_scale_one, lora_image_slice_scale_0_5, atol=1e-03 ), "Different LoRA scales should influence the outputs accordingly." assert np.allclose( - original_imagee_slice, lora_image_slice_scale_0_0, atol=1e-3 + original_imagee_slice, lora_image_slice_scale_0_0, atol=1e-03 ), "LoRA scale of 0.0 shouldn't be different from the results without LoRA." def test_with_different_scales_fusion_equivalence(self): - from peft import LoraConfig - pipeline_components, lora_components = self.get_dummy_components() sd_pipe = StableDiffusionXLPipeline(**pipeline_components) sd_pipe = sd_pipe.to(torch_device) @@ -1256,16 +1248,15 @@ def test_with_different_scales_fusion_equivalence(self): # Emulate training. set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True, var=0.1) - - lora_config = LoraConfig(r=8, target_modules=["q_proj", "k_proj", "v_proj"], init_lora_weights=False) - - sd_pipe.text_encoder.add_adapter(lora_config) - sd_pipe.text_encoder_2.add_adapter(lora_config) + set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True, var=0.1) + set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True, var=0.1) with tempfile.TemporaryDirectory() as tmpdirname: StableDiffusionXLPipeline.save_lora_weights( save_directory=tmpdirname, unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], safe_serialization=True, ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) @@ -2237,7 +2228,6 @@ def test_sdxl_1_0_last_ben(self): def test_sdxl_1_0_fuse_unfuse_all(self): pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16) - text_encoder_1_sd = copy.deepcopy(pipe.text_encoder.state_dict()) text_encoder_2_sd = copy.deepcopy(pipe.text_encoder_2.state_dict()) unet_sd = copy.deepcopy(pipe.unet.state_dict()) diff --git a/tests/lora/test_lora_layers_peft.py b/tests/lora/test_lora_layers_peft.py new file mode 100644 index 000000000000..95b1db72d7e9 --- /dev/null +++ b/tests/lora/test_lora_layers_peft.py @@ -0,0 +1,147 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + UNet2DConditionModel, +) +from diffusers.loaders import AttnProcsLayers +from diffusers.models.attention_processor import ( + LoRAAttnProcessor, + LoRAAttnProcessor2_0, +) +from diffusers.utils.testing_utils import floats_tensor + + +def create_unet_lora_layers(unet: nn.Module): + lora_attn_procs = {} + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + lora_attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) + lora_attn_procs[name] = lora_attn_processor_class( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim + ) + unet_lora_layers = AttnProcsLayers(lora_attn_procs) + return lora_attn_procs, unet_lora_layers + + +class LoraLoaderMixinTests(unittest.TestCase): + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + steps_offset=1, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet) + + pipeline_components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + } + lora_components = { + "unet_lora_layers": unet_lora_layers, + "unet_lora_attn_procs": unet_lora_attn_procs, + } + return pipeline_components, lora_components + + def get_dummy_inputs(self, with_generator=True): + batch_size = 1 + sequence_length = 10 + num_channels = 4 + sizes = (32, 32) + + generator = torch.manual_seed(0) + noise = floats_tensor((batch_size, num_channels) + sizes) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + + pipeline_inputs = { + "prompt": "A painting of a squirrel eating a burger", + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "np", + } + if with_generator: + pipeline_inputs.update({"generator": generator}) + + return noise, input_ids, pipeline_inputs + + # copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb + def get_dummy_tokens(self): + max_seq_length = 77 + + inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0)) + + prepared_inputs = {} + prepared_inputs["input_ids"] = inputs + return prepared_inputs From f8e87f6220a3ae2c63119e4f83f23f8cc321cbce Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 19 Sep 2023 13:37:59 +0000 Subject: [PATCH 20/52] add conversion utils --- src/diffusers/loaders.py | 58 ++--------- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/state_dict_utils.py | 131 ++++++++++++++++++++++++ 3 files changed, 139 insertions(+), 51 deletions(-) create mode 100644 src/diffusers/utils/state_dict_utils.py diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 70f6fd27e10f..adba0e909dca 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -33,8 +33,8 @@ DIFFUSERS_CACHE, HF_HUB_OFFLINE, _get_model_file, - convert_diffusers_state_dict_to_peft, - convert_old_state_dict_to_peft, + convert_state_dict_to_diffusers, + convert_state_dict_to_peft, deprecate, is_accelerate_available, is_omegaconf_available, @@ -1544,19 +1544,11 @@ def load_lora_into_text_encoder( if len(text_encoder_lora_state_dict) > 0: logger.info(f"Loading {prefix}.") rank = {} + text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) + if cls.use_peft_backend: - # Old diffusers to PEFT - if any("to_out_lora" in k for k in text_encoder_lora_state_dict.keys()): - attention_modules = text_encoder_attn_modules(text_encoder) - text_encoder_lora_state_dict = convert_old_state_dict_to_peft( - attention_modules, text_encoder_lora_state_dict - ) - # New diffusers format to PEFT - elif any("lora_linear_layer" in k for k in text_encoder_lora_state_dict.keys()): - attention_modules = text_encoder_attn_modules(text_encoder) - text_encoder_lora_state_dict = convert_diffusers_state_dict_to_peft( - attention_modules, text_encoder_lora_state_dict - ) + # convert state dict + text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) for name, _ in text_encoder_attn_modules(text_encoder): rank_key = f"{name}.out_proj.lora_B.weight" @@ -1570,43 +1562,6 @@ def load_lora_into_text_encoder( rank.update({rank_key_fc1: text_encoder_lora_state_dict[rank_key_fc1].shape[1]}) rank.update({rank_key_fc2: text_encoder_lora_state_dict[rank_key_fc2].shape[1]}) else: - if any("to_out_lora" in k for k in text_encoder_lora_state_dict.keys()): - # Convert from the old naming convention to the new naming convention. - # - # Previously, the old LoRA layers were stored on the state dict at the - # same level as the attention block i.e. - # `text_model.encoder.layers.11.self_attn.to_out_lora.up.weight`. - # - # This is no actual module at that point, they were monkey patched on to the - # existing module. We want to be able to load them via their actual state dict. - # They're in `PatchedLoraProjection.lora_linear_layer` now. - for name, _ in text_encoder_attn_modules(text_encoder): - text_encoder_lora_state_dict[ - f"{name}.q_proj.lora_linear_layer.up.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.up.weight") - text_encoder_lora_state_dict[ - f"{name}.k_proj.lora_linear_layer.up.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.up.weight") - text_encoder_lora_state_dict[ - f"{name}.v_proj.lora_linear_layer.up.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.up.weight") - text_encoder_lora_state_dict[ - f"{name}.out_proj.lora_linear_layer.up.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.up.weight") - - text_encoder_lora_state_dict[ - f"{name}.q_proj.lora_linear_layer.down.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.down.weight") - text_encoder_lora_state_dict[ - f"{name}.k_proj.lora_linear_layer.down.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.down.weight") - text_encoder_lora_state_dict[ - f"{name}.v_proj.lora_linear_layer.down.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.down.weight") - text_encoder_lora_state_dict[ - f"{name}.out_proj.lora_linear_layer.down.weight" - ] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.down.weight") - for name, _ in text_encoder_attn_modules(text_encoder): rank_key = f"{name}.out_proj.lora_linear_layer.up.weight" rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]}) @@ -1640,6 +1595,7 @@ def load_lora_into_text_encoder( lora_config = LoraConfig(r=lora_rank, target_modules=target_modules, lora_alpha=alpha) text_encoder.load_adapter(adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config) + is_model_cpu_offload = False is_sequential_cpu_offload = False else: diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 6f7d8a413925..eab96e54e1f7 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -91,6 +91,7 @@ recurse_remove_peft_layers, ) from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil +from .state_dict_utils import convert_state_dict_to_diffusers, convert_state_dict_to_peft logger = get_logger(__name__) diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py new file mode 100644 index 000000000000..a9b6de910f52 --- /dev/null +++ b/src/diffusers/utils/state_dict_utils.py @@ -0,0 +1,131 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +State dict utilities: utility methods for converting state dicts easily +""" +import enum + + +class StateDictType(enum.Enum): + """ + The mode to use when converting state dicts. + """ + + DIFFUSERS_OLD = "diffusers_old" + # KOHYA_SS = "kohya_ss" # TODO: implement this + PEFT = "peft" + DIFFUSERS = "diffusers" + + +DIFFUSERS_TO_PEFT = { + ".q_proj.lora_linear_layer.up": ".q_proj.lora_B", + ".q_proj.lora_linear_layer.down": ".q_proj.lora_A", + ".k_proj.lora_linear_layer.up": ".k_proj.lora_B", + ".k_proj.lora_linear_layer.down": ".k_proj.lora_A", + ".v_proj.lora_linear_layer.up": ".v_proj.lora_B", + ".v_proj.lora_linear_layer.down": ".v_proj.lora_A", + ".out_proj.lora_linear_layer.up": ".out_proj.lora_B", + ".out_proj.lora_linear_layer.down": ".out_proj.lora_A", +} + +DIFFUSERS_OLD_TO_PEFT = { + ".to_q_lora.up": ".q_proj.lora_B", + ".to_q_lora.down": ".q_proj.lora_A", + ".to_k_lora.up": ".k_proj.lora_B", + ".to_k_lora.down": ".k_proj.lora_A", + ".to_v_lora.up": ".v_proj.lora_B", + ".to_v_lora.down": ".v_proj.lora_A", + ".to_out_lora.up": ".out_proj.lora_B", + ".to_out_lora.down": ".out_proj.lora_A", +} + +PEFT_TO_DIFFUSERS = { + ".q_proj.lora_B": ".q_proj.lora_linear_layer.up", + ".q_proj.lora_A": ".q_proj.lora_linear_layer.down", + ".k_proj.lora_B": ".k_proj.lora_linear_layer.up", + ".k_proj.lora_A": ".k_proj.lora_linear_layer.down", + ".v_proj.lora_B": ".v_proj.lora_linear_layer.up", + ".v_proj.lora_A": ".v_proj.lora_linear_layer.down", + ".out_proj.lora_B": ".out_proj.lora_linear_layer.up", + ".out_proj.lora_A": ".out_proj.lora_linear_layer.down", +} + +DIFFUSERS_OLD_TO_DIFFUSERS = { + ".to_q_lora.up": ".q_proj.lora_linear_layer.up", + ".to_q_lora.down": ".q_proj.lora_linear_layer.down", + ".to_k_lora.up": ".k_proj.lora_linear_layer.up", + ".to_k_lora.down": ".k_proj.lora_linear_layer.down", + ".to_v_lora.up": ".v_proj.lora_linear_layer.up", + ".to_v_lora.down": ".v_proj.lora_linear_layer.down", + ".to_out_lora.up": ".out_proj.lora_linear_layer.up", + ".to_out_lora.down": ".out_proj.lora_linear_layer.down", +} + +PEFT_STATE_DICT_MAPPINGS = { + StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_PEFT, + StateDictType.DIFFUSERS: DIFFUSERS_TO_PEFT, +} + +DIFFUSERS_STATE_DICT_MAPPINGS = { + StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_DIFFUSERS, + StateDictType.PEFT: PEFT_TO_DIFFUSERS, +} + + +def convert_state_dict(state_dict, mapping): + converted_state_dict = {} + for k, v in state_dict.items(): + if any(pattern in k for pattern in mapping.keys()): + for old, new in mapping.items(): + k = k.replace(old, new) + + converted_state_dict[k] = v + return converted_state_dict + + +def convert_state_dict_to_peft(state_dict, original_type=None, **kwargs): + r""" + The method automatically infers in which direction the conversion should be done. + """ + if original_type is None: + # Old diffusers to PEFT + if any("to_out_lora" in k for k in state_dict.keys()): + original_type = StateDictType.DIFFUSERS_OLD + elif any("lora_linear_layer" in k for k in state_dict.keys()): + original_type = StateDictType.DIFFUSERS + else: + raise ValueError("Could not automatically infer state dict type") + + mapping = PEFT_STATE_DICT_MAPPINGS[original_type] + return convert_state_dict(state_dict, mapping) + + +def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs): + r""" + The method automatically infers in which direction the conversion should be done. + """ + peft_adapter_name = kwargs.pop("adapter_name", "") + peft_adapter_name = "." + peft_adapter_name + + if original_type is None: + # Old diffusers to PEFT + if any("to_out_lora" in k for k in state_dict.keys()): + original_type = StateDictType.DIFFUSERS_OLD + elif any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()): + original_type = StateDictType.PEFT + else: + raise ValueError("Could not automatically infer state dict type") + + mapping = DIFFUSERS_STATE_DICT_MAPPINGS[original_type] + return convert_state_dict(state_dict, mapping) From dc83fa0ec7721841b61bc9431273efaa4b91264e Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 19 Sep 2023 13:43:10 +0000 Subject: [PATCH 21/52] remove unneeded methods --- src/diffusers/utils/__init__.py | 4 -- src/diffusers/utils/peft_utils.py | 116 ------------------------------ 2 files changed, 120 deletions(-) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index eab96e54e1f7..d812c65b9a4c 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -84,10 +84,6 @@ from .logging import get_logger from .outputs import BaseOutput from .peft_utils import ( - convert_diffusers_state_dict_to_peft, - convert_old_state_dict_to_peft, - convert_peft_state_dict_to_diffusers, - convert_unet_state_dict_to_peft, recurse_remove_peft_layers, ) from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index fcdf7eb8e37c..b6d253b56412 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -48,119 +48,3 @@ def recurse_remove_peft_layers(model): # TODO: do it for Conv2d return model - - -def convert_old_state_dict_to_peft(attention_modules, state_dict): - # Convert from the old naming convention to the new naming convention. - # - # Previously, the old LoRA layers were stored on the state dict at the - # same level as the attention block i.e. - # `text_model.encoder.layers.11.self_attn.to_out_lora.lora_A.weight`. - # - # This is no actual module at that point, they were monkey patched on to the - # existing module. We want to be able to load them via their actual state dict. - # They're in `PatchedLoraProjection.lora_linear_layer` now. - converted_state_dict = {} - - for name, _ in attention_modules: - converted_state_dict[f"{name}.q_proj.lora_B.weight"] = state_dict.pop(f"{name}.to_q_lora.up.weight") - converted_state_dict[f"{name}.k_proj.lora_B.weight"] = state_dict.pop(f"{name}.to_k_lora.up.weight") - converted_state_dict[f"{name}.v_proj.lora_B.weight"] = state_dict.pop(f"{name}.to_v_lora.up.weight") - converted_state_dict[f"{name}.out_proj.lora_B.weight"] = state_dict.pop(f"{name}.to_out_lora.up.weight") - - converted_state_dict[f"{name}.q_proj.lora_A.weight"] = state_dict.pop(f"{name}.to_q_lora.down.weight") - converted_state_dict[f"{name}.k_proj.lora_A.weight"] = state_dict.pop(f"{name}.to_k_lora.down.weight") - converted_state_dict[f"{name}.v_proj.lora_A.weight"] = state_dict.pop(f"{name}.to_v_lora.down.weight") - converted_state_dict[f"{name}.out_proj.lora_A.weight"] = state_dict.pop(f"{name}.to_out_lora.down.weight") - - return converted_state_dict - - -def convert_peft_state_dict_to_diffusers(attention_modules, state_dict, adapter_name): - # Convert from the new naming convention to the diffusers naming convention. - converted_state_dict = {} - - for name, _ in attention_modules: - converted_state_dict[f"{name}.q_proj.lora_linear_layer.up.weight"] = state_dict.pop( - f"{name}.q_proj.lora_B.{adapter_name}.weight" - ) - converted_state_dict[f"{name}.k_proj.lora_linear_layer.up.weight"] = state_dict.pop( - f"{name}.k_proj.lora_B.{adapter_name}.weight" - ) - converted_state_dict[f"{name}.v_proj.lora_linear_layer.up.weight"] = state_dict.pop( - f"{name}.v_proj.lora_B.{adapter_name}.weight" - ) - converted_state_dict[f"{name}.out_proj.lora_linear_layer.up.weight"] = state_dict.pop( - f"{name}.out_proj.lora_B.{adapter_name}.weight" - ) - - converted_state_dict[f"{name}.q_proj.lora_linear_layer.down.weight"] = state_dict.pop( - f"{name}.q_proj.lora_A.{adapter_name}.weight" - ) - converted_state_dict[f"{name}.k_proj.lora_linear_layer.down.weight"] = state_dict.pop( - f"{name}.k_proj.lora_A.{adapter_name}.weight" - ) - converted_state_dict[f"{name}.v_proj.lora_linear_layer.down.weight"] = state_dict.pop( - f"{name}.v_proj.lora_A.{adapter_name}.weight" - ) - converted_state_dict[f"{name}.out_proj.lora_linear_layer.down.weight"] = state_dict.pop( - f"{name}.out_proj.lora_A.{adapter_name}.weight" - ) - - return converted_state_dict - - -def convert_diffusers_state_dict_to_peft(attention_modules, state_dict): - # Convert from the diffusers naming convention to the new naming convention. - converted_state_dict = {} - - for name, _ in attention_modules: - converted_state_dict[f"{name}.q_proj.lora_B.weight"] = state_dict.pop( - f"{name}.q_proj.lora_linear_layer.up.weight" - ) - converted_state_dict[f"{name}.k_proj.lora_B.weight"] = state_dict.pop( - f"{name}.k_proj.lora_linear_layer.up.weight" - ) - converted_state_dict[f"{name}.v_proj.lora_B.weight"] = state_dict.pop( - f"{name}.v_proj.lora_linear_layer.up.weight" - ) - converted_state_dict[f"{name}.out_proj.lora_B.weight"] = state_dict.pop( - f"{name}.out_proj.lora_linear_layer.up.weight" - ) - - converted_state_dict[f"{name}.q_proj.lora_A.weight"] = state_dict.pop( - f"{name}.q_proj.lora_linear_layer.down.weight" - ) - converted_state_dict[f"{name}.k_proj.lora_A.weight"] = state_dict.pop( - f"{name}.k_proj.lora_linear_layer.down.weight" - ) - converted_state_dict[f"{name}.v_proj.lora_A.weight"] = state_dict.pop( - f"{name}.v_proj.lora_linear_layer.down.weight" - ) - converted_state_dict[f"{name}.out_proj.lora_A.weight"] = state_dict.pop( - f"{name}.out_proj.lora_linear_layer.down.weight" - ) - - return converted_state_dict - - -def convert_unet_state_dict_to_peft(state_dict): - converted_state_dict = {} - - patterns = { - ".to_out_lora": ".to_o", - ".down": ".lora_A", - ".up": ".lora_B", - ".to_q_lora": ".to_q", - ".to_k_lora": ".to_k", - ".to_v_lora": ".to_v", - } - - for k, v in state_dict.items(): - if any(pattern in k for pattern in patterns.keys()): - for old, new in patterns.items(): - k = k.replace(old, new) - - converted_state_dict[k] = v - - return converted_state_dict From b83fcbaf866f0b5e96d84b240392ae2a32fb310c Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 19 Sep 2023 13:53:57 +0000 Subject: [PATCH 22/52] use class method instead --- src/diffusers/loaders.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index adba0e909dca..5744c17edb4a 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1131,8 +1131,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di _pipeline=self, ) + @classmethod @property - def use_peft_backend(self) -> bool: + def use_peft_backend(cls) -> bool: """ A property method that returns `True` if the current version of `peft` and `transformers` are compatible with PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are @@ -1582,6 +1583,10 @@ def load_lora_into_text_encoder( k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys } + import pdb + + pdb.set_trace() + if cls.use_peft_backend: from peft import LoraConfig From 74e33a93768d7c21e92a29edde69625574476895 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 19 Sep 2023 13:54:31 +0000 Subject: [PATCH 23/52] oops --- src/diffusers/loaders.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 5744c17edb4a..e198b08a8f79 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1583,10 +1583,6 @@ def load_lora_into_text_encoder( k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys } - import pdb - - pdb.set_trace() - if cls.use_peft_backend: from peft import LoraConfig From 9cb8563b1d26472ab0da1d640e71c084c9894940 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 19 Sep 2023 14:03:00 +0000 Subject: [PATCH 24/52] use `base_version` --- src/diffusers/loaders.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index e198b08a8f79..d697658d3fab 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1142,11 +1142,11 @@ def use_peft_backend(cls) -> bool: For PEFT is has to be greater than 0.6.0 and for transformers it has to be greater than 4.33.1. """ correct_peft_version = is_peft_available() and version.parse( - importlib.metadata.version("peft") - ) > version.parse("0.5.0") - correct_transformers_version = version.parse(importlib.metadata.version("transformers")) > version.parse( - "4.33.1" - ) + version.parse(importlib.metadata.version("peft")).base_version + ) > version.parse("0.5") + correct_transformers_version = version.parse( + version.parse(importlib.metadata.version("transformers")).base_version + ) > version.parse("4.33") return correct_peft_version and correct_transformers_version @classmethod From c90f85d3f054339d99b1b5cb6c3e4a88ca04b704 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 19 Sep 2023 14:31:53 +0000 Subject: [PATCH 25/52] fix examples --- src/diffusers/utils/state_dict_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index a9b6de910f52..d3053cd5d934 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -124,6 +124,9 @@ def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs): original_type = StateDictType.DIFFUSERS_OLD elif any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()): original_type = StateDictType.PEFT + elif any("lora_linear_layer" in k for k in state_dict.keys()): + # nothing to do + return state_dict else: raise ValueError("Could not automatically infer state dict type") From 40a489457d0c3e5c86bdbccbbe152c4626a3317c Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 19 Sep 2023 14:52:21 +0000 Subject: [PATCH 26/52] fix CI --- src/diffusers/loaders.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index d697658d3fab..0561620a3461 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1547,7 +1547,7 @@ def load_lora_into_text_encoder( rank = {} text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) - if cls.use_peft_backend: + if cls.use_peft_backend is True: # convert state dict text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) @@ -1583,7 +1583,7 @@ def load_lora_into_text_encoder( k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys } - if cls.use_peft_backend: + if cls.use_peft_backend is True: from peft import LoraConfig lora_rank = list(rank.values())[0] From ea05959c6ac209bb41a80e83843a03b4b2b371ac Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 19 Sep 2023 15:03:15 +0000 Subject: [PATCH 27/52] fix weird error with python 3.8 --- src/diffusers/loaders.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 0561620a3461..1487b3a985e3 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1670,7 +1670,7 @@ def lora_scale(self) -> float: return self._lora_scale if hasattr(self, "_lora_scale") else 1.0 def _remove_text_encoder_monkey_patch(self): - if self.use_peft_backend: + if self.use_peft_backend is True: remove_method = recurse_remove_peft_layers else: warnings.warn( @@ -2080,7 +2080,7 @@ def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora if fuse_unet: self.unet.fuse_lora(lora_scale) - if self.use_peft_backend: + if self.use_peft_backend is True: from peft.tuners.tuners_utils import BaseTunerLayer def fuse_text_encoder_lora(text_encoder, lora_scale=1.0): @@ -2137,8 +2137,8 @@ def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True if unfuse_unet: self.unet.unfuse_lora() - if self.use_peft_backend: - from peft.tuners.layers.tuner_utils import BaseTunerLayer + if self.use_peft_backend is True: + from peft.tuners.tuner_utils import BaseTunerLayer def unfuse_text_encoder_lora(text_encoder): for module in text_encoder.modules(): @@ -2872,7 +2872,7 @@ def pack_weights(layers, prefix): ) def _remove_text_encoder_monkey_patch(self): - if self.use_peft_backend: + if self.use_peft_backend is True: recurse_remove_peft_layers(self.text_encoder) recurse_remove_peft_layers(self.text_encoder_2) else: From 27e3da69dc394eb023c8ebcab654756b144cdea0 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 19 Sep 2023 15:03:40 +0000 Subject: [PATCH 28/52] fix --- src/diffusers/models/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index 07eeae712f71..387a4c428266 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -26,7 +26,7 @@ def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0, use_peft_backend: bool = False): - if use_peft_backend: + if use_peft_backend is True: from peft.tuners.lora import LoraLayer for module in text_encoder.modules(): From 3d7c567f907b095c6856650066d165e85b610413 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 19 Sep 2023 15:06:00 +0000 Subject: [PATCH 29/52] better fix --- src/diffusers/loaders.py | 44 ++++++++++++++++-------------------- src/diffusers/models/lora.py | 2 +- 2 files changed, 21 insertions(+), 25 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 1487b3a985e3..d23ece28982b 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -68,6 +68,19 @@ CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors" +# Below should be `True` if the current version of `peft` and `transformers` are compatible with +# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are +# available. +# For PEFT is has to be greater than 0.6.0 and for transformers it has to be greater than 4.33.1. +_correct_peft_version = is_peft_available() and version.parse( + version.parse(importlib.metadata.version("peft")).base_version +) > version.parse("0.5") +_correct_transformers_version = version.parse( + version.parse(importlib.metadata.version("transformers")).base_version +) > version.parse("4.33") + +USE_PEFT_BACKEND = _correct_peft_version and _correct_transformers_version + class PatchedLoraProjection(nn.Module): def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None): super().__init__() @@ -1084,6 +1097,7 @@ class LoraLoaderMixin: text_encoder_name = TEXT_ENCODER_NAME unet_name = UNET_NAME num_fused_loras = 0 + use_peft_backend = USE_PEFT_BACKEND def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): """ @@ -1131,24 +1145,6 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di _pipeline=self, ) - @classmethod - @property - def use_peft_backend(cls) -> bool: - """ - A property method that returns `True` if the current version of `peft` and `transformers` are compatible with - PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are - available. - - For PEFT is has to be greater than 0.6.0 and for transformers it has to be greater than 4.33.1. - """ - correct_peft_version = is_peft_available() and version.parse( - version.parse(importlib.metadata.version("peft")).base_version - ) > version.parse("0.5") - correct_transformers_version = version.parse( - version.parse(importlib.metadata.version("transformers")).base_version - ) > version.parse("4.33") - return correct_peft_version and correct_transformers_version - @classmethod def lora_state_dict( cls, @@ -1547,7 +1543,7 @@ def load_lora_into_text_encoder( rank = {} text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) - if cls.use_peft_backend is True: + if cls.use_peft_backend: # convert state dict text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) @@ -1583,7 +1579,7 @@ def load_lora_into_text_encoder( k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys } - if cls.use_peft_backend is True: + if cls.use_peft_backend: from peft import LoraConfig lora_rank = list(rank.values())[0] @@ -1670,7 +1666,7 @@ def lora_scale(self) -> float: return self._lora_scale if hasattr(self, "_lora_scale") else 1.0 def _remove_text_encoder_monkey_patch(self): - if self.use_peft_backend is True: + if self.use_peft_backend: remove_method = recurse_remove_peft_layers else: warnings.warn( @@ -2080,7 +2076,7 @@ def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora if fuse_unet: self.unet.fuse_lora(lora_scale) - if self.use_peft_backend is True: + if self.use_peft_backend: from peft.tuners.tuners_utils import BaseTunerLayer def fuse_text_encoder_lora(text_encoder, lora_scale=1.0): @@ -2137,7 +2133,7 @@ def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True if unfuse_unet: self.unet.unfuse_lora() - if self.use_peft_backend is True: + if self.use_peft_backend: from peft.tuners.tuner_utils import BaseTunerLayer def unfuse_text_encoder_lora(text_encoder): @@ -2872,7 +2868,7 @@ def pack_weights(layers, prefix): ) def _remove_text_encoder_monkey_patch(self): - if self.use_peft_backend is True: + if self.use_peft_backend: recurse_remove_peft_layers(self.text_encoder) recurse_remove_peft_layers(self.text_encoder_2) else: diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index 387a4c428266..07eeae712f71 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -26,7 +26,7 @@ def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0, use_peft_backend: bool = False): - if use_peft_backend is True: + if use_peft_backend: from peft.tuners.lora import LoraLayer for module in text_encoder.modules(): From d01a29273ead2682f0cd7353604061bd942fb831 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 19 Sep 2023 15:08:46 +0000 Subject: [PATCH 30/52] style --- src/diffusers/loaders.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index d23ece28982b..f00e6a7e7007 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -81,6 +81,7 @@ USE_PEFT_BACKEND = _correct_peft_version and _correct_transformers_version + class PatchedLoraProjection(nn.Module): def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None): super().__init__() From e836b145e8a8e5b71cf288b2cf6412d9c7ac8f8c Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 20 Sep 2023 13:36:09 +0200 Subject: [PATCH 31/52] Apply suggestions from code review Co-authored-by: Sayak Paul Co-authored-by: Patrick von Platen --- src/diffusers/loaders.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index f00e6a7e7007..ca3292011f51 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -71,15 +71,15 @@ # Below should be `True` if the current version of `peft` and `transformers` are compatible with # PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are # available. -# For PEFT is has to be greater than 0.6.0 and for transformers it has to be greater than 4.33.1. -_correct_peft_version = is_peft_available() and version.parse( +# For PEFT it is has to be greater than 0.6.0 and for transformers it has to be greater than 4.33.1. +_required_peft_version = is_peft_available() and version.parse( version.parse(importlib.metadata.version("peft")).base_version ) > version.parse("0.5") -_correct_transformers_version = version.parse( +_required_transformers_version = version.parse( version.parse(importlib.metadata.version("transformers")).base_version ) > version.parse("4.33") -USE_PEFT_BACKEND = _correct_peft_version and _correct_transformers_version +USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version class PatchedLoraProjection(nn.Module): From cb484056c8513cbb2ec43355c4bffb92ff67a68a Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 20 Sep 2023 13:38:20 +0200 Subject: [PATCH 32/52] Apply suggestions from code review Co-authored-by: Patrick von Platen --- src/diffusers/utils/import_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 0ffa2727e54d..02c53bc66725 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -269,8 +269,8 @@ _peft_available = importlib.util.find_spec("peft") is not None try: - _accelerate_version = importlib_metadata.version("peft") - logger.debug(f"Successfully imported accelerate version {_accelerate_version}") + _peft_version = importlib_metadata.version("peft") + logger.debug(f"Successfully imported accelerate version {_peft_version}") except importlib_metadata.PackageNotFoundError: _peft_available = False From 325462dcd2f674ed0ae00aa87300a57546d2cfde Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 20 Sep 2023 11:44:40 +0000 Subject: [PATCH 33/52] add comment --- src/diffusers/loaders.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index ca3292011f51..75a83147cef5 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1584,6 +1584,8 @@ def load_lora_into_text_encoder( from peft import LoraConfig lora_rank = list(rank.values())[0] + # By definition, the scale should be alpha divided by rank. + # https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/tuners/lora/layer.py#L71 alpha = lora_scale * lora_rank target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] From b412adc158e017ca205b82a79342bc668b311583 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 20 Sep 2023 13:45:09 +0200 Subject: [PATCH 34/52] Apply suggestions from code review Co-authored-by: Sayak Paul --- src/diffusers/utils/__init__.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index d812c65b9a4c..3cd185e86325 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -83,9 +83,7 @@ from .loading_utils import load_image from .logging import get_logger from .outputs import BaseOutput -from .peft_utils import ( - recurse_remove_peft_layers, -) +from .peft_utils import recurse_remove_peft_layers from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil from .state_dict_utils import convert_state_dict_to_diffusers, convert_state_dict_to_peft From b72ef23dfcfb48a9c8d4fe0b617c064e5eb7c4b6 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 20 Sep 2023 11:48:51 +0000 Subject: [PATCH 35/52] conv2d support for recurse remove --- src/diffusers/utils/peft_utils.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index b6d253b56412..9b34183ffaac 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -32,6 +32,8 @@ def recurse_remove_peft_layers(model): ## compound module, go inside it recurse_remove_peft_layers(module) + module_replaced = False + if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear): new_module = torch.nn.Linear(module.in_features, module.out_features, bias=module.bias is not None).to( module.weight.device @@ -40,11 +42,30 @@ def recurse_remove_peft_layers(model): if module.bias is not None: new_module.bias = module.bias + module_replaced = True + elif isinstance(module, LoraLayer) and isinstance(module, torch.nn.Conv2d): + new_module = torch.nn.Conv2d( + module.in_channels, + module.out_channels, + module.kernel_size, + module.stride, + module.padding, + module.dilation, + module.groups, + module.bias, + ).to(module.weight.device) + + new_module.weight = module.weight + if module.bias is not None: + new_module.bias = module.bias + + module_replaced = True + + if module_replaced: setattr(model, name, new_module) del module if torch.cuda.is_available(): torch.cuda.empty_cache() - # TODO: do it for Conv2d return model From e0726557653a51e96189cde3fc0f285bc0c729ad Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 20 Sep 2023 11:54:52 +0000 Subject: [PATCH 36/52] added docstrings --- src/diffusers/utils/state_dict_utils.py | 42 +++++++++++++++++++++++-- 1 file changed, 40 insertions(+), 2 deletions(-) diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index d3053cd5d934..27973abbebc7 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -84,6 +84,21 @@ class StateDictType(enum.Enum): def convert_state_dict(state_dict, mapping): + r""" + Simply iterates over the state dict and replaces the patterns in `mapping` with the corresponding values. + + Args: + state_dict (`dict[str, torch.Tensor]`): + The state dict to convert. + mapping (`dict[str, str]`): + The mapping to use for conversion, the mapping should be a dictionary with the following structure: + - key: the pattern to replace + - value: the pattern to replace with + + Returns: + converted_state_dict (`dict`) + The converted state dict. + """ converted_state_dict = {} for k, v in state_dict.items(): if any(pattern in k for pattern in mapping.keys()): @@ -96,7 +111,14 @@ def convert_state_dict(state_dict, mapping): def convert_state_dict_to_peft(state_dict, original_type=None, **kwargs): r""" - The method automatically infers in which direction the conversion should be done. + Converts a state dict to the PEFT format The state dict can be from previous diffusers format (`OLD_DIFFUSERS`), or + new diffusers format (`DIFFUSERS`). The method only supports the conversion from diffusers old/new to PEFT for now. + + Args: + state_dict (`dict[str, torch.Tensor]`): + The state dict to convert. + original_type (`StateDictType`, *optional*): + The original type of the state dict, if not provided, the method will try to infer it automatically. """ if original_type is None: # Old diffusers to PEFT @@ -107,13 +129,26 @@ def convert_state_dict_to_peft(state_dict, original_type=None, **kwargs): else: raise ValueError("Could not automatically infer state dict type") + if original_type not in PEFT_STATE_DICT_MAPPINGS.keys(): + raise ValueError(f"Original type {original_type} is not supported") + mapping = PEFT_STATE_DICT_MAPPINGS[original_type] return convert_state_dict(state_dict, mapping) def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs): r""" - The method automatically infers in which direction the conversion should be done. + Converts a state dict to new diffusers format. The state dict can be from previous diffusers format + (`OLD_DIFFUSERS`), or PEFT format (`PEFT`) or new diffusers format (`DIFFUSERS`). In the last case the method will + return the state dict as is. + + The method only supports the conversion from diffusers old, PEFT to diffusers new for now. + + Args: + state_dict (`dict[str, torch.Tensor]`): + The state dict to convert. + original_type (`StateDictType`, *optional*): + The original type of the state dict, if not provided, the method will try to infer it automatically. """ peft_adapter_name = kwargs.pop("adapter_name", "") peft_adapter_name = "." + peft_adapter_name @@ -130,5 +165,8 @@ def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs): else: raise ValueError("Could not automatically infer state dict type") + if original_type not in DIFFUSERS_STATE_DICT_MAPPINGS.keys(): + raise ValueError(f"Original type {original_type} is not supported") + mapping = DIFFUSERS_STATE_DICT_MAPPINGS[original_type] return convert_state_dict(state_dict, mapping) From bd46ae9db79f82b00ca60941bcc1cf5d6e9dbc54 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 20 Sep 2023 11:57:43 +0000 Subject: [PATCH 37/52] more docstring --- src/diffusers/utils/state_dict_utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index 27973abbebc7..5e5c3a74a946 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -149,6 +149,14 @@ def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs): The state dict to convert. original_type (`StateDictType`, *optional*): The original type of the state dict, if not provided, the method will try to infer it automatically. + kwargs (`dict`, *args*): + Additional arguments to pass to the method. + + - **adapter_name**: For example, in case of PEFT, some keys will be pre-pended + with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in + `get_peft_model_state_dict` method: + https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92 + but we add it here in case we don't want to rely on that method. """ peft_adapter_name = kwargs.pop("adapter_name", "") peft_adapter_name = "." + peft_adapter_name From 724b52bc5600d66651eb22d936569e98be9b76d5 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 20 Sep 2023 12:12:54 +0000 Subject: [PATCH 38/52] add deprecate --- src/diffusers/loaders.py | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 75a83147cef5..d753047161a9 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -14,7 +14,6 @@ import importlib import os import re -import warnings from collections import defaultdict from contextlib import nullcontext from io import BytesIO @@ -80,6 +79,7 @@ ) > version.parse("4.33") USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version +LORA_DEPRECATION_MESSAGE = "You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT make sure to install the latest PEFT and transformers packages in the future." class PatchedLoraProjection(nn.Module): @@ -1672,11 +1672,6 @@ def _remove_text_encoder_monkey_patch(self): if self.use_peft_backend: remove_method = recurse_remove_peft_layers else: - warnings.warn( - "You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT" - " make sure to install the latest PEFT and transformers packages in the future.", - FutureWarning, - ) remove_method = self._remove_text_encoder_monkey_patch_classmethod if hasattr(self, "text_encoder"): @@ -1686,6 +1681,8 @@ def _remove_text_encoder_monkey_patch(self): @classmethod def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder): + deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.23", LORA_DEPRECATION_MESSAGE) + for _, attn_module in text_encoder_attn_modules(text_encoder): if isinstance(attn_module.q_proj, PatchedLoraProjection): attn_module.q_proj.lora_linear_layer = None @@ -1712,6 +1709,7 @@ def _modify_text_encoder( r""" Monkey-patches the forward passes of attention modules of the text encoder. """ + deprecate("_modify_text_encoder", "0.23", LORA_DEPRECATION_MESSAGE) def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters): linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model @@ -2091,11 +2089,7 @@ def fuse_text_encoder_lora(text_encoder, lora_scale=1.0): module.merge() else: - warnings.warn( - "You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT" - " make sure to install the latest PEFT and transformers packages in the future.", - FutureWarning, - ) + deprecate("fuse_text_encoder_lora", "0.23", LORA_DEPRECATION_MESSAGE) def fuse_text_encoder_lora(text_encoder, lora_scale=1.0): for _, attn_module in text_encoder_attn_modules(text_encoder): @@ -2145,11 +2139,7 @@ def unfuse_text_encoder_lora(text_encoder): module.unmerge() else: - warnings.warn( - "You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT" - " make sure to install the latest PEFT and transformers packages in the future.", - FutureWarning, - ) + deprecate("unfuse_text_encoder_lora", "0.23", LORA_DEPRECATION_MESSAGE) def unfuse_text_encoder_lora(text_encoder): for _, attn_module in text_encoder_attn_modules(text_encoder): From 5e6f343d164a516a66582e7bad6a81a4e38177d2 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 20 Sep 2023 12:14:52 +0000 Subject: [PATCH 39/52] revert --- src/diffusers/loaders.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index d753047161a9..a85e1ad89b09 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -2160,6 +2160,8 @@ def unfuse_text_encoder_lora(text_encoder): if hasattr(self, "text_encoder_2"): unfuse_text_encoder_lora(self.text_encoder_2) + self.num_fused_loras -= 1 + class FromSingleFileMixin: """ From 71650d403d588e25b89d709ae6a886e3ebc7fd04 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 20 Sep 2023 12:16:32 +0000 Subject: [PATCH 40/52] try to fix merge conflicts --- tests/lora/test_lora_layers_old_backend.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/lora/test_lora_layers_old_backend.py b/tests/lora/test_lora_layers_old_backend.py index e54caeb9f0c2..20d44e0c07c4 100644 --- a/tests/lora/test_lora_layers_old_backend.py +++ b/tests/lora/test_lora_layers_old_backend.py @@ -1876,6 +1876,25 @@ def test_a1111(self): self.assertTrue(np.allclose(images, expected, atol=1e-3)) + def test_lycoris(self): + generator = torch.Generator().manual_seed(0) + + pipe = StableDiffusionPipeline.from_pretrained( + "hf-internal-testing/Amixx", safety_checker=None, use_safetensors=True, variant="fp16" + ).to(torch_device) + lora_model_id = "hf-internal-testing/edgLycorisMugler-light" + lora_filename = "edgLycorisMugler-light.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + expected = np.array([0.6463, 0.658, 0.599, 0.6542, 0.6512, 0.6213, 0.658, 0.6485, 0.6017]) + + self.assertTrue(np.allclose(images, expected, atol=1e-3)) + def test_a1111_with_model_cpu_offload(self): generator = torch.Generator().manual_seed(0) From 0985d17ea928465df225e38faf5dd8e8e29ff3ca Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Fri, 22 Sep 2023 12:35:01 +0530 Subject: [PATCH 41/52] peft integration features for text encoder 1. support multiple rank/alpha values 2. support multiple active adapters 3. support disabling and enabling adapters --- src/diffusers/loaders.py | 103 +++++++++++++++++++++++++++--- src/diffusers/models/lora.py | 8 +-- src/diffusers/utils/__init__.py | 9 ++- src/diffusers/utils/peft_utils.py | 67 +++++++++++++++++++ 4 files changed, 170 insertions(+), 17 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index c3858a91d5c8..fff8e05cb0a9 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -35,12 +35,17 @@ convert_state_dict_to_diffusers, convert_state_dict_to_peft, deprecate, + get_adapter_name, + get_rank_and_alpha_pattern, is_accelerate_available, is_omegaconf_available, is_peft_available, is_transformers_available, logging, recurse_remove_peft_layers, + scale_lora_layers, + set_adapter_layers, + set_weights_and_activate_adapters, ) from .utils.import_utils import BACKENDS_MAPPING @@ -1100,7 +1105,9 @@ class LoraLoaderMixin: num_fused_loras = 0 use_peft_backend = USE_PEFT_BACKEND - def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): + def load_lora_weights( + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and `self.text_encoder`. @@ -1144,6 +1151,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di lora_scale=self.lora_scale, low_cpu_mem_usage=low_cpu_mem_usage, _pipeline=self, + adapter_name=adapter_name, ) @classmethod @@ -1500,6 +1508,7 @@ def load_lora_into_text_encoder( lora_scale=1.0, low_cpu_mem_usage=None, _pipeline=None, + adapter_name=None, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -1522,6 +1531,9 @@ def load_lora_into_text_encoder( tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this argument to `True` will raise an error. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded """ low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT @@ -1583,18 +1595,30 @@ def load_lora_into_text_encoder( if cls.use_peft_backend: from peft import LoraConfig - lora_rank = list(rank.values())[0] - # By definition, the scale should be alpha divided by rank. - # https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/tuners/lora/layer.py#L71 - alpha = lora_scale * lora_rank + r, lora_alpha, rank_pattern, alpha_pattern, target_modules = get_rank_and_alpha_pattern( + rank, network_alphas, text_encoder_lora_state_dict + ) - target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] - if patch_mlp: - target_modules += ["fc1", "fc2"] + lora_config = LoraConfig( + r=r, + target_modules=target_modules, + lora_alpha=lora_alpha, + rank_pattern=rank_pattern, + alpha_pattern=alpha_pattern, + ) - lora_config = LoraConfig(r=lora_rank, target_modules=target_modules, lora_alpha=alpha) + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(text_encoder) - text_encoder.load_adapter(adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config) + # inject LoRA layers and load the state dict + text_encoder.load_adapter( + adapter_name=adapter_name, + adapter_state_dict=text_encoder_lora_state_dict, + peft_config=lora_config, + ) + # scale LoRA layers with `lora_scale` + scale_lora_layers(text_encoder, lora_weightage=lora_scale) is_model_cpu_offload = False is_sequential_cpu_offload = False @@ -2169,6 +2193,65 @@ def unfuse_text_encoder_lora(text_encoder): self.num_fused_loras -= 1 + def set_adapter( + self, + adapter_names: Union[List[str], str], + unet_weights: List[float] = None, + te_weights: List[float] = None, + te2_weights: List[float] = None, + ): + if not self.use_peft_backend: + raise ValueError("PEFT backend is required for this method.") + + def process_weights(adapter_names, weights): + if weights is None: + weights = [1.0] * len(adapter_names) + elif isinstance(weights, float): + weights = [weights] + + if len(adapter_names) != len(weights): + raise ValueError( + f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}" + ) + return weights + + adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names + + # To Do + # Handle the UNET + + # Handle the Text Encoder + te_weights = process_weights(adapter_names, te_weights) + if hasattr(self, "text_encoder"): + set_weights_and_activate_adapters(self.text_encoder, adapter_names, te_weights) + te2_weights = process_weights(adapter_names, te2_weights) + if hasattr(self, "text_encoder_2"): + set_weights_and_activate_adapters(self.text_encoder_2, adapter_names, te2_weights) + + def disable_lora(self): + if not self.use_peft_backend: + raise ValueError("PEFT backend is required for this method.") + # To Do + # Disbale unet adapters + + # Disbale text encoder adapters + if hasattr(self, "text_encoder"): + set_adapter_layers(self.text_encoder, enabled=False) + if hasattr(self, "text_encoder_2"): + set_adapter_layers(self.text_encoder_2, enabled=False) + + def enable_lora(self): + if not self.use_peft_backend: + raise ValueError("PEFT backend is required for this method.") + # To Do + # Enable unet adapters + + # Enable text encoder adapters + if hasattr(self, "text_encoder"): + set_adapter_layers(self.text_encoder, enabled=True) + if hasattr(self, "text_encoder_2"): + set_adapter_layers(self.text_encoder_2, enabled=True) + class FromSingleFileMixin: """ diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index 07eeae712f71..d080afb72384 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -19,7 +19,7 @@ from torch import nn from ..loaders import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules -from ..utils import logging +from ..utils import logging, scale_lora_layers logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -27,11 +27,7 @@ def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0, use_peft_backend: bool = False): if use_peft_backend: - from peft.tuners.lora import LoraLayer - - for module in text_encoder.modules(): - if isinstance(module, LoraLayer): - module.scaling[module.active_adapter] = lora_scale + scale_lora_layers(text_encoder, lora_weightage=lora_scale) else: for _, attn_module in text_encoder_attn_modules(text_encoder): if isinstance(attn_module.q_proj, PatchedLoraProjection): diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 3cd185e86325..004aa7d3405c 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -83,7 +83,14 @@ from .loading_utils import load_image from .logging import get_logger from .outputs import BaseOutput -from .peft_utils import recurse_remove_peft_layers +from .peft_utils import ( + get_adapter_name, + get_rank_and_alpha_pattern, + recurse_remove_peft_layers, + scale_lora_layers, + set_adapter_layers, + set_weights_and_activate_adapters, +) from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil from .state_dict_utils import convert_state_dict_to_diffusers, convert_state_dict_to_peft diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 9b34183ffaac..18d61fd038c7 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -69,3 +69,70 @@ def recurse_remove_peft_layers(model): torch.cuda.empty_cache() return model + + +def scale_lora_layers(model, lora_weightage): + from peft.tuners.tuner_utils import BaseTunerLayer + + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + module.scale_layer(lora_weightage) + + +def get_rank_and_alpha_pattern(rank_dict, network_alpha_dict, peft_state_dict): + rank_pattern = None + alpha_pattern = None + r = lora_alpha = list(rank_dict.values())[0] + if len(set(rank_dict.values())) > 1: + # get the rank occuring the most number of times + r = max(set(rank_dict.values()), key=list(rank_dict.values()).count) + + # for modules with rank different from the most occuring rank, add it to the `rank_pattern` + rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items())) + rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()} + + if network_alpha_dict is not None and len(set(network_alpha_dict.values())) > 1: + # get the alpha occuring the most number of times + lora_alpha = max(set(network_alpha_dict.values()), key=list(network_alpha_dict.values()).count) + + # for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern` + alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items())) + alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()} + + # layer names without the Diffusers specific + target_modules = {name.split(".lora")[0] for name in peft_state_dict.keys()} + + return r, lora_alpha, rank_pattern, alpha_pattern, target_modules + + +def get_adapter_name(model): + from peft.tuners.tuner_utils import BaseTunerLayer + + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + return f"default_{len(module.r)}" + return "default_0" + + +def set_adapter_layers(model, enabled=True): + from peft.tuners.tuner_utils import BaseTunerLayer + + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + module.disable_adapters = False if enabled else True + + +def set_weights_and_activate_adapters(model, adapter_names, weights): + from peft.tuners.tuner_utils import BaseTunerLayer + + # iterate over each adapter, make it active and set the corresponding scaling weight + for adapter_name, weight in zip(adapter_names, weights): + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + module.active_adapter = adapter_name + module.scale_layer(weight) + + # set multiple active adapters + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + module.active_adapter = adapter_names From 01a15cc6ef25855fcfffcda0b1754da992c76580 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Fri, 22 Sep 2023 16:59:05 +0530 Subject: [PATCH 42/52] fix bug --- src/diffusers/utils/peft_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 18d61fd038c7..25cb9b7e707e 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -100,7 +100,7 @@ def get_rank_and_alpha_pattern(rank_dict, network_alpha_dict, peft_state_dict): alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()} # layer names without the Diffusers specific - target_modules = {name.split(".lora")[0] for name in peft_state_dict.keys()} + target_modules = list(set([name.split(".lora")[0] for name in peft_state_dict.keys()])) return r, lora_alpha, rank_pattern, alpha_pattern, target_modules From ffbac30e4edfac4351d95cab544974d7a5f0d6e2 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Fri, 22 Sep 2023 17:09:09 +0530 Subject: [PATCH 43/52] fix code quality --- src/diffusers/utils/peft_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 25cb9b7e707e..05145f2b17ac 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -100,7 +100,7 @@ def get_rank_and_alpha_pattern(rank_dict, network_alpha_dict, peft_state_dict): alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()} # layer names without the Diffusers specific - target_modules = list(set([name.split(".lora")[0] for name in peft_state_dict.keys()])) + target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()}) return r, lora_alpha, rank_pattern, alpha_pattern, target_modules From 916c31ab0c1455515e0276c82dfca95542be959e Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Mon, 25 Sep 2023 19:05:07 +0530 Subject: [PATCH 44/52] Apply suggestions from code review Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --- src/diffusers/utils/peft_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 05145f2b17ac..42bc3f0010c9 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -72,7 +72,7 @@ def recurse_remove_peft_layers(model): def scale_lora_layers(model, lora_weightage): - from peft.tuners.tuner_utils import BaseTunerLayer + from peft.tuners.tuners_utils import BaseTunerLayer for module in model.modules(): if isinstance(module, BaseTunerLayer): @@ -106,7 +106,7 @@ def get_rank_and_alpha_pattern(rank_dict, network_alpha_dict, peft_state_dict): def get_adapter_name(model): - from peft.tuners.tuner_utils import BaseTunerLayer + from peft.tuners.tuners_utils import BaseTunerLayer for module in model.modules(): if isinstance(module, BaseTunerLayer): @@ -115,7 +115,7 @@ def get_adapter_name(model): def set_adapter_layers(model, enabled=True): - from peft.tuners.tuner_utils import BaseTunerLayer + from peft.tuners.tuners_utils import BaseTunerLayer for module in model.modules(): if isinstance(module, BaseTunerLayer): @@ -123,7 +123,7 @@ def set_adapter_layers(model, enabled=True): def set_weights_and_activate_adapters(model, adapter_names, weights): - from peft.tuners.tuner_utils import BaseTunerLayer + from peft.tuners.tuners_utils import BaseTunerLayer # iterate over each adapter, make it active and set the corresponding scaling weight for adapter_name, weight in zip(adapter_names, weights): From 5de0f1bfc9b04057f84c4959fdce3cba5c3f13ba Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Mon, 25 Sep 2023 19:32:52 +0530 Subject: [PATCH 45/52] fix bugs --- src/diffusers/utils/peft_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 05145f2b17ac..3671613818ad 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -80,8 +80,8 @@ def scale_lora_layers(model, lora_weightage): def get_rank_and_alpha_pattern(rank_dict, network_alpha_dict, peft_state_dict): - rank_pattern = None - alpha_pattern = None + rank_pattern = {} + alpha_pattern = {} r = lora_alpha = list(rank_dict.values())[0] if len(set(rank_dict.values())) > 1: # get the rank occuring the most number of times From 0acb58cdc7aad04296e245e25b6c03cda06df29e Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Mon, 25 Sep 2023 19:40:36 +0530 Subject: [PATCH 46/52] Apply suggestions from code review Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --- setup.py | 1 - src/diffusers/dependency_versions_table.py | 1 - 2 files changed, 2 deletions(-) diff --git a/setup.py b/setup.py index 0d67f1b96b12..a2201ac5b3b1 100644 --- a/setup.py +++ b/setup.py @@ -128,7 +128,6 @@ "torchvision", "transformers>=4.25.1", "urllib3<=2.0.0", - "peft>=0.5.0" ] # this is a lookup table with items like: diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index 42adc6444f53..d4b94ba6d4ed 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -41,5 +41,4 @@ "torchvision": "torchvision", "transformers": "transformers>=4.25.1", "urllib3": "urllib3<=2.0.0", - "peft": "peft>=0.5.0", } From 1ca4c62847e5826c9f1b8c32dcb9881b72291943 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Tue, 26 Sep 2023 16:51:36 +0530 Subject: [PATCH 47/52] address comments Co-Authored-By: Benjamin Bossan Co-Authored-By: Patrick von Platen --- src/diffusers/loaders.py | 21 ++++++++----------- src/diffusers/utils/__init__.py | 2 +- src/diffusers/utils/peft_utils.py | 34 ++++++++++++++++++++++++------- 3 files changed, 36 insertions(+), 21 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index eba85d7cb240..e555b5784fb6 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -36,7 +36,7 @@ convert_state_dict_to_peft, deprecate, get_adapter_name, - get_rank_and_alpha_pattern, + get_peft_kwargs, is_accelerate_available, is_omegaconf_available, is_peft_available, @@ -1127,6 +1127,9 @@ def load_lora_weights( See [`~loaders.LoraLoaderMixin.lora_state_dict`]. kwargs (`dict`, *optional*): See [`~loaders.LoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. """ # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) @@ -1508,8 +1511,8 @@ def load_lora_into_text_encoder( prefix=None, lora_scale=1.0, low_cpu_mem_usage=None, - _pipeline=None, adapter_name=None, + _pipeline=None, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -1534,7 +1537,7 @@ def load_lora_into_text_encoder( argument to `True` will raise an error. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded + `default_{i}` where i is the total number of adapters being loaded. """ low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT @@ -1596,17 +1599,9 @@ def load_lora_into_text_encoder( if cls.use_peft_backend: from peft import LoraConfig - r, lora_alpha, rank_pattern, alpha_pattern, target_modules = get_rank_and_alpha_pattern( - rank, network_alphas, text_encoder_lora_state_dict - ) + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict) - lora_config = LoraConfig( - r=r, - target_modules=target_modules, - lora_alpha=lora_alpha, - rank_pattern=rank_pattern, - alpha_pattern=alpha_pattern, - ) + lora_config = LoraConfig(**lora_config_kwargs) # adapter_name if adapter_name is None: diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 004aa7d3405c..5ad34f827967 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -85,7 +85,7 @@ from .outputs import BaseOutput from .peft_utils import ( get_adapter_name, - get_rank_and_alpha_pattern, + get_peft_kwargs, recurse_remove_peft_layers, scale_lora_layers, set_adapter_layers, diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index d8eed1cb8365..da948b209f31 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -14,6 +14,7 @@ """ PEFT utilities: Utilities related to peft library """ +import collections from .import_utils import is_torch_available @@ -79,13 +80,13 @@ def scale_lora_layers(model, lora_weightage): module.scale_layer(lora_weightage) -def get_rank_and_alpha_pattern(rank_dict, network_alpha_dict, peft_state_dict): +def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict): rank_pattern = {} alpha_pattern = {} r = lora_alpha = list(rank_dict.values())[0] if len(set(rank_dict.values())) > 1: # get the rank occuring the most number of times - r = max(set(rank_dict.values()), key=list(rank_dict.values()).count) + r = collections.Counter(rank_dict.values()).most_common()[0][0] # for modules with rank different from the most occuring rank, add it to the `rank_pattern` rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items())) @@ -93,7 +94,7 @@ def get_rank_and_alpha_pattern(rank_dict, network_alpha_dict, peft_state_dict): if network_alpha_dict is not None and len(set(network_alpha_dict.values())) > 1: # get the alpha occuring the most number of times - lora_alpha = max(set(network_alpha_dict.values()), key=list(network_alpha_dict.values()).count) + lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0] # for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern` alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items())) @@ -102,7 +103,14 @@ def get_rank_and_alpha_pattern(rank_dict, network_alpha_dict, peft_state_dict): # layer names without the Diffusers specific target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()}) - return r, lora_alpha, rank_pattern, alpha_pattern, target_modules + lora_config_kwargs = { + "r": r, + "lora_alpha": lora_alpha, + "rank_pattern": rank_pattern, + "alpha_pattern": alpha_pattern, + "target_modules": target_modules, + } + return lora_config_kwargs def get_adapter_name(model): @@ -119,7 +127,11 @@ def set_adapter_layers(model, enabled=True): for module in model.modules(): if isinstance(module, BaseTunerLayer): - module.disable_adapters = False if enabled else True + # The recent version of PEFT needs to call `enable_adapters` instead + if hasattr(module, "enable_adapters"): + module.enable_adapters(enabled=False) + else: + module.disable_adapters = True def set_weights_and_activate_adapters(model, adapter_names, weights): @@ -129,10 +141,18 @@ def set_weights_and_activate_adapters(model, adapter_names, weights): for adapter_name, weight in zip(adapter_names, weights): for module in model.modules(): if isinstance(module, BaseTunerLayer): - module.active_adapter = adapter_name + # For backward compatbility with previous PEFT versions + if hasattr(module, "set_adapter"): + module.set_adapter(adapter_name) + else: + module.active_adapter = adapter_name module.scale_layer(weight) # set multiple active adapters for module in model.modules(): if isinstance(module, BaseTunerLayer): - module.active_adapter = adapter_names + # For backward compatbility with previous PEFT versions + if hasattr(module, "set_adapter"): + module.set_adapter(adapter_names) + else: + module.active_adapter = adapter_names From 7c377887911b75c31844edc6cd397666921c431f Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Tue, 26 Sep 2023 16:57:30 +0530 Subject: [PATCH 48/52] fix code quality --- src/diffusers/utils/peft_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index da948b209f31..81b621cdd773 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -15,6 +15,7 @@ PEFT utilities: Utilities related to peft library """ import collections + from .import_utils import is_torch_available From 2fcf1741d74d8d60ad0d20f6976e881a3c299682 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Tue, 26 Sep 2023 22:03:27 +0530 Subject: [PATCH 49/52] address comments --- src/diffusers/loaders.py | 78 +++++++++++++++++++++++----------------- 1 file changed, 46 insertions(+), 32 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index e555b5784fb6..1cae1b518554 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -51,7 +51,7 @@ if is_transformers_available(): - from transformers import CLIPTextModel, CLIPTextModelWithProjection + from transformers import CLIPTextModel, CLIPTextModelWithProjection, PreTrainedModel if is_accelerate_available(): from accelerate import init_empty_weights @@ -2196,13 +2196,24 @@ def unfuse_text_encoder_lora(text_encoder): self.num_fused_loras -= 1 - def set_adapter( + def set_adapter_for_text_encoder( self, adapter_names: Union[List[str], str], - unet_weights: List[float] = None, - te_weights: List[float] = None, - te2_weights: List[float] = None, + text_encoder: Optional[PreTrainedModel] = None, + text_encoder_weights: List[float] = None, ): + """ + Sets the adapter layers for the text encoder. + + Args: + adapter_names (`List[str]` or `str`): + The names of the adapters to use. + text_encoder (`torch.nn.Module`, *optional*): + The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder` + attribute. + text_encoder_weights (`List[float]`, *optional*): + The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters. + """ if not self.use_peft_backend: raise ValueError("PEFT backend is required for this method.") @@ -2219,41 +2230,44 @@ def process_weights(adapter_names, weights): return weights adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names + text_encoder_weights = process_weights(adapter_names, text_encoder_weights) + text_encoder = text_encoder or getattr(self, "text_encoder", None) + if text_encoder is None: + raise ValueError("Text Encoder not found.") + set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights) - # To Do - # Handle the UNET - - # Handle the Text Encoder - te_weights = process_weights(adapter_names, te_weights) - if hasattr(self, "text_encoder"): - set_weights_and_activate_adapters(self.text_encoder, adapter_names, te_weights) - te2_weights = process_weights(adapter_names, te2_weights) - if hasattr(self, "text_encoder_2"): - set_weights_and_activate_adapters(self.text_encoder_2, adapter_names, te2_weights) + def disable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel] = None): + """ + Disables the LoRA layers for the text encoder. - def disable_lora(self): + Args: + text_encoder (`torch.nn.Module`, *optional*): + The text encoder module to disable the LoRA layers for. If `None`, it will try to get the + `text_encoder` attribute. + """ if not self.use_peft_backend: raise ValueError("PEFT backend is required for this method.") - # To Do - # Disbale unet adapters - # Disbale text encoder adapters - if hasattr(self, "text_encoder"): - set_adapter_layers(self.text_encoder, enabled=False) - if hasattr(self, "text_encoder_2"): - set_adapter_layers(self.text_encoder_2, enabled=False) + text_encoder = text_encoder or getattr(self, "text_encoder", None) + if text_encoder is None: + raise ValueError("Text Encoder not found.") + set_adapter_layers(text_encoder, enabled=False) - def enable_lora(self): + def enable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel] = None): + """ + Enables the LoRA layers for the text encoder. + + Args: + text_encoder (`torch.nn.Module`, *optional*): + The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder` + attribute. + """ if not self.use_peft_backend: raise ValueError("PEFT backend is required for this method.") - # To Do - # Enable unet adapters - - # Enable text encoder adapters - if hasattr(self, "text_encoder"): - set_adapter_layers(self.text_encoder, enabled=True) - if hasattr(self, "text_encoder_2"): - set_adapter_layers(self.text_encoder_2, enabled=True) + text_encoder = text_encoder or getattr(self, "text_encoder", None) + if text_encoder is None: + raise ValueError("Text Encoder not found.") + set_adapter_layers(self.text_encoder, enabled=True) class FromSingleFileMixin: From a1f012894759c01ef8012a07213525bca5f798ee Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Tue, 26 Sep 2023 22:13:19 +0530 Subject: [PATCH 50/52] address comments --- src/diffusers/utils/peft_utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 81b621cdd773..b411dd3e9718 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -74,6 +74,15 @@ def recurse_remove_peft_layers(model): def scale_lora_layers(model, lora_weightage): + """ + Adjust the weightage given to the LoRA layers of the model. + + Args: + model (`torch.nn.Module`): + The model to scale. + lora_weightage (`float`): + The weightage to be given to the LoRA layers. + """ from peft.tuners.tuners_utils import BaseTunerLayer for module in model.modules(): From fd9bcfe53e411cff94db660e1b861679e5e684f8 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 27 Sep 2023 12:11:56 +0200 Subject: [PATCH 51/52] Apply suggestions from code review --- src/diffusers/loaders.py | 6 +++--- src/diffusers/models/lora.py | 2 +- src/diffusers/utils/peft_utils.py | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index bca68d7439d6..fb44993577e7 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1153,8 +1153,8 @@ def load_lora_weights( text_encoder=self.text_encoder, lora_scale=self.lora_scale, low_cpu_mem_usage=low_cpu_mem_usage, - _pipeline=self, adapter_name=adapter_name, + _pipeline=self, ) @classmethod @@ -1614,7 +1614,7 @@ def load_lora_into_text_encoder( peft_config=lora_config, ) # scale LoRA layers with `lora_scale` - scale_lora_layers(text_encoder, lora_weightage=lora_scale) + scale_lora_layers(text_encoder, weight=lora_scale) is_model_cpu_offload = False is_sequential_cpu_offload = False @@ -2233,7 +2233,7 @@ def process_weights(adapter_names, weights): text_encoder_weights = process_weights(adapter_names, text_encoder_weights) text_encoder = text_encoder or getattr(self, "text_encoder", None) if text_encoder is None: - raise ValueError("Text Encoder not found.") + raise ValueError("The pipeline does not have a default `pipe.text_encoder` class. Please make sure to pass a `text_encoder` instead.") set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights) def disable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel] = None): diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index d080afb72384..fa8258fedc86 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -27,7 +27,7 @@ def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0, use_peft_backend: bool = False): if use_peft_backend: - scale_lora_layers(text_encoder, lora_weightage=lora_scale) + scale_lora_layers(text_encoder, weight=lora_scale) else: for _, attn_module in text_encoder_attn_modules(text_encoder): if isinstance(attn_module.q_proj, PatchedLoraProjection): diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 599b941dfe53..70bc7adb6746 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -72,15 +72,15 @@ def recurse_remove_peft_layers(model): return model -def scale_lora_layers(model, lora_weightage): +def scale_lora_layers(model, weight): """ Adjust the weightage given to the LoRA layers of the model. Args: model (`torch.nn.Module`): The model to scale. - lora_weightage (`float`): - The weightage to be given to the LoRA layers. + weight (`float`): + The weight to be given to the LoRA layers. """ from peft.tuners.tuners_utils import BaseTunerLayer From 9916ac6ae3715f3c70d91551a2d40524af463d28 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 27 Sep 2023 12:14:31 +0200 Subject: [PATCH 52/52] find and replace --- src/diffusers/loaders.py | 4 +++- src/diffusers/utils/peft_utils.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index fb44993577e7..429b1b8c7190 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -2233,7 +2233,9 @@ def process_weights(adapter_names, weights): text_encoder_weights = process_weights(adapter_names, text_encoder_weights) text_encoder = text_encoder or getattr(self, "text_encoder", None) if text_encoder is None: - raise ValueError("The pipeline does not have a default `pipe.text_encoder` class. Please make sure to pass a `text_encoder` instead.") + raise ValueError( + "The pipeline does not have a default `pipe.text_encoder` class. Please make sure to pass a `text_encoder` instead." + ) set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights) def disable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel] = None): diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 70bc7adb6746..253a57a2270e 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -86,7 +86,7 @@ def scale_lora_layers(model, weight): for module in model.modules(): if isinstance(module, BaseTunerLayer): - module.scale_layer(lora_weightage) + module.scale_layer(weight) def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict):