Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
297e311
a draft version
younesbelkada Jul 19, 2023
cedeeda
v2 integration
younesbelkada Jul 25, 2023
cf3c325
fix
younesbelkada Jul 25, 2023
c6d0848
make it more generic and works for IA3
younesbelkada Jul 26, 2023
0343763
Merge remote-tracking branch 'upstream/main' into peft-integration-at…
younesbelkada Jul 26, 2023
867ae2c
add set adapter and multiple adapters support
younesbelkada Jul 26, 2023
72762de
fixup
younesbelkada Jul 26, 2023
02b5802
adapt a bit
younesbelkada Jul 27, 2023
067fef4
oops
younesbelkada Jul 27, 2023
e8cc945
oops
younesbelkada Jul 27, 2023
d371173
oops
younesbelkada Jul 31, 2023
08043d5
adapt more
younesbelkada Jul 31, 2023
619a5d6
fix
younesbelkada Jul 31, 2023
e61de3b
add more refactor
younesbelkada Jul 31, 2023
59b3cb3
now works with model class
younesbelkada Jul 31, 2023
da8dfc5
Merge remote-tracking branch 'upstream/main' into peft-integration-at…
younesbelkada Aug 1, 2023
e67c3c3
change it to instance method as it causes issues with `jit`.
younesbelkada Aug 1, 2023
eb57382
add CR
younesbelkada Aug 1, 2023
babb278
change method name
younesbelkada Aug 2, 2023
e038629
add `add_adapter` method
younesbelkada Aug 2, 2023
81fcf40
clean up
younesbelkada Aug 2, 2023
3cbd3c2
Update src/transformers/adapters/peft_mixin.py
younesbelkada Aug 3, 2023
2345681
add moe utils
younesbelkada Aug 3, 2023
dfb6425
Merge branch 'peft-integration-attempt-2' of https://github.com/youne…
younesbelkada Aug 3, 2023
eddabd2
fixup
younesbelkada Aug 3, 2023
9e98c08
Update src/transformers/adapters/peft_mixin.py
younesbelkada Aug 3, 2023
9523cd0
adapt
younesbelkada Aug 3, 2023
715d03b
oops
younesbelkada Aug 3, 2023
38e1fe7
fixup
younesbelkada Aug 3, 2023
300243b
add is_peft_available
younesbelkada Aug 3, 2023
ec51272
remove `requires_backend`
younesbelkada Aug 3, 2023
7c1dc8a
trainer compatibility
younesbelkada Aug 3, 2023
e251f43
fixup + docstring
younesbelkada Aug 3, 2023
5703344
more details
younesbelkada Aug 3, 2023
324e18d
trigger CI
younesbelkada Aug 3, 2023
99f6905
Merge remote-tracking branch 'upstream/main' into peft-integration-at…
younesbelkada Aug 3, 2023
22284e6
Apply suggestions from code review
younesbelkada Aug 3, 2023
35fe154
Update src/transformers/modeling_utils.py
younesbelkada Aug 3, 2023
eb9efed
fixup + is_main_process
younesbelkada Aug 3, 2023
a8eb928
added `save_peft_format` in save_pretrained
younesbelkada Aug 3, 2023
f310b33
up
younesbelkada Aug 3, 2023
8333a65
fix nits here and there
younesbelkada Aug 3, 2023
38969ef
nits here and there.
younesbelkada Aug 3, 2023
a4a361d
docs
younesbelkada Aug 3, 2023
b19bc08
revert `encoding="utf-8"`
younesbelkada Aug 3, 2023
6f703c7
comment
younesbelkada Aug 3, 2023
4147341
added slow tests before the PEFT release.
younesbelkada Aug 3, 2023
cd99439
fixup and nits
younesbelkada Aug 3, 2023
1fb2b9f
let's be on the safe zone
younesbelkada Aug 3, 2023
c0e2815
added more comments
younesbelkada Aug 3, 2023
0b11f1b
v1 docs
younesbelkada Aug 3, 2023
1b5c501
add remaining docs
stevhliu Aug 7, 2023
583174f
Apply suggestions from code review
younesbelkada Aug 17, 2023
180545f
Merge remote-tracking branch 'upstream/main' into peft-integration-at…
younesbelkada Aug 18, 2023
83d0f15
move to `lib_integrations`
younesbelkada Aug 18, 2023
f739aee
fixup
younesbelkada Aug 18, 2023
fccf419
this time fixup
younesbelkada Aug 18, 2023
fb6af42
Apply suggestions from code review
younesbelkada Aug 18, 2023
616cfec
address final comments
younesbelkada Aug 18, 2023
70b1570
refactor to use `token`
younesbelkada Aug 18, 2023
2934e69
add PEFT to DockerFile for slow tests.
younesbelkada Aug 18, 2023
3dd9211
added pipeline support.
younesbelkada Aug 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/transformers/adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .peft_mixin import PeftAdapterMixin
136 changes: 136 additions & 0 deletions src/transformers/adapters/peft_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Optional

from ..utils import ADAPTER_CONFIG_NAME, cached_file, logging, requires_backends


logger = logging.get_logger(__name__)


class PeftAdapterMixin:
"""
A class containing all functions for loading and using adapters weights that are supported in PEFT library.
Currently supported PEFT methods are all non-prefix tuning methods
"""

def load_adapter(
self,
peft_model_id: str,
adapter_name: Optional[str] = "default",
revision: Optional[str] = None,
use_auth_token: Optional[str] = None,
commit_hash: Optional[str] = None,
):
"""
Load adapter weights from file. Requires peft as a backend to load the adapter weights
"""
requires_backends(self.load_adapter, "peft")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, this also requires a specific peft version, right? But it seems that requiring a version is not supported.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I can add that in the next line

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree it would be nice to get a very good error message here


from peft import LoraConfig, PeftModel, create_and_replace
from peft.utils import set_peft_model_state_dict
from peft.utils.other import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING

self.peft_config = {}

adapter_config_file = self._find_adapter_config_file(
peft_model_id,
revision=revision,
use_auth_token=use_auth_token,
commit_hash=commit_hash,
)

if adapter_config_file is None:
raise ValueError(
f"adapter model file not found in {peft_model_id}. Make sure you are passing the correct path to the "
"adapter model."
)

# TODO: automatically infer the correct config class
loaded_peft_config = LoraConfig.from_pretrained(
peft_model_id,
revision=revision,
use_auth_token=use_auth_token,
commit_hash=commit_hash,
)

if not hasattr(loaded_peft_config, "target_modules"):
target_modules = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[self.config.model_type]
loaded_peft_config.target_modules = target_modules

# TODO: constraint this to single adapter
if adapter_name not in self.peft_config:
self.peft_config[adapter_name] = loaded_peft_config
else:
raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")

# Replace the adapter with the loaded adapter
create_and_replace(loaded_peft_config.peft_type, loaded_peft_config, self, adapter_name)

# TODO: move that to peft.utils
adapter_state_dict = PeftModel._get_peft_state_dict(
peft_model_id,
revision=revision,
use_auth_token=use_auth_token,
)

# We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility
processed_adapter_state_dict = {}
for key, value in adapter_state_dict.items():
if "base_model.model" in key:
new_key = key.replace("base_model.model.", "")
else:
new_key = key
processed_adapter_state_dict[new_key] = value

# Load state dict
incompatible_keys = set_peft_model_state_dict(self, processed_adapter_state_dict, adapter_name)

if incompatible_keys is not None:
# check only for unexpected keys
if hasattr(incompatible_keys, "unexpected_keys") and len(incompatible_keys.unexpected_keys) > 0:
logger.warning(
f"Loading adapter weights from {peft_model_id} led to unexpected keys not found in the model: "
f" {incompatible_keys.unexpected_keys}. "
)

def _find_adapter_config_file(
self,
model_id: str,
revision: str = None,
use_auth_token: Optional[str] = None,
commit_hash: Optional[str] = None,
) -> Optional[str]:
r"""
Simply checks if the model stored on the Hub or locally is an adapter model or not, return the path the the
adapter config file if it is, None otherwise.
"""
adapter_cached_filename = None
if os.path.isdir(model_id):
list_remote_files = os.listdir(model_id)
if ADAPTER_CONFIG_NAME in list_remote_files:
adapter_cached_filename = os.path.join(model_id, ADAPTER_CONFIG_NAME)
else:
adapter_cached_filename = cached_file(
model_id,
ADAPTER_CONFIG_NAME,
revision=revision,
use_auth_token=use_auth_token,
_commit_hash=commit_hash,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)

return adapter_cached_filename
3 changes: 2 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from torch.nn import CrossEntropyLoss

from .activations import get_activation
from .adapters import PeftAdapterMixin
from .configuration_utils import PretrainedConfig
from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
from .dynamic_module_utils import custom_object_save
Expand Down Expand Up @@ -1021,7 +1022,7 @@ def floating_point_ops(
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)


class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin):
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin):
r"""
Base class for all models.

Expand Down
6 changes: 6 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,11 @@ def is_jieba_available():
jieba`. Please note that you may need to restart your runtime after installation.
"""

PEFT_IMPORT_ERROR = """
{0} requires the peft library but it was not found in your environment. You can install it with pip: `pip install
peft`. Please note that you may need to restart your runtime after installation.
"""

BACKENDS_MAPPING = OrderedDict(
[
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
Expand Down Expand Up @@ -1022,6 +1027,7 @@ def is_jieba_available():
("decord", (is_decord_available, DECORD_IMPORT_ERROR)),
("cython", (is_cython_available, CYTHON_IMPORT_ERROR)),
("jieba", (is_jieba_available, JIEBA_IMPORT_ERROR)),
("peft", (is_peft_available, PEFT_IMPORT_ERROR)),
]
)

Expand Down