1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414import inspect
15- from typing import TYPE_CHECKING , Any , Dict , Optional
15+ from typing import TYPE_CHECKING , Any , Dict , List , Optional , Union
1616
1717from ..utils import (
1818 check_peft_version ,
@@ -245,20 +245,27 @@ def add_adapter(self, adapter_config, adapter_name: Optional[str] = None) -> Non
245245
246246 self .set_adapter (adapter_name )
247247
248- def set_adapter (self , adapter_name : str ) -> None :
248+ def set_adapter (self , adapter_name : Union [ List [ str ], str ] ) -> None :
249249 """
250250 If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
251251 official documentation: https://huggingface.co/docs/peft
252252
253253 Sets a specific adapter by forcing the model to use a that adapter and disable the other adapters.
254254
255255 Args:
256- adapter_name (`str`):
257- The name of the adapter to set.
256+ adapter_name (`Union[List[ str], str] `):
257+ The name of the adapter to set. Can be also a list of strings to set multiple adapters.
258258 """
259259 check_peft_version (min_version = MIN_PEFT_VERSION )
260260 if not self ._hf_peft_config_loaded :
261261 raise ValueError ("No adapter loaded. Please load an adapter first." )
262+ elif isinstance (adapter_name , list ):
263+ missing = set (adapter_name ) - set (self .peft_config )
264+ if len (missing ) > 0 :
265+ raise ValueError (
266+ f"Following adapter(s) could not be found: { ', ' .join (missing )} . Make sure you are passing the correct adapter name(s)."
267+ f" current loaded adapters are: { list (self .peft_config .keys ())} "
268+ )
262269 elif adapter_name not in self .peft_config :
263270 raise ValueError (
264271 f"Adapter with name { adapter_name } not found. Please pass the correct adapter name among { list (self .peft_config .keys ())} "
@@ -270,7 +277,11 @@ def set_adapter(self, adapter_name: str) -> None:
270277
271278 for _ , module in self .named_modules ():
272279 if isinstance (module , BaseTunerLayer ):
273- module .active_adapter = adapter_name
280+ # For backward compatbility with previous PEFT versions
281+ if hasattr (module , "set_adapter" ):
282+ module .set_adapter (adapter_name )
283+ else :
284+ module .active_adapter = adapter_name
274285 _adapters_has_been_set = True
275286
276287 if not _adapters_has_been_set :
@@ -294,7 +305,11 @@ def disable_adapters(self) -> None:
294305
295306 for _ , module in self .named_modules ():
296307 if isinstance (module , BaseTunerLayer ):
297- module .disable_adapters = True
308+ # The recent version of PEFT need to call `enable_adapters` instead
309+ if hasattr (module , "enable_adapters" ):
310+ module .enable_adapters (enabled = False )
311+ else :
312+ module .disable_adapters = True
298313
299314 def enable_adapters (self ) -> None :
300315 """
@@ -312,14 +327,22 @@ def enable_adapters(self) -> None:
312327
313328 for _ , module in self .named_modules ():
314329 if isinstance (module , BaseTunerLayer ):
315- module .disable_adapters = False
330+ # The recent version of PEFT need to call `enable_adapters` instead
331+ if hasattr (module , "enable_adapters" ):
332+ module .enable_adapters (enabled = True )
333+ else :
334+ module .disable_adapters = False
316335
317- def active_adapter (self ) -> str :
336+ def active_adapters (self ) -> List [ str ] :
318337 """
319338 If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
320339 official documentation: https://huggingface.co/docs/peft
321340
322- Gets the current active adapter of the model.
341+ Gets the current active adapters of the model. In case of multi-adapter inference (combining multiple adapters
342+ for inference) returns the list of all active adapters so that users can deal with them accordingly.
343+
344+ For previous PEFT versions (that does not support multi-adapter inference), `module.active_adapter` will return
345+ a single string.
323346 """
324347 check_peft_version (min_version = MIN_PEFT_VERSION )
325348
@@ -333,7 +356,21 @@ def active_adapter(self) -> str:
333356
334357 for _ , module in self .named_modules ():
335358 if isinstance (module , BaseTunerLayer ):
336- return module .active_adapter
359+ active_adapters = module .active_adapter
360+ break
361+
362+ # For previous PEFT versions
363+ if isinstance (active_adapters , str ):
364+ active_adapters = [active_adapters ]
365+
366+ return active_adapters
367+
368+ def active_adapter (self ) -> str :
369+ logger .warning (
370+ "The `active_adapter` method is deprecated and will be removed in a future version. " , FutureWarning
371+ )
372+
373+ return self .active_adapters ()[0 ]
337374
338375 def get_adapter_state_dict (self , adapter_name : Optional [str ] = None ) -> dict :
339376 """
0 commit comments