diff --git a/.github/conda/meta.yaml b/.github/conda/meta.yaml index 4b4a416d220..7263b10035a 100644 --- a/.github/conda/meta.yaml +++ b/.github/conda/meta.yaml @@ -20,11 +20,12 @@ requirements: - dill - pandas - requests >=2.19.0 + - httpx <1.0.0 - tqdm >=4.66.3 - dataclasses - multiprocess - fsspec - - huggingface_hub >=0.24.0,<1.0.0 + - huggingface_hub >=0.25.0,<2.0.0 - packaging run: - python @@ -35,11 +36,12 @@ requirements: - dill - pandas - requests >=2.19.0 + - httpx <1.0.0 - tqdm >=4.66.3 - dataclasses - multiprocess - fsspec - - huggingface_hub >=0.24.0,<1.0.0 + - huggingface_hub >=0.25.0,<2.0.0 - packaging test: diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 128266b5e48..e40bc458d6f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -82,7 +82,7 @@ jobs: run: | python -m pytest -rfExX -m ${{ matrix.test }} -n 2 --dist loadfile -sv ./tests/ - test_py312: + test_py314: needs: check_code_quality strategy: matrix: @@ -100,10 +100,10 @@ jobs: run: | sudo apt update sudo apt install -y ffmpeg - - name: Set up Python 3.12 + - name: Set up Python 3.14 uses: actions/setup-python@v5 with: - python-version: "3.12" + python-version: "3.14" - name: Setup conda env (windows) if: ${{ matrix.os == 'windows-latest' }} uses: conda-incubator/setup-miniconda@v2 @@ -111,7 +111,7 @@ jobs: auto-update-conda: true miniconda-version: "latest" activate-environment: test - python-version: "3.12" + python-version: "3.14" - name: Setup FFmpeg (windows) if: ${{ matrix.os == 'windows-latest' }} run: conda install "ffmpeg=7.0.1" -c conda-forge @@ -127,7 +127,7 @@ jobs: run: | python -m pytest -rfExX -m ${{ matrix.test }} -n 2 --dist loadfile -sv ./tests/ - test_py312_future: + test_py314_future: needs: check_code_quality strategy: matrix: @@ -145,10 +145,10 @@ jobs: run: | sudo apt update sudo apt install -y ffmpeg - - name: Set up Python 3.12 + - name: Set up Python 3.14 uses: actions/setup-python@v5 with: - python-version: "3.12" + python-version: "3.14" - name: Setup conda env (windows) if: ${{ matrix.os == 'windows-latest' }} uses: conda-incubator/setup-miniconda@v2 @@ -156,7 +156,7 @@ jobs: auto-update-conda: true miniconda-version: "latest" activate-environment: test - python-version: "3.12" + python-version: "3.14" - name: Setup FFmpeg (windows) if: ${{ matrix.os == 'windows-latest' }} run: conda install "ffmpeg=7.0.1" -c conda-forge diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f1f022b6fd7..3ae44bd4efc 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -120,7 +120,7 @@ If you are a **dataset author**... you know what to do, it is your dataset after If you are a **user of a dataset**, the main source of information should be the dataset paper if it is available: we recommend pulling information from there into the relevant paragraphs of the template. We also eagerly welcome discussions on the [Considerations for Using the Data](https://github.com/huggingface/datasets/blob/main/templates/README_guide.md#considerations-for-using-the-data) based on existing scholarship or personal experience that would benefit the whole community. -Finally, if you want more information on the how and why of dataset cards, we strongly recommend reading the foundational works [Datasheets for Datasets](https://arxiv.org/abs/1803.09010) and [Data Statements for NLP](https://www.aclweb.org/anthology/Q18-1041/). +Finally, if you want more information on the how and why of dataset cards, we strongly recommend reading the foundational works [Datasheets for Datasets](https://huggingface.co/papers/1803.09010) and [Data Statements for NLP](https://www.aclweb.org/anthology/Q18-1041/). Thank you for your contribution! diff --git a/README.md b/README.md index d4162b9e761..0b70a39d098 100644 --- a/README.md +++ b/README.md @@ -136,7 +136,7 @@ If you're a dataset owner and wish to update any part of it (description, citati ## BibTeX -If you want to cite our 🤗 Datasets library, you can use our [paper](https://arxiv.org/abs/2109.02846): +If you want to cite our 🤗 Datasets library, you can use our [paper](https://huggingface.co/papers/2109.02846): ```bibtex @inproceedings{lhoest-etal-2021-datasets, diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 861925a7d99..cc6b7195fe2 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -88,6 +88,8 @@ title: Load document data - local: document_dataset title: Create a document dataset + - local: nifti_dataset + title: Create a medical imaging dataset title: "Vision" - sections: - local: nlp_load diff --git a/docs/source/dataset_card.mdx b/docs/source/dataset_card.mdx index f1067697fb2..5f8b998cc9e 100644 --- a/docs/source/dataset_card.mdx +++ b/docs/source/dataset_card.mdx @@ -1,7 +1,7 @@ # Create a dataset card Each dataset should have a dataset card to promote responsible usage and inform users of any potential biases within the dataset. -This idea was inspired by the Model Cards proposed by [Mitchell, 2018](https://arxiv.org/abs/1810.03993). +This idea was inspired by the Model Cards proposed by [Mitchell, 2018](https://huggingface.co/papers/1810.03993). Dataset cards help users understand a dataset's contents, the context for using the dataset, how it was created, and any other considerations a user should be aware of. Creating a dataset card is easy and can be done in just a few steps: @@ -24,4 +24,4 @@ Creating a dataset card is easy and can be done in just a few steps: YAML also allows you to customize the way your dataset is loaded by [defining splits and/or configurations](./repository_structure#define-your-splits-and-subsets-in-yaml) without the need to write any code. -Feel free to take a look at the [SNLI](https://huggingface.co/datasets/snli), [CNN/DailyMail](https://huggingface.co/datasets/cnn_dailymail), and [Allociné](https://huggingface.co/datasets/allocine) dataset cards as examples to help you get started. +Feel free to take a look at the [SNLI](https://huggingface.co/datasets/stanfordnlp/snli), [CNN/DailyMail](https://huggingface.co/datasets/abisee/cnn_dailymail), and [Allociné](https://huggingface.co/datasets/tblard/allocine) dataset cards as examples to help you get started. diff --git a/docs/source/document_dataset.mdx b/docs/source/document_dataset.mdx index 30cc1bd3121..bc2a8a229ef 100644 --- a/docs/source/document_dataset.mdx +++ b/docs/source/document_dataset.mdx @@ -1,13 +1,13 @@ # Create a document dataset -This guide will show you how to create a document dataset with `PdfFolder` and some metadata. This is a no-code solution for quickly creating a document dataset with several thousand pdfs. +This guide will show you how to create a document dataset with `PdfFolder` and some metadata. This is a no-code solution for quickly creating a document dataset with several thousand PDFs. > [!TIP] > You can control access to your dataset by requiring users to share their contact information first. Check out the [Gated datasets](https://huggingface.co/docs/hub/datasets-gated) guide for more information about how to enable this feature on the Hub. ## PdfFolder -The `PdfFolder` is a dataset builder designed to quickly load a document dataset with several thousand pdfs without requiring you to write any code. +The `PdfFolder` is a dataset builder designed to quickly load a document dataset with several thousand PDFs without requiring you to write any code. > [!TIP] > 💡 Take a look at the [Split pattern hierarchy](repository_structure#split-pattern-hierarchy) to learn more about how `PdfFolder` creates dataset splits based on your dataset repository structure. @@ -72,14 +72,14 @@ file_name,additional_feature or using `metadata.jsonl`: ```jsonl -{"file_name": "0001.pdf", "additional_feature": "This is a first value of a text feature you added to your pdfs"} -{"file_name": "0002.pdf", "additional_feature": "This is a second value of a text feature you added to your pdfs"} -{"file_name": "0003.pdf", "additional_feature": "This is a third value of a text feature you added to your pdfs"} +{"file_name": "0001.pdf", "additional_feature": "This is a first value of a text feature you added to your PDFs"} +{"file_name": "0002.pdf", "additional_feature": "This is a second value of a text feature you added to your PDFs"} +{"file_name": "0003.pdf", "additional_feature": "This is a third value of a text feature you added to your PDFs"} ``` Here the `file_name` must be the name of the PDF file next to the metadata file. More generally, it must be the relative path from the directory containing the metadata to the PDF file. -It's possible to point to more than one pdf in each row in your dataset, for example if both your input and output are pdfs: +It's possible to point to more than one PDF in each row in your dataset, for example if both your input and output are pdfs: ```jsonl {"input_file_name": "0001.pdf", "output_file_name": "0001_output.pdf"} @@ -87,7 +87,7 @@ It's possible to point to more than one pdf in each row in your dataset, for exa {"input_file_name": "0003.pdf", "output_file_name": "0003_output.pdf"} ``` -You can also define lists of pdfs. In that case you need to name the field `file_names` or `*_file_names`. Here is an example: +You can also define lists of PDFs. In that case you need to name the field `file_names` or `*_file_names`. Here is an example: ```jsonl {"pdfs_file_names": ["0001_part1.pdf", "0001_part2.pdf"], "label": "urgent"} @@ -95,9 +95,9 @@ You can also define lists of pdfs. In that case you need to name the field `file {"pdfs_file_names": ["0003_part1.pdf", "0002_part2.pdf"], "label": "normal"} ``` -### OCR (Optical character recognition) +### OCR (Optical Character Recognition) -OCR datasets have the text contained in a pdf. An example `metadata.csv` may look like: +OCR datasets have the text contained in a PDF. An example `metadata.csv` may look like: ```csv file_name,text @@ -106,7 +106,7 @@ file_name,text 0003.pdf,Attention is all you need. Abstract. The ... ``` -Load the dataset with `PdfFolder`, and it will create a `text` column for the pdf captions: +Load the dataset with `PdfFolder`, and it will create a `text` column for the PDF captions: ```py >>> dataset = load_dataset("pdffolder", data_dir="/path/to/folder", split="train") diff --git a/docs/source/faiss_es.mdx b/docs/source/faiss_es.mdx index 9ee6565df94..635051744de 100644 --- a/docs/source/faiss_es.mdx +++ b/docs/source/faiss_es.mdx @@ -22,7 +22,7 @@ FAISS retrieves documents based on the similarity of their vector representation ```py >>> from datasets import load_dataset ->>> ds = load_dataset('crime_and_punish', split='train[:100]') +>>> ds = load_dataset('community-datasets/crime_and_punish', split='train[:100]') >>> ds_with_embeddings = ds.map(lambda example: {'embeddings': ctx_encoder(**ctx_tokenizer(example["line"], return_tensors="pt"))[0][0].numpy()}) ``` @@ -62,7 +62,7 @@ FAISS retrieves documents based on the similarity of their vector representation 7. Reload it at a later time with [`Dataset.load_faiss_index`]: ```py ->>> ds = load_dataset('crime_and_punish', split='train[:100]') +>>> ds = load_dataset('community-datasets/crime_and_punish', split='train[:100]') >>> ds.load_faiss_index('embeddings', 'my_index.faiss') ``` diff --git a/docs/source/image_load.mdx b/docs/source/image_load.mdx index 676b3f51653..67c7eff9684 100644 --- a/docs/source/image_load.mdx +++ b/docs/source/image_load.mdx @@ -10,7 +10,7 @@ When you load an image dataset and call the image column, the images are decoded ```py >>> from datasets import load_dataset, Image ->>> dataset = load_dataset("beans", split="train") +>>> dataset = load_dataset("AI-Lab-Makerere/beans", split="train") >>> dataset[0]["image"] ``` @@ -33,7 +33,7 @@ You can load a dataset from the image path. Use the [`~Dataset.cast_column`] fun If you only want to load the underlying path to the image dataset without decoding the image object, set `decode=False` in the [`Image`] feature: ```py ->>> dataset = load_dataset("beans", split="train").cast_column("image", Image(decode=False)) +>>> dataset = load_dataset("AI-Lab-Makerere/beans", split="train").cast_column("image", Image(decode=False)) >>> dataset[0]["image"] {'bytes': None, 'path': '/root/.cache/huggingface/datasets/downloads/extracted/b0a21163f78769a2cf11f58dfc767fb458fc7cea5c05dccc0144a2c0f0bc1292/train/bean_rust/bean_rust_train.29.jpg'} diff --git a/docs/source/loading.mdx b/docs/source/loading.mdx index eb73ab84b5a..74e3a8e383d 100644 --- a/docs/source/loading.mdx +++ b/docs/source/loading.mdx @@ -327,7 +327,7 @@ Select specific rows of the `train` split: ```py >>> train_10_20_ds = datasets.load_dataset("ajibawa-2023/General-Stories-Collection", split="train[10:20]") ===STRINGAPI-READINSTRUCTION-SPLIT=== ->>> train_10_20_ds = datasets.load_dataset("bookcorpu", split=datasets.ReadInstruction("train", from_=10, to=20, unit="abs")) +>>> train_10_20_ds = datasets.load_dataset("rojagtap/bookcorpus", split=datasets.ReadInstruction("train", from_=10, to=20, unit="abs")) ``` Or select a percentage of a split with: diff --git a/docs/source/nifti_dataset.mdx b/docs/source/nifti_dataset.mdx new file mode 100644 index 00000000000..2770460fbaf --- /dev/null +++ b/docs/source/nifti_dataset.mdx @@ -0,0 +1,130 @@ +# Create a NIfTI dataset + +This page shows how to create and share a dataset of medical images in NIfTI format (.nii / .nii.gz) using the `datasets` library. + +You can share a dataset with your team or with anyone in the community by creating a dataset repository on the Hugging Face Hub: + +```py +from datasets import load_dataset + +dataset = load_dataset("/my_nifti_dataset") +``` + +There are two common ways to create a NIfTI dataset: + +- Create a dataset from local NIfTI files in Python and upload it with `Dataset.push_to_hub`. +- Use a folder-based convention (one file per example) and a small helper to convert it into a `Dataset`. + +> [!TIP] +> You can control access to your dataset by requiring users to share their contact information first. Check out the [Gated datasets](https://huggingface.co/docs/hub/datasets-gated) guide for more information. + +## Local files + +If you already have a list of file paths to NIfTI files, the easiest workflow is to create a `Dataset` from that list and cast the column to the `Nifti` feature. + +```py +from datasets import Dataset +from datasets import Nifti + +# simple example: create a dataset from file paths +files = ["/path/to/scan_001.nii.gz", "/path/to/scan_002.nii.gz"] +ds = Dataset.from_dict({"nifti": files}).cast_column("nifti", Nifti()) + +# access a decoded nibabel image (if decode=True) +# ds[0]["nifti"] will be a nibabel.Nifti1Image object when decode=True +# or a dict {'bytes': None, 'path': '...'} when decode=False +``` + +The `Nifti` feature supports a `decode` parameter. When `decode=True` (the default), it loads the NIfTI file into a `nibabel.nifti1.Nifti1Image` object. You can access the image data as a numpy array with `img.get_fdata()`. When `decode=False`, it returns a dict with the file path and bytes. + +```py +from datasets import Dataset, Nifti + +ds = Dataset.from_dict({"nifti": ["/path/to/scan.nii.gz"]}).cast_column("nifti", Nifti(decode=True)) +img = ds[0]["nifti"] # instance of: nibabel.nifti1.Nifti1Image +arr = img.get_fdata() +``` + +After preparing the dataset you can push it to the Hub: + +```py +ds.push_to_hub("/my_nifti_dataset") +``` + +This will create a dataset repository containing your NIfTI dataset with a `data/` folder of parquet shards. + +## Folder conventions and metadata + +If you organize your dataset in folders you can create splits automatically (train/test/validation) by following a structure like: + +``` +dataset/train/scan_0001.nii +dataset/train/scan_0002.nii +dataset/validation/scan_1001.nii +dataset/test/scan_2001.nii +``` + +If you have labels or other metadata, provide a `metadata.csv`, `metadata.jsonl`, or `metadata.parquet` in the folder so files can be linked to metadata rows. The metadata must contain a `file_name` (or `*_file_name`) field with the relative path to the NIfTI file next to the metadata file. + +Example `metadata.csv`: + +```csv +file_name,patient_id,age,diagnosis +scan_0001.nii.gz,P001,45,healthy +scan_0002.nii.gz,P002,59,disease_x +``` + +The `Nifti` feature works with zipped datasets too — each zip can contain NIfTI files and a metadata file. This is useful when uploading large datasets as archives. +This means your dataset structure could look like this (mixed compressed and uncompressed files): +``` +dataset/train/scan_0001.nii.gz +dataset/train/scan_0002.nii +dataset/validation/scan_1001.nii.gz +dataset/test/scan_2001.nii +``` + +## Converting to PyTorch tensors + +Use the [`~Dataset.set_transform`] function to apply the transformation on-the-fly to batches of the dataset: + +```py +import torch +import nibabel +import numpy as np + +def transform_to_pytorch(example): + example["nifti_torch"] = [torch.tensor(ex.get_fdata()) for ex in example["nifti"]] + return example + +ds.set_transform(transform_to_pytorch) + +``` +Accessing elements now (e.g. `ds[0]`) will yield torch tensors in the `"nifti_torch"` key. + + +## Usage of NifTI1Image + +NifTI is a format to store the result of 3 (or even 4) dimensional brain scans. This includes 3 spatial dimensions (x,y,z) +and optionally a time dimension (t). Furthermore, the given positions here are only relative to the scanner, therefore +the dimensions (4, 5, 6) are used to lift this to real world coordinates. + +You can visualize nifti files for instance leveraging `matplotlib` as follows: +```python +import matplotlib.pyplot as plt +from datasets import load_dataset + +def show_slices(slices): + """ Function to display row of image slices """ + fig, axes = plt.subplots(1, len(slices)) + for i, slice in enumerate(slices): + axes[i].imshow(slice.T, cmap="gray", origin="lower") + +nifti_ds = load_dataset("/my_nifti_dataset") +for epi_img in nifti_ds: + nifti_img = epi_img["nifti"].get_fdata() + show_slices([nifti_img[:, :, 16], nifti_img[26, :, :], nifti_img[:, 30, :]]) + plt.show() +``` + +For further reading we refer to the [nibabel documentation](https://nipy.org/nibabel/index.html) and especially [this nibabel tutorial](https://nipy.org/nibabel/coordinate_systems.html) +--- diff --git a/docs/source/object_detection.mdx b/docs/source/object_detection.mdx index f612de28fdc..f1360e3fa95 100644 --- a/docs/source/object_detection.mdx +++ b/docs/source/object_detection.mdx @@ -8,14 +8,14 @@ To run these examples, make sure you have up-to-date versions of [albumentations pip install -U albumentations opencv-python ``` -In this example, you'll use the [`cppe-5`](https://huggingface.co/datasets/cppe-5) dataset for identifying medical personal protective equipment (PPE) in the context of the COVID-19 pandemic. +In this example, you'll use the [`cppe-5`](https://huggingface.co/datasets/rishitdagli/cppe-5) dataset for identifying medical personal protective equipment (PPE) in the context of the COVID-19 pandemic. Load the dataset and take a look at an example: ```py >>> from datasets import load_dataset ->>> ds = load_dataset("cppe-5") +>>> ds = load_dataset("rishitdagli/cppe-5") >>> example = ds['train'][0] >>> example {'height': 663, diff --git a/docs/source/package_reference/loading_methods.mdx b/docs/source/package_reference/loading_methods.mdx index 786679636e7..4792d1b88f7 100644 --- a/docs/source/package_reference/loading_methods.mdx +++ b/docs/source/package_reference/loading_methods.mdx @@ -103,6 +103,12 @@ load_dataset("csv", data_dir="path/to/data/dir", sep="\t") [[autodoc]] datasets.packaged_modules.pdffolder.PdfFolder +### Nifti + +[[autodoc]] datasets.packaged_modules.niftifolder.NiftiFolderConfig + +[[autodoc]] datasets.packaged_modules.niftifolder.NiftiFolder + ### WebDataset [[autodoc]] datasets.packaged_modules.webdataset.WebDataset diff --git a/docs/source/package_reference/main_classes.mdx b/docs/source/package_reference/main_classes.mdx index 299dd765d13..84e651f9171 100644 --- a/docs/source/package_reference/main_classes.mdx +++ b/docs/source/package_reference/main_classes.mdx @@ -271,6 +271,10 @@ Dictionary with split names as keys ('train', 'test' for example), and `Iterable [[autodoc]] datasets.Pdf +### Nifti + +[[autodoc]] datasets.Nifti + ## Filesystems [[autodoc]] datasets.filesystems.is_remote_filesystem diff --git a/docs/source/quickstart.mdx b/docs/source/quickstart.mdx index 6be8bee907c..a6f2dc25bef 100644 --- a/docs/source/quickstart.mdx +++ b/docs/source/quickstart.mdx @@ -288,7 +288,7 @@ pip install -U albumentations opencv-python ## NLP -Text needs to be tokenized into individual tokens by a [tokenizer](https://huggingface.co/docs/transformers/main_classes/tokenizer). For the quickstart, you'll load the [Microsoft Research Paraphrase Corpus (MRPC)](https://huggingface.co/datasets/glue/viewer/mrpc) training dataset to train a model to determine whether a pair of sentences mean the same thing. +Text needs to be tokenized into individual tokens by a [tokenizer](https://huggingface.co/docs/transformers/main_classes/tokenizer). For the quickstart, you'll load the [Microsoft Research Paraphrase Corpus (MRPC)](https://huggingface.co/datasets/nyu-mll/glue/viewer/mrpc) training dataset to train a model to determine whether a pair of sentences mean the same thing. **1**. Load the MRPC dataset by providing the [`load_dataset`] function with the dataset name, dataset configuration (not all datasets will have a configuration), and dataset split: diff --git a/docs/source/stream.mdx b/docs/source/stream.mdx index 67f1ff420cd..b721b0959c4 100644 --- a/docs/source/stream.mdx +++ b/docs/source/stream.mdx @@ -19,7 +19,8 @@ For example, the English split of the [HuggingFaceFW/fineweb](https://huggingfac >>> from datasets import load_dataset >>> dataset = load_dataset('HuggingFaceFW/fineweb', split='train', streaming=True) >>> print(next(iter(dataset))) -{'text': "How AP reported in all formats from tornado-stricken regionsMarch 8, 2012\nWhen the first serious bout of tornadoes of 2012 blew through middle America in the middle of the night, they touched down in places hours from any AP bureau... +{'text': 'How AP reported in all formats from tornado-stricken regionsMarch 8, 2012\nWhen the first serious bout of tornadoes of 2012 blew through middle America in the middle of the night, they touched down in places hours from any AP bureau...', ..., + 'language_score': 0.9721424579620361, 'token_count': 717} ``` Dataset streaming also lets you work with a dataset made of local files without doing any conversion. @@ -29,6 +30,7 @@ This is especially helpful when: - You don't want to wait for an extremely large local dataset to be converted to Arrow. - The converted files size would exceed the amount of available disk space on your computer. - You want to quickly explore just a few samples of a dataset. +- You want to load only certain columns or efficiently filter a Parquet dataset. For example, you can stream a local dataset of hundreds of compressed JSONL files like [oscar-corpus/OSCAR-2201](https://huggingface.co/datasets/oscar-corpus/OSCAR-2201) to use it instantly: @@ -40,6 +42,19 @@ For example, you can stream a local dataset of hundreds of compressed JSONL file {'id': 0, 'text': 'Founded in 2015, Golden Bees is a leading programmatic recruitment platform dedicated to employers, HR agencies and job boards. The company has developed unique HR-custom technologies and predictive algorithms to identify and attract the best candidates for a job opportunity.', ... ``` +Parquet is a columnar format that allows you to stream and load only a subset of columns and ignore unwanted columns. Parquet also stores metadata such as column statistics (at the file and row group level), enabling efficient filtering. Use the `columns` and `filters` arguments of [`datasets.packaged_modules.parquet.ParquetConfig`] to stream Parquet datasets, select columns, and apply filters: + +```py +>>> from datasets import load_dataset +>>> dataset = load_dataset('HuggingFaceFW/fineweb', split='train', streaming=True, columns=["url", "date"]) +>>> print(next(iter(dataset))) +{'url': 'http://%20jwashington@ap.org/Content/Press-Release/2012/How-AP-reported-in-all-formats-from-tornado-stricken-regions', 'date': '2013-05-18T05:48:54Z'} +>>> dataset = load_dataset('HuggingFaceFW/fineweb', split='train', streaming=True, filters=[("language_score", ">=", 0.99)]) +>>> print(next(iter(dataset))) +{'text': 'Everyone wishes for something. And lots of people believe they know how to make their wishes come true with magical thinking.\nWhat is it? "Magical thinking is a belief in forms of causation, with no known physical basis," said Professor Emily Pronin of Princeton...', ..., + 'language_score': 0.9900368452072144, 'token_count': 716} +``` + Loading a dataset in streaming mode creates a new dataset type instance (instead of the classic [`Dataset`] object), known as an [`IterableDataset`]. This special type of dataset has its own set of processing methods shown below. @@ -99,6 +114,7 @@ The `buffer_size` argument controls the size of the buffer to randomly sample ex ``` > [!TIP] +> > [`IterableDataset.shuffle`] will also shuffle the order of the shards if the dataset is sharded into multiple files. ## Reshuffle @@ -144,11 +160,11 @@ You can split your dataset one of two ways: 🤗 Datasets supports sharding to divide a very large dataset into a predefined number of chunks. Specify the `num_shards` parameter in [`~IterableDataset.shard`] to determine the number of shards to split the dataset into. You'll also need to provide the shard you want to return with the `index` parameter. -For example, the [amazon_polarity](https://huggingface.co/datasets/amazon_polarity) dataset has 4 shards (in this case they are 4 Parquet files): +For example, the [amazon_polarity](https://huggingface.co/datasets/fancyzhx/amazon_polarity) dataset has 4 shards (in this case they are 4 Parquet files): ```py >>> from datasets import load_dataset ->>> dataset = load_dataset("amazon_polarity", split="train", streaming=True) +>>> dataset = load_dataset("fancyzhx/amazon_polarity", split="train", streaming=True) >>> print(dataset) IterableDataset({ features: ['label', 'title', 'content'], diff --git a/docs/source/use_with_jax.mdx b/docs/source/use_with_jax.mdx index a38dc7928ad..cb0a763ab7c 100644 --- a/docs/source/use_with_jax.mdx +++ b/docs/source/use_with_jax.mdx @@ -195,11 +195,11 @@ part. The easiest way to get JAX arrays out of a dataset is to use the `with_format('jax')` method. Lets assume that we want to train a neural network on the [MNIST dataset](http://yann.lecun.com/exdb/mnist/) available -at the HuggingFace Hub at https://huggingface.co/datasets/mnist. +at the HuggingFace Hub at https://huggingface.co/datasets/ylecun/mnist. ```py >>> from datasets import load_dataset ->>> ds = load_dataset("mnist") +>>> ds = load_dataset("ylecun/mnist") >>> ds = ds.with_format("jax") >>> ds["train"][0] {'image': DeviceArray([[ 0, 0, 0, ...], diff --git a/docs/source/use_with_numpy.mdx b/docs/source/use_with_numpy.mdx index bd0cd6877b7..b3ba45864e8 100644 --- a/docs/source/use_with_numpy.mdx +++ b/docs/source/use_with_numpy.mdx @@ -160,7 +160,7 @@ at the HuggingFace Hub at https://huggingface.co/datasets/mnist. ```py >>> from datasets import load_dataset ->>> ds = load_dataset("mnist") +>>> ds = load_dataset("ylecun/mnist") >>> ds = ds.with_format("numpy") >>> ds["train"][0] {'image': array([[ 0, 0, 0, ...], diff --git a/setup.py b/setup.py index 0dee50b1f42..30d66fc54db 100644 --- a/setup.py +++ b/setup.py @@ -124,10 +124,10 @@ # for fast hashing "xxhash", # for better multiprocessing - "multiprocess<0.70.17", # to align with dill<0.3.9 (see above) + "multiprocess<0.70.19", # to align with dill<0.3.9 (see above) # to save datasets locally or on any filesystem # minimum 2023.1.0 to support protocol=kwargs in fsspec's `open`, `get_fs_token_paths`, etc.: see https://github.com/fsspec/filesystem_spec/pull/1143 - "fsspec[http]>=2023.1.0,<=2025.9.0", + "fsspec[http]>=2023.1.0,<=2025.10.0", # To get datasets from the Datasets Hub on huggingface.co "huggingface-hub>=0.25.0,<2.0", # Utilities from PyPA to e.g., compare versions @@ -153,12 +153,12 @@ TESTS_REQUIRE = [ # fix pip install issues for windows - "numba>=0.56.4", # to get recent versions of llvmlite for windows ci + "numba>=0.56.4; python_version < '3.14'", # to get recent versions of llvmlite for windows ci, not available on 3.14 # test dependencies "absl-py", "decorator", "joblib<1.3.0", # joblibspark doesn't support recent joblib versions - "joblibspark", + "joblibspark; python_version < '3.14'", # python 3.14 gives AttributeError: module 'ast' has no attribute 'Num' "pytest", "pytest-datadir", "pytest-xdist", @@ -169,7 +169,7 @@ "h5py", "jax>=0.3.14; sys_platform != 'win32'", "jaxlib>=0.3.14; sys_platform != 'win32'", - "lz4", + "lz4; python_version < '3.14'", # python 3.14 gives ImportError: cannot import name '_compression' from partially initialized module 'lz4.frame "moto[server]", "pyspark>=3.4", # https://issues.apache.org/jira/browse/SPARK-40991 fixed in 3.4.0 "py7zr", @@ -177,7 +177,7 @@ "sqlalchemy", "protobuf<4.0.0", # 4.0.0 breaks compatibility with tensorflow<2.12 "tensorflow>=2.6.0; python_version<'3.10' and sys_platform != 'win32'", # numpy-2 is not supported for Python < 3.10 - "tensorflow>=2.16.0; python_version>='3.10' and sys_platform != 'win32'", # Pins numpy < 2 + "tensorflow>=2.16.0; python_version>='3.10' and sys_platform != 'win32' and python_version < '3.14'", # Pins numpy < 2 "tiktoken", "torch>=2.8.0", "torchdata", @@ -185,7 +185,8 @@ "zstandard", "polars[timezone]>=0.20.0", "Pillow>=9.4.0", # When PIL.Image.ExifTags was introduced - "torchcodec>=0.7.0", # minium version to get windows support + "torchcodec>=0.7.0; python_version < '3.14'", # minium version to get windows support, torchcodec doesn't have wheels for 3.14 yet + "nibabel>=5.3.1", ] NUMPY2_INCOMPATIBLE_LIBRARIES = [ @@ -207,6 +208,8 @@ PDFS_REQUIRE = ["pdfplumber>=0.11.4"] +NIBABEL_REQUIRE = ["nibabel>=5.3.2", "ipyniivue==2.4.2"] + EXTRAS_REQUIRE = { "audio": AUDIO_REQUIRE, "vision": VISION_REQUIRE, @@ -224,11 +227,12 @@ "benchmarks": BENCHMARKS_REQUIRE, "docs": DOCS_REQUIRE, "pdfs": PDFS_REQUIRE, + "nibabel": NIBABEL_REQUIRE, } setup( name="datasets", - version="4.1.2.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) + version="4.4.2.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) description="HuggingFace community-driven open-source library of datasets", long_description=open("README.md", encoding="utf-8").read(), long_description_content_type="text/markdown", @@ -258,6 +262,9 @@ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", "Topic :: Scientific/Engineering :: Artificial Intelligence", ], keywords="datasets machine learning datasets", diff --git a/src/datasets/__init__.py b/src/datasets/__init__.py index 77f14553a3e..6b2dc7d8600 100644 --- a/src/datasets/__init__.py +++ b/src/datasets/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "4.1.2.dev0" +__version__ = "4.4.2.dev0" from .arrow_dataset import Column, Dataset from .arrow_reader import ReadInstruction diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index c46733e71ee..36b744a024a 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -660,7 +660,11 @@ def __init__(self, source: Union["Dataset", "Column"], column_name: str): def __iter__(self) -> Iterator[Any]: if isinstance(self.source, Dataset): - source = self.source._fast_select_column(self.column_name) + if self.source._format_type == "custom": + # the formatting transform may require all columns + source = self.source + else: + source = self.source._fast_select_column(self.column_name) else: source = self.source for example in source: @@ -670,7 +674,12 @@ def __getitem__(self, key: Union[int, str, list[int]]) -> Any: if isinstance(key, str): return Column(self, key) elif isinstance(self.source, Dataset): - return self.source._fast_select_column(self.column_name)[key][self.column_name] + if self.source._format_type == "custom": + # the formatting transform may require all columns + source = self.source + else: + source = self.source._fast_select_column(self.column_name) + return source[key][self.column_name] elif isinstance(key, int): return self.source[key][self.column_name] else: @@ -1120,6 +1129,7 @@ def from_generator( gen_kwargs: Optional[dict] = None, num_proc: Optional[int] = None, split: NamedSplit = Split.TRAIN, + fingerprint: Optional[str] = None, **kwargs, ): """Create a Dataset from a generator. @@ -1146,6 +1156,12 @@ def from_generator( Split name to be assigned to the dataset. + fingerprint (`str`, *optional*): + Fingerprint that will be used to generate dataset ID. + By default `fingerprint` is generated by hashing the generator function and all the args which can be slow + if it uses large objects like AI models. + + **kwargs (additional keyword arguments): Keyword arguments to be passed to :[`GeneratorConfig`]. @@ -1183,6 +1199,7 @@ def from_generator( gen_kwargs=gen_kwargs, num_proc=num_proc, split=split, + fingerprint=fingerprint, **kwargs, ).read() @@ -1953,7 +1970,7 @@ def class_encode_column(self, column: str, include_nulls: bool = False) -> "Data ```py >>> from datasets import load_dataset - >>> ds = load_dataset("boolq", split="validation") + >>> ds = load_dataset("google/boolq", split="validation") >>> ds.features {'answer': Value('bool'), 'passage': Value('string'), @@ -4708,7 +4725,7 @@ def train_test_split( >>> ds = ds.train_test_split(test_size=0.2, seed=42) # stratified split - >>> ds = load_dataset("imdb",split="train") + >>> ds = load_dataset("stanfordnlp/imdb",split="train") Dataset({ features: ['text', 'label'], num_rows: 25000 @@ -4860,7 +4877,7 @@ def train_test_split( try: train_indices, test_indices = next( stratified_shuffle_split_generate_indices( - self.with_format("numpy")[stratify_by_column], n_train, n_test, rng=generator + np.asarray(self.with_format("numpy")[stratify_by_column]), n_train, n_test, rng=generator ) ) except Exception as error: @@ -5189,7 +5206,7 @@ def to_polars( Args: batch_size (`int`, *optional*): The size (number of rows) of the batches if `batched` is `True`. - Defaults to `genomicsml.datasets.config.DEFAULT_MAX_BATCH_SIZE`. + Defaults to `datasets.config.DEFAULT_MAX_BATCH_SIZE`. batched (`bool`): Set to `True` to return a generator that yields the dataset as batches of `batch_size` rows. Defaults to `False` (returns the whole datasets once). @@ -6158,7 +6175,7 @@ def add_faiss_index( Example: ```python - >>> ds = datasets.load_dataset('crime_and_punish', split='train') + >>> ds = datasets.load_dataset('community-datasets/crime_and_punish', split='train') >>> ds_with_embeddings = ds.map(lambda example: {'embeddings': embed(example['line']})) >>> ds_with_embeddings.add_faiss_index(column='embeddings') >>> # query @@ -6166,7 +6183,7 @@ def add_faiss_index( >>> # save index >>> ds_with_embeddings.save_faiss_index('embeddings', 'my_index.faiss') - >>> ds = datasets.load_dataset('crime_and_punish', split='train') + >>> ds = datasets.load_dataset('community-datasets/crime_and_punish', split='train') >>> # load index >>> ds.load_faiss_index('embeddings', 'my_index.faiss') >>> # query @@ -6297,7 +6314,7 @@ def add_elasticsearch_index( ```python >>> es_client = elasticsearch.Elasticsearch() - >>> ds = datasets.load_dataset('crime_and_punish', split='train') + >>> ds = datasets.load_dataset('community-datasets/crime_and_punish', split='train') >>> ds.add_elasticsearch_index(column='line', es_client=es_client, es_index_name="my_es_index") >>> scores, retrieved_examples = ds.get_nearest_examples('line', 'my new query', k=10) ``` diff --git a/src/datasets/arrow_reader.py b/src/datasets/arrow_reader.py index 3bbb58a59c3..d9cf2cf0f4b 100644 --- a/src/datasets/arrow_reader.py +++ b/src/datasets/arrow_reader.py @@ -459,34 +459,34 @@ class ReadInstruction: Examples:: # The following lines are equivalent: - ds = datasets.load_dataset('mnist', split='test[:33%]') - ds = datasets.load_dataset('mnist', split=datasets.ReadInstruction.from_spec('test[:33%]')) - ds = datasets.load_dataset('mnist', split=datasets.ReadInstruction('test', to=33, unit='%')) - ds = datasets.load_dataset('mnist', split=datasets.ReadInstruction( + ds = datasets.load_dataset('ylecun/mnist', split='test[:33%]') + ds = datasets.load_dataset('ylecun/mnist', split=datasets.ReadInstruction.from_spec('test[:33%]')) + ds = datasets.load_dataset('ylecun/mnist', split=datasets.ReadInstruction('test', to=33, unit='%')) + ds = datasets.load_dataset('ylecun/mnist', split=datasets.ReadInstruction( 'test', from_=0, to=33, unit='%')) # The following lines are equivalent: - ds = datasets.load_dataset('mnist', split='test[:33%]+train[1:-1]') - ds = datasets.load_dataset('mnist', split=datasets.ReadInstruction.from_spec( + ds = datasets.load_dataset('ylecun/mnist', split='test[:33%]+train[1:-1]') + ds = datasets.load_dataset('ylecun/mnist', split=datasets.ReadInstruction.from_spec( 'test[:33%]+train[1:-1]')) - ds = datasets.load_dataset('mnist', split=( + ds = datasets.load_dataset('ylecun/mnist', split=( datasets.ReadInstruction('test', to=33, unit='%') + datasets.ReadInstruction('train', from_=1, to=-1, unit='abs'))) # The following lines are equivalent: - ds = datasets.load_dataset('mnist', split='test[:33%](pct1_dropremainder)') - ds = datasets.load_dataset('mnist', split=datasets.ReadInstruction.from_spec( + ds = datasets.load_dataset('ylecun/mnist', split='test[:33%](pct1_dropremainder)') + ds = datasets.load_dataset('ylecun/mnist', split=datasets.ReadInstruction.from_spec( 'test[:33%](pct1_dropremainder)')) - ds = datasets.load_dataset('mnist', split=datasets.ReadInstruction( + ds = datasets.load_dataset('ylecun/mnist', split=datasets.ReadInstruction( 'test', from_=0, to=33, unit='%', rounding="pct1_dropremainder")) # 10-fold validation: tests = datasets.load_dataset( - 'mnist', + 'ylecun/mnist', [datasets.ReadInstruction('train', from_=k, to=k+10, unit='%') for k in range(0, 100, 10)]) trains = datasets.load_dataset( - 'mnist', + 'ylecun/mnist', [datasets.ReadInstruction('train', to=k, unit='%') + datasets.ReadInstruction('train', from_=k+10, unit='%') for k in range(0, 100, 10)]) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index e63960dcabf..b88aa0bf8f9 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -313,6 +313,7 @@ def __init__( data_dir: Optional[str] = None, storage_options: Optional[dict] = None, writer_batch_size: Optional[int] = None, + config_id: Optional[str] = None, **config_kwargs, ): # DatasetBuilder name @@ -343,6 +344,7 @@ def __init__( self.config, self.config_id = self._create_builder_config( config_name=config_name, custom_features=features, + config_id=config_id, **config_kwargs, ) @@ -502,7 +504,7 @@ def update_hash_with_config_parameters(hash: str, config_parameters: dict) -> st return legacy_relative_data_dir def _create_builder_config( - self, config_name=None, custom_features=None, **config_kwargs + self, config_name=None, custom_features=None, config_id=None, **config_kwargs ) -> tuple[BuilderConfig, str]: """Create and validate BuilderConfig object as well as a unique config id for this config. Raises ValueError if there are multiple builder configs and config_name and DEFAULT_CONFIG_NAME are None. @@ -570,10 +572,11 @@ def _create_builder_config( ) # compute the config id that is going to be used for caching - config_id = builder_config.create_config_id( - config_kwargs, - custom_features=custom_features, - ) + if config_id is None: + config_id = builder_config.create_config_id( + config_kwargs, + custom_features=custom_features, + ) is_custom = (config_id not in self.builder_configs) and config_id != "default" if is_custom: logger.info(f"Using custom data configuration {config_id}") diff --git a/src/datasets/config.py b/src/datasets/config.py index 908befa8c69..b6412682727 100644 --- a/src/datasets/config.py +++ b/src/datasets/config.py @@ -139,6 +139,7 @@ TORCHCODEC_AVAILABLE = importlib.util.find_spec("torchcodec") is not None TORCHVISION_AVAILABLE = importlib.util.find_spec("torchvision") is not None PDFPLUMBER_AVAILABLE = importlib.util.find_spec("pdfplumber") is not None +NIBABEL_AVAILABLE = importlib.util.find_spec("nibabel") is not None # Optional compression tools RARFILE_AVAILABLE = importlib.util.find_spec("rarfile") is not None @@ -247,6 +248,10 @@ # Streaming STREAMING_READ_MAX_RETRIES = 20 STREAMING_READ_RETRY_INTERVAL = 5 +STREAMING_READ_SERVER_UNAVAILABLE_RETRY_INTERVAL = 20 +STREAMING_READ_RATE_LIMIT_RETRY_INTERVAL = 60 +STREAMING_OPEN_MAX_RETRIES = 20 +STREAMING_OPEN_RETRY_INTERVAL = 5 # Datasets repositories exploration DATA_FILES_MAX_NUMBER_FOR_MODULE_INFERENCE = 200 diff --git a/src/datasets/data_files.py b/src/datasets/data_files.py index 9fefd4a4c69..96d4daea52d 100644 --- a/src/datasets/data_files.py +++ b/src/datasets/data_files.py @@ -349,14 +349,18 @@ def resolve_pattern( pattern, storage_options = _prepare_path_and_storage_options(pattern, download_config=download_config) fs, fs_pattern = url_to_fs(pattern, **storage_options) files_to_ignore = set(FILES_TO_IGNORE) - {xbasename(pattern)} - protocol = fs.protocol if isinstance(fs.protocol, str) else fs.protocol[0] + protocol = ( + pattern.split("://")[0] + if "://" in pattern + else (fs.protocol if isinstance(fs.protocol, str) else fs.protocol[0]) + ) protocol_prefix = protocol + "://" if protocol != "file" else "" glob_kwargs = {} if protocol == "hf": # 10 times faster glob with detail=True (ignores costly info like lastCommit) glob_kwargs["expand_info"] = False matched_paths = [ - filepath if filepath.startswith(protocol_prefix) else protocol_prefix + filepath + filepath if "://" in filepath else protocol_prefix + filepath for filepath, info in fs.glob(pattern, detail=True, **glob_kwargs).items() if (info["type"] == "file" or (info.get("islink") and os.path.isfile(os.path.realpath(filepath)))) and (xbasename(filepath) not in files_to_ignore) @@ -503,6 +507,18 @@ def _get_origin_metadata( max_workers: Optional[int] = None, ) -> list[SingleOriginMetadata]: max_workers = max_workers if max_workers is not None else config.HF_DATASETS_MULTITHREADING_MAX_WORKERS + if all("hf://" in data_file for data_file in data_files): + # No need for multithreading here since the origin metadata of HF files + # is (repo_id, revision) and is cached after first .info() call. + return [ + _get_single_origin_metadata(data_file, download_config=download_config) + for data_file in hf_tqdm( + data_files, + desc="Resolving data files", + # set `disable=None` rather than `disable=False` by default to disable progress bar when no TTY attached + disable=len(data_files) <= 16 or None, + ) + ] return thread_map( partial(_get_single_origin_metadata, download_config=download_config), data_files, diff --git a/src/datasets/dataset_dict.py b/src/datasets/dataset_dict.py index 63a93429c45..995103d26e0 100644 --- a/src/datasets/dataset_dict.py +++ b/src/datasets/dataset_dict.py @@ -515,7 +515,7 @@ def class_encode_column(self, column: str, include_nulls: bool = False) -> "Data ```py >>> from datasets import load_dataset - >>> ds = load_dataset("boolq") + >>> ds = load_dataset("google/boolq") >>> ds["train"].features {'answer': Value('bool'), 'passage': Value('string'), diff --git a/src/datasets/download/download_config.py b/src/datasets/download/download_config.py index b9e539fb053..6efae72b671 100644 --- a/src/datasets/download/download_config.py +++ b/src/datasets/download/download_config.py @@ -75,7 +75,7 @@ def copy(self) -> "DownloadConfig": def __setattr__(self, name, value): if name == "token" and getattr(self, "storage_options", None) is not None: if "hf" not in self.storage_options: - self.storage_options["hf"] = {"token": value, "endpoint": config.HF_ENDPOINT} + self.storage_options["hf"] = {"endpoint": config.HF_ENDPOINT, "token": value} elif getattr(self.storage_options["hf"], "token", None) is None: self.storage_options["hf"]["token"] = value super().__setattr__(name, value) diff --git a/src/datasets/download/download_manager.py b/src/datasets/download/download_manager.py index b6ee1d28e2b..4e84d9947fe 100644 --- a/src/datasets/download/download_manager.py +++ b/src/datasets/download/download_manager.py @@ -269,7 +269,7 @@ def iter_files(self, paths: Union[str, list[str]]): Example: ```py - >>> files = dl_manager.download_and_extract('https://huggingface.co/datasets/beans/resolve/main/data/train.zip') + >>> files = dl_manager.download_and_extract('https://huggingface.co/datasets/AI-Lab-Makerere/beans/resolve/main/data/train.zip') >>> files = dl_manager.iter_files(files) ``` """ diff --git a/src/datasets/download/streaming_download_manager.py b/src/datasets/download/streaming_download_manager.py index ff2dc1a64bd..6f4c6087027 100644 --- a/src/datasets/download/streaming_download_manager.py +++ b/src/datasets/download/streaming_download_manager.py @@ -206,7 +206,7 @@ def iter_files(self, urlpaths: Union[str, list[str]]) -> Iterable[str]: Example: ```py - >>> files = dl_manager.download_and_extract('https://huggingface.co/datasets/beans/resolve/main/data/train.zip') + >>> files = dl_manager.download_and_extract('https://huggingface.co/datasets/AI-Lab-Makerere/beans/resolve/main/data/train.zip') >>> files = dl_manager.iter_files(files) ``` """ diff --git a/src/datasets/features/__init__.py b/src/datasets/features/__init__.py index 36133ce5e5a..40a3568039a 100644 --- a/src/datasets/features/__init__.py +++ b/src/datasets/features/__init__.py @@ -15,10 +15,12 @@ "TranslationVariableLanguages", "Video", "Pdf", + "Nifti", ] from .audio import Audio from .features import Array2D, Array3D, Array4D, Array5D, ClassLabel, Features, LargeList, List, Sequence, Value from .image import Image +from .nifti import Nifti from .pdf import Pdf from .translation import Translation, TranslationVariableLanguages from .video import Video diff --git a/src/datasets/features/audio.py b/src/datasets/features/audio.py index c9b894f6605..d9513f289f5 100644 --- a/src/datasets/features/audio.py +++ b/src/datasets/features/audio.py @@ -49,9 +49,13 @@ class Audio: Args: sampling_rate (`int`, *optional*): Target sampling rate. If `None`, the native sampling rate is used. - mono (`bool`, defaults to `True`): - Whether to convert the audio signal to mono by averaging samples across - channels. + num_channels (`int`, *optional*): + The desired number of channels of the samples. By default, the number of channels of the source is used. + Audio decoding will return samples with shape (num_channels, num_samples) + Currently `None` (number of channels of the source, default), `1` (mono) or `2` (stereo) channels are supported. + The `num_channels` argument is passed to `torchcodec.decoders.AudioDecoder`. + + decode (`bool`, defaults to `True`): Whether to decode the audio data. If `False`, returns the underlying dictionary in the format `{"path": audio_path, "bytes": audio_bytes}`. @@ -63,7 +67,7 @@ class Audio: ```py >>> from datasets import load_dataset, Audio >>> ds = load_dataset("PolyAI/minds14", name="en-US", split="train") - >>> ds = ds.cast_column("audio", Audio(sampling_rate=44100)) + >>> ds = ds.cast_column("audio", Audio(sampling_rate=44100, num_channels=2)) >>> ds[0]["audio"] >>> audio = ds[0]["audio"] @@ -78,6 +82,7 @@ class Audio: sampling_rate: Optional[int] = None decode: bool = True + num_channels: Optional[int] = None stream_index: Optional[int] = None id: Optional[str] = field(default=None, repr=False) # Automatically constructed @@ -126,7 +131,7 @@ def encode_example(self, value: Union[str, bytes, bytearray, dict, "AudioDecoder buffer = BytesIO() AudioEncoder( torch.from_numpy(value["array"].astype(np.float32)), sample_rate=value["sampling_rate"] - ).to_file_like(buffer, format="wav") + ).to_file_like(buffer, format="wav", num_channels=self.num_channels) return {"bytes": buffer.getvalue(), "path": None} elif value.get("path") is not None and os.path.isfile(value["path"]): # we set "bytes": None to not duplicate the data if they're already available locally @@ -143,7 +148,7 @@ def encode_example(self, value: Union[str, bytes, bytearray, dict, "AudioDecoder buffer = BytesIO() AudioEncoder(torch.from_numpy(bytes_value), sample_rate=value["sampling_rate"]).to_file_like( - buffer, format="wav" + buffer, format="wav", num_channels=self.num_channels ) return {"bytes": buffer.getvalue(), "path": None} else: @@ -188,7 +193,9 @@ def decode_example( raise ValueError(f"An audio sample should have one of 'path' or 'bytes' but both are None in {value}.") if bytes is None and is_local_path(path): - audio = AudioDecoder(path, stream_index=self.stream_index, sample_rate=self.sampling_rate) + audio = AudioDecoder( + path, stream_index=self.stream_index, sample_rate=self.sampling_rate, num_channels=self.num_channels + ) elif bytes is None: token_per_repo_id = token_per_repo_id or {} @@ -201,10 +208,14 @@ def decode_example( download_config = DownloadConfig(token=token) f = xopen(path, "rb", download_config=download_config) - audio = AudioDecoder(f, stream_index=self.stream_index, sample_rate=self.sampling_rate) + audio = AudioDecoder( + f, stream_index=self.stream_index, sample_rate=self.sampling_rate, num_channels=self.num_channels + ) else: - audio = AudioDecoder(bytes, stream_index=self.stream_index, sample_rate=self.sampling_rate) + audio = AudioDecoder( + bytes, stream_index=self.stream_index, sample_rate=self.sampling_rate, num_channels=self.num_channels + ) audio._hf_encoded = {"path": path, "bytes": bytes} audio.metadata.path = path return audio @@ -312,5 +323,8 @@ def encode_torchcodec_audio(audio: "AudioDecoder") -> dict: samples = audio.get_all_samples() buffer = BytesIO() - AudioEncoder(samples.data.cpu(), sample_rate=samples.sample_rate).to_file_like(buffer, format="wav") + num_channels = samples.data.shape[0] + AudioEncoder(samples.data.cpu(), sample_rate=samples.sample_rate).to_file_like( + buffer, format="wav", num_channels=num_channels + ) return {"bytes": buffer.getvalue(), "path": None} diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index dbc3818e224..88259767ae0 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -42,6 +42,7 @@ from ..utils.py_utils import asdict, first_non_null_value, zip_dict from .audio import Audio from .image import Image, encode_pil_image +from .nifti import Nifti from .pdf import Pdf, encode_pdfplumber_pdf from .translation import Translation, TranslationVariableLanguages from .video import Video @@ -106,6 +107,8 @@ def _arrow_to_datasets_dtype(arrow_type: pa.DataType) -> str: return "binary" elif pyarrow.types.is_large_binary(arrow_type): return "large_binary" + elif pyarrow.types.is_binary_view(arrow_type): + return "binary_view" elif pyarrow.types.is_string(arrow_type): return "string" elif pyarrow.types.is_large_string(arrow_type): @@ -508,6 +511,7 @@ class Value: - `decimal256(precision, scale)` - `binary` - `large_binary` + - `binary_view` - `string` - `large_string` - `string_view` @@ -1267,6 +1271,7 @@ def __repr__(self): Image, Video, Pdf, + Nifti, ] @@ -1425,6 +1430,7 @@ def decode_nested_example(schema, obj, token_per_repo_id: Optional[dict[str, Uni Image.__name__: Image, Video.__name__: Video, Pdf.__name__: Pdf, + Nifti.__name__: Nifti, } @@ -1758,6 +1764,9 @@ class Features(dict): - [`Pdf`] feature to store the absolute path to a PDF file, a `pdfplumber.pdf.PDF` object or a dictionary with the relative path to a PDF file ("path" key) and its bytes content ("bytes" key). This feature loads the PDF lazily with a PDF reader. + - [`Nifti`] feature to store the absolute path to a NIfTI neuroimaging file, a `nibabel.Nifti1Image` object + or a dictionary with the relative path to a NIfTI file ("path" key) and its bytes content ("bytes" key). + This feature loads the NIfTI file lazily with nibabel. - [`Translation`] or [`TranslationVariableLanguages`] feature specific to Machine Translation. """ diff --git a/src/datasets/features/image.py b/src/datasets/features/image.py index fecc2fc5ccd..cb746b9219d 100644 --- a/src/datasets/features/image.py +++ b/src/datasets/features/image.py @@ -215,6 +215,7 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray, pa.ListArr The Arrow types that can be converted to the Image pyarrow storage type are: - `pa.string()` - it must contain the "path" data + - `pa.large_string()` - it must contain the "path" data (will be cast to string if possible) - `pa.binary()` - it must contain the image bytes - `pa.struct({"bytes": pa.binary()})` - `pa.struct({"path": pa.string()})` @@ -229,6 +230,15 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray, pa.ListArr `pa.StructArray`: Array in the Image arrow storage type, that is `pa.struct({"bytes": pa.binary(), "path": pa.string()})`. """ + if pa.types.is_large_string(storage.type): + try: + storage = storage.cast(pa.string()) + except pa.ArrowInvalid as e: + raise ValueError( + f"Failed to cast large_string to string for Image feature. " + f"This can happen if string values exceed 2GB. " + f"Original error: {e}" + ) from e if pa.types.is_string(storage.type): bytes_array = pa.array([None] * len(storage), type=pa.binary()) storage = pa.StructArray.from_arrays([bytes_array, storage], ["bytes", "path"], mask=storage.is_null()) diff --git a/src/datasets/features/nifti.py b/src/datasets/features/nifti.py new file mode 100644 index 00000000000..f3c34d29266 --- /dev/null +++ b/src/datasets/features/nifti.py @@ -0,0 +1,321 @@ +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, Union + +import pyarrow as pa + +from .. import config +from ..download.download_config import DownloadConfig +from ..table import array_cast +from ..utils.file_utils import is_local_path, xopen +from ..utils.py_utils import no_op_if_value_is_null, string_to_dict + + +if TYPE_CHECKING: + import nibabel as nib + + from .features import FeatureType + +if config.NIBABEL_AVAILABLE: + import nibabel as nib + + class Nifti1ImageWrapper(nib.nifti1.Nifti1Image): + """ + A wrapper around nibabel's Nifti1Image to customize its representation. + """ + + def __init__(self, nifti_image: nib.nifti1.Nifti1Image): + super().__init__( + dataobj=nifti_image.dataobj, + affine=nifti_image.affine, + header=nifti_image.header, + extra=nifti_image.extra, + file_map=nifti_image.file_map, + dtype=nifti_image.get_data_dtype(), + ) + self.nifti_image = nifti_image + + def _repr_html_(self): + from ipyniivue import NiiVue, ShowRender, SliceType, Volume + from IPython.display import display + + bytes_ = self.nifti_image.to_bytes() + nv = NiiVue() + nv.set_slice_type(SliceType.MULTIPLANAR) + nv.opts.multiplanar_show_render = ShowRender.ALWAYS + nv.opts.show_3d_crosshair = True + nv.opts.multiplanar_force_render = True + name = None + if hasattr(self.nifti_image, "file_map"): + if ( + "image" in self.nifti_image.file_map + and getattr(self.nifti_image.file_map["image"], "filename", None) is not None + ): + name = self.nifti_image.file_map["image"].filename + if name is None: + name = "volume.nii.gz" + volume = Volume(name=name, data=bytes_) + nv.load_volumes([volume]) + display(nv) + + +@dataclass +class Nifti: + """ + **Experimental.** + Nifti [`Feature`] to read NIfTI neuroimaging files. + + Input: The Nifti feature accepts as input: + - A `str`: Absolute path to the NIfTI file (i.e. random access is allowed). + - A `pathlib.Path`: path to the NIfTI file (i.e. random access is allowed). + - A `dict` with the keys: + - `path`: String with relative path of the NIfTI file in a dataset repository. + - `bytes`: Bytes of the NIfTI file. + This is useful for archived files with sequential access. + + - A `nibabel` image object (e.g., `nibabel.nifti1.Nifti1Image`). + + Args: + decode (`bool`, defaults to `True`): + Whether to decode the NIfTI data. If `False` a string with the bytes is returned. `decode=False` is not supported when decoding examples. + + Examples: + + ```py + >>> from datasets import Dataset, Nifti + >>> ds = Dataset.from_dict({"nifti": ["path/to/file.nii.gz"]}).cast_column("nifti", Nifti()) + >>> ds.features["nifti"] + Nifti(decode=True, id=None) + >>> ds[0]["nifti"] + + >>> ds = ds.cast_column("nifti", Nifti(decode=False)) + >>> ds[0]["nifti"] + {'bytes': None, + 'path': 'path/to/file.nii.gz'} + ``` + """ + + decode: bool = True + id: Optional[str] = field(default=None, repr=False) + + # Automatically constructed + dtype: ClassVar[str] = "nibabel.nifti1.Nifti1Image" + pa_type: ClassVar[Any] = pa.struct({"bytes": pa.binary(), "path": pa.string()}) + _type: str = field(default="Nifti", init=False, repr=False) + + def __call__(self): + return self.pa_type + + def encode_example(self, value: Union[str, bytes, bytearray, dict, "nib.Nifti1Image"]) -> dict: + """Encode example into a format for Arrow. + + Args: + value (`str`, `bytes`, `nibabel.Nifti1Image` or `dict`): + Data passed as input to Nifti feature. + + Returns: + `dict` with "path" and "bytes" fields + """ + if config.NIBABEL_AVAILABLE: + import nibabel as nib + else: + nib = None + + if isinstance(value, str): + return {"path": value, "bytes": None} + elif isinstance(value, Path): + return {"path": str(value.absolute()), "bytes": None} + elif isinstance(value, (bytes, bytearray)): + return {"path": None, "bytes": value} + elif nib is not None and isinstance(value, nib.spatialimages.SpatialImage): + # nibabel image object - try to get path or convert to bytes + return encode_nibabel_image(value) + elif isinstance(value, dict): + if value.get("path") is not None and os.path.isfile(value["path"]): + # we set "bytes": None to not duplicate the data if they're already available locally + return {"bytes": None, "path": value.get("path")} + elif value.get("bytes") is not None or value.get("path") is not None: + # store the nifti bytes, and path is used to infer the format using the file extension + return {"bytes": value.get("bytes"), "path": value.get("path")} + else: + raise ValueError( + f"A nifti sample should have one of 'path' or 'bytes' but they are missing or None in {value}." + ) + else: + raise ValueError( + f"A nifti sample should be a string, bytes, Path, nibabel image, or dict, but got {type(value)}." + ) + + def decode_example(self, value: dict, token_per_repo_id=None) -> "Nifti1ImageWrapper": + """Decode example NIfTI file into nibabel image object. + + Args: + value (`str` or `dict`): + A string with the absolute NIfTI file path, a dictionary with + keys: + + - `path`: String with absolute or relative NIfTI file path. + - `bytes`: The bytes of the NIfTI file. + + token_per_repo_id (`dict`, *optional*): + To access and decode NIfTI files from private repositories on + the Hub, you can pass a dictionary + repo_id (`str`) -> token (`bool` or `str`). + + Returns: + `nibabel.Nifti1Image` objects + """ + if config.NIBABEL_AVAILABLE: + import nibabel as nib + else: + raise ImportError("To support decoding NIfTI files, please install 'nibabel'.") + + if token_per_repo_id is None: + token_per_repo_id = {} + + path, bytes_ = value["path"], value["bytes"] + if bytes_ is None: + if path is None: + raise ValueError(f"A nifti should have one of 'path' or 'bytes' but both are None in {value}.") + else: + # gzipped files have the structure: 'gzip://T1.nii::' + if path.startswith("gzip://") and is_local_path(path.split("::")[-1]): + path = path.split("::")[-1] + if is_local_path(path): + nifti = nib.load(path) + else: + source_url = path.split("::")[-1] + pattern = ( + config.HUB_DATASETS_URL + if source_url.startswith(config.HF_ENDPOINT) + else config.HUB_DATASETS_HFFS_URL + ) + source_url_fields = string_to_dict(source_url, pattern) + token = ( + token_per_repo_id.get(source_url_fields["repo_id"]) if source_url_fields is not None else None + ) + download_config = DownloadConfig(token=token) + with xopen(path, "rb", download_config=download_config) as f: + nifti = nib.load(f) + else: + import gzip + + if ( + bytes_[:2] == b"\x1f\x8b" + ): # gzip magic number, see https://stackoverflow.com/a/76055284/9534390 or "Magic number" on https://en.wikipedia.org/wiki/Gzip + bytes_ = gzip.decompress(bytes_) + + nifti = nib.Nifti1Image.from_bytes(bytes_) + + return Nifti1ImageWrapper(nifti) + + def embed_storage(self, storage: pa.StructArray, token_per_repo_id=None) -> pa.StructArray: + """Embed NifTI files into the Arrow array. + + Args: + storage (`pa.StructArray`): + PyArrow array to embed. + + Returns: + `pa.StructArray`: Array in the NifTI arrow storage type, that is + `pa.struct({"bytes": pa.binary(), "path": pa.string()})`. + """ + if token_per_repo_id is None: + token_per_repo_id = {} + + @no_op_if_value_is_null + def path_to_bytes(path): + source_url = path.split("::")[-1] + pattern = ( + config.HUB_DATASETS_URL if source_url.startswith(config.HF_ENDPOINT) else config.HUB_DATASETS_HFFS_URL + ) + source_url_fields = string_to_dict(source_url, pattern) + token = token_per_repo_id.get(source_url_fields["repo_id"]) if source_url_fields is not None else None + download_config = DownloadConfig(token=token) + with xopen(path, "rb", download_config=download_config) as f: + return f.read() + + bytes_array = pa.array( + [ + (path_to_bytes(x["path"]) if x["bytes"] is None else x["bytes"]) if x is not None else None + for x in storage.to_pylist() + ], + type=pa.binary(), + ) + path_array = pa.array( + [os.path.basename(path) if path is not None else None for path in storage.field("path").to_pylist()], + type=pa.string(), + ) + storage = pa.StructArray.from_arrays([bytes_array, path_array], ["bytes", "path"], mask=bytes_array.is_null()) + return array_cast(storage, self.pa_type) + + def flatten(self) -> Union["FeatureType", Dict[str, "FeatureType"]]: + """If in the decodable state, return the feature itself, otherwise flatten the feature into a dictionary.""" + from .features import Value + + return ( + self + if self.decode + else { + "bytes": Value("binary"), + "path": Value("string"), + } + ) + + def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray, pa.BinaryArray]) -> pa.StructArray: + """Cast an Arrow array to the Nifti arrow storage type. + The Arrow types that can be converted to the Nifti pyarrow storage type are: + + - `pa.string()` - it must contain the "path" data + - `pa.binary()` - it must contain the NIfTI bytes + - `pa.struct({"bytes": pa.binary()})` + - `pa.struct({"path": pa.string()})` + - `pa.struct({"bytes": pa.binary(), "path": pa.string()})` - order doesn't matter + + Args: + storage (`Union[pa.StringArray, pa.StructArray, pa.BinaryArray]`): + PyArrow array to cast. + + Returns: + `pa.StructArray`: Array in the Nifti arrow storage type, that is + `pa.struct({"bytes": pa.binary(), "path": pa.string()})`. + """ + if pa.types.is_string(storage.type): + bytes_array = pa.array([None] * len(storage), type=pa.binary()) + storage = pa.StructArray.from_arrays([bytes_array, storage], ["bytes", "path"], mask=storage.is_null()) + elif pa.types.is_binary(storage.type): + path_array = pa.array([None] * len(storage), type=pa.string()) + storage = pa.StructArray.from_arrays([storage, path_array], ["bytes", "path"], mask=storage.is_null()) + elif pa.types.is_struct(storage.type): + if storage.type.get_field_index("bytes") >= 0: + bytes_array = storage.field("bytes") + else: + bytes_array = pa.array([None] * len(storage), type=pa.binary()) + if storage.type.get_field_index("path") >= 0: + path_array = storage.field("path") + else: + path_array = pa.array([None] * len(storage), type=pa.string()) + storage = pa.StructArray.from_arrays([bytes_array, path_array], ["bytes", "path"], mask=storage.is_null()) + return array_cast(storage, self.pa_type) + + +def encode_nibabel_image(img: "nib.Nifti1Image") -> dict[str, Optional[Union[str, bytes]]]: + """ + Encode a nibabel image object into a dictionary. + + If the image has an associated file path, returns the path. Otherwise, serializes + the image content into bytes. + + Args: + img: A nibabel image object (e.g., Nifti1Image). + + Returns: + dict: A dictionary with "path" or "bytes" field. + """ + if hasattr(img, "file_map") and img.file_map is not None: + filename = img.file_map["image"].filename + return {"path": filename, "bytes": None} + + bytes_data = img.to_bytes() + return {"path": None, "bytes": bytes_data} diff --git a/src/datasets/features/pdf.py b/src/datasets/features/pdf.py index 414c497356c..756530554d4 100644 --- a/src/datasets/features/pdf.py +++ b/src/datasets/features/pdf.py @@ -44,8 +44,6 @@ class Pdf: - A `pdfplumber.pdf.PDF`: pdfplumber pdf object. Args: - mode (`str`, *optional*): - The mode to convert the pdf to. If `None`, the native mode of the pdf is used. decode (`bool`, defaults to `True`): Whether to decode the pdf data. If `False`, returns the underlying dictionary in the format `{"path": pdf_path, "bytes": pdf_bytes}`. diff --git a/src/datasets/features/video.py b/src/datasets/features/video.py index adbfaaa30f3..8d7f3e3be51 100644 --- a/src/datasets/features/video.py +++ b/src/datasets/features/video.py @@ -45,8 +45,6 @@ class Video: Output: The Video features output data as `torchcodec.decoders.VideoDecoder` objects. Args: - mode (`str`, *optional*): - The mode to convert the video to. If `None`, the native mode of the video is used. decode (`bool`, defaults to `True`): Whether to decode the video data. If `False`, returns the underlying dictionary in the format `{"path": video_path, "bytes": video_bytes}`. diff --git a/src/datasets/io/generator.py b/src/datasets/io/generator.py index b10609cac23..6c1eaee9b0f 100644 --- a/src/datasets/io/generator.py +++ b/src/datasets/io/generator.py @@ -16,6 +16,7 @@ def __init__( gen_kwargs: Optional[dict] = None, num_proc: Optional[int] = None, split: NamedSplit = Split.TRAIN, + fingerprint: Optional[str] = None, **kwargs, ): super().__init__( @@ -32,8 +33,10 @@ def __init__( generator=generator, gen_kwargs=gen_kwargs, split=split, + config_id="default-fingerprint=" + fingerprint if fingerprint else None, **kwargs, ) + self.fingerprint = fingerprint def read(self): # Build iterable dataset @@ -56,4 +59,6 @@ def read(self): dataset = self.builder.as_dataset( split=self.builder.config.split, verification_mode=verification_mode, in_memory=self.keep_in_memory ) + if self.fingerprint: + dataset._fingerprint = self.fingerprint return dataset diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 7d16baa7d0d..26c35a60555 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -26,7 +26,15 @@ import pandas as pd import pyarrow as pa import pyarrow.parquet as pq -from huggingface_hub import CommitInfo, CommitOperationAdd, CommitOperationDelete, DatasetCard, DatasetCardData, HfApi +from huggingface_hub import ( + CommitInfo, + CommitOperationAdd, + CommitOperationDelete, + DatasetCard, + DatasetCardData, + HfApi, + HfFileSystem, +) from huggingface_hub.hf_api import RepoFile from huggingface_hub.utils import HfHubHTTPError, RepositoryNotFoundError from multiprocess import Pool @@ -161,6 +169,19 @@ def _convert_to_arrow( yield new_key, pa.Table.from_pylist(cast_to_python_objects(examples, only_1d_for_numpy=True)) +def shift_ex_examples_rngs(ex_iterable: "_BaseExamplesIterable", value: int) -> "_BaseExamplesIterable": + """We need to go through the ex_iterables recursively, create a new seed and return a new iterable, then set it to the containing ex_iterable.""" + + def set_seed_recursively(ex_iterable): + if hasattr(ex_iterable, "shift_rngs"): + ex_iterable = ex_iterable.shift_rngs(value) + if hasattr(ex_iterable, "ex_iterable"): + ex_iterable.ex_iterable = set_seed_recursively(ex_iterable.ex_iterable) + return ex_iterable + + return set_seed_recursively(ex_iterable) + + class _BaseExamplesIterable: """Base class for the examples iterable used by an IterableDataset""" @@ -275,6 +296,14 @@ def __init__( super().__init__(generate_examples_fn, kwargs) self.generator = deepcopy(generator) + def shift_rngs(self, value: int) -> "_BaseExamplesIterable": + new_seed = self.generator.bit_generator.state["state"]["state"] + value + return ShuffledDataSourcesExamplesIterable( + self.generate_examples_fn, + self.kwargs, + np.random.default_rng(seed=new_seed), + ) + def _init_state_dict(self) -> dict: self._state_dict = {"shard_idx": 0, "shard_example_idx": 0, "type": self.__class__.__name__} return self._state_dict @@ -382,6 +411,14 @@ def __init__( super().__init__(generate_tables_fn, kwargs) self.generator = deepcopy(generator) + def shift_rngs(self, value: int) -> "_BaseExamplesIterable": + new_seed = self.generator.bit_generator.state["state"]["state"] + value + return ShuffledDataSourcesArrowExamplesIterable( + self.generate_examples_fn, + self.kwargs, + np.random.default_rng(seed=new_seed), + ) + def _init_state_dict(self) -> dict: self._state_dict = {"shard_idx": 0, "shard_example_idx": 0, "type": self.__class__.__name__} return self._state_dict @@ -1023,6 +1060,15 @@ def __init__( self.generator = deepcopy(generator) self.probabilities = probabilities + def shift_rngs(self, value: int) -> "_BaseExamplesIterable": + new_seed = self.generator.bit_generator.state["state"]["state"] + value + return RandomlyCyclingMultiSourcesExamplesIterable( + ex_iterables=self.ex_iterables, + generator=np.random.default_rng(seed=new_seed), + probabilities=self.probabilities, + stopping_strategy=self.stopping_strategy, + ) + @property def is_typed(self): return self.ex_iterables[0].is_typed @@ -1620,6 +1666,14 @@ def __init__(self, ex_iterable: _BaseExamplesIterable, buffer_size: int, generat self.buffer_size = buffer_size self.generator = generator + def shift_rngs(self, value: int) -> "_BaseExamplesIterable": + new_seed = self.generator.bit_generator.state["state"]["state"] + value + return BufferShuffledExamplesIterable( + ex_iterable=self.ex_iterable, + buffer_size=self.buffer_size, + generator=np.random.default_rng(seed=new_seed), + ) + @property def is_typed(self): return self.ex_iterable.is_typed @@ -2151,6 +2205,7 @@ def __init__( self._token_per_repo_id: dict[str, Union[str, bool, None]] = token_per_repo_id or {} self._epoch: Union[int, "torch.Tensor"] = _maybe_share_with_torch_persistent_workers(0) self._starting_state_dict: Optional[dict] = None + self.__hffs_cache = HfFileSystem._cache # keep the cache on pickling (e.g. for dataloader workers) self._prepare_ex_iterable_for_iteration() # set state_dict _maybe_add_torch_iterable_dataset_parent_class(self.__class__) # subclass of torch IterableDataset @@ -2299,6 +2354,8 @@ def __setstate__(self, d): self.__dict__ = d # Re-add torch shared memory, since shared memory is not always kept when pickling self._epoch = _maybe_share_with_torch_persistent_workers(self._epoch) + # Re-add the cache to keep on pickling (e.g. for dataloader workers) + self.__hffs_cache = HfFileSystem._cache # Re-add torch iterable dataset as a parent class, since dynamically added parent classes are not kept when pickling _maybe_add_torch_iterable_dataset_parent_class(self.__class__) @@ -2361,6 +2418,7 @@ def _iter_pytorch(self): ex_iterable = ex_iterable.shard_data_sources( num_shards=worker_info.num_workers, index=worker_info.id, contiguous=False ) + ex_iterable = shift_ex_examples_rngs(ex_iterable=ex_iterable, value=worker_info.id) self._state_dict = { "examples_iterable": ex_iterable._init_state_dict(), "epoch": self.epoch, @@ -3160,7 +3218,7 @@ def shard( ```py >>> from datasets import load_dataset - >>> ds = load_dataset("amazon_polarity", split="train", streaming=True) + >>> ds = load_dataset("fancyzhx/amazon_polarity", split="train", streaming=True) >>> ds Dataset({ features: ['label', 'title', 'content'], @@ -3570,15 +3628,12 @@ def batch(self, batch_size: int, drop_last_batch: bool = False) -> "IterableData ``` """ - def batch_fn(unbatched): - return {k: [v] for k, v in unbatched.items()} - if self.features: features = Features({col: List(feature) for col, feature in self.features.items()}) else: features = None return self.map( - batch_fn, batched=True, batch_size=batch_size, drop_last_batch=drop_last_batch, features=features + _batch_fn, batched=True, batch_size=batch_size, drop_last_batch=drop_last_batch, features=features ) def to_dict(self, batch_size: Optional[int] = None, batched: bool = False) -> Union[dict, Iterator[dict]]: @@ -3660,7 +3715,7 @@ def to_polars( Args: batch_size (`int`, *optional*): The size (number of rows) of the batches if `batched` is `True`. - Defaults to `genomicsml.datasets.config.DEFAULT_MAX_BATCH_SIZE`. + Defaults to `datasets.config.DEFAULT_MAX_BATCH_SIZE`. batched (`bool`): Set to `True` to return a generator that yields the dataset as batches of `batch_size` rows. Defaults to `False` (returns the whole datasets once). @@ -4648,3 +4703,7 @@ async def _apply_async(pool, func, x): return future.get() else: await asyncio.sleep(0) + + +def _batch_fn(unbatched): + return {k: [v] for k, v in unbatched.items()} diff --git a/src/datasets/packaged_modules/__init__.py b/src/datasets/packaged_modules/__init__.py index 515ff147b29..9d076df44b7 100644 --- a/src/datasets/packaged_modules/__init__.py +++ b/src/datasets/packaged_modules/__init__.py @@ -11,6 +11,7 @@ from .hdf5 import hdf5 from .imagefolder import imagefolder from .json import json +from .niftifolder import niftifolder from .pandas import pandas from .parquet import parquet from .pdffolder import pdffolder @@ -46,6 +47,7 @@ def _hash_python_lines(lines: list[str]) -> str: "audiofolder": (audiofolder.__name__, _hash_python_lines(inspect.getsource(audiofolder).splitlines())), "videofolder": (videofolder.__name__, _hash_python_lines(inspect.getsource(videofolder).splitlines())), "pdffolder": (pdffolder.__name__, _hash_python_lines(inspect.getsource(pdffolder).splitlines())), + "niftifolder": (niftifolder.__name__, _hash_python_lines(inspect.getsource(niftifolder).splitlines())), "webdataset": (webdataset.__name__, _hash_python_lines(inspect.getsource(webdataset).splitlines())), "xml": (xml.__name__, _hash_python_lines(inspect.getsource(xml).splitlines())), "hdf5": (hdf5.__name__, _hash_python_lines(inspect.getsource(hdf5).splitlines())), @@ -89,6 +91,8 @@ def _hash_python_lines(lines: list[str]) -> str: _EXTENSION_TO_MODULE.update({ext.upper(): ("videofolder", {}) for ext in videofolder.VideoFolder.EXTENSIONS}) _EXTENSION_TO_MODULE.update({ext: ("pdffolder", {}) for ext in pdffolder.PdfFolder.EXTENSIONS}) _EXTENSION_TO_MODULE.update({ext.upper(): ("pdffolder", {}) for ext in pdffolder.PdfFolder.EXTENSIONS}) +_EXTENSION_TO_MODULE.update({ext: ("niftifolder", {}) for ext in niftifolder.NiftiFolder.EXTENSIONS}) +_EXTENSION_TO_MODULE.update({ext.upper(): ("niftifolder", {}) for ext in niftifolder.NiftiFolder.EXTENSIONS}) # Used to filter data files based on extensions given a module name _MODULE_TO_EXTENSIONS: dict[str, list[str]] = {} @@ -106,3 +110,4 @@ def _hash_python_lines(lines: list[str]) -> str: _MODULE_TO_METADATA_FILE_NAMES["audiofolder"] = imagefolder.ImageFolder.METADATA_FILENAMES _MODULE_TO_METADATA_FILE_NAMES["videofolder"] = imagefolder.ImageFolder.METADATA_FILENAMES _MODULE_TO_METADATA_FILE_NAMES["pdffolder"] = imagefolder.ImageFolder.METADATA_FILENAMES +_MODULE_TO_METADATA_FILE_NAMES["niftifolder"] = imagefolder.ImageFolder.METADATA_FILENAMES diff --git a/src/datasets/packaged_modules/hdf5/hdf5.py b/src/datasets/packaged_modules/hdf5/hdf5.py index fb9100e1a0a..1b0e80aa6a8 100644 --- a/src/datasets/packaged_modules/hdf5/hdf5.py +++ b/src/datasets/packaged_modules/hdf5/hdf5.py @@ -61,8 +61,9 @@ def _split_generators(self, dl_manager): # Infer features from first file if self.info.features is None: for first_file in itertools.chain.from_iterable(files): - with h5py.File(first_file, "r") as h5: - self.info.features = _recursive_infer_features(h5) + with open(first_file, "rb") as f: + with h5py.File(f, "r") as h5: + self.info.features = _recursive_infer_features(h5) break splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files})) return splits @@ -73,22 +74,23 @@ def _generate_tables(self, files): batch_size_cfg = self.config.batch_size for file_idx, file in enumerate(itertools.chain.from_iterable(files)): try: - with h5py.File(file, "r") as h5: - # Infer features and lengths from first file - if self.info.features is None: - self.info.features = _recursive_infer_features(h5) - num_rows = _check_dataset_lengths(h5, self.info.features) - if num_rows is None: - logger.warning(f"File {file} contains no data, skipping...") - continue - effective_batch = batch_size_cfg or self._writer_batch_size or num_rows - for start in range(0, num_rows, effective_batch): - end = min(start + effective_batch, num_rows) - pa_table = _recursive_load_arrays(h5, self.info.features, start, end) - if pa_table is None: + with open(file, "rb") as f: + with h5py.File(f, "r") as h5: + # Infer features and lengths from first file + if self.info.features is None: + self.info.features = _recursive_infer_features(h5) + num_rows = _check_dataset_lengths(h5, self.info.features) + if num_rows is None: logger.warning(f"File {file} contains no data, skipping...") continue - yield f"{file_idx}_{start}", cast_table_to_features(pa_table, self.info.features) + effective_batch = batch_size_cfg or self._writer_batch_size or num_rows + for start in range(0, num_rows, effective_batch): + end = min(start + effective_batch, num_rows) + pa_table = _recursive_load_arrays(h5, self.info.features, start, end) + if pa_table is None: + logger.warning(f"File {file} contains no data, skipping...") + continue + yield f"{file_idx}_{start}", cast_table_to_features(pa_table, self.info.features) except ValueError as e: logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}") raise diff --git a/src/datasets/packaged_modules/niftifolder/__init__.py b/src/datasets/packaged_modules/niftifolder/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/datasets/packaged_modules/niftifolder/niftifolder.py b/src/datasets/packaged_modules/niftifolder/niftifolder.py new file mode 100644 index 00000000000..c6d039419f1 --- /dev/null +++ b/src/datasets/packaged_modules/niftifolder/niftifolder.py @@ -0,0 +1,23 @@ +import datasets + +from ..folder_based_builder import folder_based_builder + + +logger = datasets.utils.logging.get_logger(__name__) + + +class NiftiFolderConfig(folder_based_builder.FolderBasedBuilderConfig): + """BuilderConfig for NiftiFolder.""" + + drop_labels: bool = None + drop_metadata: bool = None + + def __post_init__(self): + super().__post_init__() + + +class NiftiFolder(folder_based_builder.FolderBasedBuilder): + BASE_FEATURE = datasets.Nifti + BASE_COLUMN_NAME = "nifti" + BUILDER_CONFIG_CLASS = NiftiFolderConfig + EXTENSIONS: list[str] = [".nii", ".nii.gz"] diff --git a/src/datasets/packaged_modules/parquet/parquet.py b/src/datasets/packaged_modules/parquet/parquet.py index 10797753657..52a675d41c7 100644 --- a/src/datasets/packaged_modules/parquet/parquet.py +++ b/src/datasets/packaged_modules/parquet/parquet.py @@ -1,6 +1,6 @@ import itertools from dataclasses import dataclass -from typing import Optional, Union +from typing import Literal, Optional, Union import pyarrow as pa import pyarrow.dataset as ds @@ -15,12 +15,73 @@ @dataclass class ParquetConfig(datasets.BuilderConfig): - """BuilderConfig for Parquet.""" + """ + BuilderConfig for Parquet. + + Args: + batch_size (`int`, *optional*): + Size of the RecordBatches to iterate on. + The default is the row group size (defined by the first row group). + columns (`list[str]`, *optional*) + List of columns to load, the other ones are ignored. + All columns are loaded by default. + features: (`Features`, *optional*): + Cast the data to `features`. + filters (`Union[pyarrow.dataset.Expression, list[tuple], list[list[tuple]]]`, *optional*): + Return only the rows matching the filter. + If possible the predicate will be pushed down to exploit the partition information + or internal metadata found in the data source, e.g. Parquet statistics. + Otherwise filters the loaded RecordBatches before yielding them. + fragment_scan_options (`pyarrow.dataset.ParquetFragmentScanOptions`, *optional*) + Scan-specific options for Parquet fragments. + This is especially useful to configure buffering and caching. + + + on_bad_files (`Literal["error", "warn", "skip"]`, *optional*, defaults to "error") + Specify what to do upon encountering a bad file (a file that can't be read). Allowed values are : + * 'error', raise an Exception when a bad file is encountered. + * 'warn', raise a warning when a bad file is encountered and skip that file. + * 'skip', skip bad files without raising or warning when they are encountered. + + + + Example: + + Load a subset of columns: + + ```python + >>> ds = load_dataset(parquet_dataset_id, columns=["col_0", "col_1"]) + ``` + + Stream data and efficiently filter data, possibly skipping entire files or row groups: + + ```python + >>> filters = [("col_0", "==", 0)] + >>> ds = load_dataset(parquet_dataset_id, streaming=True, filters=filters) + ``` + + Increase the minimum request size when streaming from 32MiB (default) to 128MiB and enable prefetching: + + ```python + >>> import pyarrow + >>> import pyarrow.dataset + >>> fragment_scan_options = pyarrow.dataset.ParquetFragmentScanOptions( + ... cache_options=pyarrow.CacheOptions( + ... prefetch_limit=1, + ... range_size_limit=128 << 20 + ... ), + ... ) + >>> ds = load_dataset(parquet_dataset_id, streaming=True, fragment_scan_options=fragment_scan_options) + ``` + + """ batch_size: Optional[int] = None columns: Optional[list[str]] = None features: Optional[datasets.Features] = None filters: Optional[Union[ds.Expression, list[tuple], list[list[tuple]]]] = None + fragment_scan_options: Optional[ds.ParquetFragmentScanOptions] = None + on_bad_files: Literal["error", "warn", "skip"] = "error" def __post_init__(self): super().__post_init__() @@ -56,9 +117,22 @@ def _split_generators(self, dl_manager): # Infer features if they are stored in the arrow schema if self.info.features is None: for file in itertools.chain.from_iterable(files): - with open(file, "rb") as f: - self.info.features = datasets.Features.from_arrow_schema(pq.read_schema(f)) - break + try: + with open(file, "rb") as f: + self.info.features = datasets.Features.from_arrow_schema(pq.read_schema(f)) + break + except pa.ArrowInvalid as e: + if self.config.on_bad_files == "error": + logger.error(f"Failed to read schema from '{file}' with error {type(e).__name__}: {e}") + raise + elif self.config.on_bad_files == "warn": + logger.warning(f"Skipping bad schema from '{file}'. {type(e).__name__}: {e}`") + else: + logger.debug(f"Skipping bad schema from '{file}'. {type(e).__name__}: {e}`") + if self.info.features is None: + raise ValueError( + f"At least one valid data file must be specified, all the data_files are invalid: {self.config.data_files}" + ) splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files})) if self.config.columns is not None and set(self.config.columns) != set(self.info.features): self.info.features = datasets.Features( @@ -84,12 +158,13 @@ def _generate_tables(self, files): if isinstance(self.config.filters, list) else self.config.filters ) + parquet_file_format = ds.ParquetFileFormat(default_fragment_scan_options=self.config.fragment_scan_options) for file_idx, file in enumerate(itertools.chain.from_iterable(files)): - with open(file, "rb") as f: - parquet_fragment = ds.ParquetFileFormat().make_fragment(f) - if parquet_fragment.row_groups: - batch_size = self.config.batch_size or parquet_fragment.row_groups[0].num_rows - try: + try: + with open(file, "rb") as f: + parquet_fragment = parquet_file_format.make_fragment(f) + if parquet_fragment.row_groups: + batch_size = self.config.batch_size or parquet_fragment.row_groups[0].num_rows for batch_idx, record_batch in enumerate( parquet_fragment.to_batches( batch_size=batch_size, @@ -104,6 +179,11 @@ def _generate_tables(self, files): # logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}") # logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows))) yield f"{file_idx}_{batch_idx}", self._cast_table(pa_table) - except ValueError as e: - logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}") - raise + except (pa.ArrowInvalid, ValueError) as e: + if self.config.on_bad_files == "error": + logger.error(f"Failed to read file '{file}' with error {type(e).__name__}: {e}") + raise + elif self.config.on_bad_files == "warn": + logger.warning(f"Skipping bad file '{file}'. {type(e).__name__}: {e}`") + else: + logger.debug(f"Skipping bad file '{file}'. {type(e).__name__}: {e}`") diff --git a/src/datasets/utils/_dill.py b/src/datasets/utils/_dill.py index fad95f7edf5..f3a4baba681 100644 --- a/src/datasets/utils/_dill.py +++ b/src/datasets/utils/_dill.py @@ -69,9 +69,7 @@ def save(self, obj, save_persistent_id=True): obj = getattr(obj, "_torchdynamo_orig_callable", obj) dill.Pickler.save(self, obj, save_persistent_id=save_persistent_id) - def _batch_setitems(self, items): - if self._legacy_no_dict_keys_sorting: - return super()._batch_setitems(items) + def _batch_setitems(self, items, *args, **kwargs): # Ignore the order of keys in a dict try: # Faster, but fails for unorderable elements @@ -80,7 +78,7 @@ def _batch_setitems(self, items): from datasets.fingerprint import Hasher items = sorted(items, key=lambda x: Hasher.hash(x[0])) - dill.Pickler._batch_setitems(self, items) + return super()._batch_setitems(items, *args, **kwargs) def memoize(self, obj): # Don't memoize strings since two identical strings can have different Python ids diff --git a/src/datasets/utils/file_utils.py b/src/datasets/utils/file_utils.py index 7a07f8cd267..b57a3784547 100644 --- a/src/datasets/utils/file_utils.py +++ b/src/datasets/utils/file_utils.py @@ -61,6 +61,16 @@ class _AiohttpClientError(Exception): T = TypeVar("T", str, Path) +CONNECTION_ERRORS_TO_RETRY = ( + _AiohttpClientError, + asyncio.TimeoutError, + requests.exceptions.ConnectionError, + requests.exceptions.Timeout, + httpx.RequestError, +) +SERVER_UNAVAILABLE_CODE = 504 +RATE_LIMIT_CODE = 429 + def is_remote_url(url_or_filename: str) -> bool: return urlparse(url_or_filename).scheme != "" and not os.path.ismount(urlparse(url_or_filename).scheme + ":/") @@ -813,18 +823,28 @@ def read_with_retries(*args, **kwargs): try: out = read(*args, **kwargs) break - except ( - _AiohttpClientError, - asyncio.TimeoutError, - requests.exceptions.ConnectionError, - requests.exceptions.Timeout, - httpx.RequestError, - ) as err: + except CONNECTION_ERRORS_TO_RETRY as err: disconnect_err = err logger.warning( f"Got disconnected from remote data host. Retrying in {config.STREAMING_READ_RETRY_INTERVAL}sec [{retry}/{max_retries}]" ) time.sleep(config.STREAMING_READ_RETRY_INTERVAL) + except huggingface_hub.errors.HfHubHTTPError as err: + if err.response is not None and err.response.status_code == SERVER_UNAVAILABLE_CODE: + disconnect_err = err + logger.warning( + f"Got disconnected from remote data host. Retrying in {config.STREAMING_READ_SERVER_UNAVAILABLE_RETRY_INTERVAL}sec [{retry}/{max_retries}]" + ) + time.sleep(config.STREAMING_READ_SERVER_UNAVAILABLE_RETRY_INTERVAL) + elif err.response is not None and err.response.status_code == RATE_LIMIT_CODE: + disconnect_err = err + logger.warning(str(err)) + logger.warning( + f"Got disconnected from remote data host. Retrying in {config.STREAMING_READ_RATE_LIMIT_RETRY_INTERVAL}sec [{retry}/{max_retries}]" + ) + time.sleep(config.STREAMING_READ_RATE_LIMIT_RETRY_INTERVAL) + else: + raise else: raise ConnectionError("Server Disconnected") from disconnect_err return out @@ -895,8 +915,8 @@ def _prepare_single_hop_path_and_storage_options( storage_options["headers"] = {"Accept-Encoding": "identity", **headers} elif protocol == "hf": storage_options = { - "token": token, "endpoint": config.HF_ENDPOINT, + "token": token, **storage_options, } if storage_options: @@ -930,23 +950,37 @@ def xopen(file: str, mode="r", *args, download_config: Optional[DownloadConfig] # add headers and cookies for authentication on the HF Hub and for Google Drive file, storage_options = _prepare_path_and_storage_options(file_str, download_config=download_config) kwargs = {**kwargs, **(storage_options or {})} - try: - file_obj = fsspec.open(file, mode=mode, *args, **kwargs).open() - except ValueError as e: - if str(e) == "Cannot seek streaming HTTP file": - raise NonStreamableDatasetError( - "Streaming is not possible for this dataset because data host server doesn't support HTTP range " - "requests. You can still load this dataset in non-streaming mode by passing `streaming=False` (default)" - ) from e - else: - raise - except FileNotFoundError: - if file.startswith(config.HF_ENDPOINT): - raise FileNotFoundError( - file + "\nIf the repo is private or gated, make sure to log in with `huggingface-cli login`." - ) from None - else: - raise + + max_retries = config.STREAMING_OPEN_MAX_RETRIES + + disconnect_err = None + for retry in range(1, max_retries + 1): + try: + file_obj = fsspec.open(file, mode=mode, *args, **kwargs).open() + break + except CONNECTION_ERRORS_TO_RETRY as err: + disconnect_err = err + logger.warning( + f"Failed to connect to remote data host. Retrying in {config.STREAMING_OPEN_RETRY_INTERVAL}sec [{retry}/{max_retries}]" + ) + time.sleep(config.STREAMING_OPEN_RETRY_INTERVAL) + except ValueError as e: + if str(e) == "Cannot seek streaming HTTP file": + raise NonStreamableDatasetError( + "Streaming is not possible for this dataset because data host server doesn't support HTTP range " + "requests. You can still load this dataset in non-streaming mode by passing `streaming=False` (default)" + ) from e + else: + raise + except FileNotFoundError: + if file.startswith(config.HF_ENDPOINT): + raise FileNotFoundError( + file + "\nIf the repo is private or gated, make sure to log in with `huggingface-cli login`." + ) from None + else: + raise + else: + raise ConnectionError("Server Disconnected") from disconnect_err file_obj = _add_retries_to_file_obj_read_method(file_obj) return file_obj diff --git a/src/datasets/utils/patching.py b/src/datasets/utils/patching.py index f245cabd970..69563f562e4 100644 --- a/src/datasets/utils/patching.py +++ b/src/datasets/utils/patching.py @@ -28,7 +28,7 @@ class patch_submodule: >>> from datasets.load import dataset_module_factory >>> from datasets.streaming import patch_submodule, xjoin >>> - >>> dataset_module = dataset_module_factory("snli") + >>> dataset_module = dataset_module_factory("stanfordnlp/snli") >>> snli_module = importlib.import_module(dataset_module.module_path) >>> patcher = patch_submodule(snli_module, "os.path.join", xjoin) >>> patcher.start() diff --git a/templates/README_guide.md b/templates/README_guide.md index 8be42708543..d8e7173c84f 100644 --- a/templates/README_guide.md +++ b/templates/README_guide.md @@ -163,7 +163,7 @@ Also describe in this section if the proposed dataset contains a low-resource or Provide descriptions of specific biases that are likely to be reflected in the data, and state whether any steps were taken to reduce their impact. -For Wikipedia text, see for example [Dinan et al 2020 on biases in Wikipedia (esp. Table 1)](https://arxiv.org/abs/2005.00614), or [Blodgett et al 2020](https://www.aclweb.org/anthology/2020.acl-main.485/) for a more general discussion of the topic. +For Wikipedia text, see for example [Dinan et al 2020 on biases in Wikipedia (esp. Table 1)](https://huggingface.co/papers/2005.00614), or [Blodgett et al 2020](https://www.aclweb.org/anthology/2020.acl-main.485/) for a more general discussion of the topic. If analyses have been run quantifying these biases, please add brief summaries and links to the studies here. diff --git a/tests/features/data/test_nifti.nii b/tests/features/data/test_nifti.nii new file mode 100644 index 00000000000..c1d560658c6 Binary files /dev/null and b/tests/features/data/test_nifti.nii differ diff --git a/tests/features/data/test_nifti.nii.gz b/tests/features/data/test_nifti.nii.gz new file mode 100644 index 00000000000..d5683901665 Binary files /dev/null and b/tests/features/data/test_nifti.nii.gz differ diff --git a/tests/features/test_audio.py b/tests/features/test_audio.py index aa5b2fcda94..a6dbca799fe 100644 --- a/tests/features/test_audio.py +++ b/tests/features/test_audio.py @@ -713,6 +713,7 @@ def test_dataset_with_audio_feature_loaded_from_cache(): assert isinstance(ds, Dataset) +@require_torchcodec def test_dataset_with_audio_feature_undecoded(shared_datadir): audio_path = str(shared_datadir / "test_audio_44100.wav") data = {"audio": [audio_path]} @@ -730,6 +731,7 @@ def test_dataset_with_audio_feature_undecoded(shared_datadir): assert column[0] == {"path": audio_path, "bytes": None} +@require_torchcodec def test_formatted_dataset_with_audio_feature_undecoded(shared_datadir): audio_path = str(shared_datadir / "test_audio_44100.wav") data = {"audio": [audio_path]} @@ -761,6 +763,7 @@ def test_formatted_dataset_with_audio_feature_undecoded(shared_datadir): assert column[0] == {"path": audio_path, "bytes": None} +@require_torchcodec def test_dataset_with_audio_feature_map_undecoded(shared_datadir): audio_path = str(shared_datadir / "test_audio_44100.wav") data = {"audio": [audio_path]} @@ -786,3 +789,31 @@ def test_audio_embed_storage(shared_datadir): embedded_storage = Audio().embed_storage(storage) embedded_example = embedded_storage.to_pylist()[0] assert embedded_example == {"bytes": open(audio_path, "rb").read(), "path": "test_audio_44100.wav"} + + +@require_torchcodec +def test_audio_decode_example_opus_convert_to_stereo(shared_datadir): + # GH 7837 + from torchcodec.decoders import AudioDecoder + + audio_path = str(shared_datadir / "test_audio_48000.opus") # mono file + audio = Audio(num_channels=2) + decoded_example = audio.decode_example(audio.encode_example(audio_path)) + assert isinstance(decoded_example, AudioDecoder) + samples = decoded_example.get_all_samples() + assert samples.sample_rate == 48000 + assert samples.data.shape == (2, 48000) + + +@require_torchcodec +def test_audio_decode_example_opus_convert_to_mono(shared_datadir): + # GH 7837 + from torchcodec.decoders import AudioDecoder + + audio_path = str(shared_datadir / "test_audio_44100.wav") # stereo file + audio = Audio(num_channels=1) + decoded_example = audio.decode_example(audio.encode_example(audio_path)) + assert isinstance(decoded_example, AudioDecoder) + samples = decoded_example.get_all_samples() + assert samples.sample_rate == 44100 + assert samples.data.shape == (1, 202311) diff --git a/tests/features/test_image.py b/tests/features/test_image.py index 68e6f4b91cc..136b7ee9f6b 100644 --- a/tests/features/test_image.py +++ b/tests/features/test_image.py @@ -320,6 +320,18 @@ def test_dataset_cast_to_image_features(shared_datadir, build_data): assert isinstance(item["image"], PIL.Image.Image) +def test_dataset_cast_to_image_features_polars(shared_datadir): + import PIL.Image + + pl = pytest.importorskip("polars") + image_path = str(shared_datadir / "test_image_rgb.jpg") + df = pl.DataFrame({"image_path": [image_path]}) + dataset = Dataset.from_polars(df) + item = dataset.cast_column("image_path", Image())[0] + assert item.keys() == {"image_path"} + assert isinstance(item["image_path"], PIL.Image.Image) + + @require_pil def test_dataset_concatenate_image_features(shared_datadir): # we use a different data structure between 1 and 2 to make sure they are compatible with each other diff --git a/tests/features/test_nifti.py b/tests/features/test_nifti.py new file mode 100644 index 00000000000..527a5083c3e --- /dev/null +++ b/tests/features/test_nifti.py @@ -0,0 +1,149 @@ +## taken from: https://github.com/yarikoptic/nitest-balls1/blob/2cd07d86e2cc2d3c612d5d4d659daccd7a58f126/NIFTI/T1.nii.gz + +from pathlib import Path + +import pyarrow as pa +import pytest + +from datasets import Dataset, Features, Nifti, load_dataset +from src.datasets.features.nifti import encode_nibabel_image + +from ..utils import require_nibabel + + +@require_nibabel +@pytest.mark.parametrize("nifti_file", ["test_nifti.nii", "test_nifti.nii.gz"]) +@pytest.mark.parametrize( + "build_example", + [ + lambda nifti_path: nifti_path, + lambda nifti_path: Path(nifti_path), + lambda nifti_path: open(nifti_path, "rb").read(), + lambda nifti_path: {"path": nifti_path}, + lambda nifti_path: {"path": nifti_path, "bytes": None}, + lambda nifti_path: {"path": nifti_path, "bytes": open(nifti_path, "rb").read()}, + lambda nifti_path: {"path": None, "bytes": open(nifti_path, "rb").read()}, + lambda nifti_path: {"bytes": open(nifti_path, "rb").read()}, + ], +) +def test_nifti_feature_encode_example(shared_datadir, nifti_file, build_example): + import nibabel + + nifti_path = str(shared_datadir / nifti_file) + nifti = Nifti() + encoded_example = nifti.encode_example(build_example(nifti_path)) + assert isinstance(encoded_example, dict) + assert encoded_example.keys() == {"bytes", "path"} + assert encoded_example["bytes"] is not None or encoded_example["path"] is not None + decoded_example = nifti.decode_example(encoded_example) + assert isinstance(decoded_example, nibabel.nifti1.Nifti1Image) + + +@require_nibabel +@pytest.mark.parametrize("nifti_file", ["test_nifti.nii", "test_nifti.nii.gz"]) +def test_dataset_with_nifti_feature(shared_datadir, nifti_file): + import nibabel + + nifti_path = str(shared_datadir / nifti_file) + data = {"nifti": [nifti_path]} + features = Features({"nifti": Nifti()}) + dset = Dataset.from_dict(data, features=features) + item = dset[0] + assert item.keys() == {"nifti"} + assert isinstance(item["nifti"], nibabel.nifti1.Nifti1Image) + batch = dset[:1] + assert len(batch) == 1 + assert batch.keys() == {"nifti"} + assert isinstance(batch["nifti"], list) and all( + isinstance(item, nibabel.nifti1.Nifti1Image) for item in batch["nifti"] + ) + column = dset["nifti"] + assert len(column) == 1 + assert all(isinstance(item, nibabel.nifti1.Nifti1Image) for item in column) + + # from bytes + with open(nifti_path, "rb") as f: + data = {"nifti": [f.read()]} + dset = Dataset.from_dict(data, features=features) + item = dset[0] + assert item.keys() == {"nifti"} + assert isinstance(item["nifti"], nibabel.nifti1.Nifti1Image) + + +@require_nibabel +def test_encode_nibabel_image(shared_datadir): + import nibabel + + nifti_path = str(shared_datadir / "test_nifti.nii") + img = nibabel.load(nifti_path) + encoded_example = encode_nibabel_image(img) + nifti = Nifti() + assert isinstance(encoded_example, dict) + assert encoded_example.keys() == {"bytes", "path"} + assert encoded_example["path"] is not None and encoded_example["bytes"] is None + decoded_example = nifti.decode_example(encoded_example) + assert isinstance(decoded_example, nibabel.nifti1.Nifti1Image) + + # test bytes only + img.file_map = None + encoded_example_bytes = encode_nibabel_image(img) + assert isinstance(encoded_example_bytes, dict) + assert encoded_example_bytes["bytes"] is not None and encoded_example_bytes["path"] is None + # this cannot be converted back from bytes (yet) + + +@require_nibabel +def test_embed_storage(shared_datadir): + from io import BytesIO + + import nibabel as nib + + nifti_path = str(shared_datadir / "test_nifti.nii") + img = nib.load(nifti_path) + nifti = Nifti() + + bytes_array = pa.array([None], type=pa.binary()) + path_array = pa.array([nifti_path], type=pa.string()) + storage = pa.StructArray.from_arrays([bytes_array, path_array], ["bytes", "path"]) + + embedded_storage = nifti.embed_storage(storage) + + embedded_bytes = embedded_storage[0]["bytes"].as_py() + + bio = BytesIO(embedded_bytes) + fh = nib.FileHolder(fileobj=bio) + nifti_img = nib.Nifti1Image.from_file_map({"header": fh, "image": fh}) + + assert embedded_bytes is not None + assert nifti_img.header == img.header + assert (nifti_img.affine == img.affine).all() + assert (nifti_img.get_fdata() == img.get_fdata()).all() + + +@require_nibabel +def test_load_zipped_file_locally(shared_datadir): + import nibabel as nib + + nifti_path = str(shared_datadir / "test_nifti.nii.gz") + + ds = load_dataset("niftifolder", data_files=nifti_path) + assert isinstance(ds["train"][0]["nifti"], nib.nifti1.Nifti1Image) + + +@require_nibabel +def test_nifti_lazy_loading(shared_datadir): + import nibabel as nib + import numpy as np + + nifti_path = str(shared_datadir / "test_nifti.nii.gz") + nifti = Nifti() + encoded_example = nifti.encode_example(nifti_path) + decoded_example = nifti.decode_example(encoded_example) + + # Verify that the data object is an ArrayProxy (lazy) and not a numpy array (dense) + assert nib.is_proxy(decoded_example.dataobj) + assert not isinstance(decoded_example.dataobj, np.ndarray) + + # Verify that we can still access the data + data = decoded_example.get_fdata() + assert data.shape == (80, 80, 10) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 4661e8c6dd7..8e76952d6ca 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -4114,6 +4114,16 @@ def test_dataset_from_generator_split(split, data_generator, tmp_path): _check_generator_dataset(dataset, expected_features, expected_split) +@pytest.mark.parametrize("fingerprint", [None, "test-dataset"]) +def test_dataset_from_generator_fingerprint(fingerprint, data_generator, tmp_path): + cache_dir = tmp_path / "cache" + expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + dataset = Dataset.from_generator(data_generator, cache_dir=cache_dir, fingerprint=fingerprint) + _check_generator_dataset(dataset, expected_features, NamedSplit("train")) + if fingerprint: + assert dataset._fingerprint == fingerprint + + @require_not_windows @require_dill_gt_0_3_2 @require_pyspark diff --git a/tests/test_download_manager.py b/tests/test_download_manager.py index 08eb77366c1..457bd9de49b 100644 --- a/tests/test_download_manager.py +++ b/tests/test_download_manager.py @@ -131,7 +131,7 @@ def test_download_manager_delete_extracted_files(xz_file): assert extracted_path == dl_manager.extracted_paths[xz_file] extracted_path = Path(extracted_path) parts = extracted_path.parts - # import pdb; pdb.set_trace() + assert parts[-1] == hash_url_to_filename(str(xz_file), etag=None) assert parts[-2] == extracted_subdir assert extracted_path.exists() diff --git a/tests/test_extract.py b/tests/test_extract.py index 186d65fd0ba..489e5efa586 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -1,5 +1,4 @@ import os -import zipfile import pytest @@ -199,5 +198,5 @@ def test_is_zipfile_false_positive(tmpdir): ) with not_a_zip_file.open("wb") as f: f.write(data) - assert zipfile.is_zipfile(str(not_a_zip_file)) # is a false positive for `zipfile` + # zipfile.is_zipfile(str(not_a_zip_file)) could be a false positive for `zipfile` assert not ZipExtractor.is_extractable(not_a_zip_file) # but we're right diff --git a/tests/test_filesystem.py b/tests/test_filesystem.py index aef0dfc2a89..63f627b72cc 100644 --- a/tests/test_filesystem.py +++ b/tests/test_filesystem.py @@ -1,9 +1,7 @@ -import importlib import os import fsspec import pytest -from fsspec import register_implementation from fsspec.core import url_to_fs from fsspec.registry import _registry as _fsspec_registry @@ -44,7 +42,6 @@ def test_compression_filesystems(compression_fs_class, gz_file, bz2_file, lz4_fi reason += require_zstandard.kwargs["reason"] pytest.skip(reason) fs = fsspec.filesystem(compression_fs_class.protocol, fo=input_path) - assert isinstance(fs, compression_fs_class) expected_filename = os.path.basename(input_path) expected_filename = expected_filename[: expected_filename.rindex(".")] assert fs.glob("*") == [expected_filename] @@ -61,21 +58,3 @@ def test_fs_isfile(protocol, zip_jsonl_path, jsonl_gz_path): fs, *_ = url_to_fs(path) assert fs.isfile(member_file_path) assert not fs.isfile("non_existing_" + member_file_path) - - -def test_fs_overwrites(): - protocol = "bz2" - - # Import module - import datasets.filesystems - - # Overwrite protocol and reload - register_implementation(protocol, None, clobber=True) - with pytest.warns(UserWarning) as warning_info: - importlib.reload(datasets.filesystems) - - assert len(warning_info) == 1 - assert ( - str(warning_info[0].message) - == f"A filesystem protocol was already set for {protocol} and will be overwritten." - ) diff --git a/tests/test_fingerprint.py b/tests/test_fingerprint.py index 0b7a45458bd..e3ca7464b16 100644 --- a/tests/test_fingerprint.py +++ b/tests/test_fingerprint.py @@ -26,6 +26,7 @@ require_spacy, require_tiktoken, require_torch, + require_torch_compile, require_transformers, ) @@ -347,7 +348,7 @@ def test_hash_spacy_model(self): self.assertNotEqual(hash1, hash2) @require_not_windows - @require_torch + @require_torch_compile def test_hash_torch_compiled_function(self): import torch @@ -360,7 +361,7 @@ def f(x): self.assertEqual(hash1, hash2) @require_not_windows - @require_torch + @require_torch_compile def test_hash_torch_compiled_module(self): m = TorchModule() next(iter(m.parameters())).data.fill_(1.0) diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 1bca866bdf8..bdfa60fdc01 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -1553,6 +1553,82 @@ def test_iterable_dataset_from_hub_torch_dataloader_parallel(num_workers, tmp_pa assert len(result) == 10 +def gen_with_worker_info(shard): + from torch.utils.data import get_worker_info + + worker_info = get_worker_info() + for i in range(100): + yield {"value": i, "worker_id": worker_info.id} + + +@require_torch +def test_iterable_dataset_shuffle_with_multiple_workers_different_rng(): + # GH 7567 + from torch.utils.data import DataLoader + + num_workers = 20 + ds = IterableDataset.from_generator(gen_with_worker_info, gen_kwargs={"shard": list(range(num_workers))}) + ds = ds.shuffle(buffer_size=100, seed=1234) + dataloader = DataLoader(ds, batch_size=None, num_workers=num_workers) + + result = list(dataloader) + for single_chunk in [result[x : x + num_workers] for x in range(0, len(result), num_workers)]: + values = [item["value"] for item in single_chunk] + # This will fail with the chance 1/100 ** 20! + assert len(set(values)) != 1, "Make sure not all values are identical" + + +def gen_with_value(shard, value): + for i in range(100): + yield {"value": value} + + +@require_torch +def test_iterable_dataset_interleave_dataset_with_multiple_workers(): + # GH 7567 + from torch.utils.data import DataLoader + + num_workers = 20 + ds = [ + IterableDataset.from_generator(gen_with_value, gen_kwargs={"shard": list(range(num_workers)), "value": i}) + for i in range(10) + ] + ds = interleave_datasets(ds, probabilities=[1 / len(ds)] * len(ds), seed=1234) + dataloader = DataLoader(ds, batch_size=None, num_workers=num_workers) + + result = list(dataloader) + for single_chunk in [result[x : x + num_workers] for x in range(0, len(result), num_workers)]: + values = [item["value"] for item in single_chunk] + assert len(set(values)) != 1, "Make sure not all values are identical" + + +def gen_with_id(shard, value): + for i in range(50): + yield {"value": value, "id": i} + + +@require_torch +def test_iterable_dataset_interleave_dataset_deterministic_across_iterations(): + # GH 7567 + from torch.utils.data import DataLoader + + num_workers = 10 + ds = [ + IterableDataset.from_generator(gen_with_id, gen_kwargs={"shard": list(range(num_workers)), "value": i}) + for i in range(5) + ] + ds = interleave_datasets(ds, probabilities=[1 / len(ds)] * len(ds), seed=1234) + dataloader = DataLoader(ds, batch_size=None, num_workers=num_workers) + + # First iteration + first_result = list(dataloader) + + # Second iteration + second_result = list(dataloader) + + assert first_result == second_result, "Results should be identical across iterations when using same seed" + + @pytest.mark.parametrize("batch_size", [4, 5]) @pytest.mark.parametrize("drop_last_batch", [False, True]) def test_iterable_dataset_iter_batch(batch_size, drop_last_batch): diff --git a/tests/test_metadata_util.py b/tests/test_metadata_util.py index b6b45e1812f..cf9111fa6d9 100644 --- a/tests/test_metadata_util.py +++ b/tests/test_metadata_util.py @@ -282,7 +282,7 @@ def test_split_order_in_metadata_configs_from_exported_parquet_files_and_dataset "dataset": "AI-Lab-Makerere/beans", "config": "default", "split": "test", - "url": "https://huggingface.co/datasets/beans/resolve/refs%2Fconvert%2Fparquet/default/test/0000.parquet", + "url": "https://huggingface.co/datasets/AI-Lab-Makerere/beans/resolve/refs%2Fconvert%2Fparquet/default/test/0000.parquet", "filename": "0000.parquet", "size": 17707203, }, @@ -290,7 +290,7 @@ def test_split_order_in_metadata_configs_from_exported_parquet_files_and_dataset "dataset": "AI-Lab-Makerere/beans", "config": "default", "split": "train", - "url": "https://huggingface.co/datasets/beans/resolve/refs%2Fconvert%2Fparquet/default/train/0000.parquet", + "url": "https://huggingface.co/datasets/AI-Lab-Makerere/beans/resolve/refs%2Fconvert%2Fparquet/default/train/0000.parquet", "filename": "0000.parquet", "size": 143780164, }, @@ -298,7 +298,7 @@ def test_split_order_in_metadata_configs_from_exported_parquet_files_and_dataset "dataset": "AI-Lab-Makerere/beans", "config": "default", "split": "validation", - "url": "https://huggingface.co/datasets/beans/resolve/refs%2Fconvert%2Fparquet/default/validation/0000.parquet", + "url": "https://huggingface.co/datasets/AI-Lab-Makerere/beans/resolve/refs%2Fconvert%2Fparquet/default/validation/0000.parquet", "filename": "0000.parquet", "size": 18500862, }, @@ -332,15 +332,15 @@ def test_split_order_in_metadata_configs_from_exported_parquet_files_and_dataset }, }, download_checksums={ - "https://huggingface.co/datasets/beans/resolve/main/data/train.zip": { + "https://huggingface.co/datasets/AI-Lab-Makerere/beans/resolve/main/data/train.zip": { "num_bytes": 143812152, "checksum": None, }, - "https://huggingface.co/datasets/beans/resolve/main/data/validation.zip": { + "https://huggingface.co/datasets/AI-Lab-Makerere/beans/resolve/main/data/validation.zip": { "num_bytes": 18504213, "checksum": None, }, - "https://huggingface.co/datasets/beans/resolve/main/data/test.zip": { + "https://huggingface.co/datasets/AI-Lab-Makerere/beans/resolve/main/data/test.zip": { "num_bytes": 17708541, "checksum": None, }, diff --git a/tests/test_py_utils.py b/tests/test_py_utils.py index d3e7795bf9d..aad95f74a59 100644 --- a/tests/test_py_utils.py +++ b/tests/test_py_utils.py @@ -1,4 +1,5 @@ import os +import pickle import time from dataclasses import dataclass from multiprocessing import Pool @@ -81,7 +82,7 @@ def test_map_nested(self): {k: v.tolist() for k, v in map_nested(int, sn1, map_numpy=True, num_proc=num_proc).items()}, {k: v.tolist() for k, v in expected_map_nested_sn1_int.items()}, ) - with self.assertRaises(AttributeError): # can't pickle a local lambda + with self.assertRaises((AttributeError, pickle.PicklingError)): # can't pickle a local lambda map_nested(lambda x: x + 1, sn1, num_proc=num_proc) def test_zip_dict(self): diff --git a/tests/test_streaming_download_manager.py b/tests/test_streaming_download_manager.py index d569637fdad..1fc53502ba6 100644 --- a/tests/test_streaming_download_manager.py +++ b/tests/test_streaming_download_manager.py @@ -1,5 +1,6 @@ import json import os +from pathlib import Path import pytest @@ -26,10 +27,16 @@ Bulbasaur, grass""" -@pytest.mark.parametrize("urlpath", [r"C:\\foo\bar.txt", "/foo/bar.txt", "https://f.oo/bar.txt"]) -def test_streaming_dl_manager_download_dummy_path(urlpath): +def test_streaming_dl_manager_download_dummy_path(): + path = str(Path(__file__).resolve()) dl_manager = StreamingDownloadManager() - assert dl_manager.download(urlpath) == urlpath + assert dl_manager.download(path) == path + + +def test_streaming_dl_manager_download_dummy_url(): + url = "https://f.oo/bar.txt" + dl_manager = StreamingDownloadManager() + assert dl_manager.download(url) == url @pytest.mark.parametrize( @@ -54,10 +61,16 @@ def test_streaming_dl_manager_download(text_path): assert f.read() == expected_file.read() -@pytest.mark.parametrize("urlpath", [r"C:\\foo\bar.txt", "/foo/bar.txt", "https://f.oo/bar.txt"]) -def test_streaming_dl_manager_download_and_extract_no_extraction(urlpath): +def test_streaming_dl_manager_download_and_extract_no_extraction_dummy_path(): + path = str(Path(__file__).resolve()) + dl_manager = StreamingDownloadManager() + assert dl_manager.download_and_extract(path) == path + + +def test_streaming_dl_manager_download_and_extract_no_extraction_dummy_url(): + url = "https://f.oo/bar.txt" dl_manager = StreamingDownloadManager() - assert dl_manager.download_and_extract(urlpath) == urlpath + assert dl_manager.download_and_extract(url) == url def test_streaming_dl_manager_extract(text_gz_path, text_path): diff --git a/tests/utils.py b/tests/utils.py index 166bd4789c2..1980cf3e257 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -125,6 +125,20 @@ def require_torch(test_case): return test_case +def require_torch_compile(test_case): + """ + Decorator marking a test that requires PyTorch. + + These tests are skipped when PyTorch isn't installed. + + """ + if not config.TORCH_AVAILABLE: + test_case = unittest.skip("test requires PyTorch")(test_case) + if config.PY_VERSION >= version.parse("3.14"): + test_case = unittest.skip("test requires torch compile which isn't available in python 3.14")(test_case) + return test_case + + def require_polars(test_case): """ Decorator marking a test that requires Polars. @@ -209,6 +223,18 @@ def require_pdfplumber(test_case): return test_case +def require_nibabel(test_case): + """ + Decorator marking a test that requires nibabel. + + These tests are skipped when nibabel isn't installed. + + """ + if not config.NIBABEL_AVAILABLE: + test_case = unittest.skip("test requires nibabel")(test_case) + return test_case + + def require_transformers(test_case): """ Decorator marking a test that requires transformers.