Skip to content

Commit 01f5ac7

Browse files
authored
flash attn pytest marker (#41781)
* flash attn marker * 111 --------- Co-authored-by: ydshieh <[email protected]>
1 parent 2c5b888 commit 01f5ac7

File tree

14 files changed

+29
-0
lines changed

14 files changed

+29
-0
lines changed

conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ def pytest_configure(config):
8787
config.addinivalue_line("markers", "not_device_test: mark the tests always running on cpu")
8888
config.addinivalue_line("markers", "torch_compile_test: mark test which tests torch compile functionality")
8989
config.addinivalue_line("markers", "torch_export_test: mark test which tests torch export functionality")
90+
config.addinivalue_line("markers", "flash_attn_test: mark test which tests flash attention functionality")
91+
config.addinivalue_line("markers", "flash_attn_3_test: mark test which tests flash attention 3 functionality")
9092

9193
os.environ["DISABLE_SAFETENSORS_CONVERSION"] = "true"
9294

tests/models/exaone4/test_modeling_exaone4.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def test_model_generation_sdpa(self):
120120
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
121121
self.assertEqual(EXPECTED_TEXT, text)
122122

123+
@pytest.mark.flash_attn_test
123124
@slow
124125
@require_torch_accelerator
125126
@require_flash_attn

tests/models/idefics2/test_modeling_idefics2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,7 @@ def test_integration_test_4bit_batch2(self):
643643
self.assertEqual(batched_generated_texts[0], generated_text_0[0])
644644
self.assertEqual(batched_generated_texts[1], generated_text_1[0])
645645

646+
@pytest.mark.flash_attn_test
646647
@require_flash_attn
647648
@require_torch_gpu
648649
@require_bitsandbytes

tests/models/ministral/test_modeling_ministral.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ def test_export_text_with_hybrid_cache(self):
208208

209209
self.assertEqual(export_generated_text, eager_generated_text)
210210

211+
@pytest.mark.flash_attn_test
211212
@require_flash_attn
212213
@slow
213214
def test_past_sliding_window_generation(self):

tests/models/mistral/test_modeling_mistral.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ def test_compile_static_cache(self):
300300
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
301301
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text)
302302

303+
@pytest.mark.flash_attn_test
303304
@parameterized.expand([("flash_attention_2",), ("sdpa",), ("flex_attention",), ("eager",)])
304305
@require_flash_attn
305306
@slow

tests/models/qwen2/test_modeling_qwen2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ def test_export_static_cache(self):
274274
ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True)
275275
self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text)
276276

277+
@pytest.mark.flash_attn_test
277278
@require_flash_attn
278279
@slow
279280
def test_3b_generation(self):

tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,7 @@ def test_small_model_integration_test_w_audio(self):
816816
@slow
817817
@require_flash_attn
818818
@require_torch_gpu
819+
@pytest.mark.flash_attn_test
819820
def test_small_model_integration_test_batch_flashatt2(self):
820821
model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
821822
"Qwen/Qwen2.5-Omni-7B",

tests/models/qwen2_5_vl/test_modeling_qwen2_5_vl.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import tempfile
1818
import unittest
1919

20+
import pytest
2021
import requests
2122

2223
from transformers import (
@@ -630,6 +631,7 @@ def test_small_model_integration_test_batch_different_resolutions(self):
630631
@slow
631632
@require_flash_attn
632633
@require_torch_gpu
634+
@pytest.mark.flash_attn_test
633635
def test_small_model_integration_test_batch_flashatt2(self):
634636
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
635637
"Qwen/Qwen2.5-VL-7B-Instruct",
@@ -658,6 +660,7 @@ def test_small_model_integration_test_batch_flashatt2(self):
658660
@slow
659661
@require_flash_attn
660662
@require_torch_gpu
663+
@pytest.mark.flash_attn_test
661664
def test_small_model_integration_test_batch_wo_image_flashatt2(self):
662665
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
663666
"Qwen/Qwen2.5-VL-7B-Instruct",

tests/models/qwen2_vl/test_modeling_qwen2_vl.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import tempfile
1919
import unittest
2020

21+
import pytest
2122
import requests
2223

2324
from transformers import (
@@ -562,6 +563,7 @@ def test_small_model_integration_test_batch_different_resolutions(self):
562563
@slow
563564
@require_flash_attn
564565
@require_torch_gpu
566+
@pytest.mark.flash_attn_test
565567
def test_small_model_integration_test_batch_flashatt2(self):
566568
model = Qwen2VLForConditionalGeneration.from_pretrained(
567569
"Qwen/Qwen2-VL-7B-Instruct",
@@ -589,6 +591,7 @@ def test_small_model_integration_test_batch_flashatt2(self):
589591
@slow
590592
@require_flash_attn
591593
@require_torch_gpu
594+
@pytest.mark.flash_attn_test
592595
def test_small_model_integration_test_batch_wo_image_flashatt2(self):
593596
model = Qwen2VLForConditionalGeneration.from_pretrained(
594597
"Qwen/Qwen2-VL-7B-Instruct",

tests/models/qwen3/test_modeling_qwen3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ def test_export_static_cache(self):
266266

267267
@require_flash_attn
268268
@slow
269+
@pytest.mark.flash_attn_test
269270
def test_600m_generation(self):
270271
model_id = "Qwen/Qwen3-0.6B-Base"
271272
tokenizer = AutoTokenizer.from_pretrained(model_id)

0 commit comments

Comments
 (0)