Skip to content

Conversation

@younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Jul 25, 2023

What does this PR do?

From the offline discussion + the comments from @patrickvonplaten in #24827 (comment) I propose a new design for tightly integrating PEFT into transformers.
This integration enables loading any PEFT adapter that is saved locally or on the Hub directly into PEFT without dispatching the entire model creation process to PEFT as introduced in #24827.

This would also enable an easier pipeline integration (a one-liner to load adapter weights) | EDIT: pipeline should work out of the box

Let's constraint this integration to few PEFT methods only, for simplicity and redirect users to use PEFT for advanced features (e.g. merge and unload) and advanced PEFT methods (adaptation prompt, prompt learning).

Current API:

Load a model with an adapter locally or from the Hub:

import torch
from transformers import AutoModelForCausalLM, OPTForCausalLM

model_id = "facebook/opt-350m"
adapter_model_id = "ybelkada/opt-350m-lora"

# directly on from_pretrained
model = AutoModelForCausalLM.from_pretrained(adapter_model_id)
print(model)

# directly on from_pretrained
model = OPTForCausalLM.from_pretrained(adapter_model_id)
print(model)

Load and attach adapter to an existing model

from transformers import AutoModelForCausalLM

# with load_adapter
model = AutoModelForCausalLM.from_pretrained(model_id)
model.load_adapter(adapter_model_id)

print(model)

# 8-bit + multiGPU compatiblity
model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map="balanced")
model.load_adapter(adapter_model_id)

print(model)
print(set(model.hf_device_map.values()))

_ = model(torch.LongTensor([[0, 1, 2, 3]]).to(0))

Attach an adapter, iteratively enable / disable adapters

from transformers import AutoModelForCausalLM, OPTForCausalLM, AutoTokenizer
from peft import PeftConfig

model_id = "facebook/opt-350m"
adapter_model_id = "ybelkada/opt-350m-lora"
tokenizer = AutoTokenizer.from_pretrained(model_id)
text = "Hello"
inputs = tokenizer(text, return_tensors="pt")

model = AutoModelForCausalLM.from_pretrained(model_id)
peft_config = PeftConfig.from_pretrained(adapter_model_id)

# To get random weights
peft_config.init_lora_weights = False

model.add_adapter(peft_config)
print(model)

model.disable_adapters()
output_disabled = model.generate(**inputs)
print(tokenizer.decode(output_disabled[0], skip_special_tokens=True))
>>> Hello, I'm a newbie to this sub. I'm looking for a good place to

model.enable_adapters()
output_enabled = model.generate(**inputs)
print(tokenizer.decode(output_enabled[0], skip_special_tokens=True))
>>> Hello, MMMMMMMM

Add multiple adapters

from transformers import AutoModelForCausalLM, OPTForCausalLM, AutoTokenizer
from peft import PeftConfig, LoraConfig

model_id = "facebook/opt-350m"

# directly on from_pretrained
model = AutoModelForCausalLM.from_pretrained(model_id)

lora_config = LoraConfig(
    target_modules=["q_proj", "k_proj"],
    init_lora_weights=False
)

model.add_adapter(lora_config, adapter_name="adapter_1")

# attach new adapter with same config
model.add_adapter(lora_config, adapter_name="adapter_2")

model.set_adapter("adapter_1")
output_disabled = model.generate(**inputs)
print(tokenizer.decode(output_disabled[0], skip_special_tokens=True))
>>> Hello, I'm a newbie to this sub. I'm looking for a good place to

model.set_adapter("adapter_2")
output_enabled = model.generate(**inputs)
print(tokenizer.decode(output_enabled[0], skip_special_tokens=True))
>>> Hello, I'm a newbie to the game. I'm looking for a good way to

Save adapters

from transformers import AutoModelForCausalLM, OPTForCausalLM, AutoTokenizer
from peft import PeftConfig, LoraConfig

model_id = "facebook/opt-350m"

# directly on from_pretrained
model = AutoModelForCausalLM.from_pretrained(model_id)

lora_config = LoraConfig(
    target_modules=["q_proj", "k_proj"],
)

model.add_adapter(lora_config)

... # train here

model.save_pretrained(save_dir) 

# you can either load it back with transformers or PEFT

from peft import AutoPeftModelForCausalLM

model = AutoPeftModelForCausalLM.from_pretrained(save_dir)

# or

model = AutoModelForCausalLM.from_pretrained(save_dir)

Train adapters using Trainer

Check this gist: https://gist.github.com/younesbelkada/cdda6e4abcb09e58f6324d75e0d88862

This PR is on par with: huggingface/peft#749

Features to support:

  • loading PEFT adapters
  • using multiple adapters
  • deal with models loaded with accelerate
  • Loading directly from from_pretrained
  • Merging adapter weights - to not support
  • Unload adapter weights - to not support
  • Training with BC with expected PEFT checkpoints format (do we really want to support training? Shall we just redirect users to load a classic PeftModel if they want to train a model?)
  • What about save_pretrained ?

Features to not support:

  • disabling adapters
  • prompt tuning / prompt learning methods

TODOs:

  •  docs
  • tests

cc @sgugger @patrickvonplaten @BenjaminBossan @pacman100

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 25, 2023

The documentation is not available anymore as the PR was closed or merged.

@sgugger
Copy link
Collaborator

sgugger commented Jul 25, 2023

Mmmm, but know we can't do model = AutoModelForCausalLM.from_pretrained("ybelkada/opt-350m-lora") despite the checkpoint having all the correct info to load the model (compared to the alternative PR). Can we still add the necessary code in from_pretrained to have the load take one line instead of two?

Apart from that, the design looks great to me!

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Design looks very nice to me! We'll be able to fully leverage this from diffusers I believe

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks very nice. I can't comment on the overall design, as my knowledge about transformers auto models is not very deep, but the approach looks promising from my point of view.

Let's constraint this integration to few PEFT methods only, for simplicity and redirect users to use PEFT for advanced features and advanced PEFT methods.

Features to not support: ...

Yes, this seems to be the way to go. As mentioned by others, this should be very well documented, with an error message explaining what works and what doesn't if the user tries to do something that is not supported.

I have a few small comments on the code, please take a look.

Regarding the extension of PreTrainedModel base classes with PeftAdapterMixin, it seems that it would be possible to extend it dynamically:

https://stackoverflow.com/questions/11042424/adding-base-class-to-existing-object-in-python

So in theory, we could do a check if peft is installed and only then add the mixin to the base classes, which means that PreTrainedModel would stay the same for all users who don't use peft. I'm not sure if that's a good idea or not, just throwing it out there.

"""
Load adapter weights from file. Requires peft as a backend to load the adapter weights
"""
requires_backends(self.load_adapter, "peft")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, this also requires a specific peft version, right? But it seems that requiring a version is not supported.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I can add that in the next line

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree it would be nice to get a very good error message here

Copy link
Contributor

@pacman100 pacman100 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like the design, very clever way to use class methods to inject the modules and load multiple adapters. I believe with some effort, it can support all the goodies that aren't yet present such as merge_and_unload, unload, create_weighted_adapter and delete_adapter. We can take these up later on. Overall, great work @younesbelkada 🔥.

Btw, to answer Sylvain's question, we can also use load_adapter in from_pretrained so that only 1 line of code is needed instead of 2, WDYT?

)

# TODO: change it to a property but torch.jit fails. Maybe we should return None is PEFT is not available
def current_active_adapter(self) -> str:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we really want to have a property maybe return none + a warning in case PEFT is not installed

Comment on lines 91 to 126
create_and_replace(loaded_peft_config, self, adapter_name)

adapter_state_dict = load_peft_weights(
peft_model_id,
revision=revision,
use_auth_token=use_auth_token,
)

# We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility
processed_adapter_state_dict = {}
for key, value in adapter_state_dict.items():
if "base_model.model" in key:
new_key = key.replace("base_model.model.", "")
else:
new_key = key
processed_adapter_state_dict[new_key] = value

# Load state dict
incompatible_keys = set_peft_model_state_dict(self, processed_adapter_state_dict, adapter_name)

if incompatible_keys is not None:
# check only for unexpected keys
if hasattr(incompatible_keys, "unexpected_keys") and len(incompatible_keys.unexpected_keys) > 0:
logger.warning(
f"Loading adapter weights from {peft_model_id} led to unexpected keys not found in the model: "
f" {incompatible_keys.unexpected_keys}. "
)

# @pacman100 why this was needed?
if (
(getattr(self, "hf_device_map", None) is not None)
and (len(set(self.hf_device_map.values()).intersection({"cpu", "disk"})) > 0)
and len(self.peft_config) == 1
):
self._dispatch_accelerate_model(
device_map=device_map, max_memory=max_memory, offload_dir=offload_dir, offload_index=offload_index
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think PeftModel.from_pretrained(self, peft_model_id) address all of this, WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes correct ! However I preferred to not use PeftModel and have a granular control on each thing we do (inject adapter, get state_dict, load it, dispatch if cpu offload) so that it will be easier to maintain if in the future we add more features or some important changes in PeftModel.from_pretrained.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But, on the contrary, if we fix PeftModel.from_pretrained, we need to update all the changes here too increasing the maintenance and breaking changes because from_pretrained would be always updated to be in sync with changes of PeftModel whereas here we might get an updated version of PeftModel but the current loading logic might not handle it. WDYT?

@younesbelkada younesbelkada marked this pull request as ready for review August 18, 2023 08:24
@younesbelkada younesbelkada requested a review from sgugger August 18, 2023 08:24
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all your work on this! Have a couple of nits and make sure to remove use_auth_token to use token instead.

Comment on lines +211 to +216
<!--
TODO: (@younesbelkada @stevhliu)
- Link to PEFT docs for further details
- Trainer
- 8-bit / 4-bit examples ?
-->
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice in a followup PR to add docs about training a model in 4/8-bit with peft and LoRA as there are many issues around this :-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed on this!

@pacman100
Copy link
Contributor

Thank you @younesbelkada for the impressive work on adding PEFT as a utility library in Transformers 🔥🚀✨

@younesbelkada younesbelkada mentioned this pull request Aug 22, 2023
3 tasks
parambharat pushed a commit to parambharat/transformers that referenced this pull request Sep 26, 2023
* a draft version

* v2 integration

* fix

* make it more generic and works for IA3

* add set adapter and multiple adapters support

* fixup

* adapt a bit

* oops

* oops

* oops

* adapt more

* fix

* add more refactor

* now works with model class

* change it to instance method as it causes issues with `jit`.

* add CR

* change method name

* add `add_adapter` method

* clean up

* Update src/transformers/adapters/peft_mixin.py

Co-authored-by: Patrick von Platen <[email protected]>

* add moe utils

* fixup

* Update src/transformers/adapters/peft_mixin.py

Co-authored-by: Patrick von Platen <[email protected]>

* adapt

* oops

* fixup

* add is_peft_available

* remove `requires_backend`

* trainer compatibility

* fixup + docstring

* more details

* trigger CI

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <[email protected]>

* Update src/transformers/modeling_utils.py

* fixup + is_main_process

* added `save_peft_format` in save_pretrained

* up

* fix nits here and there

* nits here and there.

* docs

* revert `encoding="utf-8"`

* comment

* added slow tests before the PEFT release.

* fixup and nits

* let's be on the safe zone

* added more comments

* v1 docs

* add remaining docs

* Apply suggestions from code review

Co-authored-by: Steven Liu <[email protected]>

* move to `lib_integrations`

* fixup

* this time fixup

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <[email protected]>

* address final comments

* refactor to use `token`

* add PEFT to DockerFile for slow tests.

* added pipeline support.

---------

Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Sylvain Gugger <[email protected]>
Co-authored-by: Steven Liu <[email protected]>
@BoccheseGiacomo
Copy link

What does this PR do?

From the offline discussion + the comments from @patrickvonplaten in #24827 (comment) I propose a new design for tightly integrating PEFT into transformers. This integration enables loading any PEFT adapter that is saved locally or on the Hub directly into PEFT without dispatching the entire model creation process to PEFT as introduced in #24827.

This would also enable an easier pipeline integration (a one-liner to load adapter weights) | EDIT: pipeline should work out of the box

Let's constraint this integration to few PEFT methods only, for simplicity and redirect users to use PEFT for advanced features (e.g. merge and unload) and advanced PEFT methods (adaptation prompt, prompt learning).

Current API:

Load a model with an adapter locally or from the Hub:

import torch
from transformers import AutoModelForCausalLM, OPTForCausalLM

model_id = "facebook/opt-350m"
adapter_model_id = "ybelkada/opt-350m-lora"

# directly on from_pretrained
model = AutoModelForCausalLM.from_pretrained(adapter_model_id)
print(model)

# directly on from_pretrained
model = OPTForCausalLM.from_pretrained(adapter_model_id)
print(model)

Load and attach adapter to an existing model

from transformers import AutoModelForCausalLM

# with load_adapter
model = AutoModelForCausalLM.from_pretrained(model_id)
model.load_adapter(adapter_model_id)

print(model)

# 8-bit + multiGPU compatiblity
model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map="balanced")
model.load_adapter(adapter_model_id)

print(model)
print(set(model.hf_device_map.values()))

_ = model(torch.LongTensor([[0, 1, 2, 3]]).to(0))

Attach an adapter, iteratively enable / disable adapters

from transformers import AutoModelForCausalLM, OPTForCausalLM, AutoTokenizer
from peft import PeftConfig

model_id = "facebook/opt-350m"
adapter_model_id = "ybelkada/opt-350m-lora"
tokenizer = AutoTokenizer.from_pretrained(model_id)
text = "Hello"
inputs = tokenizer(text, return_tensors="pt")

model = AutoModelForCausalLM.from_pretrained(model_id)
peft_config = PeftConfig.from_pretrained(adapter_model_id)

# To get random weights
peft_config.init_lora_weights = False

model.add_adapter(peft_config)
print(model)

model.disable_adapters()
output_disabled = model.generate(**inputs)
print(tokenizer.decode(output_disabled[0], skip_special_tokens=True))
>>> Hello, I'm a newbie to this sub. I'm looking for a good place to

model.enable_adapters()
output_enabled = model.generate(**inputs)
print(tokenizer.decode(output_enabled[0], skip_special_tokens=True))
>>> Hello, MMMMMMMM

Add multiple adapters

from transformers import AutoModelForCausalLM, OPTForCausalLM, AutoTokenizer
from peft import PeftConfig, LoraConfig

model_id = "facebook/opt-350m"

# directly on from_pretrained
model = AutoModelForCausalLM.from_pretrained(model_id)

lora_config = LoraConfig(
    target_modules=["q_proj", "k_proj"],
    init_lora_weights=False
)

model.add_adapter(lora_config, adapter_name="adapter_1")

# attach new adapter with same config
model.add_adapter(lora_config, adapter_name="adapter_2")

model.set_adapter("adapter_1")
output_disabled = model.generate(**inputs)
print(tokenizer.decode(output_disabled[0], skip_special_tokens=True))
>>> Hello, I'm a newbie to this sub. I'm looking for a good place to

model.set_adapter("adapter_2")
output_enabled = model.generate(**inputs)
print(tokenizer.decode(output_enabled[0], skip_special_tokens=True))
>>> Hello, I'm a newbie to the game. I'm looking for a good way to

Save adapters

from transformers import AutoModelForCausalLM, OPTForCausalLM, AutoTokenizer
from peft import PeftConfig, LoraConfig

model_id = "facebook/opt-350m"

# directly on from_pretrained
model = AutoModelForCausalLM.from_pretrained(model_id)

lora_config = LoraConfig(
    target_modules=["q_proj", "k_proj"],
)

model.add_adapter(lora_config)

... # train here

model.save_pretrained(save_dir) 

# you can either load it back with transformers or PEFT

from peft import AutoPeftModelForCausalLM

model = AutoPeftModelForCausalLM.from_pretrained(save_dir)

# or

model = AutoModelForCausalLM.from_pretrained(save_dir)

Train adapters using Trainer

Check this gist: https://gist.github.com/younesbelkada/cdda6e4abcb09e58f6324d75e0d88862

This PR is on par with: huggingface/peft#749

Features to support:

  • loading PEFT adapters
  • using multiple adapters
  • deal with models loaded with accelerate
  • Loading directly from from_pretrained
  • Merging adapter weights - to not support
  • Unload adapter weights - to not support
  • Training with BC with expected PEFT checkpoints format (do we really want to support training? Shall we just redirect users to load a classic PeftModel if they want to train a model?)
  • What about save_pretrained ?

Features to not support:

  • disabling adapters
  • prompt tuning / prompt learning methods

TODOs:

  •  docs
  • tests

cc @sgugger @patrickvonplaten @BenjaminBossan @pacman100

What packages to install do i need to run this code?

@younesbelkada
Copy link
Contributor Author

What packages to install do i need to run this code?

Just install the latest transformers & peft

pip install -U peft transformers

Guy-Bilitski pushed a commit to Guy-Bilitski/peft that referenced this pull request May 13, 2025
… method (huggingface#749)

Refactors a bit the internals of some PEFT models and introduces a new
method inject_adapter_in_model for users that want to pass a bare model
and a peft config to inject adapters in-place into the model. These
changes are totally BC with the previous PEFT versions.

This PR makes things easier for the PEFT integration in transformers
huggingface/transformers#25077

The main goal of the PR is to expose a new API for advanced users that
want to integrate PEFT method without the need to use the PeftModel
wrapper. A simple use case could be someone that wants to inject adapters
into a model and wants to keep the original class of the model without
having to offload that to peft that will create a PeftModel. I have
faced this issue in huggingface/transformers#25077 Among other things,
this PR refactors some internals of PEFT library, while keeping it fully
backward compatible.

To tackle the main motivation I propose to differentiate things between
two type of adapters

1- adapters that are injectable (LoRA, AdaLoRA, IA3)
2- adapters that are not injectable (the rest)

As a first iteration this API would be supported only for the scenario
1- / therefore I decided to create 2 abstract classes to make things
easy to be able to determine if the adapter layer (e.g. LoraLayer) /
adapter module (e.g. LoraModel) does follow the minimal
requirement (i.e. needed attributes, etc.)

Other related changes:

1- Creates a new property method is_prompt_learning to avoid importing
   PromptLearningConfig all the way down
2- Introduces a new object TUNERS_MAPPING, which is a mapping of
   supported pluggable adapters
3- Creates two abstract classes
3.1- BaseTunerLayer: a mixin to check for minimal required attributes
     that a tuner layer should have active_adapter / _is_plugable
3.2- BaseTuner: a higher level module mixin that should be used for any
     injectable adapters in the future.

---------

Co-authored-by: Benjamin Bossan <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
csqaiub added a commit to csqaiub/peft that referenced this pull request Sep 28, 2025
… method (#749)

Refactors a bit the internals of some PEFT models and introduces a new
method inject_adapter_in_model for users that want to pass a bare model
and a peft config to inject adapters in-place into the model. These
changes are totally BC with the previous PEFT versions.

This PR makes things easier for the PEFT integration in transformers
huggingface/transformers#25077

The main goal of the PR is to expose a new API for advanced users that
want to integrate PEFT method without the need to use the PeftModel
wrapper. A simple use case could be someone that wants to inject adapters
into a model and wants to keep the original class of the model without
having to offload that to peft that will create a PeftModel. I have
faced this issue in huggingface/transformers#25077 Among other things,
this PR refactors some internals of PEFT library, while keeping it fully
backward compatible.

To tackle the main motivation I propose to differentiate things between
two type of adapters

1- adapters that are injectable (LoRA, AdaLoRA, IA3)
2- adapters that are not injectable (the rest)

As a first iteration this API would be supported only for the scenario
1- / therefore I decided to create 2 abstract classes to make things
easy to be able to determine if the adapter layer (e.g. LoraLayer) /
adapter module (e.g. LoraModel) does follow the minimal
requirement (i.e. needed attributes, etc.)

Other related changes:

1- Creates a new property method is_prompt_learning to avoid importing
   PromptLearningConfig all the way down
2- Introduces a new object TUNERS_MAPPING, which is a mapping of
   supported pluggable adapters
3- Creates two abstract classes
3.1- BaseTunerLayer: a mixin to check for minimal required attributes
     that a tuner layer should have active_adapter / _is_plugable
3.2- BaseTuner: a higher level module mixin that should be used for any
     injectable adapters in the future.

---------

Co-authored-by: Benjamin Bossan <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants