-
Notifications
You must be signed in to change notification settings - Fork 31.9k
[PEFT] Peft integration alternative design
#25077
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
297e311
cedeeda
cf3c325
c6d0848
0343763
867ae2c
72762de
02b5802
067fef4
e8cc945
d371173
08043d5
619a5d6
e61de3b
59b3cb3
da8dfc5
e67c3c3
eb57382
babb278
e038629
81fcf40
3cbd3c2
2345681
dfb6425
eddabd2
9e98c08
9523cd0
715d03b
38e1fe7
300243b
ec51272
7c1dc8a
e251f43
5703344
324e18d
99f6905
22284e6
35fe154
eb9efed
a8eb928
f310b33
8333a65
38969ef
a4a361d
b19bc08
6f703c7
4147341
cd99439
1fb2b9f
c0e2815
0b11f1b
1b5c501
583174f
180545f
83d0f15
f739aee
fccf419
fb6af42
616cfec
70b1570
2934e69
3dd9211
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| # Copyright 2023 The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from .peft_mixin import PeftAdapterMixin |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,136 @@ | ||
| # Copyright 2023 The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import os | ||
| from typing import Optional | ||
|
|
||
| from ..utils import ADAPTER_CONFIG_NAME, cached_file, logging, requires_backends | ||
|
|
||
|
|
||
| logger = logging.get_logger(__name__) | ||
|
|
||
|
|
||
| class PeftAdapterMixin: | ||
| """ | ||
| A class containing all functions for loading and using adapters weights that are supported in PEFT library. | ||
| Currently supported PEFT methods are all non-prefix tuning methods | ||
| """ | ||
|
|
||
| def load_adapter( | ||
| self, | ||
| peft_model_id: str, | ||
| adapter_name: Optional[str] = "default", | ||
| revision: Optional[str] = None, | ||
| use_auth_token: Optional[str] = None, | ||
| commit_hash: Optional[str] = None, | ||
| ): | ||
| """ | ||
| Load adapter weights from file. Requires peft as a backend to load the adapter weights | ||
| """ | ||
| requires_backends(self.load_adapter, "peft") | ||
|
||
|
|
||
| from peft import LoraConfig, PeftModel, create_and_replace | ||
| from peft.utils import set_peft_model_state_dict | ||
| from peft.utils.other import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING | ||
|
|
||
| self.peft_config = {} | ||
younesbelkada marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| adapter_config_file = self._find_adapter_config_file( | ||
| peft_model_id, | ||
| revision=revision, | ||
| use_auth_token=use_auth_token, | ||
| commit_hash=commit_hash, | ||
| ) | ||
|
|
||
| if adapter_config_file is None: | ||
| raise ValueError( | ||
| f"adapter model file not found in {peft_model_id}. Make sure you are passing the correct path to the " | ||
| "adapter model." | ||
| ) | ||
|
|
||
| # TODO: automatically infer the correct config class | ||
| loaded_peft_config = LoraConfig.from_pretrained( | ||
| peft_model_id, | ||
| revision=revision, | ||
| use_auth_token=use_auth_token, | ||
| commit_hash=commit_hash, | ||
| ) | ||
|
|
||
| if not hasattr(loaded_peft_config, "target_modules"): | ||
| target_modules = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[self.config.model_type] | ||
| loaded_peft_config.target_modules = target_modules | ||
|
|
||
| # TODO: constraint this to single adapter | ||
| if adapter_name not in self.peft_config: | ||
| self.peft_config[adapter_name] = loaded_peft_config | ||
| else: | ||
younesbelkada marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.") | ||
|
|
||
| # Replace the adapter with the loaded adapter | ||
| create_and_replace(loaded_peft_config.peft_type, loaded_peft_config, self, adapter_name) | ||
|
|
||
| # TODO: move that to peft.utils | ||
| adapter_state_dict = PeftModel._get_peft_state_dict( | ||
| peft_model_id, | ||
| revision=revision, | ||
| use_auth_token=use_auth_token, | ||
| ) | ||
|
|
||
| # We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility | ||
| processed_adapter_state_dict = {} | ||
| for key, value in adapter_state_dict.items(): | ||
| if "base_model.model" in key: | ||
| new_key = key.replace("base_model.model.", "") | ||
younesbelkada marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| else: | ||
| new_key = key | ||
| processed_adapter_state_dict[new_key] = value | ||
|
|
||
| # Load state dict | ||
| incompatible_keys = set_peft_model_state_dict(self, processed_adapter_state_dict, adapter_name) | ||
|
|
||
| if incompatible_keys is not None: | ||
| # check only for unexpected keys | ||
| if hasattr(incompatible_keys, "unexpected_keys") and len(incompatible_keys.unexpected_keys) > 0: | ||
| logger.warning( | ||
| f"Loading adapter weights from {peft_model_id} led to unexpected keys not found in the model: " | ||
| f" {incompatible_keys.unexpected_keys}. " | ||
| ) | ||
|
|
||
| def _find_adapter_config_file( | ||
| self, | ||
| model_id: str, | ||
| revision: str = None, | ||
| use_auth_token: Optional[str] = None, | ||
| commit_hash: Optional[str] = None, | ||
| ) -> Optional[str]: | ||
| r""" | ||
younesbelkada marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| Simply checks if the model stored on the Hub or locally is an adapter model or not, return the path the the | ||
| adapter config file if it is, None otherwise. | ||
| """ | ||
| adapter_cached_filename = None | ||
| if os.path.isdir(model_id): | ||
| list_remote_files = os.listdir(model_id) | ||
| if ADAPTER_CONFIG_NAME in list_remote_files: | ||
| adapter_cached_filename = os.path.join(model_id, ADAPTER_CONFIG_NAME) | ||
| else: | ||
| adapter_cached_filename = cached_file( | ||
| model_id, | ||
| ADAPTER_CONFIG_NAME, | ||
| revision=revision, | ||
| use_auth_token=use_auth_token, | ||
| _commit_hash=commit_hash, | ||
| _raise_exceptions_for_missing_entries=False, | ||
| _raise_exceptions_for_connection_errors=False, | ||
| ) | ||
|
|
||
| return adapter_cached_filename | ||
Uh oh!
There was an error while loading. Please reload this page.