Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,30 @@ jobs:
- name: Test with pytest
run: |
python -m pytest -rfExX -m ${{ matrix.test }} -n 2 --dist loadfile -sv ./tests/

test_py310_numpy2:
needs: check_code_quality
strategy:
matrix:
test: ['unit']
os: [ubuntu-latest, windows-latest]
deps_versions: [deps-latest]
continue-on-error: false
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python 3.10
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Upgrade pip
run: python -m pip install --upgrade pip
- name: Install uv
run: pip install --upgrade uv
- name: Install dependencies
run: uv pip install --system "datasets[tests_numpy2] @ ."
- name: Test with pytest
run: |
python -m pytest -rfExX -m ${{ matrix.test }} -n 2 --dist loadfile -sv ./tests/
19 changes: 15 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@
# For file locking
"filelock",
# We use numpy>=1.17 to have np.random.Generator (Dataset shuffling)
"numpy>=1.17,<2.0.0", # Temporary upper version
"numpy>=1.17",
# Backend and serialization.
# Minimum 15.0.0 to be able to cast dictionary types to their underlying types
"pyarrow>=15.0.0",
Expand Down Expand Up @@ -166,7 +166,7 @@
"pytest-xdist",
# optional dependencies
"elasticsearch<8.0.0", # 8.0 asks users to provide hosts or cloud_id when instantiating ElasticSearch()
"faiss-cpu>=1.6.4",
"faiss-cpu>=1.8.0.post1", # Pins numpy < 2
"jax>=0.3.14; sys_platform != 'win32'",
"jaxlib>=0.3.14; sys_platform != 'win32'",
"lz4",
Expand All @@ -176,11 +176,11 @@
"sqlalchemy",
"s3fs>=2021.11.1", # aligned with fsspec[http]>=2021.11.1; test only on python 3.7 for now
"protobuf<4.0.0", # 4.0.0 breaks compatibility with tensorflow<2.12
"tensorflow>=2.6.0",
"tensorflow>=2.6.0", # Issue installing 2.16.0 with Python 3.8; we rely on other dependencies pinning numpy < 2
"tiktoken",
"torch>=2.0.0",
"soundfile>=0.12.1",
"transformers",
"transformers>=4.42.0", # Pins numpy < 2
"zstandard",
"polars[timezone]>=0.20.0",
]
Expand All @@ -189,6 +189,16 @@
TESTS_REQUIRE.extend(VISION_REQUIRE)
TESTS_REQUIRE.extend(AUDIO_REQUIRE)

NUMPY2_INCOMPATIBLE_LIBRARIES = [
"faiss-cpu",
"librosa",
"tensorflow",
"transformers",
]
TESTS_NUMPY2_REQUIRE = [
library for library in TESTS_REQUIRE if library.partition(">")[0] not in NUMPY2_INCOMPATIBLE_LIBRARIES
]

QUALITY_REQUIRE = ["ruff>=0.3.0"]

DOCS_REQUIRE = [
Expand All @@ -213,6 +223,7 @@
"streaming": [], # for backward compatibility
"dev": TESTS_REQUIRE + QUALITY_REQUIRE + DOCS_REQUIRE,
"tests": TESTS_REQUIRE,
"tests_numpy2": TESTS_NUMPY2_REQUIRE,
"quality": QUALITY_REQUIRE,
"benchmarks": BENCHMARKS_REQUIRE,
"docs": DOCS_REQUIRE,
Expand Down
3 changes: 2 additions & 1 deletion tests/features/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from datasets.info import DatasetInfo
from datasets.utils.py_utils import asdict

from ..utils import require_jax, require_tf, require_torch
from ..utils import require_jax, require_numpy1_on_windows, require_tf, require_torch


class FeaturesTest(TestCase):
Expand Down Expand Up @@ -543,6 +543,7 @@ def test_cast_to_python_objects_pandas_timedelta(self):
casted_obj = cast_to_python_objects(pd.DataFrame({"a": [obj]}))
self.assertDictEqual(casted_obj, {"a": [expected_obj]})

@require_numpy1_on_windows
@require_torch
def test_cast_to_python_objects_torch(self):
import torch
Expand Down
3 changes: 2 additions & 1 deletion tests/packaged_modules/test_webdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from datasets import Audio, DownloadManager, Features, Image, Sequence, Value
from datasets.packaged_modules.webdataset.webdataset import WebDataset

from ..utils import require_librosa, require_pil, require_sndfile, require_torch
from ..utils import require_librosa, require_numpy1_on_windows, require_pil, require_sndfile, require_torch


@pytest.fixture
Expand Down Expand Up @@ -226,6 +226,7 @@ def test_webdataset_with_features(image_wds_file):
assert isinstance(decoded["jpg"], PIL.Image.Image)


@require_numpy1_on_windows
@require_torch
def test_tensor_webdataset(tensor_wds_file):
import torch
Expand Down
4 changes: 4 additions & 0 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
require_dill_gt_0_3_2,
require_jax,
require_not_windows,
require_numpy1_on_windows,
require_pil,
require_polars,
require_pyspark,
Expand Down Expand Up @@ -420,6 +421,7 @@ def test_set_format_numpy_multiple_columns(self, in_memory):
self.assertIsInstance(dset[0]["col_2"], np.str_)
self.assertEqual(dset[0]["col_2"].item(), "a")

@require_numpy1_on_windows
@require_torch
def test_set_format_torch(self, in_memory):
import torch
Expand Down Expand Up @@ -1525,6 +1527,7 @@ def func_return_multi_row_pd_dataframe(x):
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
self.assertRaises(ValueError, dset.map, func_return_multi_row_pd_dataframe)

@require_numpy1_on_windows
@require_torch
def test_map_torch(self, in_memory):
import torch
Expand Down Expand Up @@ -1590,6 +1593,7 @@ def func(example):
)
self.assertListEqual(dset_test[0]["tensor"], [1, 2, 3])

@require_numpy1_on_windows
@require_torch
def test_map_tensor_batched(self, in_memory):
import torch
Expand Down
2 changes: 2 additions & 0 deletions tests/test_dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .utils import (
assert_arrow_memory_doesnt_increase,
assert_arrow_memory_increases,
require_numpy1_on_windows,
require_polars,
require_tf,
require_torch,
Expand Down Expand Up @@ -109,6 +110,7 @@ def test_set_format_numpy(self):
self.assertEqual(dset_split[0]["col_2"].item(), "a")
del dset

@require_numpy1_on_windows
@require_torch
def test_set_format_torch(self):
import torch
Expand Down
3 changes: 3 additions & 0 deletions tests/test_fingerprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from .utils import (
require_not_windows,
require_numpy1_on_windows,
require_regex,
require_spacy,
require_tiktoken,
Expand Down Expand Up @@ -303,6 +304,7 @@ def test_hash_tiktoken_encoding(self):
self.assertEqual(hash1, hash3)
self.assertNotEqual(hash1, hash2)

@require_numpy1_on_windows
@require_torch
def test_hash_torch_tensor(self):
import torch
Expand All @@ -316,6 +318,7 @@ def test_hash_torch_tensor(self):
self.assertEqual(hash1, hash3)
self.assertNotEqual(hash1, hash2)

@require_numpy1_on_windows
@require_torch
def test_hash_torch_generator(self):
import torch
Expand Down
5 changes: 5 additions & 0 deletions tests/test_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .utils import (
require_jax,
require_librosa,
require_numpy1_on_windows,
require_pil,
require_polars,
require_sndfile,
Expand Down Expand Up @@ -353,6 +354,7 @@ def test_polars_formatter(self):
assert pl.Series.eq(batch["a"], pl.Series("a", _COL_A)).all()
assert pl.Series.eq(batch["b"], pl.Series("b", _COL_B)).all()

@require_numpy1_on_windows
@require_torch
def test_torch_formatter(self):
import torch
Expand All @@ -373,6 +375,7 @@ def test_torch_formatter(self):
torch.testing.assert_close(batch["c"], torch.tensor(_COL_C, dtype=torch.float32))
assert batch["c"].shape == np.array(_COL_C).shape

@require_numpy1_on_windows
@require_torch
def test_torch_formatter_torch_tensor_kwargs(self):
import torch
Expand All @@ -389,6 +392,7 @@ def test_torch_formatter_torch_tensor_kwargs(self):
self.assertEqual(batch["a"].dtype, torch.float16)
self.assertEqual(batch["c"].dtype, torch.float16)

@require_numpy1_on_windows
@require_torch
@require_pil
def test_torch_formatter_image(self):
Expand Down Expand Up @@ -975,6 +979,7 @@ def test_tf_formatter_sets_default_dtypes(cast_schema, arrow_table):
tf.debugging.assert_equal(batch["col_float"], tf.ragged.constant(list_float, dtype=tf.float32))


@require_numpy1_on_windows
@require_torch
@pytest.mark.parametrize(
"cast_schema",
Expand Down
2 changes: 2 additions & 0 deletions tests/test_iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
is_rng_equal,
require_dill_gt_0_3_2,
require_not_windows,
require_numpy1_on_windows,
require_pyspark,
require_tf,
require_torch,
Expand Down Expand Up @@ -1279,6 +1280,7 @@ def gen(shard_names):
assert dataset.n_shards == len(shard_names)


@require_numpy1_on_windows
def test_iterable_dataset_from_file(dataset: IterableDataset, arrow_file: str):
with assert_arrow_memory_doesnt_increase():
dataset_from_file = IterableDataset.from_file(arrow_file)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_py_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
zip_dict,
)

from .utils import require_tf, require_torch
from .utils import require_numpy1_on_windows, require_tf, require_torch


def np_sum(x): # picklable for multiprocessing
Expand Down Expand Up @@ -151,6 +151,7 @@ def gen_random_output():
np.testing.assert_equal(out1, out2)
self.assertGreater(np.abs(out1 - out3).sum(), 0)

@require_numpy1_on_windows
@require_torch
def test_torch(self):
import torch
Expand Down
5 changes: 5 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ def parse_flag_from_env(key, default=False):

require_faiss = pytest.mark.skipif(find_spec("faiss") is None or sys.platform == "win32", reason="test requires faiss")

require_numpy1_on_windows = pytest.mark.skipif(
version.parse(importlib.metadata.version("numpy")) >= version.parse("2.0.0") and sys.platform == "win32",
reason="test requires numpy < 2.0 on windows",
)


def require_regex(test_case):
"""
Expand Down