Skip to content

Commit 1793a95

Browse files
FIX: Avoid caching in X-LoRA generate (#2384)
X-LoRA tests started failing after this transformers PR: huggingface/transformers#35724 The solution appears to be to disable caching completely when calling generate on the X-LoRA model. This also makes some previously xfail-ing tests pass. I tested this locally with transformers checked out before and after the mentioned PR and the tests pass in both circumstances. I also tested changing the base model from "facebook/opt-125m" to "trl-internal-testing/tiny-random-LlamaForCausalLM" and the tests passed with both. Also, mark X-LoRA save_load_function test as flaky. It was marked as xfail beforehand, but it is in fact just flaky.
1 parent 1e2d6b5 commit 1793a95

2 files changed

Lines changed: 23 additions & 4 deletions

File tree

src/peft/tuners/xlora/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ def _maybe_freeze_all_adapters(self):
314314
param.requires_grad = False
315315

316316
def generate(self, *args, **kwargs):
317+
kwargs["use_cache"] = False
317318
res = self.lora_model.generate(*args, **kwargs) # type: ignore
318319
# This is necessary because we use PeftModel.disable_adapter() which reenables the adapters
319320
self._maybe_freeze_all_adapters()

tests/test_xlora.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import os
16+
from functools import wraps
1617

1718
import huggingface_hub
1819
import pytest
@@ -25,6 +26,25 @@
2526
from peft.utils import infer_device
2627

2728

29+
def flaky(num_tries: int):
30+
"""Decorator for test functions that are flaky"""
31+
32+
def decorator(func):
33+
@wraps(func)
34+
def wrapper(*args, **kwargs):
35+
for _ in range(num_tries):
36+
try:
37+
return func(*args, **kwargs)
38+
except AssertionError as e:
39+
print(f"Failed test {func.__name__} with error: {e}")
40+
continue
41+
raise AssertionError(f"Failed test {func.__name__} after {num_tries} tries")
42+
43+
return wrapper
44+
45+
return decorator
46+
47+
2848
class TestXlora:
2949
torch_device = infer_device()
3050

@@ -128,8 +148,6 @@ def test_functional(self, tokenizer, model):
128148
)
129149
assert torch.isfinite(outputs[: inputs.shape[1] :]).all()
130150

131-
# TODO: fix the xfailing test
132-
@pytest.mark.xfail
133151
def test_scalings_logging_methods(self, tokenizer, model):
134152
model.enable_scalings_logging()
135153

@@ -182,8 +200,8 @@ def test_misc_methods(self, tokenizer, model):
182200

183201
assert str(model) is not None
184202

185-
# TODO: On CI (but not locally), this test seems to have become flaky with the latest transformers changes (v4.45).
186-
@pytest.mark.xfail
203+
# On CI (but not locally), this test is flaky since transformers v4.45.0.
204+
@flaky(num_tries=5)
187205
def test_save_load_functional(self, tokenizer, model, tmp_path):
188206
inputs = tokenizer.encode("Python is a", add_special_tokens=False, return_tensors="pt")
189207
outputs = model.generate(

0 commit comments

Comments
 (0)