Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/conda/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ requirements:
- dataclasses
- multiprocess
- fsspec
- huggingface_hub >=0.24.0,<1.0.0
- huggingface_hub >=0.25.0,<2.0.0
- packaging
run:
- python
Expand All @@ -41,7 +41,7 @@ requirements:
- dataclasses
- multiprocess
- fsspec
- huggingface_hub >=0.24.0,<1.0.0
- huggingface_hub >=0.25.0,<2.0.0
- packaging

test:
Expand Down
16 changes: 8 additions & 8 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ jobs:
run: |
python -m pytest -rfExX -m ${{ matrix.test }} -n 2 --dist loadfile -sv ./tests/

test_py312:
test_py314:
needs: check_code_quality
strategy:
matrix:
Expand All @@ -100,18 +100,18 @@ jobs:
run: |
sudo apt update
sudo apt install -y ffmpeg
- name: Set up Python 3.12
- name: Set up Python 3.14
uses: actions/setup-python@v5
with:
python-version: "3.12"
python-version: "3.14"
- name: Setup conda env (windows)
if: ${{ matrix.os == 'windows-latest' }}
uses: conda-incubator/setup-miniconda@v2
with:
auto-update-conda: true
miniconda-version: "latest"
activate-environment: test
python-version: "3.12"
python-version: "3.14"
- name: Setup FFmpeg (windows)
if: ${{ matrix.os == 'windows-latest' }}
run: conda install "ffmpeg=7.0.1" -c conda-forge
Expand All @@ -127,7 +127,7 @@ jobs:
run: |
python -m pytest -rfExX -m ${{ matrix.test }} -n 2 --dist loadfile -sv ./tests/

test_py312_future:
test_py314_future:
needs: check_code_quality
strategy:
matrix:
Expand All @@ -145,18 +145,18 @@ jobs:
run: |
sudo apt update
sudo apt install -y ffmpeg
- name: Set up Python 3.12
- name: Set up Python 3.14
uses: actions/setup-python@v5
with:
python-version: "3.12"
python-version: "3.14"
- name: Setup conda env (windows)
if: ${{ matrix.os == 'windows-latest' }}
uses: conda-incubator/setup-miniconda@v2
with:
auto-update-conda: true
miniconda-version: "latest"
activate-environment: test
python-version: "3.12"
python-version: "3.14"
- name: Setup FFmpeg (windows)
if: ${{ matrix.os == 'windows-latest' }}
run: conda install "ffmpeg=7.0.1" -c conda-forge
Expand Down
15 changes: 9 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@
# for fast hashing
"xxhash",
# for better multiprocessing
"multiprocess<0.70.17", # to align with dill<0.3.9 (see above)
"multiprocess<0.70.19", # to align with dill<0.3.9 (see above)
# to save datasets locally or on any filesystem
# minimum 2023.1.0 to support protocol=kwargs in fsspec's `open`, `get_fs_token_paths`, etc.: see https://github.com/fsspec/filesystem_spec/pull/1143
"fsspec[http]>=2023.1.0,<=2025.9.0",
Expand Down Expand Up @@ -153,12 +153,12 @@

TESTS_REQUIRE = [
# fix pip install issues for windows
"numba>=0.56.4", # to get recent versions of llvmlite for windows ci
"numba>=0.56.4; python_version < '3.14'", # to get recent versions of llvmlite for windows ci, not available on 3.14
# test dependencies
"absl-py",
"decorator",
"joblib<1.3.0", # joblibspark doesn't support recent joblib versions
"joblibspark",
"joblibspark; python_version < '3.14'", # python 3.14 gives AttributeError: module 'ast' has no attribute 'Num'
"pytest",
"pytest-datadir",
"pytest-xdist",
Expand All @@ -169,23 +169,23 @@
"h5py",
"jax>=0.3.14; sys_platform != 'win32'",
"jaxlib>=0.3.14; sys_platform != 'win32'",
"lz4",
"lz4; python_version < '3.14'", # python 3.14 gives ImportError: cannot import name '_compression' from partially initialized module 'lz4.frame
"moto[server]",
"pyspark>=3.4", # https://issues.apache.org/jira/browse/SPARK-40991 fixed in 3.4.0
"py7zr",
"rarfile>=4.0",
"sqlalchemy",
"protobuf<4.0.0", # 4.0.0 breaks compatibility with tensorflow<2.12
"tensorflow>=2.6.0; python_version<'3.10' and sys_platform != 'win32'", # numpy-2 is not supported for Python < 3.10
"tensorflow>=2.16.0; python_version>='3.10' and sys_platform != 'win32'", # Pins numpy < 2
"tensorflow>=2.16.0; python_version>='3.10' and sys_platform != 'win32' and python_version < '3.14'", # Pins numpy < 2
"tiktoken",
"torch>=2.8.0",
"torchdata",
"transformers>=4.42.0", # Pins numpy < 2
"zstandard",
"polars[timezone]>=0.20.0",
"Pillow>=9.4.0", # When PIL.Image.ExifTags was introduced
"torchcodec>=0.7.0", # minium version to get windows support
"torchcodec>=0.7.0; python_version < '3.14'", # minium version to get windows support, torchcodec doesn't have wheels for 3.14 yet
"nibabel>=5.3.1",
]

Expand Down Expand Up @@ -262,6 +262,9 @@
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
keywords="datasets machine learning datasets",
Expand Down
3 changes: 3 additions & 0 deletions tests/features/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,7 @@ def test_dataset_with_audio_feature_loaded_from_cache():
assert isinstance(ds, Dataset)


@require_torchcodec
def test_dataset_with_audio_feature_undecoded(shared_datadir):
audio_path = str(shared_datadir / "test_audio_44100.wav")
data = {"audio": [audio_path]}
Expand All @@ -730,6 +731,7 @@ def test_dataset_with_audio_feature_undecoded(shared_datadir):
assert column[0] == {"path": audio_path, "bytes": None}


@require_torchcodec
def test_formatted_dataset_with_audio_feature_undecoded(shared_datadir):
audio_path = str(shared_datadir / "test_audio_44100.wav")
data = {"audio": [audio_path]}
Expand Down Expand Up @@ -761,6 +763,7 @@ def test_formatted_dataset_with_audio_feature_undecoded(shared_datadir):
assert column[0] == {"path": audio_path, "bytes": None}


@require_torchcodec
def test_dataset_with_audio_feature_map_undecoded(shared_datadir):
audio_path = str(shared_datadir / "test_audio_44100.wav")
data = {"audio": [audio_path]}
Expand Down
3 changes: 1 addition & 2 deletions tests/test_extract.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import zipfile

import pytest

Expand Down Expand Up @@ -199,5 +198,5 @@ def test_is_zipfile_false_positive(tmpdir):
)
with not_a_zip_file.open("wb") as f:
f.write(data)
assert zipfile.is_zipfile(str(not_a_zip_file)) # is a false positive for `zipfile`
# zipfile.is_zipfile(str(not_a_zip_file)) could be a false positive for `zipfile`
assert not ZipExtractor.is_extractable(not_a_zip_file) # but we're right
5 changes: 3 additions & 2 deletions tests/test_fingerprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
require_spacy,
require_tiktoken,
require_torch,
require_torch_compile,
require_transformers,
)

Expand Down Expand Up @@ -347,7 +348,7 @@ def test_hash_spacy_model(self):
self.assertNotEqual(hash1, hash2)

@require_not_windows
@require_torch
@require_torch_compile
def test_hash_torch_compiled_function(self):
import torch

Expand All @@ -360,7 +361,7 @@ def f(x):
self.assertEqual(hash1, hash2)

@require_not_windows
@require_torch
@require_torch_compile
def test_hash_torch_compiled_module(self):
m = TorchModule()
next(iter(m.parameters())).data.fill_(1.0)
Expand Down
39 changes: 22 additions & 17 deletions tests/test_iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1553,18 +1553,21 @@ def test_iterable_dataset_from_hub_torch_dataloader_parallel(num_workers, tmp_pa
assert len(result) == 10


def gen_with_worker_info(shard):
from torch.utils.data import get_worker_info

worker_info = get_worker_info()
for i in range(100):
yield {"value": i, "worker_id": worker_info.id}


@require_torch
def test_iterable_dataset_shuffle_with_multiple_workers_different_rng():
# GH 7567
from torch.utils.data import DataLoader, get_worker_info

def gen(shard):
worker_info = get_worker_info()
for i in range(100):
yield {"value": i, "worker_id": worker_info.id}
from torch.utils.data import DataLoader

num_workers = 20
ds = IterableDataset.from_generator(gen, gen_kwargs={"shard": list(range(num_workers))})
ds = IterableDataset.from_generator(gen_with_worker_info, gen_kwargs={"shard": list(range(num_workers))})
ds = ds.shuffle(buffer_size=100, seed=1234)
dataloader = DataLoader(ds, batch_size=None, num_workers=num_workers)

Expand All @@ -1575,18 +1578,19 @@ def gen(shard):
assert len(set(values)) != 1, "Make sure not all values are identical"


def gen_with_value(shard, value):
for i in range(100):
yield {"value": value}


@require_torch
def test_iterable_dataset_interleave_dataset_with_multiple_workers():
# GH 7567
from torch.utils.data import DataLoader

def gen(shard, value):
for i in range(100):
yield {"value": value}

num_workers = 20
ds = [
IterableDataset.from_generator(gen, gen_kwargs={"shard": list(range(num_workers)), "value": i})
IterableDataset.from_generator(gen_with_value, gen_kwargs={"shard": list(range(num_workers)), "value": i})
for i in range(10)
]
ds = interleave_datasets(ds, probabilities=[1 / len(ds)] * len(ds), seed=1234)
Expand All @@ -1598,18 +1602,19 @@ def gen(shard, value):
assert len(set(values)) != 1, "Make sure not all values are identical"


def gen_with_id(shard, value):
for i in range(50):
yield {"value": value, "id": i}


@require_torch
def test_iterable_dataset_interleave_dataset_deterministic_across_iterations():
# GH 7567
from torch.utils.data import DataLoader

def gen(shard, value):
for i in range(50):
yield {"value": value, "id": i}

num_workers = 10
ds = [
IterableDataset.from_generator(gen, gen_kwargs={"shard": list(range(num_workers)), "value": i})
IterableDataset.from_generator(gen_with_id, gen_kwargs={"shard": list(range(num_workers)), "value": i})
for i in range(5)
]
ds = interleave_datasets(ds, probabilities=[1 / len(ds)] * len(ds), seed=1234)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_py_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import pickle
import time
from dataclasses import dataclass
from multiprocessing import Pool
Expand Down Expand Up @@ -81,7 +82,7 @@ def test_map_nested(self):
{k: v.tolist() for k, v in map_nested(int, sn1, map_numpy=True, num_proc=num_proc).items()},
{k: v.tolist() for k, v in expected_map_nested_sn1_int.items()},
)
with self.assertRaises(AttributeError): # can't pickle a local lambda
with self.assertRaises((AttributeError, pickle.PicklingError)): # can't pickle a local lambda
map_nested(lambda x: x + 1, sn1, num_proc=num_proc)

def test_zip_dict(self):
Expand Down
25 changes: 19 additions & 6 deletions tests/test_streaming_download_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
from pathlib import Path

import pytest

Expand All @@ -26,10 +27,16 @@
Bulbasaur, grass"""


@pytest.mark.parametrize("urlpath", [r"C:\\foo\bar.txt", "/foo/bar.txt", "https://f.oo/bar.txt"])
def test_streaming_dl_manager_download_dummy_path(urlpath):
def test_streaming_dl_manager_download_dummy_path():
path = str(Path(__file__).resolve())
dl_manager = StreamingDownloadManager()
assert dl_manager.download(urlpath) == urlpath
assert dl_manager.download(path) == path


def test_streaming_dl_manager_download_dummy_url():
url = "https://f.oo/bar.txt"
dl_manager = StreamingDownloadManager()
assert dl_manager.download(url) == url


@pytest.mark.parametrize(
Expand All @@ -54,10 +61,16 @@ def test_streaming_dl_manager_download(text_path):
assert f.read() == expected_file.read()


@pytest.mark.parametrize("urlpath", [r"C:\\foo\bar.txt", "/foo/bar.txt", "https://f.oo/bar.txt"])
def test_streaming_dl_manager_download_and_extract_no_extraction(urlpath):
def test_streaming_dl_manager_download_and_extract_no_extraction_dummy_path():
path = str(Path(__file__).resolve())
dl_manager = StreamingDownloadManager()
assert dl_manager.download_and_extract(path) == path


def test_streaming_dl_manager_download_and_extract_no_extraction_dummy_url():
url = "https://f.oo/bar.txt"
dl_manager = StreamingDownloadManager()
assert dl_manager.download_and_extract(urlpath) == urlpath
assert dl_manager.download_and_extract(url) == url


def test_streaming_dl_manager_extract(text_gz_path, text_path):
Expand Down
14 changes: 14 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,20 @@ def require_torch(test_case):
return test_case


def require_torch_compile(test_case):
"""
Decorator marking a test that requires PyTorch.

These tests are skipped when PyTorch isn't installed.

"""
if not config.TORCH_AVAILABLE:
test_case = unittest.skip("test requires PyTorch")(test_case)
if config.PY_VERSION >= version.parse("3.14"):
test_case = unittest.skip("test requires torch compile which isn't available in python 3.14")(test_case)
return test_case


def require_polars(test_case):
"""
Decorator marking a test that requires Polars.
Expand Down
Loading