Skip to content

Conversation

@pacman100
Copy link
Contributor

@pacman100 pacman100 commented Aug 28, 2023

What does this PR do?

  1. Add support for different ranks and alphas for LoRA tuner.
  2. Add support for multiple active adapters

To Do:

  • Tests
  • Documentation
  • Example

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 28, 2023

The documentation is not available anymore as the PR was closed or merged.

@pacman100 pacman100 marked this pull request as ready for review August 28, 2023 11:54
Co-Authored-By: Benjamin Bossan <[email protected]>
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for adding this so quickly.

There are a few things I'd like to see in relation to this PR, but given that it helps with the diffusers PR, I'm fine with adding those later to speed things up:

  • unit tests
  • an example in the docs or examples folder
  • similar config options for other tuners where it makes sense, like AdaLoRA
  • maybe: regex matching

@kovalexal
Copy link
Contributor

Hi, @pacman100! This feature is a super useful addition!

I am currently working on proper support of LoRA-C3Lier for Stable Diffusion. It's a special type of LoRA for which the Conv2d 3x3 layers may have different ranks and alphas. So this code will be super-useful for me to finish implementing this type of LoRA.

Do you by chance know what is the ETA for this feature to get into main branch?

@kovalexal
Copy link
Contributor

@pacman100, @BenjaminBossan sorry to disturb you, I've been trying out the code from this branch, and I've faced and issue which is related to the usage of LoRAs with provided rank_pattern and alpha_pattern.

In short - if you need to create 2 adapters for model - first one without provided rank_pattern and alpha_pattern, second one with rank_pattern and alpha_pattern provided. After adding second adapter you can see, that ranks (and also alphas) for custom layers are not taken into account.

Here is a demo notebook which demonstrates how to reproduce this issue.

As far as I see, there is a possible problem in this specific lines:

current_key = optional_kwargs["current_key"]
r = lora_config.rank_pattern.get(current_key, lora_config.r)
alpha = lora_config.alpha_pattern.get(current_key, lora_config.lora_alpha)

When the first adapter is initialized - the name of the module is passed as usual (e.g. "path.to.module"), but when the second adapter is added - the name of the module is nested in peft prefix (e.g. "base_model.model.path.to.module"), so that's why custom alphas and ranks are not propagated properly.

@kovalexal
Copy link
Contributor

@pacman100 @BenjaminBossan I've addressed this issue and worked on 3 out of 4 comments @BenjaminBossan left (all except AdaLoRA).

Please let me know if there a way for me to somehow add my modifications to this PR.

@pacman100
Copy link
Contributor Author

Hello @kovalexal, Thank you for actively working on this! Please raise a PR to merge your changes in this PR, we will be happy to quickly review it and incorporate the changes.

@pacman100
Copy link
Contributor Author

pacman100 commented Sep 15, 2023

Do you by chance know what is the ETA for this feature to get into main branch?

Next week is the plan

kovalexal and others added 4 commits September 18, 2023 17:19
…ort of LoRA-C3Lier conversion (#937)

* Fixed multirank multialpha for sequential loras, added tests, fixed docs

* Refactored kohya_ss conversion script for proper support of LoRA-C3Lier

* Fixed styling

* Removed old comment from docstring
Co-Authored-By: Benjamin Bossan <[email protected]>
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot Sourab! I left few minor comments as a first pass ! Let me know what do you think

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, this is a huge addition to PEFT.

I left a couple of comments, please take a look if they make sense. I have mostly reviewed the lora code closely, as I think IA³ and AdaLora just follow the same logic, so some of my comments apply there as well.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, this LGTM, thanks for addressing the issues. I think there is one bug, however, please check my comment.

It is important that we also add tests for multiple active adapters, but if this is time critical, I'm okay with adding them later. Docs/examples can also be added later.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for addressing the comments.

Let's not forget to add tests + docs + examples later.

I think using multiple adapters with dropout could be problematic, as more and more values will be set to 0. But I guess that's on the user to address, not sure if there is anything that we could do about it.

@pacman100
Copy link
Contributor Author

I think using multiple adapters with dropout could be problematic, as more and more values will be set to 0. But I guess that's on the user to address, not sure if there is anything that we could do about it.

I believe the support for multiple active adapters should be used only during inference in eval mode. In that case, dropout won't be an issue.

@BenjaminBossan
Copy link
Member

I believe the support for multiple active adapters should be used only during inference in eval mode. In that case, dropout won't be an issue.

I see. But in theory, it would also work during training, right? And there isn't anything in the code right now that would prevent that.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work @pacman100 , as always! Thanks!

@whybeyoung
Copy link

Any clearly examples for this pr?

@mcc311
Copy link

mcc311 commented Oct 14, 2023

Could anyone provide any example here? Thanks!

@BenjaminBossan
Copy link
Member

Sorry that we haven't added any documentation yet. The idea is that you can do something like this:

config = LoraConfig(
    r=8,
    rank_pattern={"foo": 16},
    lora_alpha=10,
    alpha_pattern={"bar": 20, "baz": 30},
)

In this case, all LoRA layers would use rank 8, except for layers that match "foo", which would have rank 16. Similarly, LoRA alpha would be 10 by default except for "bar", which would be 20, and "baz", which would be 30.

@pacman100 pacman100 deleted the smangrul/lora-supoort-multirank-multialpha branch February 20, 2024 05:48
Guy-Bilitski pushed a commit to Guy-Bilitski/peft that referenced this pull request May 13, 2025
* support multiple ranks and alphas

* Update lora.py

* Update lora.py

* commit suggestions

Co-Authored-By: Benjamin Bossan <[email protected]>

* address comments

Co-Authored-By: Benjamin Bossan <[email protected]>

* Fixed multirank + multialpha for sequential LoRAs, added correct support of LoRA-C3Lier conversion (huggingface#937)

* Fixed multirank multialpha for sequential loras, added tests, fixed docs

* Refactored kohya_ss conversion script for proper support of LoRA-C3Lier

* Fixed styling

* Removed old comment from docstring

* shift `scale_layer`/`unscale_layer` to `LoraLayer` class to support all the child classes

* support multiple active adapters

* add `active_adapters` property

Co-Authored-By: Benjamin Bossan <[email protected]>

* fix bug related to active adapter of `ModulesToSaveWrapper`

* revert the change wrt active_adapter assignment

Co-Authored-By: Younes Belkada <[email protected]>

* Apply suggestions from code review

Co-authored-by: Younes Belkada <[email protected]>

* Apply suggestions from code review

Co-authored-by: Younes Belkada <[email protected]>

* addressing comments

* address comments

* address comment

---------

Co-authored-by: Benjamin Bossan <[email protected]>
Co-authored-by: Alexander Kovalchuk <[email protected]>
Co-authored-by: Younes Belkada <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants