Skip to content
33 changes: 29 additions & 4 deletions src/transformers/integrations/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,11 @@ def set_adapter(self, adapter_name: str) -> None:

for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
module.active_adapter = adapter_name
# For backward compatbility with previous PEFT versions
if hasattr(module, "set_adapter"):
module.set_adapter(adapter_name)
else:
module.active_adapter = adapter_name
_adapters_has_been_set = True

if not _adapters_has_been_set:
Expand All @@ -294,7 +298,11 @@ def disable_adapters(self) -> None:

for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
module.disable_adapters = True
# The recent version of PEFT need to call `enable_adapters` instead
if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=False)
else:
module.disable_adapters = True

def enable_adapters(self) -> None:
"""
Expand All @@ -312,7 +320,11 @@ def enable_adapters(self) -> None:

for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
module.disable_adapters = False
# The recent version of PEFT need to call `enable_adapters` instead
if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=True)
else:
module.disable_adapters = False

def active_adapter(self) -> str:
"""
Expand All @@ -333,7 +345,11 @@ def active_adapter(self) -> str:

for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
return module.active_adapter
active_adapter = module.active_adapter
if isinstance(active_adapter, list):
# In case the adapter name is a list (multiple adapters), we only consider the first one
active_adapter = active_adapter[0]
return active_adapter

def get_adapter_state_dict(self, adapter_name: Optional[str] = None) -> dict:
"""
Expand All @@ -357,6 +373,15 @@ def get_adapter_state_dict(self, adapter_name: Optional[str] = None) -> dict:
if adapter_name is None:
adapter_name = self.active_adapter()

if isinstance(adapter_name, list):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible for this method to be called with a list of str? Wouldn't this require the user to explicitly pass that argument? If not, I think this check is not necessary. If it is a valid argument, then the type annotation should also be adjusted.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I don't think we should allow users to call this method with a list of str - I refactored a bit the logic of def active_adapter to extend it for multi-adapter inference. let me know what do you think

# In case the adapter name is a list (multiple adapters), we only consider the first one
adapter_name = adapter_name[0]

logger.warning(
"Multiple adapters detected, we will only consider the first adapter, to get all adapters state dict manually loop "
"over the list of adapters and call `get_adapter_state_dict` for each adapter."
)

adapter_state_dict = get_peft_model_state_dict(self, adapter_name=adapter_name)
return adapter_state_dict

Expand Down