Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
297e311
a draft version
younesbelkada Jul 19, 2023
cedeeda
v2 integration
younesbelkada Jul 25, 2023
cf3c325
fix
younesbelkada Jul 25, 2023
c6d0848
make it more generic and works for IA3
younesbelkada Jul 26, 2023
0343763
Merge remote-tracking branch 'upstream/main' into peft-integration-at…
younesbelkada Jul 26, 2023
867ae2c
add set adapter and multiple adapters support
younesbelkada Jul 26, 2023
72762de
fixup
younesbelkada Jul 26, 2023
02b5802
adapt a bit
younesbelkada Jul 27, 2023
067fef4
oops
younesbelkada Jul 27, 2023
e8cc945
oops
younesbelkada Jul 27, 2023
d371173
oops
younesbelkada Jul 31, 2023
08043d5
adapt more
younesbelkada Jul 31, 2023
619a5d6
fix
younesbelkada Jul 31, 2023
e61de3b
add more refactor
younesbelkada Jul 31, 2023
59b3cb3
now works with model class
younesbelkada Jul 31, 2023
da8dfc5
Merge remote-tracking branch 'upstream/main' into peft-integration-at…
younesbelkada Aug 1, 2023
e67c3c3
change it to instance method as it causes issues with `jit`.
younesbelkada Aug 1, 2023
eb57382
add CR
younesbelkada Aug 1, 2023
babb278
change method name
younesbelkada Aug 2, 2023
e038629
add `add_adapter` method
younesbelkada Aug 2, 2023
81fcf40
clean up
younesbelkada Aug 2, 2023
3cbd3c2
Update src/transformers/adapters/peft_mixin.py
younesbelkada Aug 3, 2023
2345681
add moe utils
younesbelkada Aug 3, 2023
dfb6425
Merge branch 'peft-integration-attempt-2' of https://github.com/youne…
younesbelkada Aug 3, 2023
eddabd2
fixup
younesbelkada Aug 3, 2023
9e98c08
Update src/transformers/adapters/peft_mixin.py
younesbelkada Aug 3, 2023
9523cd0
adapt
younesbelkada Aug 3, 2023
715d03b
oops
younesbelkada Aug 3, 2023
38e1fe7
fixup
younesbelkada Aug 3, 2023
300243b
add is_peft_available
younesbelkada Aug 3, 2023
ec51272
remove `requires_backend`
younesbelkada Aug 3, 2023
7c1dc8a
trainer compatibility
younesbelkada Aug 3, 2023
e251f43
fixup + docstring
younesbelkada Aug 3, 2023
5703344
more details
younesbelkada Aug 3, 2023
324e18d
trigger CI
younesbelkada Aug 3, 2023
99f6905
Merge remote-tracking branch 'upstream/main' into peft-integration-at…
younesbelkada Aug 3, 2023
22284e6
Apply suggestions from code review
younesbelkada Aug 3, 2023
35fe154
Update src/transformers/modeling_utils.py
younesbelkada Aug 3, 2023
eb9efed
fixup + is_main_process
younesbelkada Aug 3, 2023
a8eb928
added `save_peft_format` in save_pretrained
younesbelkada Aug 3, 2023
f310b33
up
younesbelkada Aug 3, 2023
8333a65
fix nits here and there
younesbelkada Aug 3, 2023
38969ef
nits here and there.
younesbelkada Aug 3, 2023
a4a361d
docs
younesbelkada Aug 3, 2023
b19bc08
revert `encoding="utf-8"`
younesbelkada Aug 3, 2023
6f703c7
comment
younesbelkada Aug 3, 2023
4147341
added slow tests before the PEFT release.
younesbelkada Aug 3, 2023
cd99439
fixup and nits
younesbelkada Aug 3, 2023
1fb2b9f
let's be on the safe zone
younesbelkada Aug 3, 2023
c0e2815
added more comments
younesbelkada Aug 3, 2023
0b11f1b
v1 docs
younesbelkada Aug 3, 2023
1b5c501
add remaining docs
stevhliu Aug 7, 2023
583174f
Apply suggestions from code review
younesbelkada Aug 17, 2023
180545f
Merge remote-tracking branch 'upstream/main' into peft-integration-at…
younesbelkada Aug 18, 2023
83d0f15
move to `lib_integrations`
younesbelkada Aug 18, 2023
f739aee
fixup
younesbelkada Aug 18, 2023
fccf419
this time fixup
younesbelkada Aug 18, 2023
fb6af42
Apply suggestions from code review
younesbelkada Aug 18, 2023
616cfec
address final comments
younesbelkada Aug 18, 2023
70b1570
refactor to use `token`
younesbelkada Aug 18, 2023
2934e69
add PEFT to DockerFile for slow tests.
younesbelkada Aug 18, 2023
3dd9211
added pipeline support.
younesbelkada Aug 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@

# Base objects, independent of any specific backend
_import_structure = {
"adapters": [],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure we will have that many adapters to justify a new folder. How about an integrations folder where we would have PEFT, maybe move bitsandbytes to it and deepspeed, and more generally put submodules likes to integrations with other libs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, what about moving all other folders in a follow up PR?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs to be done within the same release cycle so good for me as long as this is merged shortly after the next release branch cut (so that the PR moving everything is done before the release after)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds great!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decided to name the folder lib_integrations as a file integrations.py already exists (which created conflicts), let me know if you have a better name in mind (I can also rename integrations.py file)

"audio_utils": [],
"benchmark": [],
"commands": [],
Expand Down
15 changes: 15 additions & 0 deletions src/transformers/adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .peft_mixin import PeftAdapterMixin
217 changes: 217 additions & 0 deletions src/transformers/adapters/peft_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Optional

from ..utils import find_adapter_config_file, is_accelerate_available, is_peft_available, logging, requires_backends


if is_accelerate_available():
from accelerate import dispatch_model
from accelerate.utils import get_balanced_memory, infer_auto_device_map


logger = logging.get_logger(__name__)


class PeftAdapterMixin:
"""
A class containing all functions for loading and using adapters weights that are supported in PEFT library.
Currently supported PEFT methods are all non-prefix tuning methods
"""

_hf_peft_config_loaded = False

def load_adapter(
self,
peft_model_id: str,
adapter_name: Optional[str] = "default",
revision: Optional[str] = None,
use_auth_token: Optional[str] = None,
commit_hash: Optional[str] = None,
device_map: Optional[str] = "auto",
max_memory: Optional[int] = None,
offload_dir: Optional[str] = None,
offload_index: Optional[int] = None,
) -> None:
"""
Load adapter weights from file. Requires peft as a backend to load the adapter weights
"""
requires_backends(self.load_adapter, "peft")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, this also requires a specific peft version, right? But it seems that requiring a version is not supported.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I can add that in the next line

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree it would be nice to get a very good error message here


from peft import PeftConfig, create_and_replace, load_peft_weights
from peft.utils import set_peft_model_state_dict
from peft.utils.other import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING

if not self._hf_peft_config_loaded:
self.peft_config = {}
self._hf_peft_config_loaded = True

adapter_config_file = find_adapter_config_file(
peft_model_id,
revision=revision,
use_auth_token=use_auth_token,
commit_hash=commit_hash,
)

if adapter_config_file is None:
raise ValueError(
f"adapter model file not found in {peft_model_id}. Make sure you are passing the correct path to the "
"adapter model."
)

loaded_peft_config = PeftConfig.from_pretrained(
peft_model_id,
revision=revision,
use_auth_token=use_auth_token,
commit_hash=commit_hash,
)

if not hasattr(loaded_peft_config, "target_modules"):
target_modules = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[self.config.model_type]
loaded_peft_config.target_modules = target_modules

if adapter_name not in self.peft_config:
self.peft_config[adapter_name] = loaded_peft_config
else:
raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")

# Replace the adapter with the loaded adapter
create_and_replace(loaded_peft_config, self, adapter_name)

adapter_state_dict = load_peft_weights(
peft_model_id,
revision=revision,
use_auth_token=use_auth_token,
)

# We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility
processed_adapter_state_dict = {}
for key, value in adapter_state_dict.items():
if "base_model.model" in key:
new_key = key.replace("base_model.model.", "")
else:
new_key = key
processed_adapter_state_dict[new_key] = value

# Load state dict
incompatible_keys = set_peft_model_state_dict(self, processed_adapter_state_dict, adapter_name)

if incompatible_keys is not None:
# check only for unexpected keys
if hasattr(incompatible_keys, "unexpected_keys") and len(incompatible_keys.unexpected_keys) > 0:
logger.warning(
f"Loading adapter weights from {peft_model_id} led to unexpected keys not found in the model: "
f" {incompatible_keys.unexpected_keys}. "
)

# @pacman100 why this was needed?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be addressed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually it is addressed, I have discussed with @pacman100 offline, and this is needed to correctly dispatch the model into CPU /Disk in case the base model is offloaded

if (
(getattr(self, "hf_device_map", None) is not None)
and (len(set(self.hf_device_map.values()).intersection({"cpu", "disk"})) > 0)
and len(self.peft_config) == 1
):
self._dispatch_accelerate_model(
device_map=device_map, max_memory=max_memory, offload_dir=offload_dir, offload_index=offload_index
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think PeftModel.from_pretrained(self, peft_model_id) address all of this, WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes correct ! However I preferred to not use PeftModel and have a granular control on each thing we do (inject adapter, get state_dict, load it, dispatch if cpu offload) so that it will be easier to maintain if in the future we add more features or some important changes in PeftModel.from_pretrained.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But, on the contrary, if we fix PeftModel.from_pretrained, we need to update all the changes here too increasing the maintenance and breaking changes because from_pretrained would be always updated to be in sync with changes of PeftModel whereas here we might get an updated version of PeftModel but the current loading logic might not handle it. WDYT?

)

def set_adapter(self, adapter_name: str) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is an example of an "adapter_name"?
Does the user decide themselves what the adapter name should be?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should there also be a "remove_adapter" function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes users can decide what the adapter name should be when loading an adapter or adding an adapter, and switch between them using set_adapter
I am not sure we should add a remove_adapter method as we now have a disable_adapters method here: 2345681.
What kind of usage would you see for remove_adapter?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok!

remove_adapter could be useful to make the model lighter again and eventually completely remove all adapters. But probably better in a next iteration.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK having it in the next iteration sounds good!

r"""
Sets an adapter to switch easily between multiple adapters.
"""
requires_backends(self.set_adapter, "peft")
if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")
elif adapter_name not in self.peft_config:
raise ValueError(
f"Adapter with name {adapter_name} not found. Please pass the correct adapter name among {list(self.peft_config.keys())}"
)

from peft.tuners.tuners_utils import BaseTunerLayer

_adapters_has_been_set = False

for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
module.active_adapter = adapter_name
_adapters_has_been_set = True

if not _adapters_has_been_set:
raise ValueError(
"Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters."
)

@property
def current_active_adapter(self) -> str:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we really want to have a property maybe return none + a warning in case PEFT is not installed

r"""
Gets the current active adapter of the model.
"""
if not is_peft_available():
raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.")

if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")

from peft.tuners.tuners_utils import BaseTunerLayer

for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
return module.active_adapter

def _dispatch_accelerate_model(
self,
device_map: str,
max_memory: Optional[int] = None,
offload_dir: Optional[str] = None,
offload_index: Optional[int] = None,
) -> None:
r"""
Optionnal re-dispatch the model and attach new hooks to the model in case the model has been loaded with
accelerate (i.e. with `device_map=xxx`)

Args:
device_map (`str`):
The device map used to load the model with accelerate.
max_memory (`int`, `optional`):
The maximum memory argument to be passed to `accelerate.get_balanced_memory` method.
offload_dir (`str`, `optional`):
The offload_dir argument to be passed to `accelerate.dispatch_model` method.
offload_index (`int`, `optional`):
The offload_index argument to be passed to `accelerate.dispatch_model` method.
"""
dispatch_model_kwargs = {}
# Safety checker for previous `accelerate` versions
# `offload_index` was introduced in https://github.com/huggingface/accelerate/pull/873/
if "offload_index" in inspect.signature(dispatch_model).parameters:
dispatch_model_kwargs["offload_index"] = offload_index

no_split_module_classes = self._no_split_modules

if device_map != "sequential":
max_memory = get_balanced_memory(
self,
max_memory=max_memory,
no_split_module_classes=no_split_module_classes,
low_zero=(device_map == "balanced_low_0"),
)
if isinstance(device_map, str):
device_map = infer_auto_device_map(
self, max_memory=max_memory, no_split_module_classes=no_split_module_classes
)
dispatch_model(
self,
device_map=device_map,
offload_dir=offload_dir,
**dispatch_model_kwargs,
)
40 changes: 39 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from torch.nn import CrossEntropyLoss

from .activations import get_activation
from .adapters import PeftAdapterMixin
from .configuration_utils import PretrainedConfig
from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
from .dynamic_module_utils import custom_object_save
Expand Down Expand Up @@ -67,6 +68,7 @@
is_bitsandbytes_available,
is_offline_mode,
is_optimum_available,
is_peft_available,
is_remote_url,
is_safetensors_available,
is_torch_tpu_available,
Expand Down Expand Up @@ -113,6 +115,9 @@
else:
IS_SAGEMAKER_MP_POST_1_10 = False

if is_peft_available():
from .utils import find_adapter_config_file


@contextmanager
def no_init_weights(_enable=True):
Expand Down Expand Up @@ -1025,7 +1030,7 @@ def floating_point_ops(
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)


class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin):
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin):
r"""
Base class for all models.

Expand Down Expand Up @@ -2211,6 +2216,7 @@ def from_pretrained(
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)
variant = kwargs.pop("variant", None)
_adapter_model_path = kwargs.pop("_adapter_model_path", None)

if use_auth_token is not None:
warnings.warn(
Expand All @@ -2236,6 +2242,29 @@ def from_pretrained(
" ignored."
)

if is_peft_available() and _adapter_model_path is None:
maybe_adapter_model_path = find_adapter_config_file(
pretrained_model_name_or_path,
revision=revision,
subfolder=subfolder,
use_auth_token=use_auth_token,
commit_hash=commit_hash,
)
elif is_peft_available() and _adapter_model_path is not None:
maybe_adapter_model_path = _adapter_model_path
else:
maybe_adapter_model_path = None

has_adapter_config = maybe_adapter_model_path is not None

if has_adapter_config:
if _adapter_model_path is not None:
adapter_model_id = _adapter_model_path
else:
with open(maybe_adapter_model_path, "r", encoding="utf-8") as f:
adapter_model_id = pretrained_model_name_or_path
pretrained_model_name_or_path = json.load(f)["base_model_name_or_path"]

# change device_map into a map if we passed an int, a str or a torch.device
if isinstance(device_map, torch.device):
device_map = {"": device_map}
Expand Down Expand Up @@ -2981,6 +3010,15 @@ def from_pretrained(
kwargs["skip_keys"] = model._skip_keys_device_placement
dispatch_model(model, **kwargs)

if has_adapter_config:
model.load_adapter(
adapter_model_id,
adapter_name="default",
revision=revision,
use_auth_token=use_auth_token,
commit_hash=commit_hash,
)

if output_loading_info:
if loading_info is None:
loading_info = {
Expand Down
21 changes: 20 additions & 1 deletion src/transformers/models/auto/auto_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
"""Factory function to build auto-model classes."""
import copy
import importlib
import json
import os
import warnings
from collections import OrderedDict

from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ...utils import copy_func, logging, requires_backends
from ...utils import copy_func, find_adapter_config_file, is_peft_available, logging, requires_backends
from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings


Expand Down Expand Up @@ -469,6 +470,24 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
if token is not None:
hub_kwargs["token"] = token

if is_peft_available():
revision = kwargs.get("revision", None)
subfolder = kwargs.get("subfolder", None)

maybe_adapter_path = find_adapter_config_file(
pretrained_model_name_or_path,
revision=revision,
use_auth_token=use_auth_token,
subfolder=subfolder,
)

if maybe_adapter_path is not None:
with open(maybe_adapter_path, "r") as f:
adapter_config = json.load(f)

kwargs["_adapter_model_path"] = pretrained_model_name_or_path
pretrained_model_name_or_path = adapter_config["base_model_name_or_path"]

if not isinstance(config, PretrainedConfig):
kwargs_orig = copy.deepcopy(kwargs)
# ensure not to pollute the config object with torch_dtype="auto" - since it's
Expand Down
4 changes: 1 addition & 3 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,11 @@
requires_backends,
torch_only_method,
)
from .peft_utils import ADAPTER_CONFIG_NAME, ADAPTER_SAFE_WEIGHTS_NAME, ADAPTER_WEIGHTS_NAME, find_adapter_config_file


WEIGHTS_NAME = "pytorch_model.bin"
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
ADAPTER_CONFIG_NAME = "adapter_config.json"
ADAPTER_WEIGHTS_NAME = "adapter_model.bin"
ADAPTER_SAFE_WEIGHTS_NAME = "adapter_model.safetensors"
TF2_WEIGHTS_NAME = "tf_model.h5"
TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json"
TF_WEIGHTS_NAME = "model.ckpt"
Expand Down
Loading