Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/peft/tuners/xlora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ def _maybe_freeze_all_adapters(self):
param.requires_grad = False

def generate(self, *args, **kwargs):
kwargs["use_cache"] = False
res = self.lora_model.generate(*args, **kwargs) # type: ignore
# This is necessary because we use PeftModel.disable_adapter() which reenables the adapters
self._maybe_freeze_all_adapters()
Expand Down
26 changes: 22 additions & 4 deletions tests/test_xlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import os
from functools import wraps

import huggingface_hub
import pytest
Expand All @@ -25,6 +26,25 @@
from peft.utils import infer_device


def flaky(num_tries: int):
"""Decorator for test functions that are flaky"""

def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
for _ in range(num_tries):
try:
return func(*args, **kwargs)
except AssertionError as e:
print(f"Failed test {func.__name__} with error: {e}")
continue
raise AssertionError(f"Failed test {func.__name__} after {num_tries} tries")

return wrapper

return decorator


class TestXlora:
torch_device = infer_device()

Expand Down Expand Up @@ -128,8 +148,6 @@ def test_functional(self, tokenizer, model):
)
assert torch.isfinite(outputs[: inputs.shape[1] :]).all()

# TODO: fix the xfailing test
@pytest.mark.xfail
def test_scalings_logging_methods(self, tokenizer, model):
model.enable_scalings_logging()

Expand Down Expand Up @@ -182,8 +200,8 @@ def test_misc_methods(self, tokenizer, model):

assert str(model) is not None

# TODO: On CI (but not locally), this test seems to have become flaky with the latest transformers changes (v4.45).
@pytest.mark.xfail
# On CI (but not locally), this test is flaky since transformers v4.45.0.
@flaky(num_tries=5)
def test_save_load_functional(self, tokenizer, model, tmp_path):
inputs = tokenizer.encode("Python is a", add_special_tokens=False, return_tensors="pt")
outputs = model.generate(
Expand Down