-
Notifications
You must be signed in to change notification settings - Fork 6.6k
PEFT Integration for Text Encoder to handle multiple alphas/ranks, disable/enable adapters and support for multiple adapters #5147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
63 commits
Select commit
Hold shift + click to select a range
ba24f2a
more fixes
younesbelkada c17634c
up
younesbelkada 2a6e535
up
younesbelkada 01f6d1d
style
younesbelkada 5a150b2
add in setup
younesbelkada 961e776
oops
younesbelkada cdbe739
more changes
younesbelkada 691368b
v1 rzfactor CI
younesbelkada 7918851
Apply suggestions from code review
younesbelkada 14db139
few todos
younesbelkada c06c40b
Merge branch 'main' into peftpart-1
younesbelkada d56a14d
protect torch import
younesbelkada ec87c19
style
younesbelkada 40a6028
fix fuse text encoder
younesbelkada 0c62ef3
Merge remote-tracking branch 'upstream/main' into peftpart-1
younesbelkada c4295c9
Update src/diffusers/loaders.py
younesbelkada 4162ddf
replace with `recurse_replace_peft_layers`
younesbelkada 1d13f40
keep old modules for BC
younesbelkada 78a860d
adjustments on `adjust_lora_scale_text_encoder`
younesbelkada 78a01d5
Merge branch 'main' into peftpart-1
younesbelkada ecbc714
Merge remote-tracking branch 'upstream/main' into peftpart-1
younesbelkada 9d650c9
Merge branch 'peftpart-1' of https://github.com/younesbelkada/diffuse…
younesbelkada 6f1adcd
nit
younesbelkada f890906
move tests
younesbelkada f8e87f6
add conversion utils
younesbelkada 3ba2d4e
Merge remote-tracking branch 'upstream/main' into peftpart-1
younesbelkada dc83fa0
remove unneeded methods
younesbelkada b83fcba
use class method instead
younesbelkada 74e33a9
oops
younesbelkada 9cb8563
use `base_version`
younesbelkada c90f85d
fix examples
younesbelkada 40a4894
fix CI
younesbelkada ea05959
fix weird error with python 3.8
younesbelkada 27e3da6
fix
younesbelkada 3d7c567
better fix
younesbelkada d01a292
style
younesbelkada e836b14
Apply suggestions from code review
younesbelkada cb48405
Apply suggestions from code review
younesbelkada 325462d
add comment
younesbelkada b412adc
Apply suggestions from code review
younesbelkada b72ef23
conv2d support for recurse remove
younesbelkada e072655
added docstrings
younesbelkada bd46ae9
more docstring
younesbelkada 724b52b
add deprecate
younesbelkada 5e6f343
revert
younesbelkada 71650d4
try to fix merge conflicts
younesbelkada 920333f
Merge remote-tracking branch 'upstream/main' into peftpart-1
younesbelkada 0985d17
peft integration features for text encoder
pacman100 ece3b02
Merge branch 'main' into peftpart-1
pacman100 01a15cc
fix bug
pacman100 080db75
Merge branch 'main' into smangrul/peft-integration
pacman100 ffbac30
fix code quality
pacman100 916c31a
Apply suggestions from code review
pacman100 5de0f1b
fix bugs
pacman100 c32872e
Merge branch 'smangrul/peft-integration' of https://github.com/huggin…
pacman100 0acb58c
Apply suggestions from code review
pacman100 1ca4c62
address comments
pacman100 7c37788
fix code quality
pacman100 2fcf174
address comments
pacman100 a1f0128
address comments
pacman100 7b2ccff
Merge branch 'main' into smangrul/peft-integration
patrickvonplaten fd9bcfe
Apply suggestions from code review
patrickvonplaten 9916ac6
find and replace
patrickvonplaten File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,18 +35,23 @@ | |
| convert_state_dict_to_diffusers, | ||
| convert_state_dict_to_peft, | ||
| deprecate, | ||
| get_adapter_name, | ||
| get_peft_kwargs, | ||
| is_accelerate_available, | ||
| is_omegaconf_available, | ||
| is_peft_available, | ||
| is_transformers_available, | ||
| logging, | ||
| recurse_remove_peft_layers, | ||
| scale_lora_layers, | ||
| set_adapter_layers, | ||
| set_weights_and_activate_adapters, | ||
| ) | ||
| from .utils.import_utils import BACKENDS_MAPPING | ||
|
|
||
|
|
||
| if is_transformers_available(): | ||
| from transformers import CLIPTextModel, CLIPTextModelWithProjection | ||
| from transformers import CLIPTextModel, CLIPTextModelWithProjection, PreTrainedModel | ||
|
|
||
| if is_accelerate_available(): | ||
| from accelerate import init_empty_weights | ||
|
|
@@ -1100,7 +1105,9 @@ class LoraLoaderMixin: | |
| num_fused_loras = 0 | ||
| use_peft_backend = USE_PEFT_BACKEND | ||
|
|
||
| def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): | ||
| def load_lora_weights( | ||
| self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs | ||
| ): | ||
| """ | ||
| Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and | ||
| `self.text_encoder`. | ||
|
|
@@ -1120,6 +1127,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di | |
| See [`~loaders.LoraLoaderMixin.lora_state_dict`]. | ||
| kwargs (`dict`, *optional*): | ||
| See [`~loaders.LoraLoaderMixin.lora_state_dict`]. | ||
| adapter_name (`str`, *optional*): | ||
| Adapter name to be used for referencing the loaded adapter model. If not specified, it will use | ||
| `default_{i}` where i is the total number of adapters being loaded. | ||
| """ | ||
| # First, ensure that the checkpoint is a compatible one and can be successfully loaded. | ||
| state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) | ||
|
|
@@ -1143,6 +1153,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di | |
| text_encoder=self.text_encoder, | ||
| lora_scale=self.lora_scale, | ||
| low_cpu_mem_usage=low_cpu_mem_usage, | ||
| adapter_name=adapter_name, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's keep the private variable at the end. |
||
| _pipeline=self, | ||
| ) | ||
|
|
||
|
|
@@ -1500,6 +1511,7 @@ def load_lora_into_text_encoder( | |
| prefix=None, | ||
| lora_scale=1.0, | ||
| low_cpu_mem_usage=None, | ||
| adapter_name=None, | ||
| _pipeline=None, | ||
| ): | ||
| """ | ||
|
|
@@ -1523,6 +1535,9 @@ def load_lora_into_text_encoder( | |
| tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. | ||
| Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this | ||
| argument to `True` will raise an error. | ||
| adapter_name (`str`, *optional*): | ||
| Adapter name to be used for referencing the loaded adapter model. If not specified, it will use | ||
| `default_{i}` where i is the total number of adapters being loaded. | ||
| """ | ||
| low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT | ||
|
|
||
|
|
@@ -1584,19 +1599,22 @@ def load_lora_into_text_encoder( | |
| if cls.use_peft_backend: | ||
| from peft import LoraConfig | ||
|
|
||
| lora_rank = list(rank.values())[0] | ||
| # By definition, the scale should be alpha divided by rank. | ||
| # https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/tuners/lora/layer.py#L71 | ||
| alpha = lora_scale * lora_rank | ||
| lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict) | ||
|
|
||
| target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] | ||
| if patch_mlp: | ||
| target_modules += ["fc1", "fc2"] | ||
| lora_config = LoraConfig(**lora_config_kwargs) | ||
|
|
||
| # TODO: support multi alpha / rank: https://github.com/huggingface/peft/pull/873 | ||
| lora_config = LoraConfig(r=lora_rank, target_modules=target_modules, lora_alpha=alpha) | ||
| # adapter_name | ||
| if adapter_name is None: | ||
| adapter_name = get_adapter_name(text_encoder) | ||
|
|
||
| text_encoder.load_adapter(adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config) | ||
| # inject LoRA layers and load the state dict | ||
| text_encoder.load_adapter( | ||
| adapter_name=adapter_name, | ||
| adapter_state_dict=text_encoder_lora_state_dict, | ||
| peft_config=lora_config, | ||
| ) | ||
| # scale LoRA layers with `lora_scale` | ||
| scale_lora_layers(text_encoder, weight=lora_scale) | ||
|
|
||
| is_model_cpu_offload = False | ||
| is_sequential_cpu_offload = False | ||
|
|
@@ -2178,6 +2196,81 @@ def unfuse_text_encoder_lora(text_encoder): | |
|
|
||
| self.num_fused_loras -= 1 | ||
|
|
||
| def set_adapter_for_text_encoder( | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self, | ||
| adapter_names: Union[List[str], str], | ||
| text_encoder: Optional[PreTrainedModel] = None, | ||
| text_encoder_weights: List[float] = None, | ||
| ): | ||
pacman100 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| Sets the adapter layers for the text encoder. | ||
|
|
||
| Args: | ||
| adapter_names (`List[str]` or `str`): | ||
| The names of the adapters to use. | ||
| text_encoder (`torch.nn.Module`, *optional*): | ||
| The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder` | ||
| attribute. | ||
| text_encoder_weights (`List[float]`, *optional*): | ||
| The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters. | ||
| """ | ||
| if not self.use_peft_backend: | ||
| raise ValueError("PEFT backend is required for this method.") | ||
|
|
||
| def process_weights(adapter_names, weights): | ||
| if weights is None: | ||
| weights = [1.0] * len(adapter_names) | ||
| elif isinstance(weights, float): | ||
| weights = [weights] | ||
|
|
||
| if len(adapter_names) != len(weights): | ||
| raise ValueError( | ||
| f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}" | ||
| ) | ||
| return weights | ||
|
|
||
| adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names | ||
| text_encoder_weights = process_weights(adapter_names, text_encoder_weights) | ||
| text_encoder = text_encoder or getattr(self, "text_encoder", None) | ||
| if text_encoder is None: | ||
| raise ValueError( | ||
| "The pipeline does not have a default `pipe.text_encoder` class. Please make sure to pass a `text_encoder` instead." | ||
| ) | ||
| set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights) | ||
|
|
||
| def disable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel] = None): | ||
| """ | ||
| Disables the LoRA layers for the text encoder. | ||
|
|
||
| Args: | ||
| text_encoder (`torch.nn.Module`, *optional*): | ||
| The text encoder module to disable the LoRA layers for. If `None`, it will try to get the | ||
| `text_encoder` attribute. | ||
| """ | ||
| if not self.use_peft_backend: | ||
| raise ValueError("PEFT backend is required for this method.") | ||
|
|
||
| text_encoder = text_encoder or getattr(self, "text_encoder", None) | ||
| if text_encoder is None: | ||
| raise ValueError("Text Encoder not found.") | ||
| set_adapter_layers(text_encoder, enabled=False) | ||
|
|
||
| def enable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel] = None): | ||
| """ | ||
| Enables the LoRA layers for the text encoder. | ||
|
|
||
| Args: | ||
| text_encoder (`torch.nn.Module`, *optional*): | ||
| The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder` | ||
| attribute. | ||
| """ | ||
| if not self.use_peft_backend: | ||
| raise ValueError("PEFT backend is required for this method.") | ||
| text_encoder = text_encoder or getattr(self, "text_encoder", None) | ||
| if text_encoder is None: | ||
| raise ValueError("Text Encoder not found.") | ||
| set_adapter_layers(self.text_encoder, enabled=True) | ||
|
|
||
|
|
||
| class FromSingleFileMixin: | ||
| """ | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.