diff --git a/tests/test_builder.py b/tests/test_builder.py index 0bfb39b45ef..20237a2b81d 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -1068,6 +1068,7 @@ def test_arrow_based_builder_download_and_prepare_as_sharded_parquet_with_max_sh assert sum(parquet_file.metadata.num_rows for parquet_file in parquet_files) == 100 +@require_beam def test_beam_based_builder_download_and_prepare_as_parquet(tmp_path): builder = DummyBeamBasedBuilder(cache_dir=tmp_path, beam_runner="DirectRunner") builder.download_and_prepare(file_format="parquet") diff --git a/tests/utils.py b/tests/utils.py index c37be337757..8f7f4907dbf 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -75,17 +75,8 @@ def parse_flag_from_env(key, default=False): reason="test requires torchaudio>=0.12", ) - -def require_beam(test_case): - """ - Decorator marking a test that requires Apache Beam. - - These tests are skipped when Apache Beam isn't installed. - - """ - if not config.TORCH_AVAILABLE: - test_case = unittest.skip("test requires PyTorch")(test_case) - return test_case +# Beam +require_beam = pytest.mark.skipif(not config.BEAM_AVAILABLE, reason="test requires apache-beam") def require_faiss(test_case):