-
Notifications
You must be signed in to change notification settings - Fork 2.2k
[core] PEFT refactor + introducing inject_adapter_in_model public method
#749
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 16 commits
4cb6f6e
d41717c
375a238
a08e858
420d39c
9a9138e
088a1c1
57fa268
d04a91b
4a08ecb
1e99c8e
6acfeff
10a9ab2
8fc3a63
2817d2f
66a5e3d
221ab34
3da5b6d
8827879
2cac1d6
2d54ea6
b8bd2e2
8efb5f4
a92e5ea
45e8877
dbf764f
88c91ae
4ce8541
ead8545
c3e6b61
7648c27
8f09a8e
4d1dffb
a6be082
a032f62
fccae04
2e632f1
f8fcba7
84a56e0
8b88401
8053c88
b4adb6d
7171188
758ecb8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
|
|
||
| from typing import TYPE_CHECKING, Any, Dict | ||
|
|
||
| from .config import PeftConfig | ||
| from .peft_model import ( | ||
| PeftModel, | ||
| PeftModelForCausalLM, | ||
|
|
@@ -27,6 +28,7 @@ | |
| PeftModelForTokenClassification, | ||
| ) | ||
| from .tuners import ( | ||
| TUNERS_MAPPING, | ||
| AdaLoraConfig, | ||
| AdaptionPromptConfig, | ||
| IA3Config, | ||
|
|
@@ -35,14 +37,12 @@ | |
| PromptEncoderConfig, | ||
| PromptTuningConfig, | ||
| ) | ||
| from .utils import PromptLearningConfig, _prepare_prompt_learning_config | ||
| from .utils import _get_submodules, _prepare_prompt_learning_config | ||
|
|
||
|
|
||
| if TYPE_CHECKING: | ||
| from transformers import PreTrainedModel | ||
|
|
||
| from .utils.config import PeftConfig | ||
|
|
||
|
|
||
| MODEL_TYPE_TO_PEFT_MODEL_MAPPING = { | ||
| "SEQ_CLS": PeftModelForSequenceClassification, | ||
|
|
@@ -89,10 +89,51 @@ def get_peft_model(model: PreTrainedModel, peft_config: PeftConfig, adapter_name | |
|
|
||
| peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None) | ||
|
|
||
| if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not isinstance( | ||
| peft_config, PromptLearningConfig | ||
| ): | ||
| if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not peft_config.is_prompt_learning: | ||
| return PeftModel(model, peft_config, adapter_name=adapter_name) | ||
| if isinstance(peft_config, PromptLearningConfig): | ||
| if peft_config.is_prompt_learning: | ||
| peft_config = _prepare_prompt_learning_config(peft_config, model_config) | ||
| return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](model, peft_config, adapter_name=adapter_name) | ||
|
|
||
|
|
||
| # TODO: docstring and typehints | ||
| def create_and_replace(peft_config, model, adapter_name): | ||
| if not isinstance(peft_config, PeftConfig): | ||
younesbelkada marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| raise ValueError(f"peft_config must be an instance of PeftConfig got {type(peft_config)} instead.") | ||
|
|
||
| peft_type = peft_config.peft_type | ||
|
|
||
| if peft_type not in TUNERS_MAPPING: | ||
| raise ValueError( | ||
| f"Task type {peft_type} is not supported. Supported task types are {list(TUNERS_MAPPING.keys())}" | ||
| ) | ||
| tuner_cls = TUNERS_MAPPING[peft_type] | ||
|
|
||
| # TODO: test that | ||
| for module in model.modules(): | ||
| if not getattr(module, "_is_peft_tuner_layer", False): | ||
|
||
| module.requires_grad_(False) | ||
|
|
||
| is_target_modules_in_base_model = False | ||
| key_list = [key for key, _ in model.named_modules()] | ||
|
|
||
| for key in key_list: | ||
| if not tuner_cls._check_target_module_exists(peft_config, key): | ||
| continue | ||
|
|
||
| is_target_modules_in_base_model = True | ||
| parent, target, target_name = _get_submodules(model, key) | ||
|
|
||
| optionnal_kwargs = { | ||
| "loaded_in_8bit": getattr(model, "is_loaded_in_8bit", False), | ||
| "loaded_in_4bit": getattr(model, "is_loaded_in_4bit", False), | ||
| "current_key": key, | ||
| } | ||
|
|
||
| tuner_cls.create_and_replace(peft_config, adapter_name, target, target_name, parent, **optionnal_kwargs) | ||
|
|
||
| if not is_target_modules_in_base_model: | ||
| raise ValueError( | ||
| f"Target modules {peft_config.target_modules} not found in the base model. " | ||
| f"Please check the target modules and try again." | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.