@@ -680,6 +680,11 @@ def save_lora_weights(
680680 save_function : Callable = None ,
681681 safe_serialization : bool = True ,
682682 transformer_lora_adapter_metadata : Optional [dict ] = None ,
683+ text_encoder_lora_adapter_metadata : Optional [dict ] = None ,
684+ text_encoder_2_lora_adapter_metadata : Optional [dict ] = None ,
685+ text_encoder_3_lora_adapter_metadata : Optional [dict ] = None ,
686+ text_encoder_4_lora_adapter_metadata : Optional [dict ] = None ,
687+ controlnet_lora_adapter_metadata : Optional [dict ] = None ,
683688 ):
684689 r"""
685690 Save the LoRA parameters corresponding to the transformer, text encoders, and optionally controlnet.
@@ -711,6 +716,16 @@ def save_lora_weights(
711716 Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
712717 transformer_lora_adapter_metadata:
713718 LoRA adapter metadata associated with the transformer to be serialized with the state dict.
719+ text_encoder_lora_adapter_metadata:
720+ LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
721+ text_encoder_2_lora_adapter_metadata:
722+ LoRA adapter metadata associated with the second text encoder to be serialized with the state dict.
723+ text_encoder_3_lora_adapter_metadata:
724+ LoRA adapter metadata associated with the third text encoder to be serialized with the state dict.
725+ text_encoder_4_lora_adapter_metadata:
726+ LoRA adapter metadata associated with the fourth text encoder to be serialized with the state dict.
727+ controlnet_lora_adapter_metadata:
728+ LoRA adapter metadata associated with the controlnet to be serialized with the state dict.
714729 """
715730 state_dict = {}
716731 lora_adapter_metadata = {}
@@ -746,6 +761,21 @@ def save_lora_weights(
746761 if transformer_lora_adapter_metadata is not None :
747762 lora_adapter_metadata .update (cls .pack_weights (transformer_lora_adapter_metadata , cls .transformer_name ))
748763
764+ if text_encoder_lora_adapter_metadata is not None :
765+ lora_adapter_metadata .update (cls .pack_weights (text_encoder_lora_adapter_metadata , "text_encoder" ))
766+
767+ if text_encoder_2_lora_adapter_metadata is not None :
768+ lora_adapter_metadata .update (cls .pack_weights (text_encoder_2_lora_adapter_metadata , "text_encoder_2" ))
769+
770+ if text_encoder_3_lora_adapter_metadata is not None :
771+ lora_adapter_metadata .update (cls .pack_weights (text_encoder_3_lora_adapter_metadata , "text_encoder_3" ))
772+
773+ if text_encoder_4_lora_adapter_metadata is not None :
774+ lora_adapter_metadata .update (cls .pack_weights (text_encoder_4_lora_adapter_metadata , "text_encoder_4" ))
775+
776+ if controlnet_lora_adapter_metadata is not None :
777+ lora_adapter_metadata .update (cls .pack_weights (controlnet_lora_adapter_metadata , cls .controlnet_name ))
778+
749779 # Save the model
750780 cls .write_lora_layers (
751781 state_dict = state_dict ,
0 commit comments