Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
4cb6f6e
v2 integration
younesbelkada Jul 25, 2023
d41717c
sanity check
younesbelkada Jul 25, 2023
375a238
Merge remote-tracking branch 'upstream/main' into integration-v2
younesbelkada Jul 26, 2023
a08e858
major refactor
younesbelkada Jul 26, 2023
420d39c
fix doc
younesbelkada Jul 26, 2023
9a9138e
some refactor
younesbelkada Jul 26, 2023
088a1c1
add abstract class
younesbelkada Jul 26, 2023
57fa268
refactor a bit
younesbelkada Jul 27, 2023
d04a91b
properly freeze the base model
younesbelkada Jul 27, 2023
4a08ecb
more refactor
younesbelkada Jul 27, 2023
1e99c8e
v1 of a better abstraction
younesbelkada Jul 27, 2023
6acfeff
better refactoring
younesbelkada Jul 28, 2023
10a9ab2
use classmethod instead
younesbelkada Jul 28, 2023
8fc3a63
Merge branch 'main' into integration-v2
younesbelkada Jul 28, 2023
2817d2f
style
younesbelkada Jul 28, 2023
66a5e3d
fix test
younesbelkada Jul 28, 2023
221ab34
addressed general comments
younesbelkada Jul 28, 2023
3da5b6d
fix failing CIs
younesbelkada Jul 31, 2023
8827879
not checking the type of the args
younesbelkada Jul 31, 2023
2cac1d6
remove `pre_init`
younesbelkada Jul 31, 2023
2d54ea6
Update src/peft/tuners/tuners_utils.py
younesbelkada Jul 31, 2023
b8bd2e2
Update src/peft/tuners/tuners_utils.py
younesbelkada Jul 31, 2023
8efb5f4
tiny comments to remove later
younesbelkada Jul 31, 2023
a92e5ea
adapt from suggestions
younesbelkada Jul 31, 2023
45e8877
added docstrings.
younesbelkada Jul 31, 2023
dbf764f
remove check
younesbelkada Jul 31, 2023
88c91ae
remove manual freezing
younesbelkada Jul 31, 2023
4ce8541
added comments to explain hack.
younesbelkada Jul 31, 2023
ead8545
fix docstring
younesbelkada Jul 31, 2023
c3e6b61
fix docstring
younesbelkada Jul 31, 2023
7648c27
remove `supports_merging`
younesbelkada Jul 31, 2023
8f09a8e
remove unneeded comments
younesbelkada Jul 31, 2023
4d1dffb
extensive docstring on `BaseTuner`
younesbelkada Jul 31, 2023
a6be082
improve type hints
younesbelkada Jul 31, 2023
a032f62
add more comments.
younesbelkada Jul 31, 2023
fccae04
add final suggestion
younesbelkada Jul 31, 2023
2e632f1
Merge branch 'main' into integration-v2
younesbelkada Aug 1, 2023
f8fcba7
minor tweaks
younesbelkada Aug 1, 2023
84a56e0
a bit of refactoring and addressing comments.
younesbelkada Aug 2, 2023
8b88401
use `child` instead of `target`
younesbelkada Aug 2, 2023
8053c88
few other nits
younesbelkada Aug 2, 2023
b4adb6d
Merge branch 'main' into integration-v2
younesbelkada Aug 3, 2023
7171188
nits
younesbelkada Aug 3, 2023
758ecb8
Update src/peft/mapping.py
younesbelkada Aug 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/package_reference/config.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ The configuration classes stores the configuration of a [`PeftModel`], PEFT adap

## PeftConfigMixin

[[autodoc]] utils.config.PeftConfigMixin
[[autodoc]] config.PeftConfigMixin
- all

## PeftConfig
Expand Down
12 changes: 9 additions & 3 deletions src/peft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@
AutoPeftModelForQuestionAnswering,
AutoPeftModelForFeatureExtraction,
)
from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PEFT_TYPE_TO_CONFIG_MAPPING, get_peft_config, get_peft_model
from .mapping import (
MODEL_TYPE_TO_PEFT_MODEL_MAPPING,
PEFT_TYPE_TO_CONFIG_MAPPING,
get_peft_config,
get_peft_model,
create_and_replace,
)
from .peft_model import (
PeftModel,
PeftModelForCausalLM,
Expand Down Expand Up @@ -58,14 +64,14 @@
)
from .utils import (
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
PeftConfig,
PeftType,
PromptLearningConfig,
TaskType,
bloom_model_postprocess_past_key_value,
get_peft_model_state_dict,
prepare_model_for_int8_training,
prepare_model_for_kbit_training,
set_peft_model_state_dict,
shift_tokens_right,
load_peft_weights,
)
from .config import PeftConfig, PromptLearningConfig
2 changes: 1 addition & 1 deletion src/peft/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
AutoModelForTokenClassification,
)

from .config import PeftConfig
from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING
from .peft_model import (
PeftModel,
Expand All @@ -37,7 +38,6 @@
PeftModelForSequenceClassification,
PeftModelForTokenClassification,
)
from .utils import PeftConfig


class _BaseAutoPeftModel:
Expand Down
47 changes: 25 additions & 22 deletions src/peft/utils/config.py → src/peft/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import inspect
import json
import os
Expand All @@ -22,26 +21,7 @@
from huggingface_hub import hf_hub_download
from transformers.utils import PushToHubMixin

from .other import CONFIG_NAME


class PeftType(str, enum.Enum):
PROMPT_TUNING = "PROMPT_TUNING"
P_TUNING = "P_TUNING"
PREFIX_TUNING = "PREFIX_TUNING"
LORA = "LORA"
ADALORA = "ADALORA"
ADAPTION_PROMPT = "ADAPTION_PROMPT"
IA3 = "IA3"


class TaskType(str, enum.Enum):
SEQ_CLS = "SEQ_CLS"
SEQ_2_SEQ_LM = "SEQ_2_SEQ_LM"
CAUSAL_LM = "CAUSAL_LM"
TOKEN_CLS = "TOKEN_CLS"
QUESTION_ANS = "QUESTION_ANS"
FEATURE_EXTRACTION = "FEATURE_EXTRACTION"
from .utils import CONFIG_NAME, PeftType, TaskType


@dataclass
Expand Down Expand Up @@ -102,6 +82,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, subfolder=None, **kwargs
kwargs (additional keyword arguments, *optional*):
Additional keyword arguments passed along to the child class initialization.
"""
# Avoid circular dependency .. TODO: fix this with a larger refactor
from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING

path = (
os.path.join(pretrained_model_name_or_path, subfolder)
if subfolder is not None
Expand All @@ -122,7 +105,13 @@ def from_pretrained(cls, pretrained_model_name_or_path, subfolder=None, **kwargs

loaded_attributes = cls.from_json_file(config_file)

config = cls(**class_kwargs)
if "peft_type" in loaded_attributes:
peft_type = loaded_attributes["peft_type"]
config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_type]
else:
config_cls = cls

config = config_cls(**class_kwargs)

for key, value in loaded_attributes.items():
if hasattr(config, key):
Expand Down Expand Up @@ -185,6 +174,13 @@ def _get_peft_type(
loaded_attributes = cls.from_json_file(config_file)
return loaded_attributes["peft_type"]

@property
def is_prompt_learning(self):
r"""
Utility method to check if the configuration is for prompt learning.
"""
return False


@dataclass
class PeftConfig(PeftConfigMixin):
Expand Down Expand Up @@ -227,3 +223,10 @@ class PromptLearningConfig(PeftConfig):
)
num_attention_heads: Optional[int] = field(default=None, metadata={"help": "Number of attention heads"})
num_layers: Optional[int] = field(default=None, metadata={"help": "Number of transformer layers"})

@property
def is_prompt_learning(self):
r"""
Utility method to check if the configuration is for prompt learning.
"""
return True
55 changes: 48 additions & 7 deletions src/peft/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from typing import TYPE_CHECKING, Any, Dict

from .config import PeftConfig
from .peft_model import (
PeftModel,
PeftModelForCausalLM,
Expand All @@ -27,6 +28,7 @@
PeftModelForTokenClassification,
)
from .tuners import (
TUNERS_MAPPING,
AdaLoraConfig,
AdaptionPromptConfig,
IA3Config,
Expand All @@ -35,14 +37,12 @@
PromptEncoderConfig,
PromptTuningConfig,
)
from .utils import PromptLearningConfig, _prepare_prompt_learning_config
from .utils import _get_submodules, _prepare_prompt_learning_config


if TYPE_CHECKING:
from transformers import PreTrainedModel

from .utils.config import PeftConfig


MODEL_TYPE_TO_PEFT_MODEL_MAPPING = {
"SEQ_CLS": PeftModelForSequenceClassification,
Expand Down Expand Up @@ -89,10 +89,51 @@ def get_peft_model(model: PreTrainedModel, peft_config: PeftConfig, adapter_name

peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None)

if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not isinstance(
peft_config, PromptLearningConfig
):
if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not peft_config.is_prompt_learning:
return PeftModel(model, peft_config, adapter_name=adapter_name)
if isinstance(peft_config, PromptLearningConfig):
if peft_config.is_prompt_learning:
peft_config = _prepare_prompt_learning_config(peft_config, model_config)
return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](model, peft_config, adapter_name=adapter_name)


# TODO: docstring and typehints
def create_and_replace(peft_config, model, adapter_name):
if not isinstance(peft_config, PeftConfig):
raise ValueError(f"peft_config must be an instance of PeftConfig got {type(peft_config)} instead.")

peft_type = peft_config.peft_type

if peft_type not in TUNERS_MAPPING:
raise ValueError(
f"Task type {peft_type} is not supported. Supported task types are {list(TUNERS_MAPPING.keys())}"
)
tuner_cls = TUNERS_MAPPING[peft_type]

# TODO: test that
for module in model.modules():
if not getattr(module, "_is_peft_tuner_layer", False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering: Could we check instead if isinstance(module, BaseTunerLayerMixin) or do we expect BaseTunerLayerMixin with Is_peft_tuner_layer=False or do we want to allow layers with Is_peft_tuner_layer=True that are not BaseTunerLayerMixin? I'm not suggesting to change, rather I want to understand the intent.

Also, I wonder if this is redundant, because add_adapter calls _mark_only_adapters_as_trainable, which seems to do the same job, but it's more specialized, as it is implemented on each tuner class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed with all the points you have shared here

module.requires_grad_(False)

is_target_modules_in_base_model = False
key_list = [key for key, _ in model.named_modules()]

for key in key_list:
if not tuner_cls._check_target_module_exists(peft_config, key):
continue

is_target_modules_in_base_model = True
parent, target, target_name = _get_submodules(model, key)

optionnal_kwargs = {
"loaded_in_8bit": getattr(model, "is_loaded_in_8bit", False),
"loaded_in_4bit": getattr(model, "is_loaded_in_4bit", False),
"current_key": key,
}

tuner_cls.create_and_replace(peft_config, adapter_name, target, target_name, parent, **optionnal_kwargs)

if not is_target_modules_in_base_model:
raise ValueError(
f"Target modules {peft_config.target_modules} not found in the base model. "
f"Please check the target modules and try again."
)
Loading