-
Notifications
You must be signed in to change notification settings - Fork 2.1k
[core] PEFT refactor + introducing inject_adapter_in_model public method
#749
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4cb6f6e
d41717c
375a238
a08e858
420d39c
9a9138e
088a1c1
57fa268
d04a91b
4a08ecb
1e99c8e
6acfeff
10a9ab2
8fc3a63
2817d2f
66a5e3d
221ab34
3da5b6d
8827879
2cac1d6
2d54ea6
b8bd2e2
8efb5f4
a92e5ea
45e8877
dbf764f
88c91ae
4ce8541
ead8545
c3e6b61
7648c27
8f09a8e
4d1dffb
a6be082
a032f62
fccae04
2e632f1
f8fcba7
84a56e0
8b88401
8053c88
b4adb6d
7171188
758ecb8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -17,6 +17,9 @@ | |||||||||
|
|
||||||||||
| from typing import TYPE_CHECKING, Any, Dict | ||||||||||
|
|
||||||||||
| import torch | ||||||||||
|
|
||||||||||
| from .config import PeftConfig | ||||||||||
| from .peft_model import ( | ||||||||||
| PeftModel, | ||||||||||
| PeftModelForCausalLM, | ||||||||||
|
|
@@ -28,21 +31,22 @@ | |||||||||
| ) | ||||||||||
| from .tuners import ( | ||||||||||
| AdaLoraConfig, | ||||||||||
| AdaLoraModel, | ||||||||||
| AdaptionPromptConfig, | ||||||||||
| IA3Config, | ||||||||||
| IA3Model, | ||||||||||
| LoraConfig, | ||||||||||
| LoraModel, | ||||||||||
| PrefixTuningConfig, | ||||||||||
| PromptEncoderConfig, | ||||||||||
| PromptTuningConfig, | ||||||||||
| ) | ||||||||||
| from .utils import PromptLearningConfig, _prepare_prompt_learning_config | ||||||||||
| from .utils import _prepare_prompt_learning_config | ||||||||||
|
|
||||||||||
|
|
||||||||||
| if TYPE_CHECKING: | ||||||||||
| from transformers import PreTrainedModel | ||||||||||
|
|
||||||||||
| from .utils.config import PeftConfig | ||||||||||
|
|
||||||||||
|
|
||||||||||
| MODEL_TYPE_TO_PEFT_MODEL_MAPPING = { | ||||||||||
| "SEQ_CLS": PeftModelForSequenceClassification, | ||||||||||
|
|
@@ -63,6 +67,12 @@ | |||||||||
| "IA3": IA3Config, | ||||||||||
| } | ||||||||||
|
|
||||||||||
| PEFT_TYPE_TO_TUNER_MAPPING = { | ||||||||||
| "LORA": LoraModel, | ||||||||||
| "ADALORA": AdaLoraModel, | ||||||||||
| "IA3": IA3Model, | ||||||||||
| } | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def get_peft_config(config_dict: Dict[str, Any]): | ||||||||||
| """ | ||||||||||
|
|
@@ -89,10 +99,38 @@ def get_peft_model(model: PreTrainedModel, peft_config: PeftConfig, adapter_name | |||||||||
|
|
||||||||||
| peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None) | ||||||||||
|
|
||||||||||
| if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not isinstance( | ||||||||||
| peft_config, PromptLearningConfig | ||||||||||
| ): | ||||||||||
| if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not peft_config.is_prompt_learning: | ||||||||||
| return PeftModel(model, peft_config, adapter_name=adapter_name) | ||||||||||
| if isinstance(peft_config, PromptLearningConfig): | ||||||||||
| if peft_config.is_prompt_learning: | ||||||||||
| peft_config = _prepare_prompt_learning_config(peft_config, model_config) | ||||||||||
| return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](model, peft_config, adapter_name=adapter_name) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def inject_adapter_in_model(peft_config: PeftConfig, model: torch.nn.Module, adapter_name: str): | ||||||||||
| r""" | ||||||||||
| A simple API to create and inject adapter in-place into a model. Currently the API does not support prompt learning | ||||||||||
| methods and adaption prompt. Make sure to have the correct `target_names` set in the `peft_config` object. The API | ||||||||||
| calls `get_peft_model` under the hood but would be restricted only to non-prompt learning methods. | ||||||||||
|
|
||||||||||
| Args: | ||||||||||
| peft_config (`PeftConfig`): | ||||||||||
| Configuration object containing the parameters of the Peft model. | ||||||||||
| model (`torch.nn.Module`): | ||||||||||
| The input model where the adapter will be injected. | ||||||||||
| adapter_name (`str`): | ||||||||||
| The name of the adapter to be injected. | ||||||||||
| """ | ||||||||||
| if peft_config.is_prompt_learning or peft_config.is_adaption_prompt: | ||||||||||
| raise ValueError("`create_and_replace` does not support prompt learning and adaption prompt yet.") | ||||||||||
|
|
||||||||||
| if peft_config.peft_type not in PEFT_TYPE_TO_TUNER_MAPPING.keys(): | ||||||||||
| raise ValueError( | ||||||||||
| f"`inject_adapter_in_model` does not support {peft_config.peft_type} yet. Please use `get_peft_model`." | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| tuner_cls = PEFT_TYPE_TO_TUNER_MAPPING[peft_config.peft_type] | ||||||||||
|
|
||||||||||
| # By instantiating a peft model we are injecting randomly initialized LoRA layers into the model's modules. | ||||||||||
| peft_model = tuner_cls(model, peft_config, adapter_name=adapter_name) | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I went down the path correctly here it looks like this line will set Also it seems like the peft config decides further down the road here: Line 203 in ec267c6
Can we add a flag here that would allow for this ? Also I'd maybe add a comment to make it a bit easier for the user to understand what this line does.
Suggested change
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Freezing the base model should be optional when you load a pretrained adapter I think you are right However when you attach fresh new adapters usually (for 99% of the usecases) it is to train them, so maybe we should make a distinction 1- load a pretrained adapter --> not necessearly freeze the base model WDYT? @pacman100 @BenjaminBossan
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That makes sense, I would maybe give the user full flexibility here then and add a new function argument:
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Personally, I would prefer to give the user an explicit argument, rather than (or in addition to) trying to guess based on the circumstances, so I would go with that option. @patrickvonplaten: You bring up some good points. I will still merge the PR as is because Younes will be OoO for some time and we don't want this big refactor to become stale. We should be able to address your concerns in a follow up PR.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good! Agree that giving the user an explicit argument is the right option here |
||||||||||
|
|
||||||||||
| return peft_model.model | ||||||||||
Uh oh!
There was an error while loading. Please reload this page.