diff --git a/.github/workflows/benchmarks.yaml b/.github/workflows/benchmarks.yaml index 6e8971b5c48..bb336a32019 100644 --- a/.github/workflows/benchmarks.yaml +++ b/.github/workflows/benchmarks.yaml @@ -16,8 +16,8 @@ jobs: pip install setuptools wheel pip install -e .[benchmarks] - # pyarrow==1.0.0 - pip install pyarrow==1.0.0 + # pyarrow==3.0.0 + pip install pyarrow==3.0.0 dvc repro --force @@ -26,7 +26,7 @@ jobs: python ./benchmarks/format.py report.json report.md - echo "
\nShow benchmarks\n\nPyArrow==1.0.0\n" > final_report.md + echo "
\nShow benchmarks\n\nPyArrow==3.0.0\n" > final_report.md cat report.md >> final_report.md # pyarrow diff --git a/datasets/parquet/dummy/0.0.0/dummy_data.zip b/datasets/parquet/dummy/0.0.0/dummy_data.zip new file mode 100644 index 00000000000..152bb8a720a Binary files /dev/null and b/datasets/parquet/dummy/0.0.0/dummy_data.zip differ diff --git a/docs/source/package_reference/main_classes.rst b/docs/source/package_reference/main_classes.rst index c3e34ee0395..3b7d9a493b9 100644 --- a/docs/source/package_reference/main_classes.rst +++ b/docs/source/package_reference/main_classes.rst @@ -25,16 +25,15 @@ The base class :class:`datasets.Dataset` implements a Dataset backed by an Apach __getitem__, cleanup_cache_files, map, filter, select, sort, shuffle, train_test_split, shard, export, save_to_disk, load_from_disk, flatten_indices, - to_csv, to_pandas, to_dict, + to_csv, to_pandas, to_dict, to_json, to_parquet, add_faiss_index, add_faiss_index_from_external_arrays, save_faiss_index, load_faiss_index, add_elasticsearch_index, load_elasticsearch_index, list_indexes, get_index, drop_index, search, search_batch, get_nearest_examples, get_nearest_examples_batch, info, split, builder_name, citation, config_name, dataset_size, description, download_checksums, download_size, features, homepage, license, size_in_bytes, supervised_keys, version, - from_csv, from_json, from_text, + from_csv, from_json, from_parquet, from_text, prepare_for_task, align_labels_with_mapping, - to_json, .. autofunction:: datasets.concatenate_datasets @@ -56,7 +55,7 @@ It also has dataset transform methods like map or filter, to process all the spl flatten_, cast_, remove_columns_, rename_column_, flatten, cast, remove_columns, rename_column, class_encode_column, save_to_disk, load_from_disk, - from_csv, from_json, from_text, + from_csv, from_json, from_parquet, from_text, prepare_for_task, align_labels_with_mapping diff --git a/setup.py b/setup.py index 22412aaa119..e0a013d343c 100644 --- a/setup.py +++ b/setup.py @@ -73,7 +73,7 @@ # We use numpy>=1.17 to have np.random.Generator (Dataset shuffling) "numpy>=1.17", # Backend and serialization. - # Minimum 1.0.0 to avoid permission errors on windows when using the compute layer on memory mapped data + # Minimum 3.0.0 to support mix of struct and list types in parquet, and batch iterators of parquet data # pyarrow 4.0.0 introduced segfault bug, see: https://github.com/huggingface/datasets/pull/2268 "pyarrow>=1.0.0,!=4.0.0", # For smart caching dataset processing diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 24d8aadfb60..9dbf9387e6b 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -468,7 +468,7 @@ def from_csv( path_or_paths (path-like or list of path-like): Path(s) of the CSV file(s). split (:class:`NamedSplit`, optional): Split name to be assigned to the dataset. features (:class:`Features`, optional): Dataset features. - cache_dir (:obj:`str`, optional, default ``"~/datasets"``): Directory to cache data. + cache_dir (:obj:`str`, optional, default ``"~/.cache/huggingface/datasets"``): Directory to cache data. keep_in_memory (:obj:`bool`, default ``False``): Whether to copy the data in-memory. **kwargs: Keyword arguments to be passed to :meth:`pandas.read_csv`. @@ -498,7 +498,7 @@ def from_json( path_or_paths (path-like or list of path-like): Path(s) of the JSON or JSON Lines file(s). split (:class:`NamedSplit`, optional): Split name to be assigned to the dataset. features (:class:`Features`, optional): Dataset features. - cache_dir (:obj:`str`, optional, default ``"~/datasets"``): Directory to cache data. + cache_dir (:obj:`str`, optional, default ``"~/.cache/huggingface/datasets"``): Directory to cache data. keep_in_memory (:obj:`bool`, default ``False``): Whether to copy the data in-memory. field (:obj:`str`, optional): Field name of the JSON file where the dataset is contained in. **kwargs: Keyword arguments to be passed to :class:`JsonConfig`. @@ -519,6 +519,45 @@ def from_json( **kwargs, ).read() + @staticmethod + def from_parquet( + path_or_paths: Union[PathLike, List[PathLike]], + split: Optional[NamedSplit] = None, + features: Optional[Features] = None, + cache_dir: str = None, + keep_in_memory: bool = False, + columns: Optional[List[str]] = None, + **kwargs, + ): + """Create Dataset from Parquet file(s). + + Args: + path_or_paths (path-like or list of path-like): Path(s) of the Parquet file(s). + split (:class:`NamedSplit`, optional): Split name to be assigned to the dataset. + features (:class:`Features`, optional): Dataset features. + cache_dir (:obj:`str`, optional, default ``"~/.cache/huggingface/datasets"``): Directory to cache data. + keep_in_memory (:obj:`bool`, default ``False``): Whether to copy the data in-memory. + columns (:obj:`List[str]`, optional): If not None, only these columns will be read from the file. + A column name may be a prefix of a nested field, e.g. 'a' will select + 'a.b', 'a.c', and 'a.d.e'. + **kwargs: Keyword arguments to be passed to :class:`ParquetConfig`. + + Returns: + :class:`Dataset` + """ + # Dynamic import to avoid circular dependency + from .io.parquet import ParquetDatasetReader + + return ParquetDatasetReader( + path_or_paths, + split=split, + features=features, + cache_dir=cache_dir, + keep_in_memory=keep_in_memory, + columns=columns, + **kwargs, + ).read() + @staticmethod def from_text( path_or_paths: Union[PathLike, List[PathLike]], @@ -534,7 +573,7 @@ def from_text( path_or_paths (path-like or list of path-like): Path(s) of the text file(s). split (:class:`NamedSplit`, optional): Split name to be assigned to the dataset. features (:class:`Features`, optional): Dataset features. - cache_dir (:obj:`str`, optional, default ``"~/datasets"``): Directory to cache data. + cache_dir (:obj:`str`, optional, default ``"~/.cache/huggingface/datasets"``): Directory to cache data. keep_in_memory (:obj:`bool`, default ``False``): Whether to copy the data in-memory. **kwargs: Keyword arguments to be passed to :class:`TextConfig`. @@ -2862,6 +2901,28 @@ def to_pandas( for offset in range(0, len(self), batch_size) ) + def to_parquet( + self, + path_or_buf: Union[PathLike, BinaryIO], + batch_size: Optional[int] = None, + **parquet_writer_kwargs, + ) -> int: + """Exports the dataset to parquet + + Args: + path_or_buf (``PathLike`` or ``FileOrBuffer``): Either a path to a file or a BinaryIO. + batch_size (Optional ``int``): Size of the batch to load in memory and write at once. + Defaults to :obj:`datasets.config.DEFAULT_MAX_BATCH_SIZE`. + parquet_writer_kwargs: Parameters to pass to PyArrow's :class:`pyarrow.parquet.ParquetWriter` + + Returns: + int: The number of characters or bytes written + """ + # Dynamic import to avoid circular dependency + from .io.parquet import ParquetDatasetWriter + + return ParquetDatasetWriter(self, path_or_buf, batch_size=batch_size, **parquet_writer_kwargs).write() + @transmit_format @fingerprint_transform(inplace=False) def add_column(self, name: str, column: Union[list, np.array], new_fingerprint: str): diff --git a/src/datasets/dataset_dict.py b/src/datasets/dataset_dict.py index 1d80ddb3166..0249e1713c2 100644 --- a/src/datasets/dataset_dict.py +++ b/src/datasets/dataset_dict.py @@ -748,7 +748,7 @@ def from_csv( Args: path_or_paths (dict of path-like): Path(s) of the CSV file(s). features (:class:`Features`, optional): Dataset features. - cache_dir (str, optional, default="~/datasets"): Directory to cache data. + cache_dir (str, optional, default="~/.cache/huggingface/datasets"): Directory to cache data. keep_in_memory (bool, default=False): Whether to copy the data in-memory. **kwargs: Keyword arguments to be passed to :meth:`pandas.read_csv`. @@ -775,7 +775,7 @@ def from_json( Args: path_or_paths (path-like or list of path-like): Path(s) of the JSON Lines file(s). features (:class:`Features`, optional): Dataset features. - cache_dir (str, optional, default="~/datasets"): Directory to cache data. + cache_dir (str, optional, default="~/.cache/huggingface/datasets"): Directory to cache data. keep_in_memory (bool, default=False): Whether to copy the data in-memory. **kwargs: Keyword arguments to be passed to :class:`JsonConfig`. @@ -789,6 +789,42 @@ def from_json( path_or_paths, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs ).read() + @staticmethod + def from_parquet( + path_or_paths: Dict[str, PathLike], + features: Optional[Features] = None, + cache_dir: str = None, + keep_in_memory: bool = False, + columns: Optional[List[str]] = None, + **kwargs, + ) -> "DatasetDict": + """Create DatasetDict from Parquet file(s). + + Args: + path_or_paths (dict of path-like): Path(s) of the CSV file(s). + features (:class:`Features`, optional): Dataset features. + cache_dir (str, optional, default="~/.cache/huggingface/datasets"): Directory to cache data. + keep_in_memory (bool, default=False): Whether to copy the data in-memory. + columns (:obj:`List[str]`, optional): If not None, only these columns will be read from the file. + A column name may be a prefix of a nested field, e.g. 'a' will select + 'a.b', 'a.c', and 'a.d.e'. + **kwargs: Keyword arguments to be passed to :class:`ParquetConfig`. + + Returns: + :class:`DatasetDict` + """ + # Dynamic import to avoid circular dependency + from .io.parquet import ParquetDatasetReader + + return ParquetDatasetReader( + path_or_paths, + features=features, + cache_dir=cache_dir, + keep_in_memory=keep_in_memory, + columns=columns, + **kwargs, + ).read() + @staticmethod def from_text( path_or_paths: Dict[str, PathLike], @@ -802,7 +838,7 @@ def from_text( Args: path_or_paths (dict of path-like): Path(s) of the text file(s). features (:class:`Features`, optional): Dataset features. - cache_dir (str, optional, default="~/datasets"): Directory to cache data. + cache_dir (str, optional, default="~/.cache/huggingface/datasets"): Directory to cache data. keep_in_memory (bool, default=False): Whether to copy the data in-memory. **kwargs: Keyword arguments to be passed to :class:`TextConfig`. diff --git a/src/datasets/io/parquet.py b/src/datasets/io/parquet.py new file mode 100644 index 00000000000..dabc801dce9 --- /dev/null +++ b/src/datasets/io/parquet.py @@ -0,0 +1,111 @@ +import os +from typing import BinaryIO, Optional, Union + +import pyarrow as pa +import pyarrow.parquet as pq +from packaging import version + +from .. import Dataset, Features, NamedSplit, config +from ..formatting import query_table +from ..packaged_modules import _PACKAGED_DATASETS_MODULES +from ..packaged_modules.parquet.parquet import Parquet +from ..utils.typing import NestedDataStructureLike, PathLike +from .abc import AbstractDatasetReader + + +class ParquetDatasetReader(AbstractDatasetReader): + def __init__( + self, + path_or_paths: NestedDataStructureLike[PathLike], + split: Optional[NamedSplit] = None, + features: Optional[Features] = None, + cache_dir: str = None, + keep_in_memory: bool = False, + **kwargs, + ): + if version.parse(pa.__version__) < version.parse("3.0.0"): + raise ImportError( + "PyArrow >= 3.0.0 is required to used the ParquetDatasetReader: pip install --upgrade pyarrow" + ) + super().__init__( + path_or_paths, split=split, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs + ) + path_or_paths = path_or_paths if isinstance(path_or_paths, dict) else {self.split: path_or_paths} + hash = _PACKAGED_DATASETS_MODULES["parquet"][1] + self.builder = Parquet( + cache_dir=cache_dir, + data_files=path_or_paths, + features=features, + hash=hash, + **kwargs, + ) + + def read(self): + download_config = None + download_mode = None + ignore_verifications = False + use_auth_token = None + base_path = None + + self.builder.download_and_prepare( + download_config=download_config, + download_mode=download_mode, + ignore_verifications=ignore_verifications, + # try_from_hf_gcs=try_from_hf_gcs, + base_path=base_path, + use_auth_token=use_auth_token, + ) + + # Build dataset for splits + dataset = self.builder.as_dataset( + split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory + ) + return dataset + + +class ParquetDatasetWriter: + def __init__( + self, + dataset: Dataset, + path_or_buf: Union[PathLike, BinaryIO], + batch_size: Optional[int] = None, + **parquet_writer_kwargs, + ): + if version.parse(pa.__version__) < version.parse("3.0.0"): + raise ImportError( + "PyArrow >= 3.0.0 is required to used the ParquetDatasetWriter: pip install --upgrade pyarrow" + ) + self.dataset = dataset + self.path_or_buf = path_or_buf + self.batch_size = batch_size + self.parquet_writer_kwargs = parquet_writer_kwargs + + def write(self) -> int: + batch_size = self.batch_size if self.batch_size else config.DEFAULT_MAX_BATCH_SIZE + + if isinstance(self.path_or_buf, (str, bytes, os.PathLike)): + with open(self.path_or_buf, "wb+") as buffer: + written = self._write(file_obj=buffer, batch_size=batch_size, **self.parquet_writer_kwargs) + else: + written = self._write(file_obj=self.path_or_buf, batch_size=batch_size, **self.parquet_writer_kwargs) + return written + + def _write(self, file_obj: BinaryIO, batch_size: int, **parquet_writer_kwargs) -> int: + """Writes the pyarrow table as Parquet to a binary file handle. + + Caller is responsible for opening and closing the handle. + """ + written = 0 + _ = parquet_writer_kwargs.pop("path_or_buf", None) + schema = pa.schema(self.dataset.features.type) + writer = pq.ParquetWriter(file_obj, schema=schema, **parquet_writer_kwargs) + + for offset in range(0, len(self.dataset), batch_size): + batch = query_table( + table=self.dataset._data, + key=slice(offset, offset + batch_size), + indices=self.dataset._indices if self.dataset._indices is not None else None, + ) + writer.write_table(batch) + written += batch.nbytes + return written diff --git a/src/datasets/packaged_modules/__init__.py b/src/datasets/packaged_modules/__init__.py index 4befdea6016..afd6518daa2 100644 --- a/src/datasets/packaged_modules/__init__.py +++ b/src/datasets/packaged_modules/__init__.py @@ -6,6 +6,7 @@ from .csv import csv from .json import json from .pandas import pandas +from .parquet import parquet from .text import text @@ -30,5 +31,6 @@ def hash_python_lines(lines: List[str]) -> str: "csv": (csv.__name__, hash_python_lines(inspect.getsource(csv).splitlines())), "json": (json.__name__, hash_python_lines(inspect.getsource(json).splitlines())), "pandas": (pandas.__name__, hash_python_lines(inspect.getsource(pandas).splitlines())), + "parquet": (parquet.__name__, hash_python_lines(inspect.getsource(parquet).splitlines())), "text": (text.__name__, hash_python_lines(inspect.getsource(text).splitlines())), } diff --git a/src/datasets/packaged_modules/parquet/__init__.py b/src/datasets/packaged_modules/parquet/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/datasets/packaged_modules/parquet/parquet.py b/src/datasets/packaged_modules/parquet/parquet.py new file mode 100644 index 00000000000..59311176282 --- /dev/null +++ b/src/datasets/packaged_modules/parquet/parquet.py @@ -0,0 +1,75 @@ +# coding=utf-8 + +from dataclasses import dataclass +from typing import List, Optional + +import pyarrow as pa +import pyarrow.parquet as pq +from packaging import version + +import datasets + + +logger = datasets.utils.logging.get_logger(__name__) + + +@dataclass +class ParquetConfig(datasets.BuilderConfig): + """BuilderConfig for Parquet.""" + + batch_size: int = 10_000 + columns: Optional[List[str]] = None + features: Optional[datasets.Features] = None + + +class Parquet(datasets.ArrowBasedBuilder): + BUILDER_CONFIG_CLASS = ParquetConfig + + def _info(self): + if version.parse(pa.__version__) < version.parse("3.0.0"): + raise ImportError( + "PyArrow >= 3.0.0 is required to used the Parquet dataset builder: pip install --upgrade pyarrow" + ) + return datasets.DatasetInfo(features=self.config.features) + + def _split_generators(self, dl_manager): + """We handle string, list and dicts in datafiles""" + if not self.config.data_files: + raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") + data_files = dl_manager.download_and_extract(self.config.data_files) + if isinstance(data_files, (str, list, tuple)): + files = data_files + if isinstance(files, str): + files = [files] + return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": files})] + splits = [] + for split_name, files in data_files.items(): + if isinstance(files, str): + files = [files] + splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files})) + return splits + + def _generate_tables(self, files): + schema = pa.schema(self.config.features.type) if self.config.features is not None else None + if self.config.features is not None and self.config.columns is not None: + if sorted([field.name for field in schema]) != sorted(self.config.columns): + raise ValueError( + f"Tried to load parquet data with columns '{self.config.columns}' with mismatching features '{self.config.features}'" + ) + for file_idx, file in enumerate(files): + with open(file, "rb") as f: + parquet_file = pq.ParquetFile(f) + try: + for batch_idx, record_batch in enumerate( + parquet_file.iter_batches(batch_size=self.config.batch_size, columns=self.config.columns) + ): + pa_table = pa.Table.from_batches([record_batch]) + if self.config.features is not None: + pa_table = pa.Table.from_arrays([pa_table[field.name] for field in schema], schema=schema) + # Uncomment for debugging (will print the Arrow table size and elements) + # logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}") + # logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows))) + yield f"{file_idx}_{batch_idx}", pa_table + except ValueError as e: + logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}") + raise diff --git a/tests/conftest.py b/tests/conftest.py index f930d986a07..7e4a860d046 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,8 @@ import lzma import textwrap +import pyarrow as pa +import pyarrow.parquet as pq import pytest from datasets.arrow_dataset import Dataset @@ -164,6 +166,24 @@ def csv_path(tmp_path_factory): return path +@pytest.fixture(scope="session") +def parquet_path(tmp_path_factory): + path = str(tmp_path_factory.mktemp("data") / "dataset.parquet") + schema = pa.schema( + { + "col_1": pa.string(), + "col_2": pa.int64(), + "col_3": pa.float64(), + } + ) + with open(path, "wb") as f: + writer = pq.ParquetWriter(f, schema=schema) + pa_table = pa.Table.from_pydict({k: [DATA[i][k] for i in range(len(DATA))] for k in DATA[0]}, schema=schema) + writer.write_table(pa_table) + writer.close() + return path + + @pytest.fixture(scope="session") def json_list_of_dicts_path(tmp_path_factory): path = str(tmp_path_factory.mktemp("data") / "dataset.json") diff --git a/tests/io/test_json.py b/tests/io/test_json.py index 5a54c787c95..9c7473b0abd 100644 --- a/tests/io/test_json.py +++ b/tests/io/test_json.py @@ -141,7 +141,6 @@ def test_datasetdict_from_json_keep_in_memory(keep_in_memory, jsonl_path, tmp_pa ) def test_datasetdict_from_json_features(features, jsonl_path, tmp_path): cache_dir = tmp_path / "cache" - # CSV file loses col_1 string dtype information: default now is "int64" instead of "string" default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} expected_features = features.copy() if features else default_expected_features features = ( diff --git a/tests/io/test_parquet.py b/tests/io/test_parquet.py new file mode 100644 index 00000000000..1b284378780 --- /dev/null +++ b/tests/io/test_parquet.py @@ -0,0 +1,140 @@ +import pyarrow.parquet as pq +import pytest + +from datasets import Dataset, DatasetDict, Features, NamedSplit, Value +from datasets.io.parquet import ParquetDatasetReader, ParquetDatasetWriter + +from ..utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases, require_pyarrow_at_least_3 + + +def _check_parquet_dataset(dataset, expected_features): + assert isinstance(dataset, Dataset) + assert dataset.num_rows == 4 + assert dataset.num_columns == 3 + assert dataset.column_names == ["col_1", "col_2", "col_3"] + for feature, expected_dtype in expected_features.items(): + assert dataset.features[feature].dtype == expected_dtype + + +@require_pyarrow_at_least_3 +@pytest.mark.parametrize("keep_in_memory", [False, True]) +def test_dataset_from_parquet_keep_in_memory(keep_in_memory, parquet_path, tmp_path): + cache_dir = tmp_path / "cache" + expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase(): + dataset = ParquetDatasetReader(parquet_path, cache_dir=cache_dir, keep_in_memory=keep_in_memory).read() + _check_parquet_dataset(dataset, expected_features) + + +@require_pyarrow_at_least_3 +@pytest.mark.parametrize( + "features", + [ + None, + {"col_1": "string", "col_2": "int64", "col_3": "float64"}, + {"col_1": "string", "col_2": "string", "col_3": "string"}, + {"col_1": "int32", "col_2": "int32", "col_3": "int32"}, + {"col_1": "float32", "col_2": "float32", "col_3": "float32"}, + ], +) +def test_dataset_from_parquet_features(features, parquet_path, tmp_path): + cache_dir = tmp_path / "cache" + default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + expected_features = features.copy() if features else default_expected_features + features = ( + Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None + ) + dataset = ParquetDatasetReader(parquet_path, features=features, cache_dir=cache_dir).read() + _check_parquet_dataset(dataset, expected_features) + + +@require_pyarrow_at_least_3 +@pytest.mark.parametrize("split", [None, NamedSplit("train"), "train", "test"]) +def test_dataset_from_parquet_split(split, parquet_path, tmp_path): + cache_dir = tmp_path / "cache" + expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + dataset = ParquetDatasetReader(parquet_path, cache_dir=cache_dir, split=split).read() + _check_parquet_dataset(dataset, expected_features) + assert dataset.split == str(split) if split else "train" + + +@require_pyarrow_at_least_3 +@pytest.mark.parametrize("path_type", [str, list]) +def test_dataset_from_parquet_path_type(path_type, parquet_path, tmp_path): + if issubclass(path_type, str): + path = parquet_path + elif issubclass(path_type, list): + path = [parquet_path] + cache_dir = tmp_path / "cache" + expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + dataset = ParquetDatasetReader(path, cache_dir=cache_dir).read() + _check_parquet_dataset(dataset, expected_features) + + +def _check_parquet_datasetdict(dataset_dict, expected_features, splits=("train",)): + assert isinstance(dataset_dict, DatasetDict) + for split in splits: + dataset = dataset_dict[split] + assert dataset.num_rows == 4 + assert dataset.num_columns == 3 + assert dataset.column_names == ["col_1", "col_2", "col_3"] + for feature, expected_dtype in expected_features.items(): + assert dataset.features[feature].dtype == expected_dtype + + +@require_pyarrow_at_least_3 +@pytest.mark.parametrize("keep_in_memory", [False, True]) +def test_parquet_datasetdict_reader_keep_in_memory(keep_in_memory, parquet_path, tmp_path): + cache_dir = tmp_path / "cache" + expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase(): + dataset = ParquetDatasetReader( + {"train": parquet_path}, cache_dir=cache_dir, keep_in_memory=keep_in_memory + ).read() + _check_parquet_datasetdict(dataset, expected_features) + + +@require_pyarrow_at_least_3 +@pytest.mark.parametrize( + "features", + [ + None, + {"col_1": "string", "col_2": "int64", "col_3": "float64"}, + {"col_1": "string", "col_2": "string", "col_3": "string"}, + {"col_1": "int32", "col_2": "int32", "col_3": "int32"}, + {"col_1": "float32", "col_2": "float32", "col_3": "float32"}, + ], +) +def test_parquet_datasetdict_reader_features(features, parquet_path, tmp_path): + cache_dir = tmp_path / "cache" + default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + expected_features = features.copy() if features else default_expected_features + features = ( + Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None + ) + dataset = ParquetDatasetReader({"train": parquet_path}, features=features, cache_dir=cache_dir).read() + _check_parquet_datasetdict(dataset, expected_features) + + +@require_pyarrow_at_least_3 +@pytest.mark.parametrize("split", [None, NamedSplit("train"), "train", "test"]) +def test_parquet_datasetdict_reader_split(split, parquet_path, tmp_path): + if split: + path = {split: parquet_path} + else: + split = "train" + path = {"train": parquet_path, "test": parquet_path} + cache_dir = tmp_path / "cache" + expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + dataset = ParquetDatasetReader(path, cache_dir=cache_dir).read() + _check_parquet_datasetdict(dataset, expected_features, splits=list(path.keys())) + assert all(dataset[split].split == split for split in path.keys()) + + +@require_pyarrow_at_least_3 +def test_parquer_write(dataset, tmp_path): + writer = ParquetDatasetWriter(dataset, tmp_path / "foo.parquet") + assert writer.write() > 0 + pf = pq.ParquetFile(tmp_path / "foo.parquet") + output_table = pf.read() + assert dataset.data.table == output_table diff --git a/tests/io/test_text.py b/tests/io/test_text.py index 428399cc2b6..a93e3cff95f 100644 --- a/tests/io/test_text.py +++ b/tests/io/test_text.py @@ -35,7 +35,6 @@ def test_dataset_from_text_keep_in_memory(keep_in_memory, text_path, tmp_path): ) def test_dataset_from_text_features(features, text_path, tmp_path): cache_dir = tmp_path / "cache" - # CSV file loses col_1 string dtype information: default now is "int64" instead of "string" default_expected_features = {"text": "string"} expected_features = features.copy() if features else default_expected_features features = ( diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 23f917e8430..738640f4659 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -27,6 +27,7 @@ from .utils import ( assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases, + require_pyarrow_at_least_3, require_s3, require_tf, require_torch, @@ -1692,6 +1693,59 @@ def test_to_pandas(self, in_memory): for col_name in dset.column_names: self.assertEqual(len(dset_to_pandas[col_name]), dset.num_rows) + @require_pyarrow_at_least_3 + def test_to_parquet(self, in_memory): + with tempfile.TemporaryDirectory() as tmp_dir: + # File path argument + with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: + file_path = os.path.join(tmp_dir, "test_path.parquet") + dset.to_parquet(path_or_buf=file_path) + + self.assertTrue(os.path.isfile(file_path)) + # self.assertEqual(bytes_written, os.path.getsize(file_path)) # because of compression, the number of bytes doesn't match + parquet_dset = pd.read_parquet(file_path) + + self.assertEqual(parquet_dset.shape, dset.shape) + self.assertListEqual(list(parquet_dset.columns), list(dset.column_names)) + + # File buffer argument + with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: + file_path = os.path.join(tmp_dir, "test_buffer.parquet") + with open(file_path, "wb+") as buffer: + dset.to_parquet(path_or_buf=buffer) + + self.assertTrue(os.path.isfile(file_path)) + # self.assertEqual(bytes_written, os.path.getsize(file_path)) # because of compression, the number of bytes doesn't match + parquet_dset = pd.read_parquet(file_path) + + self.assertEqual(parquet_dset.shape, dset.shape) + self.assertListEqual(list(parquet_dset.columns), list(dset.column_names)) + + # After a select/shuffle transform + with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: + dset = dset.select(range(0, len(dset), 2)).shuffle() + file_path = os.path.join(tmp_dir, "test_path.parquet") + dset.to_parquet(path_or_buf=file_path) + + self.assertTrue(os.path.isfile(file_path)) + # self.assertEqual(bytes_written, os.path.getsize(file_path)) # because of compression, the number of bytes doesn't match + parquet_dset = pd.read_parquet(file_path) + + self.assertEqual(parquet_dset.shape, dset.shape) + self.assertListEqual(list(parquet_dset.columns), list(dset.column_names)) + + # With array features + with self._create_dummy_dataset(in_memory, tmp_dir, array_features=True) as dset: + file_path = os.path.join(tmp_dir, "test_path.parquet") + dset.to_parquet(path_or_buf=file_path) + + self.assertTrue(os.path.isfile(file_path)) + # self.assertEqual(bytes_written, os.path.getsize(file_path)) # because of compression, the number of bytes doesn't match + parquet_dset = pd.read_parquet(file_path) + + self.assertEqual(parquet_dset.shape, dset.shape) + self.assertListEqual(list(parquet_dset.columns), list(dset.column_names)) + def test_train_test_split(self, in_memory): with tempfile.TemporaryDirectory() as tmp_dir: with self._create_dummy_dataset(in_memory, tmp_dir) as dset: @@ -2586,7 +2640,6 @@ def test_dataset_from_json_keep_in_memory(keep_in_memory, jsonl_path, tmp_path): ) def test_dataset_from_json_features(features, jsonl_path, tmp_path): cache_dir = tmp_path / "cache" - # CSV file loses col_1 string dtype information: default now is "int64" instead of "string" default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} expected_features = features.copy() if features else default_expected_features features = ( @@ -2617,6 +2670,70 @@ def test_dataset_from_json_path_type(path_type, jsonl_path, tmp_path): _check_json_dataset(dataset, expected_features) +def _check_parquet_dataset(dataset, expected_features): + assert isinstance(dataset, Dataset) + assert dataset.num_rows == 4 + assert dataset.num_columns == 3 + assert dataset.column_names == ["col_1", "col_2", "col_3"] + for feature, expected_dtype in expected_features.items(): + assert dataset.features[feature].dtype == expected_dtype + + +@require_pyarrow_at_least_3 +@pytest.mark.parametrize("keep_in_memory", [False, True]) +def test_dataset_from_parquet_keep_in_memory(keep_in_memory, parquet_path, tmp_path): + cache_dir = tmp_path / "cache" + expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase(): + dataset = Dataset.from_parquet(parquet_path, cache_dir=cache_dir, keep_in_memory=keep_in_memory) + _check_parquet_dataset(dataset, expected_features) + + +@require_pyarrow_at_least_3 +@pytest.mark.parametrize( + "features", + [ + None, + {"col_1": "string", "col_2": "int64", "col_3": "float64"}, + {"col_1": "string", "col_2": "string", "col_3": "string"}, + {"col_1": "int32", "col_2": "int32", "col_3": "int32"}, + {"col_1": "float32", "col_2": "float32", "col_3": "float32"}, + ], +) +def test_dataset_from_parquet_features(features, parquet_path, tmp_path): + cache_dir = tmp_path / "cache" + default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + expected_features = features.copy() if features else default_expected_features + features = ( + Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None + ) + dataset = Dataset.from_parquet(parquet_path, features=features, cache_dir=cache_dir) + _check_parquet_dataset(dataset, expected_features) + + +@require_pyarrow_at_least_3 +@pytest.mark.parametrize("split", [None, NamedSplit("train"), "train", "test"]) +def test_dataset_from_parquet_split(split, parquet_path, tmp_path): + cache_dir = tmp_path / "cache" + expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + dataset = Dataset.from_parquet(parquet_path, cache_dir=cache_dir, split=split) + _check_parquet_dataset(dataset, expected_features) + assert dataset.split == str(split) if split else "train" + + +@require_pyarrow_at_least_3 +@pytest.mark.parametrize("path_type", [str, list]) +def test_dataset_from_parquet_path_type(path_type, parquet_path, tmp_path): + if issubclass(path_type, str): + path = parquet_path + elif issubclass(path_type, list): + path = [parquet_path] + cache_dir = tmp_path / "cache" + expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + dataset = Dataset.from_parquet(path, cache_dir=cache_dir) + _check_parquet_dataset(dataset, expected_features) + + def _check_text_dataset(dataset, expected_features): assert isinstance(dataset, Dataset) assert dataset.num_rows == 4 @@ -2646,7 +2763,6 @@ def test_dataset_from_text_keep_in_memory(keep_in_memory, text_path, tmp_path): ) def test_dataset_from_text_features(features, text_path, tmp_path): cache_dir = tmp_path / "cache" - # CSV file loses col_1 string dtype information: default now is "int64" instead of "string" default_expected_features = {"text": "string"} expected_features = features.copy() if features else default_expected_features features = ( diff --git a/tests/test_dataset_cards.py b/tests/test_dataset_cards.py index 0ed1009022b..be23cce3b57 100644 --- a/tests/test_dataset_cards.py +++ b/tests/test_dataset_cards.py @@ -19,6 +19,7 @@ import pytest +from datasets.packaged_modules import _PACKAGED_DATASETS_MODULES from datasets.utils.logging import get_logger from datasets.utils.metadata import DatasetMetadata from datasets.utils.readme import ReadMe @@ -43,11 +44,12 @@ def get_changed_datasets(repo_path: Path) -> List[Path]: if f.exists() and str(f.resolve()).startswith(str(datasets_dir_path)) ) - return sorted(changed_datasets) + return sorted(dataset_name for dataset_name in changed_datasets if dataset_name not in _PACKAGED_DATASETS_MODULES) def get_all_datasets(repo_path: Path) -> List[Path]: - return [path.parts[-1] for path in (repo_path / "datasets").iterdir()] + dataset_names = [path.parts[-1] for path in (repo_path / "datasets").iterdir()] + return [dataset_name for dataset_name in dataset_names if dataset_name not in _PACKAGED_DATASETS_MODULES] @pytest.mark.parametrize("dataset_name", get_changed_datasets(repo_path)) diff --git a/tests/test_dataset_common.py b/tests/test_dataset_common.py index bc49e9a8bfc..49d30c49a49 100644 --- a/tests/test_dataset_common.py +++ b/tests/test_dataset_common.py @@ -22,7 +22,9 @@ from typing import List, Optional from unittest import TestCase +import pyarrow as pa from absl.testing import parameterized +from packaging import version from datasets import cached_path, import_main_class, load_dataset, prepare_module from datasets.builder import BuilderConfig, DatasetBuilder @@ -74,7 +76,7 @@ def wrapper(self, dataset_name): def get_packaged_dataset_dummy_data_files(dataset_name, path_to_dummy_data): - extensions = {"text": "txt", "json": "json", "pandas": "pkl", "csv": "csv"} + extensions = {"text": "txt", "json": "json", "pandas": "pkl", "csv": "csv", "parquet": "parquet"} return { "train": os.path.join(path_to_dummy_data, "train." + extensions[dataset_name]), "test": os.path.join(path_to_dummy_data, "test." + extensions[dataset_name]), @@ -270,7 +272,10 @@ def test_load_real_dataset_all_configs(self, dataset_name): def get_packaged_dataset_names(): - return [{"testcase_name": x, "dataset_name": x} for x in _PACKAGED_DATASETS_MODULES.keys()] + packaged_datasets = [{"testcase_name": x, "dataset_name": x} for x in _PACKAGED_DATASETS_MODULES.keys()] + if version.parse(pa.__version__) < version.parse("3.0.0"): # parquet is not supported for pyarrow<3.0.0 + packaged_datasets = [pd for pd in packaged_datasets if pd["dataset_name"] != "parquet"] + return packaged_datasets @parameterized.named_parameters(get_packaged_dataset_names()) diff --git a/tests/test_dataset_dict.py b/tests/test_dataset_dict.py index a61c8bbc41c..50b5463c979 100644 --- a/tests/test_dataset_dict.py +++ b/tests/test_dataset_dict.py @@ -16,6 +16,7 @@ from .utils import ( assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases, + require_pyarrow_at_least_3, require_s3, require_tf, require_torch, @@ -563,7 +564,6 @@ def test_datasetdict_from_json_keep_in_memory(keep_in_memory, jsonl_path, tmp_pa ) def test_datasetdict_from_json_features(features, jsonl_path, tmp_path): cache_dir = tmp_path / "cache" - # CSV file loses col_1 string dtype information: default now is "int64" instead of "string" default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} expected_features = features.copy() if features else default_expected_features features = ( @@ -587,6 +587,64 @@ def test_datasetdict_from_json_splits(split, jsonl_path, tmp_path): assert all(dataset[split].split == split for split in path.keys()) +def _check_parquet_datasetdict(dataset_dict, expected_features, splits=("train",)): + assert isinstance(dataset_dict, DatasetDict) + for split in splits: + dataset = dataset_dict[split] + assert dataset.num_rows == 4 + assert dataset.num_columns == 3 + assert dataset.column_names == ["col_1", "col_2", "col_3"] + for feature, expected_dtype in expected_features.items(): + assert dataset.features[feature].dtype == expected_dtype + + +@require_pyarrow_at_least_3 +@pytest.mark.parametrize("keep_in_memory", [False, True]) +def test_datasetdict_from_parquet_keep_in_memory(keep_in_memory, parquet_path, tmp_path): + cache_dir = tmp_path / "cache" + expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase(): + dataset = DatasetDict.from_parquet({"train": parquet_path}, cache_dir=cache_dir, keep_in_memory=keep_in_memory) + _check_parquet_datasetdict(dataset, expected_features) + + +@require_pyarrow_at_least_3 +@pytest.mark.parametrize( + "features", + [ + None, + {"col_1": "string", "col_2": "int64", "col_3": "float64"}, + {"col_1": "string", "col_2": "string", "col_3": "string"}, + {"col_1": "int32", "col_2": "int32", "col_3": "int32"}, + {"col_1": "float32", "col_2": "float32", "col_3": "float32"}, + ], +) +def test_datasetdict_from_parquet_features(features, parquet_path, tmp_path): + cache_dir = tmp_path / "cache" + default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + expected_features = features.copy() if features else default_expected_features + features = ( + Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None + ) + dataset = DatasetDict.from_parquet({"train": parquet_path}, features=features, cache_dir=cache_dir) + _check_parquet_datasetdict(dataset, expected_features) + + +@require_pyarrow_at_least_3 +@pytest.mark.parametrize("split", [None, NamedSplit("train"), "train", "test"]) +def test_datasetdict_from_parquet_split(split, parquet_path, tmp_path): + if split: + path = {split: parquet_path} + else: + split = "train" + path = {"train": parquet_path, "test": parquet_path} + cache_dir = tmp_path / "cache" + expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + dataset = DatasetDict.from_parquet(path, cache_dir=cache_dir) + _check_parquet_datasetdict(dataset, expected_features, splits=list(path.keys())) + assert all(dataset[split].split == split for split in path.keys()) + + def _check_text_datasetdict(dataset_dict, expected_features, splits=("train",)): assert isinstance(dataset_dict, DatasetDict) for split in splits: @@ -618,7 +676,6 @@ def test_datasetdict_from_text_keep_in_memory(keep_in_memory, text_path, tmp_pat ) def test_datasetdict_from_text_features(features, text_path, tmp_path): cache_dir = tmp_path / "cache" - # CSV file loses col_1 string dtype information: default now is "int64" instead of "string" default_expected_features = {"text": "string"} expected_features = features.copy() if features else default_expected_features features = ( diff --git a/tests/utils.py b/tests/utils.py index 5f6206d31e4..872c969374c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -8,6 +8,7 @@ from unittest.mock import patch import pyarrow as pa +from packaging import version from datasets import config @@ -34,6 +35,19 @@ def parse_flag_from_env(key, default=False): _run_packaged_tests = parse_flag_from_env("RUN_PACKAGED", default=True) +def require_pyarrow_at_least_3(test_case): + """ + Decorator marking a test that requires PyArrow 3.0.0 + to allow nested types in parquet, as well as batch iterators of parquet files. + + These tests are skipped when the PyArrow version is outdated. + + """ + if version.parse(config.PYARROW_VERSION) < version.parse("3.0.0"): + test_case = unittest.skip("test requires PyArrow>=3.0.0")(test_case) + return test_case + + def require_beam(test_case): """ Decorator marking a test that requires Apache Beam.