Skip to content

Commit 2b0e33a

Browse files
authored
Merge pull request #1853 from bghira/bugfix/lora-alpha-metadata
integrate lora_alpha fully with diffusers metadata collection
2 parents 35526b3 + d80ff3f commit 2b0e33a

File tree

14 files changed

+377
-26
lines changed

14 files changed

+377
-26
lines changed

simpletuner/helpers/models/auraflow/pipeline.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,8 @@ def save_lora_weights(
386386
weight_name: str = None,
387387
save_function: Callable = None,
388388
safe_serialization: bool = True,
389+
transformer_lora_adapter_metadata: Optional[dict] = None,
390+
controlnet_lora_adapter_metadata: Optional[dict] = None,
389391
):
390392
r"""
391393
Save the LoRA parameters corresponding to the transformer and optionally controlnet.
@@ -409,6 +411,7 @@ def save_lora_weights(
409411
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
410412
"""
411413
state_dict = {}
414+
lora_adapter_metadata = {}
412415

413416
if not (transformer_lora_layers or controlnet_lora_layers):
414417
raise ValueError("You must pass at least one of `transformer_lora_layers` or `controlnet_lora_layers`.")
@@ -419,6 +422,12 @@ def save_lora_weights(
419422
if controlnet_lora_layers:
420423
state_dict.update(cls.pack_weights(controlnet_lora_layers, cls.controlnet_name))
421424

425+
if transformer_lora_adapter_metadata:
426+
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
427+
428+
if controlnet_lora_adapter_metadata:
429+
lora_adapter_metadata.update(cls.pack_weights(controlnet_lora_adapter_metadata, cls.controlnet_name))
430+
422431
# Save the model
423432
cls.write_lora_layers(
424433
state_dict=state_dict,
@@ -427,6 +436,7 @@ def save_lora_weights(
427436
weight_name=weight_name,
428437
save_function=save_function,
429438
safe_serialization=safe_serialization,
439+
lora_adapter_metadata=lora_adapter_metadata,
430440
)
431441

432442
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora

simpletuner/helpers/models/chroma/pipeline.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,9 @@ def save_lora_weights(
658658
weight_name: str = None,
659659
save_function: Callable = None,
660660
safe_serialization: bool = True,
661+
transformer_lora_adapter_metadata: Optional[dict] = None,
662+
text_encoder_lora_adapter_metadata: Optional[dict] = None,
663+
controlnet_lora_adapter_metadata: Optional[dict] = None,
661664
):
662665
r"""
663666
Save the LoRA parameters of the UNet, text encoder and optionally the ControlNet.
@@ -684,6 +687,7 @@ def save_lora_weights(
684687
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
685688
"""
686689
state_dict = {}
690+
lora_adapter_metadata = {}
687691

688692
if not (transformer_lora_layers or text_encoder_lora_layers or controlnet_lora_layers):
689693
raise ValueError(
@@ -700,6 +704,16 @@ def save_lora_weights(
700704
controlnet_prefix = "controlnet"
701705
state_dict.update(cls.pack_weights(controlnet_lora_layers, controlnet_prefix))
702706

707+
if transformer_lora_adapter_metadata:
708+
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
709+
710+
if text_encoder_lora_adapter_metadata:
711+
lora_adapter_metadata.update(cls.pack_weights(text_encoder_lora_adapter_metadata, cls.text_encoder_name))
712+
713+
if controlnet_lora_adapter_metadata:
714+
controlnet_prefix = "controlnet"
715+
lora_adapter_metadata.update(cls.pack_weights(controlnet_lora_adapter_metadata, controlnet_prefix))
716+
703717
# Save the model
704718
cls.write_lora_layers(
705719
state_dict=state_dict,
@@ -708,6 +722,7 @@ def save_lora_weights(
708722
weight_name=weight_name,
709723
save_function=save_function,
710724
safe_serialization=safe_serialization,
725+
lora_adapter_metadata=lora_adapter_metadata,
711726
)
712727

713728
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer

simpletuner/helpers/models/flux/pipeline.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -812,6 +812,9 @@ def save_lora_weights(
812812
weight_name: str = None,
813813
save_function: Callable = None,
814814
safe_serialization: bool = True,
815+
transformer_lora_adapter_metadata: Optional[dict] = None,
816+
text_encoder_lora_adapter_metadata: Optional[dict] = None,
817+
controlnet_lora_adapter_metadata: Optional[dict] = None,
815818
):
816819
r"""
817820
Save the LoRA parameters corresponding to the UNet, text encoder, and optionally controlnet.
@@ -838,6 +841,7 @@ def save_lora_weights(
838841
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
839842
"""
840843
state_dict = {}
844+
lora_adapter_metadata = {}
841845

842846
if not (transformer_lora_layers or text_encoder_lora_layers or controlnet_lora_layers):
843847
raise ValueError(
@@ -854,6 +858,16 @@ def save_lora_weights(
854858
controlnet_prefix = "controlnet"
855859
state_dict.update(cls.pack_weights(controlnet_lora_layers, controlnet_prefix))
856860

861+
if transformer_lora_adapter_metadata:
862+
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
863+
864+
if text_encoder_lora_adapter_metadata:
865+
lora_adapter_metadata.update(cls.pack_weights(text_encoder_lora_adapter_metadata, cls.text_encoder_name))
866+
867+
if controlnet_lora_adapter_metadata:
868+
controlnet_prefix = "controlnet"
869+
lora_adapter_metadata.update(cls.pack_weights(controlnet_lora_adapter_metadata, controlnet_prefix))
870+
857871
# Save the model
858872
cls.write_lora_layers(
859873
state_dict=state_dict,
@@ -862,6 +876,7 @@ def save_lora_weights(
862876
weight_name=weight_name,
863877
save_function=save_function,
864878
safe_serialization=safe_serialization,
879+
lora_adapter_metadata=lora_adapter_metadata,
865880
)
866881

867882
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer

simpletuner/helpers/models/hidream/pipeline.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,11 @@ def save_lora_weights(
680680
save_function: Callable = None,
681681
safe_serialization: bool = True,
682682
transformer_lora_adapter_metadata: Optional[dict] = None,
683+
text_encoder_lora_adapter_metadata: Optional[dict] = None,
684+
text_encoder_2_lora_adapter_metadata: Optional[dict] = None,
685+
text_encoder_3_lora_adapter_metadata: Optional[dict] = None,
686+
text_encoder_4_lora_adapter_metadata: Optional[dict] = None,
687+
controlnet_lora_adapter_metadata: Optional[dict] = None,
683688
):
684689
r"""
685690
Save the LoRA parameters corresponding to the transformer, text encoders, and optionally controlnet.
@@ -711,6 +716,16 @@ def save_lora_weights(
711716
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
712717
transformer_lora_adapter_metadata:
713718
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
719+
text_encoder_lora_adapter_metadata:
720+
LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
721+
text_encoder_2_lora_adapter_metadata:
722+
LoRA adapter metadata associated with the second text encoder to be serialized with the state dict.
723+
text_encoder_3_lora_adapter_metadata:
724+
LoRA adapter metadata associated with the third text encoder to be serialized with the state dict.
725+
text_encoder_4_lora_adapter_metadata:
726+
LoRA adapter metadata associated with the fourth text encoder to be serialized with the state dict.
727+
controlnet_lora_adapter_metadata:
728+
LoRA adapter metadata associated with the controlnet to be serialized with the state dict.
714729
"""
715730
state_dict = {}
716731
lora_adapter_metadata = {}
@@ -746,6 +761,21 @@ def save_lora_weights(
746761
if transformer_lora_adapter_metadata is not None:
747762
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
748763

764+
if text_encoder_lora_adapter_metadata is not None:
765+
lora_adapter_metadata.update(cls.pack_weights(text_encoder_lora_adapter_metadata, "text_encoder"))
766+
767+
if text_encoder_2_lora_adapter_metadata is not None:
768+
lora_adapter_metadata.update(cls.pack_weights(text_encoder_2_lora_adapter_metadata, "text_encoder_2"))
769+
770+
if text_encoder_3_lora_adapter_metadata is not None:
771+
lora_adapter_metadata.update(cls.pack_weights(text_encoder_3_lora_adapter_metadata, "text_encoder_3"))
772+
773+
if text_encoder_4_lora_adapter_metadata is not None:
774+
lora_adapter_metadata.update(cls.pack_weights(text_encoder_4_lora_adapter_metadata, "text_encoder_4"))
775+
776+
if controlnet_lora_adapter_metadata is not None:
777+
lora_adapter_metadata.update(cls.pack_weights(controlnet_lora_adapter_metadata, cls.controlnet_name))
778+
749779
# Save the model
750780
cls.write_lora_layers(
751781
state_dict=state_dict,

simpletuner/helpers/models/pixart/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,8 @@ def save_lora_weights(self, output_dir: str, **kwargs):
393393
save_directory=output_dir,
394394
transformer_lora_layers=None, # No transformer LoRA
395395
controlnet_lora_layers=controlnet_lora_layers,
396+
transformer_lora_adapter_metadata=kwargs.get("transformer_lora_adapter_metadata"),
397+
controlnet_lora_adapter_metadata=kwargs.get("controlnet_lora_adapter_metadata"),
396398
)
397399

398400
def custom_model_card_schedule_info(self):

simpletuner/helpers/models/pixart/pipeline.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,12 @@ def save_lora_weights(
207207
weight_name: str = None,
208208
save_function: Callable = None,
209209
safe_serialization: bool = True,
210+
transformer_lora_adapter_metadata: Optional[dict] = None,
211+
controlnet_lora_adapter_metadata: Optional[dict] = None,
210212
):
211213
"""Save LoRA weights for both transformer and controlnet."""
212214
state_dict = {}
215+
lora_adapter_metadata = {}
213216

214217
# Pack transformer weights (only the non-replaced blocks)
215218
if transformer_lora_layers:
@@ -220,6 +223,12 @@ def save_lora_weights(
220223
if controlnet_lora_layers:
221224
state_dict.update(controlnet_lora_layers) # they're already packed
222225

226+
if transformer_lora_adapter_metadata:
227+
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
228+
229+
if controlnet_lora_adapter_metadata:
230+
lora_adapter_metadata.update(cls.pack_weights(controlnet_lora_adapter_metadata, cls.controlnet_name))
231+
223232
# Save the model
224233
cls.write_lora_layers(
225234
state_dict=state_dict,
@@ -228,6 +237,7 @@ def save_lora_weights(
228237
weight_name=weight_name,
229238
save_function=save_function,
230239
safe_serialization=safe_serialization,
240+
lora_adapter_metadata=lora_adapter_metadata,
231241
)
232242

233243
def load_lora_weights(

simpletuner/helpers/models/sd3/pipeline.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,10 @@ def save_lora_weights(
779779
weight_name: str = None,
780780
save_function: Callable = None,
781781
safe_serialization: bool = True,
782+
transformer_lora_adapter_metadata: Optional[dict] = None,
783+
text_encoder_lora_adapter_metadata: Optional[dict] = None,
784+
text_encoder_2_lora_adapter_metadata: Optional[dict] = None,
785+
controlnet_lora_adapter_metadata: Optional[dict] = None,
782786
):
783787
r"""
784788
Save the LoRA parameters corresponding to the transformer, text encoders, and optionally controlnet.
@@ -808,6 +812,7 @@ def save_lora_weights(
808812
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
809813
"""
810814
state_dict = {}
815+
lora_adapter_metadata = {}
811816

812817
if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers or controlnet_lora_layers):
813818
raise ValueError(
@@ -826,6 +831,18 @@ def save_lora_weights(
826831
if controlnet_lora_layers:
827832
state_dict.update(cls.pack_weights(controlnet_lora_layers, cls.controlnet_name))
828833

834+
if transformer_lora_adapter_metadata:
835+
lora_adapter_metadata.update(cls.pack_weights(transformer_lora_adapter_metadata, cls.transformer_name))
836+
837+
if text_encoder_lora_adapter_metadata:
838+
lora_adapter_metadata.update(cls.pack_weights(text_encoder_lora_adapter_metadata, "text_encoder"))
839+
840+
if text_encoder_2_lora_adapter_metadata:
841+
lora_adapter_metadata.update(cls.pack_weights(text_encoder_2_lora_adapter_metadata, "text_encoder_2"))
842+
843+
if controlnet_lora_adapter_metadata:
844+
lora_adapter_metadata.update(cls.pack_weights(controlnet_lora_adapter_metadata, cls.controlnet_name))
845+
829846
# Save the model
830847
cls.write_lora_layers(
831848
state_dict=state_dict,
@@ -834,6 +851,7 @@ def save_lora_weights(
834851
weight_name=weight_name,
835852
save_function=save_function,
836853
safe_serialization=safe_serialization,
854+
lora_adapter_metadata=lora_adapter_metadata,
837855
)
838856

839857
def fuse_lora(

simpletuner/helpers/models/sdxl/pipeline.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,10 @@ def save_lora_weights(
371371
weight_name: str = None,
372372
save_function: Callable = None,
373373
safe_serialization: bool = True,
374+
unet_lora_adapter_metadata: Optional[dict] = None,
375+
text_encoder_lora_adapter_metadata: Optional[dict] = None,
376+
text_encoder_2_lora_adapter_metadata: Optional[dict] = None,
377+
controlnet_lora_adapter_metadata: Optional[dict] = None,
374378
):
375379
r"""
376380
Save the LoRA parameters corresponding to the UNet, text encoders, and optionally controlnet.
@@ -394,6 +398,7 @@ def save_lora_weights(
394398
Whether to save the model using `safetensors` or the traditional PyTorch way.
395399
"""
396400
state_dict = {}
401+
lora_adapter_metadata = {}
397402

398403
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers or controlnet_lora_layers):
399404
raise ValueError(
@@ -413,13 +418,26 @@ def save_lora_weights(
413418
if controlnet_lora_layers:
414419
state_dict.update(cls.pack_weights(controlnet_lora_layers, cls.controlnet_name))
415420

421+
if unet_lora_adapter_metadata:
422+
lora_adapter_metadata.update(cls.pack_weights(unet_lora_adapter_metadata, "unet"))
423+
424+
if text_encoder_lora_adapter_metadata:
425+
lora_adapter_metadata.update(cls.pack_weights(text_encoder_lora_adapter_metadata, "text_encoder"))
426+
427+
if text_encoder_2_lora_adapter_metadata:
428+
lora_adapter_metadata.update(cls.pack_weights(text_encoder_2_lora_adapter_metadata, "text_encoder_2"))
429+
430+
if controlnet_lora_adapter_metadata:
431+
lora_adapter_metadata.update(cls.pack_weights(controlnet_lora_adapter_metadata, cls.controlnet_name))
432+
416433
cls.write_lora_layers(
417434
state_dict=state_dict,
418435
save_directory=save_directory,
419436
is_main_process=is_main_process,
420437
weight_name=weight_name,
421438
save_function=save_function,
422439
safe_serialization=safe_serialization,
440+
lora_adapter_metadata=lora_adapter_metadata,
423441
)
424442

425443
def fuse_lora(

0 commit comments

Comments
 (0)