Skip to content

Commit b578de0

Browse files
FIX Skip X-LoRA tests for transformers >= 4.49.0
Since the latest transformers release of v4.49.0, X-LoRA tests are broken. The PR that caused it was: huggingface/transformers#35724 For the time being, let's skip the X-LoRA tests if this transformers version is detected and also advice users against using X-LoRA with this transformers version.
1 parent 6d03360 commit b578de0

4 files changed

Lines changed: 27 additions & 2 deletions

File tree

docs/source/package_reference/xlora.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ rendered properly in your Markdown viewer.
1616

1717
# X-LoRA
1818

19+
> [!WARNING]
20+
> X-LoRA is broken for transformers version 4.49.0 and higher. It is recommended to use it with an older transformers version or to resort to other PEFT methods.
21+
1922
Mixture of LoRA Experts ([X-LoRA](https://arxiv.org/abs/2402.07148)) is a PEFT method enabling sparse or dense mixture of LoRA experts based on a high granularity (token, layer, sequence) scalings matrix. This leverages frozen LoRA adapters and a frozen base model to drastically reduces the number of parameters that need to be fine-tuned.
2023

2124
A unique aspect of X-LoRA is its versatility: it can be applied to any `transformers` base model with LoRA adapters. This means that, despite the mixture of experts strategy, no changes to the model code must be made.

src/peft/tuners/xlora/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414

1515
from peft.utils import register_peft_method
1616

17-
from .config import XLoraConfig
17+
from .config import XLORA_TRANSFORMERS_MAX_VERSION, XLoraConfig
1818
from .model import XLoraModel
1919

2020

21-
__all__ = ["XLoraConfig", "XLoraModel"]
21+
__all__ = ["XLORA_TRANSFORMERS_MAX_VERSION", "XLoraConfig", "XLoraModel"]
2222

2323
register_peft_method(name="xlora", config_cls=XLoraConfig, model_cls=XLoraModel)

src/peft/tuners/xlora/config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,20 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16+
import importlib
1617
import warnings
1718
from dataclasses import dataclass
1819
from typing import Optional
1920

21+
from packaging import version
22+
2023
from peft.config import PeftConfig
2124
from peft.utils.peft_types import PeftType
2225

2326

27+
XLORA_TRANSFORMERS_MAX_VERSION = "4.49.0"
28+
29+
2430
@dataclass
2531
class XLoraConfig(PeftConfig):
2632
r"""
@@ -79,6 +85,12 @@ def __post_init__(self):
7985
super().__post_init__()
8086
self.peft_type = PeftType.XLORA
8187

88+
if version.parse(importlib.metadata.version("transformers")) >= version.parse(XLORA_TRANSFORMERS_MAX_VERSION):
89+
warnings.warn(
90+
f"X-LoRA is currently broken with transformers {XLORA_TRANSFORMERS_MAX_VERSION}, it is recommended to "
91+
"use an older transformers version or a different PEFT method"
92+
)
93+
8294
if self.hidden_size is None:
8395
warnings.warn(
8496
"No value was provided for `hidden_size`. This will be set to 4096 by default, please ensure that this is correct."

tests/test_xlora.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,29 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import importlib
1516
import os
1617

1718
import huggingface_hub
1819
import pytest
1920
import torch
21+
from packaging import version
2022
from safetensors.torch import load_file
2123
from transformers import AutoModelForCausalLM, AutoTokenizer
2224

2325
from peft import LoraConfig, PeftType, TaskType, XLoraConfig, get_peft_model
2426
from peft.peft_model import PeftModel
27+
from peft.tuners.xlora import XLORA_TRANSFORMERS_MAX_VERSION
2528
from peft.utils import infer_device
2629

2730

31+
TRANSFORMERS_VERSION = version.parse(importlib.metadata.version("transformers"))
32+
33+
34+
@pytest.mark.skipif(
35+
TRANSFORMERS_VERSION >= version.parse(XLORA_TRANSFORMERS_MAX_VERSION),
36+
reason="X-LoRA is currently broken with the given transformers version, thus skipping tests",
37+
)
2838
class TestXlora:
2939
torch_device = infer_device()
3040

0 commit comments

Comments
 (0)