3535 convert_state_dict_to_diffusers ,
3636 convert_state_dict_to_peft ,
3737 deprecate ,
38+ get_adapter_name ,
39+ get_peft_kwargs ,
3840 is_accelerate_available ,
3941 is_omegaconf_available ,
4042 is_peft_available ,
4143 is_transformers_available ,
4244 logging ,
4345 recurse_remove_peft_layers ,
46+ scale_lora_layers ,
47+ set_adapter_layers ,
48+ set_weights_and_activate_adapters ,
4449)
4550from .utils .import_utils import BACKENDS_MAPPING
4651
4752
4853if is_transformers_available ():
49- from transformers import CLIPTextModel , CLIPTextModelWithProjection
54+ from transformers import CLIPTextModel , CLIPTextModelWithProjection , PreTrainedModel
5055
5156if is_accelerate_available ():
5257 from accelerate import init_empty_weights
@@ -1100,7 +1105,9 @@ class LoraLoaderMixin:
11001105 num_fused_loras = 0
11011106 use_peft_backend = USE_PEFT_BACKEND
11021107
1103- def load_lora_weights (self , pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]], ** kwargs ):
1108+ def load_lora_weights (
1109+ self , pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]], adapter_name = None , ** kwargs
1110+ ):
11041111 """
11051112 Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
11061113 `self.text_encoder`.
@@ -1120,6 +1127,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
11201127 See [`~loaders.LoraLoaderMixin.lora_state_dict`].
11211128 kwargs (`dict`, *optional*):
11221129 See [`~loaders.LoraLoaderMixin.lora_state_dict`].
1130+ adapter_name (`str`, *optional*):
1131+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1132+ `default_{i}` where i is the total number of adapters being loaded.
11231133 """
11241134 # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
11251135 state_dict , network_alphas = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
@@ -1143,6 +1153,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
11431153 text_encoder = self .text_encoder ,
11441154 lora_scale = self .lora_scale ,
11451155 low_cpu_mem_usage = low_cpu_mem_usage ,
1156+ adapter_name = adapter_name ,
11461157 _pipeline = self ,
11471158 )
11481159
@@ -1500,6 +1511,7 @@ def load_lora_into_text_encoder(
15001511 prefix = None ,
15011512 lora_scale = 1.0 ,
15021513 low_cpu_mem_usage = None ,
1514+ adapter_name = None ,
15031515 _pipeline = None ,
15041516 ):
15051517 """
@@ -1523,6 +1535,9 @@ def load_lora_into_text_encoder(
15231535 tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
15241536 Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
15251537 argument to `True` will raise an error.
1538+ adapter_name (`str`, *optional*):
1539+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1540+ `default_{i}` where i is the total number of adapters being loaded.
15261541 """
15271542 low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
15281543
@@ -1584,19 +1599,22 @@ def load_lora_into_text_encoder(
15841599 if cls .use_peft_backend :
15851600 from peft import LoraConfig
15861601
1587- lora_rank = list (rank .values ())[0 ]
1588- # By definition, the scale should be alpha divided by rank.
1589- # https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/tuners/lora/layer.py#L71
1590- alpha = lora_scale * lora_rank
1602+ lora_config_kwargs = get_peft_kwargs (rank , network_alphas , text_encoder_lora_state_dict )
15911603
1592- target_modules = ["q_proj" , "k_proj" , "v_proj" , "out_proj" ]
1593- if patch_mlp :
1594- target_modules += ["fc1" , "fc2" ]
1604+ lora_config = LoraConfig (** lora_config_kwargs )
15951605
1596- # TODO: support multi alpha / rank: https://github.com/huggingface/peft/pull/873
1597- lora_config = LoraConfig (r = lora_rank , target_modules = target_modules , lora_alpha = alpha )
1606+ # adapter_name
1607+ if adapter_name is None :
1608+ adapter_name = get_adapter_name (text_encoder )
15981609
1599- text_encoder .load_adapter (adapter_state_dict = text_encoder_lora_state_dict , peft_config = lora_config )
1610+ # inject LoRA layers and load the state dict
1611+ text_encoder .load_adapter (
1612+ adapter_name = adapter_name ,
1613+ adapter_state_dict = text_encoder_lora_state_dict ,
1614+ peft_config = lora_config ,
1615+ )
1616+ # scale LoRA layers with `lora_scale`
1617+ scale_lora_layers (text_encoder , weight = lora_scale )
16001618
16011619 is_model_cpu_offload = False
16021620 is_sequential_cpu_offload = False
@@ -2178,6 +2196,81 @@ def unfuse_text_encoder_lora(text_encoder):
21782196
21792197 self .num_fused_loras -= 1
21802198
2199+ def set_adapter_for_text_encoder (
2200+ self ,
2201+ adapter_names : Union [List [str ], str ],
2202+ text_encoder : Optional [PreTrainedModel ] = None ,
2203+ text_encoder_weights : List [float ] = None ,
2204+ ):
2205+ """
2206+ Sets the adapter layers for the text encoder.
2207+
2208+ Args:
2209+ adapter_names (`List[str]` or `str`):
2210+ The names of the adapters to use.
2211+ text_encoder (`torch.nn.Module`, *optional*):
2212+ The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
2213+ attribute.
2214+ text_encoder_weights (`List[float]`, *optional*):
2215+ The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters.
2216+ """
2217+ if not self .use_peft_backend :
2218+ raise ValueError ("PEFT backend is required for this method." )
2219+
2220+ def process_weights (adapter_names , weights ):
2221+ if weights is None :
2222+ weights = [1.0 ] * len (adapter_names )
2223+ elif isinstance (weights , float ):
2224+ weights = [weights ]
2225+
2226+ if len (adapter_names ) != len (weights ):
2227+ raise ValueError (
2228+ f"Length of adapter names { len (adapter_names )} is not equal to the length of the weights { len (weights )} "
2229+ )
2230+ return weights
2231+
2232+ adapter_names = [adapter_names ] if isinstance (adapter_names , str ) else adapter_names
2233+ text_encoder_weights = process_weights (adapter_names , text_encoder_weights )
2234+ text_encoder = text_encoder or getattr (self , "text_encoder" , None )
2235+ if text_encoder is None :
2236+ raise ValueError (
2237+ "The pipeline does not have a default `pipe.text_encoder` class. Please make sure to pass a `text_encoder` instead."
2238+ )
2239+ set_weights_and_activate_adapters (text_encoder , adapter_names , text_encoder_weights )
2240+
2241+ def disable_lora_for_text_encoder (self , text_encoder : Optional [PreTrainedModel ] = None ):
2242+ """
2243+ Disables the LoRA layers for the text encoder.
2244+
2245+ Args:
2246+ text_encoder (`torch.nn.Module`, *optional*):
2247+ The text encoder module to disable the LoRA layers for. If `None`, it will try to get the
2248+ `text_encoder` attribute.
2249+ """
2250+ if not self .use_peft_backend :
2251+ raise ValueError ("PEFT backend is required for this method." )
2252+
2253+ text_encoder = text_encoder or getattr (self , "text_encoder" , None )
2254+ if text_encoder is None :
2255+ raise ValueError ("Text Encoder not found." )
2256+ set_adapter_layers (text_encoder , enabled = False )
2257+
2258+ def enable_lora_for_text_encoder (self , text_encoder : Optional [PreTrainedModel ] = None ):
2259+ """
2260+ Enables the LoRA layers for the text encoder.
2261+
2262+ Args:
2263+ text_encoder (`torch.nn.Module`, *optional*):
2264+ The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder`
2265+ attribute.
2266+ """
2267+ if not self .use_peft_backend :
2268+ raise ValueError ("PEFT backend is required for this method." )
2269+ text_encoder = text_encoder or getattr (self , "text_encoder" , None )
2270+ if text_encoder is None :
2271+ raise ValueError ("Text Encoder not found." )
2272+ set_adapter_layers (self .text_encoder , enabled = True )
2273+
21812274
21822275class FromSingleFileMixin :
21832276 """
0 commit comments