Skip to content
Merged
4 changes: 2 additions & 2 deletions tests/slow/test_dpo_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from transformers.testing_utils import backend_empty_cache, require_torch_accelerator, torch_device
from transformers.testing_utils import backend_empty_cache, torch_device
from transformers.utils import is_peft_available

from trl import DPOConfig, DPOTrainer

from ..testing_utils import TrlTestCase, require_bitsandbytes, require_peft
from ..testing_utils import TrlTestCase, require_bitsandbytes, require_peft, require_torch_accelerator
from .testing_constants import DPO_LOSS_TYPES, DPO_PRECOMPUTE_LOGITS, GRADIENT_CHECKPOINTING_KWARGS, MODELS_TO_TEST


Expand Down
18 changes: 10 additions & 8 deletions tests/slow/test_grpo_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,21 @@
AutoTokenizer,
BitsAndBytesConfig,
)
from transformers.testing_utils import (
backend_empty_cache,
require_flash_attn,
require_liger_kernel,
require_torch_accelerator,
torch_device,
)
from transformers.testing_utils import backend_empty_cache, torch_device
from transformers.utils import is_peft_available

from trl import GRPOConfig, GRPOTrainer
from trl.trainer.utils import get_kbit_device_map

from ..testing_utils import TrlTestCase, require_bitsandbytes, require_peft, require_vllm
from ..testing_utils import (
TrlTestCase,
require_bitsandbytes,
require_flash_attn,
require_liger_kernel,
require_peft,
require_torch_accelerator,
require_vllm,
)
from .testing_constants import MODELS_TO_TEST


Expand Down
17 changes: 9 additions & 8 deletions tests/slow/test_sft_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,19 @@
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from transformers.testing_utils import (
backend_empty_cache,
require_liger_kernel,
require_torch_accelerator,
require_torch_multi_accelerator,
torch_device,
)
from transformers.testing_utils import backend_empty_cache, torch_device
from transformers.utils import is_peft_available

from trl import SFTConfig, SFTTrainer

from ..testing_utils import TrlTestCase, require_bitsandbytes, require_peft
from ..testing_utils import (
TrlTestCase,
require_bitsandbytes,
require_liger_kernel,
require_peft,
require_torch_accelerator,
require_torch_multi_accelerator,
)
from .testing_constants import DEVICE_MAP_OPTIONS, GRADIENT_CHECKPOINTING_KWARGS, MODELS_TO_TEST, PACKING_OPTIONS


Expand Down
4 changes: 2 additions & 2 deletions tests/test_activation_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
import torch
from torch import nn
from transformers import AutoModelForCausalLM
from transformers.testing_utils import require_torch_accelerator, torch_device
from transformers.testing_utils import torch_device
from transformers.utils import is_peft_available

from trl.models.activation_offloading import NoOpManager, OffloadActivations

from .testing_utils import TrlTestCase, require_peft
from .testing_utils import TrlTestCase, require_peft, require_torch_accelerator


if is_peft_available():
Expand Down
3 changes: 1 addition & 2 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, Trainer, TrainingArguments
from transformers.testing_utils import require_wandb
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import is_peft_available

Expand All @@ -33,7 +32,7 @@
)
from trl.mergekit_utils import MergeConfig

from .testing_utils import TrlTestCase, require_comet, require_mergekit, require_peft
from .testing_utils import TrlTestCase, require_comet, require_mergekit, require_peft, require_wandb


if is_peft_available():
Expand Down
6 changes: 2 additions & 4 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,14 @@
PreTrainedTokenizerBase,
is_vision_available,
)
from transformers.testing_utils import (
get_device_properties,
require_liger_kernel,
)
from transformers.testing_utils import get_device_properties

from trl import DPOConfig, DPOTrainer, FDivergenceType

from .testing_utils import (
TrlTestCase,
require_bitsandbytes,
require_liger_kernel,
require_no_wandb,
require_peft,
require_torch_gpu_if_bnb_not_multi_backend_enabled,
Expand Down
3 changes: 1 addition & 2 deletions tests/test_gkd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@
import torch.nn.functional as F
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from transformers.testing_utils import require_liger_kernel

from trl import GKDConfig, GKDTrainer
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE

from .testing_utils import TrlTestCase
from .testing_utils import TrlTestCase, require_liger_kernel


class TestGKDTrainerGenerateOnPolicy(TrlTestCase):
Expand Down
3 changes: 1 addition & 2 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
AutoModelForSequenceClassification,
AutoTokenizer,
)
from transformers.testing_utils import require_liger_kernel
from transformers.utils import is_peft_available

from trl import GRPOConfig, GRPOTrainer
Expand All @@ -35,7 +34,7 @@
)
from trl.experimental.gspo_token import GRPOTrainer as GSPOTokenTrainer

from .testing_utils import TrlTestCase, require_peft, require_vision, require_vllm
from .testing_utils import TrlTestCase, require_liger_kernel, require_peft, require_vision, require_vllm


if is_peft_available():
Expand Down
3 changes: 1 addition & 2 deletions tests/test_kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
from transformers.testing_utils import require_liger_kernel

from trl import KTOConfig, KTOTrainer
from trl.trainer.kto_trainer import _get_kl_dataset, _process_tokens, _tokenize

from .testing_utils import TrlTestCase, require_no_wandb, require_peft
from .testing_utils import TrlTestCase, require_liger_kernel, require_no_wandb, require_peft


class TestKTOTrainer(TrlTestCase):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from packaging.version import Version
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
from transformers.testing_utils import require_torch_accelerator
from transformers.utils import is_peft_available, is_vision_available

from trl import OnlineDPOConfig, OnlineDPOTrainer
Expand All @@ -28,6 +27,7 @@
TrlTestCase,
require_llm_blender,
require_peft,
require_torch_accelerator,
require_vision,
require_vllm,
)
Expand Down
11 changes: 9 additions & 2 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,20 @@
from packaging.version import parse as parse_version
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.testing_utils import require_flash_attn, require_liger_kernel
from transformers.utils import is_peft_available

from trl import SFTConfig, SFTTrainer
from trl.trainer.sft_trainer import DataCollatorForLanguageModeling, dft_loss

from .testing_utils import TrlTestCase, ignore_warnings, require_bitsandbytes, require_peft, require_vision
from .testing_utils import (
TrlTestCase,
ignore_warnings,
require_bitsandbytes,
require_flash_attn,
require_liger_kernel,
require_peft,
require_vision,
)


if is_peft_available():
Expand Down
10 changes: 8 additions & 2 deletions tests/test_vllm_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,18 @@

import pytest
from transformers import AutoModelForCausalLM
from transformers.testing_utils import require_torch_multi_accelerator, torch_device
from transformers.testing_utils import torch_device

from trl.extras.vllm_client import VLLMClient
from trl.scripts.vllm_serve import chunk_list

from .testing_utils import TrlTestCase, kill_process, require_3_accelerators, require_vllm
from .testing_utils import (
TrlTestCase,
kill_process,
require_3_accelerators,
require_torch_multi_accelerator,
require_vllm,
)


class TestChunkList(TrlTestCase):
Expand Down
37 changes: 35 additions & 2 deletions tests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,20 @@
import pytest
import torch
from transformers import is_bitsandbytes_available, is_comet_available, is_sklearn_available, is_wandb_available
from transformers.testing_utils import torch_device
from transformers.utils import is_peft_available, is_rich_available, is_vision_available
from transformers.testing_utils import backend_device_count, torch_device
from transformers.utils import (
is_flash_attn_2_available,
is_kernels_available,
is_peft_available,
is_rich_available,
is_torch_available,
is_vision_available,
)

from trl import BaseBinaryJudge, BasePairwiseJudge
from trl.import_utils import (
is_joblib_available,
is_liger_kernel_available,
is_llm_blender_available,
is_math_verify_available,
is_mergekit_available,
Expand All @@ -37,6 +45,7 @@

require_bitsandbytes = pytest.mark.skipif(not is_bitsandbytes_available(), reason="test requires bitsandbytes")
require_comet = pytest.mark.skipif(not is_comet_available(), reason="test requires comet_ml")
require_liger_kernel = pytest.mark.skipif(not is_liger_kernel_available(), reason="test requires liger-kernel")
require_llm_blender = pytest.mark.skipif(not is_llm_blender_available(), reason="test requires llm-blender")
require_math_latex = pytest.mark.skipif(not is_math_verify_available(), reason="test requires math_verify")
require_mergekit = pytest.mark.skipif(not is_mergekit_available(), reason="test requires mergekit")
Expand All @@ -45,8 +54,15 @@
require_sklearn = pytest.mark.skipif(
not (is_sklearn_available() and is_joblib_available()), reason="test requires sklearn"
)
require_torch_accelerator = pytest.mark.skipif(
torch_device is None or torch_device == "cpu", reason="test requires accelerator"
)
require_torch_multi_accelerator = pytest.mark.skipif(
not is_torch_available() or backend_device_count(torch_device) <= 1, reason="test requires multiple accelerators"
)
require_vision = pytest.mark.skipif(not is_vision_available(), reason="test requires vision")
require_vllm = pytest.mark.skipif(not is_vllm_available(), reason="test requires vllm")
require_wandb = pytest.mark.skipif(not is_wandb_available(), reason="test requires wandb")
require_no_wandb = pytest.mark.skipif(is_wandb_available(), reason="test requires no wandb")
require_3_accelerators = pytest.mark.skipif(
not (getattr(torch, torch_device, torch.cuda).device_count() >= 3),
Expand All @@ -69,6 +85,23 @@ def is_bitsandbytes_multi_backend_available() -> bool:
)


def is_flash_attn_available():
flash_attn_available = is_flash_attn_2_available()
kernels_available = is_kernels_available()
try:
from kernels import get_kernel

get_kernel("kernels-community/flash-attn")
except Exception:
kernels_available = False

return kernels_available or flash_attn_available


# Function ported from transformers.testing_utils
require_flash_attn = pytest.mark.skipif(not is_flash_attn_available(), reason="test requires Flash Attention")

Comment on lines +88 to +103
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this one it 100% correct, because if kernels is available, but not flash-attn, and you try to use atto_implementation="flash_attention2", it will most likely fail.

Copy link
Member Author

@albertvillanova albertvillanova Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just ported it from transformers. 😅
https://github.com/huggingface/transformers/blob/12a50f294d50e3d0e124511f2b6f43625f73ffce/src/transformers/testing_utils.py#L575-L591

def require_flash_attn(test_case):
    flash_attn_available = is_flash_attn_2_available()
    kernels_available = is_kernels_available()
    try:
        from kernels import get_kernel

        get_kernel("kernels-community/flash-attn")
    except Exception as _:
        kernels_available = False

    return unittest.skipUnless(kernels_available | flash_attn_available, "test requires Flash Attention")(test_case)

Do you think I made an error in porting it? @qgallouedec

Copy link
Member Author

@albertvillanova albertvillanova Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic on transformers side was recently changed by this PR:

Modified require_flash_attn in testing_utils.py to allow tests to run if either FlashAttention2 or the community kernel is available, broadening test coverage and reliability.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I see.
A lot of our tests rely on flash-attn lib; it's probably a good time to drop flash-attn and rely only on kernels:

  • replace all attn_implementation="flash_attention_2" -> attn_implementation="kernels-community/flash-attn"
  • replace require_flash_attn by require_kernels

@albertvillanova we can do this in a future PR

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree! I created an issue for that:


class RandomBinaryJudge(BaseBinaryJudge):
"""
Random binary judge, for testing purposes.
Expand Down
Loading