Skip to content

Commit 3ca18d6

Browse files
younesbelkadaBenjaminBossanpatrickvonplaten
authored
[PEFT] Fix PEFT multi adapters support (#26407)
* fix PEFT multi adapters support * refactor a bit * save pretrained + BC + added tests * Update src/transformers/integrations/peft.py Co-authored-by: Benjamin Bossan <[email protected]> * add more tests * add suggestion * final changes * adapt a bit * fixup * Update src/transformers/integrations/peft.py Co-authored-by: Patrick von Platen <[email protected]> * adapt from suggestions --------- Co-authored-by: Benjamin Bossan <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
1 parent 946bac7 commit 3ca18d6

File tree

3 files changed

+76
-11
lines changed

3 files changed

+76
-11
lines changed

src/transformers/integrations/peft.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import inspect
15-
from typing import TYPE_CHECKING, Any, Dict, Optional
15+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
1616

1717
from ..utils import (
1818
check_peft_version,
@@ -245,20 +245,27 @@ def add_adapter(self, adapter_config, adapter_name: Optional[str] = None) -> Non
245245

246246
self.set_adapter(adapter_name)
247247

248-
def set_adapter(self, adapter_name: str) -> None:
248+
def set_adapter(self, adapter_name: Union[List[str], str]) -> None:
249249
"""
250250
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
251251
official documentation: https://huggingface.co/docs/peft
252252
253253
Sets a specific adapter by forcing the model to use a that adapter and disable the other adapters.
254254
255255
Args:
256-
adapter_name (`str`):
257-
The name of the adapter to set.
256+
adapter_name (`Union[List[str], str]`):
257+
The name of the adapter to set. Can be also a list of strings to set multiple adapters.
258258
"""
259259
check_peft_version(min_version=MIN_PEFT_VERSION)
260260
if not self._hf_peft_config_loaded:
261261
raise ValueError("No adapter loaded. Please load an adapter first.")
262+
elif isinstance(adapter_name, list):
263+
missing = set(adapter_name) - set(self.peft_config)
264+
if len(missing) > 0:
265+
raise ValueError(
266+
f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)."
267+
f" current loaded adapters are: {list(self.peft_config.keys())}"
268+
)
262269
elif adapter_name not in self.peft_config:
263270
raise ValueError(
264271
f"Adapter with name {adapter_name} not found. Please pass the correct adapter name among {list(self.peft_config.keys())}"
@@ -270,7 +277,11 @@ def set_adapter(self, adapter_name: str) -> None:
270277

271278
for _, module in self.named_modules():
272279
if isinstance(module, BaseTunerLayer):
273-
module.active_adapter = adapter_name
280+
# For backward compatbility with previous PEFT versions
281+
if hasattr(module, "set_adapter"):
282+
module.set_adapter(adapter_name)
283+
else:
284+
module.active_adapter = adapter_name
274285
_adapters_has_been_set = True
275286

276287
if not _adapters_has_been_set:
@@ -294,7 +305,11 @@ def disable_adapters(self) -> None:
294305

295306
for _, module in self.named_modules():
296307
if isinstance(module, BaseTunerLayer):
297-
module.disable_adapters = True
308+
# The recent version of PEFT need to call `enable_adapters` instead
309+
if hasattr(module, "enable_adapters"):
310+
module.enable_adapters(enabled=False)
311+
else:
312+
module.disable_adapters = True
298313

299314
def enable_adapters(self) -> None:
300315
"""
@@ -312,14 +327,22 @@ def enable_adapters(self) -> None:
312327

313328
for _, module in self.named_modules():
314329
if isinstance(module, BaseTunerLayer):
315-
module.disable_adapters = False
330+
# The recent version of PEFT need to call `enable_adapters` instead
331+
if hasattr(module, "enable_adapters"):
332+
module.enable_adapters(enabled=True)
333+
else:
334+
module.disable_adapters = False
316335

317-
def active_adapter(self) -> str:
336+
def active_adapters(self) -> List[str]:
318337
"""
319338
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
320339
official documentation: https://huggingface.co/docs/peft
321340
322-
Gets the current active adapter of the model.
341+
Gets the current active adapters of the model. In case of multi-adapter inference (combining multiple adapters
342+
for inference) returns the list of all active adapters so that users can deal with them accordingly.
343+
344+
For previous PEFT versions (that does not support multi-adapter inference), `module.active_adapter` will return
345+
a single string.
323346
"""
324347
check_peft_version(min_version=MIN_PEFT_VERSION)
325348

@@ -333,7 +356,21 @@ def active_adapter(self) -> str:
333356

334357
for _, module in self.named_modules():
335358
if isinstance(module, BaseTunerLayer):
336-
return module.active_adapter
359+
active_adapters = module.active_adapter
360+
break
361+
362+
# For previous PEFT versions
363+
if isinstance(active_adapters, str):
364+
active_adapters = [active_adapters]
365+
366+
return active_adapters
367+
368+
def active_adapter(self) -> str:
369+
logger.warning(
370+
"The `active_adapter` method is deprecated and will be removed in a future version. ", FutureWarning
371+
)
372+
373+
return self.active_adapters()[0]
337374

338375
def get_adapter_state_dict(self, adapter_name: Optional[str] = None) -> dict:
339376
"""

src/transformers/modeling_utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2006,7 +2006,16 @@ def save_pretrained(
20062006
peft_state_dict[f"base_model.model.{key}"] = value
20072007
state_dict = peft_state_dict
20082008

2009-
current_peft_config = self.peft_config[self.active_adapter()]
2009+
active_adapter = self.active_adapters()
2010+
2011+
if len(active_adapter) > 1:
2012+
raise ValueError(
2013+
"Multiple active adapters detected, saving multiple active adapters is not supported yet. You can save adapters separately one by one "
2014+
"by iteratively calling `model.set_adapter(adapter_name)` then `model.save_pretrained(...)`"
2015+
)
2016+
active_adapter = active_adapter[0]
2017+
2018+
current_peft_config = self.peft_config[active_adapter]
20102019
current_peft_config.save_pretrained(save_directory)
20112020

20122021
# Save the model

tests/peft_integration/test_peft_integration.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,9 +265,11 @@ def test_peft_add_multi_adapter(self):
265265
_ = model.generate(input_ids=dummy_input)
266266

267267
model.set_adapter("default")
268+
self.assertTrue(model.active_adapters() == ["default"])
268269
self.assertTrue(model.active_adapter() == "default")
269270

270271
model.set_adapter("adapter-2")
272+
self.assertTrue(model.active_adapters() == ["adapter-2"])
271273
self.assertTrue(model.active_adapter() == "adapter-2")
272274

273275
# Logits comparison
@@ -276,6 +278,23 @@ def test_peft_add_multi_adapter(self):
276278
)
277279
self.assertFalse(torch.allclose(logits_original_model, logits_adapter_2.logits, atol=1e-6, rtol=1e-6))
278280

281+
model.set_adapter(["adapter-2", "default"])
282+
self.assertTrue(model.active_adapters() == ["adapter-2", "default"])
283+
self.assertTrue(model.active_adapter() == "adapter-2")
284+
285+
logits_adapter_mixed = model(dummy_input)
286+
self.assertFalse(
287+
torch.allclose(logits_adapter_1.logits, logits_adapter_mixed.logits, atol=1e-6, rtol=1e-6)
288+
)
289+
290+
self.assertFalse(
291+
torch.allclose(logits_adapter_2.logits, logits_adapter_mixed.logits, atol=1e-6, rtol=1e-6)
292+
)
293+
294+
# multi active adapter saving not supported
295+
with self.assertRaises(ValueError), tempfile.TemporaryDirectory() as tmpdirname:
296+
model.save_pretrained(tmpdirname)
297+
279298
@require_torch_gpu
280299
def test_peft_from_pretrained_kwargs(self):
281300
"""

0 commit comments

Comments
 (0)