diff --git a/docs/source/filesystems.mdx b/docs/source/filesystems.mdx index 048dcafcd36..f2da377568f 100644 --- a/docs/source/filesystems.mdx +++ b/docs/source/filesystems.mdx @@ -1,6 +1,8 @@ # Cloud storage -🤗 Datasets supports access to cloud storage providers through a S3 filesystem implementation: [`filesystems.S3FileSystem`]. You can save and load datasets from your Amazon S3 bucket in a Pythonic way. Take a look at the following table for other supported cloud storage providers: +🤗 Datasets supports access to cloud storage providers through a `fsspec` FileSystem implementations. +You can save and load datasets from any cloud storage in a Pythonic way. +Take a look at the following table for some example of supported cloud storage providers: | Storage provider | Filesystem implementation | |----------------------|---------------------------------------------------------------| @@ -10,11 +12,12 @@ | Dropbox | [dropboxdrivefs](https://github.com/MarineChap/dropboxdrivefs)| | Google Drive | [gdrivefs](https://github.com/intake/gdrivefs) | -This guide will show you how to save and load datasets with **s3fs** to a S3 bucket, but other filesystem implementations can be used similarly. An example is shown also for Google Cloud Storage and Azure Blob Storage. +This guide will show you how to save and load datasets with any cloud storage. +Here are examples for S3, Google Cloud Storage and Azure Blob Storage. -## Amazon S3 +## Set up your cloud storage FileSystem -### Listing datasets +### Amazon S3 1. Install the S3 dependency with 🤗 Datasets: @@ -22,163 +25,178 @@ This guide will show you how to save and load datasets with **s3fs** to a S3 buc >>> pip install datasets[s3] ``` -2. List files from a public S3 bucket with `s3.ls`: +2. Define your credentials + +To use an anonymous connection, use `anon=True`. +Otherwise, include your `aws_access_key_id` and `aws_secret_access_key` whenever you are interacting with a private S3 bucket. ```py ->>> import datasets ->>> s3 = datasets.filesystems.S3FileSystem(anon=True) ->>> s3.ls('public-datasets/imdb/train') -['dataset_info.json.json','dataset.arrow','state.json'] +>>> storage_options = {"anon": True} # for anonymous connection +# or use your credentials +>>> storage_options = {"key": aws_access_key_id, "secret": aws_secret_access_key} # for private buckets +# or use a botocore session +>>> import botocore +>>> s3_session = botocore.session.Session(profile="my_profile_name") +>>> storage_options = {"session": s3_session} ``` -Access a private S3 bucket by entering your `aws_access_key_id` and `aws_secret_access_key`: +3. Create your FileSystem instance ```py ->>> import datasets ->>> s3 = datasets.filesystems.S3FileSystem(key=aws_access_key_id, secret=aws_secret_access_key) ->>> s3.ls('my-private-datasets/imdb/train') -['dataset_info.json.json','dataset.arrow','state.json'] +>>> import s3fs +>>> fs = s3fs.S3FileSystem(**storage_options) ``` -### Saving datasets +### Google Cloud Storage -After you have processed your dataset, you can save it to S3 with [`Dataset.save_to_disk`]: +1. Install the Google Cloud Storage implementation: + +``` +>>> conda install -c conda-forge gcsfs +# or install with pip +>>> pip install gcsfs +``` + +2. Define your credentials ```py ->>> from datasets.filesystems import S3FileSystem +>>> storage_options={"token": "anon"} # for anonymous connection +# or use your credentials of your default gcloud credentials or from the google metadata service +>>> storage_options={"project": "my-google-project"} +# or use your credentials from elsewhere, see the documentation at https://gcsfs.readthedocs.io/ +>>> storage_options={"project": "my-google-project", "token": TOKEN} +``` -# create S3FileSystem instance ->>> s3 = S3FileSystem(anon=True) +3. Create your FileSystem instance -# saves encoded_dataset to your s3 bucket ->>> encoded_dataset.save_to_disk('s3://my-private-datasets/imdb/train', fs=s3) +```py +>>> import gcsfs +>>> fs = gcsfs.GCSFileSystem(**storage_options) ``` - +### Azure Blob Storage -Remember to include your `aws_access_key_id` and `aws_secret_access_key` whenever you are interacting with a private S3 bucket. +1. Install the Azure Blob Storage implementation: - +``` +>>> conda install -c conda-forge adlfs +# or install with pip +>>> pip install adlfs +``` -Save your dataset with `botocore.session.Session` and a custom AWS profile: +2. Define your credentials ```py ->>> import botocore ->>> from datasets.filesystems import S3FileSystem - -# creates a botocore session with the provided AWS profile ->>> s3_session = botocore.session.Session(profile='my_profile_name') +>>> storage_options = {"anon": True} # for anonymous connection +# or use your credentials +>>> storage_options = {"account_name": ACCOUNT_NAME, "account_key": ACCOUNT_KEY) # gen 2 filesystem +# or use your credentials with the gen 1 filesystem +>>> storage_options={"tenant_id": TENANT_ID, "client_id": CLIENT_ID, "client_secret": CLIENT_SECRET} +``` -# create S3FileSystem instance with s3_session ->>> s3 = S3FileSystem(session=s3_session) +3. Create your FileSystem instance -# saves encoded_dataset to your s3 bucket ->>> encoded_dataset.save_to_disk('s3://my-private-datasets/imdb/train',fs=s3) +```py +>>> import adlfs +>>> fs = adlfs.AzureBlobFileSystem(**storage_options) ``` -### Loading datasets +## Load and Save your datasets using your cloud storage FileSystem -When you are ready to use your dataset again, reload it with [`Dataset.load_from_disk`]: +### Download and prepare a dataset into a cloud storage -```py ->>> from datasets import load_from_disk ->>> from datasets.filesystems import S3FileSystem +You can download and prepare a dataset into your cloud storage by specifying a remote `output_dir` in `download_and_prepare`. +Don't forget to use the previously defined `storage_options` containing your credentials to write into a private cloud storage. -# create S3FileSystem without credentials ->>> s3 = S3FileSystem(anon=True) +The `download_and_prepare` method works in two steps: +1. it first downloads the raw data files (if any) in your local cache. You can set your cache directory by passing `cache_dir` to [`load_dataset_builder`] +2. then it generates the dataset in Arrow or Parquet format in your cloud storage by iterating over the raw data files. -# load encoded_dataset to from s3 bucket ->>> dataset = load_from_disk('s3://a-public-datasets/imdb/train',fs=s3) +Load a dataset builder from the Hugging Face Hub (see [how to load from the Hugging Face Hub](./loading#hugging-face-hub)): ->>> print(len(dataset)) ->>> # 25000 +```py +>>> output_dir = "s3://my-bucket/imdb" +>>> builder = load_dataset_builder("imdb") +>>> builder.download_and_prepare(output_dir, storage_options=storage_options, file_format="parquet") ``` -Load with `botocore.session.Session` and custom AWS profile: +Load a dataset builder using a loading script (see [how to load a local loading script](./loading#local-loading-script)): ```py ->>> import botocore ->>> from datasets.filesystems import S3FileSystem - -# create S3FileSystem instance with aws_access_key_id and aws_secret_access_key ->>> s3_session = botocore.session.Session(profile='my_profile_name') - -# create S3FileSystem instance with s3_session ->>> s3 = S3FileSystem(session=s3_session) +>>> output_dir = "s3://my-bucket/imdb" +>>> builder = load_dataset_builder("path/to/local/loading_script/loading_script.py") +>>> builder.download_and_prepare(output_dir, storage_options=storage_options, file_format="parquet") +``` -# load encoded_dataset to from s3 bucket ->>> dataset = load_from_disk('s3://my-private-datasets/imdb/train',fs=s3) +Use your own data files (see [how to load local and remote files](./loading#local-and-remote-files)): ->>> print(len(dataset)) ->>> # 25000 +```py +>>> data_files = {"train": ["path/to/train.csv"]} +>>> output_dir = "s3://my-bucket/imdb" +>>> builder = load_dataset_builder("csv", data_files=data_files) +>>> builder.download_and_prepare(output_dir, storage_options=storage_options, file_format="parquet") ``` -## Google Cloud Storage +It is highly recommended to save the files as compressed Parquet files to optimize I/O by specifying `file_format="parquet"`. +Otherwise the dataset is saved as an uncompressed Arrow file. -1. Install the Google Cloud Storage implementation: +#### Dask -``` ->>> conda install -c conda-forge gcsfs -# or install with pip ->>> pip install gcsfs -``` +Dask is a parallel computing library and it has a pandas-like API for working with larger than memory Parquet datasets in parallel. +Dask can use multiple threads or processes on a single machine, or a cluster of machines to process data in parallel. +Dask supports local data but also data from a cloud storage. -2. Save your dataset: +Therefore you can load a dataset saved as sharded Parquet files in Dask with ```py ->>> import gcsfs +import dask.dataframe as dd -# create GCSFileSystem instance using default gcloud credentials with project ->>> gcs = gcsfs.GCSFileSystem(project='my-google-project') +df = dd.read_parquet(output_dir, storage_options=storage_options) -# saves encoded_dataset to your gcs bucket ->>> encoded_dataset.save_to_disk('gcs://my-private-datasets/imdb/train', fs=gcs) +# or if your dataset is split into train/valid/test +df_train = dd.read_parquet(output_dir + f"/{builder.name}-train-*.parquet", storage_options=storage_options) +df_valid = dd.read_parquet(output_dir + f"/{builder.name}-validation-*.parquet", storage_options=storage_options) +df_test = dd.read_parquet(output_dir + f"/{builder.name}-test-*.parquet", storage_options=storage_options) ``` -3. Load your dataset: +You can find more about dask dataframes in their [documentation](https://docs.dask.org/en/stable/dataframe.html). -```py ->>> import gcsfs ->>> from datasets import load_from_disk +## Saving serialized datasets -# create GCSFileSystem instance using default gcloud credentials with project ->>> gcs = gcsfs.GCSFileSystem(project='my-google-project') +After you have processed your dataset, you can save it to your cloud storage with [`Dataset.save_to_disk`]: -# loads encoded_dataset from your gcs bucket ->>> dataset = load_from_disk('gcs://my-private-datasets/imdb/train', fs=gcs) +```py +# saves encoded_dataset to amazon s3 +>>> encoded_dataset.save_to_disk("s3://my-private-datasets/imdb/train", fs=fs) +# saves encoded_dataset to google cloud storage +>>> encoded_dataset.save_to_disk("gcs://my-private-datasets/imdb/train", fs=fs) +# saves encoded_dataset to microsoft azure blob/datalake +>>> encoded_dataset.save_to_disk("adl://my-private-datasets/imdb/train", fs=fs) ``` -## Azure Blob Storage - -1. Install the Azure Blob Storage implementation: + -``` ->>> conda install -c conda-forge adlfs -# or install with pip ->>> pip install adlfs -``` +Remember to define your credentials in your [FileSystem instance](#set-up-your-cloud-storage-filesystem) `fs` whenever you are interacting with a private cloud storage. -2. Save your dataset: + -```py ->>> import adlfs +## Listing serialized datasets -# create AzureBlobFileSystem instance with account_name and account_key ->>> abfs = adlfs.AzureBlobFileSystem(account_name="XXXX", account_key="XXXX") +List files from a cloud storage with your FileSystem instance `fs`, using `fs.ls`: -# saves encoded_dataset to your azure container ->>> encoded_dataset.save_to_disk('abfs://my-private-datasets/imdb/train', fs=abfs) +```py +>>> fs.ls("my-private-datasets/imdb/train") +["dataset_info.json.json","dataset.arrow","state.json"] ``` -3. Load your dataset: +### Load serialized datasets + +When you are ready to use your dataset again, reload it with [`Dataset.load_from_disk`]: ```py ->>> import adlfs >>> from datasets import load_from_disk - -# create AzureBlobFileSystem instance with account_name and account_key ->>> abfs = adlfs.AzureBlobFileSystem(account_name="XXXX", account_key="XXXX") - -# loads encoded_dataset from your azure container ->>> dataset = load_from_disk('abfs://my-private-datasets/imdb/train', fs=abfs) +# load encoded_dataset from cloud storage +>>> dataset = load_from_disk("s3://a-public-datasets/imdb/train", fs=fs) +>>> print(len(dataset)) +25000 ``` diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index bc2c5ab37ba..54b5b7726d7 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1134,7 +1134,7 @@ def save_to_disk(self, dataset_path: str, fs=None): fs.makedirs(dataset_path, exist_ok=True) with fs.open(Path(dataset_path, config.DATASET_ARROW_FILENAME).as_posix(), "wb") as dataset_file: with ArrowWriter(stream=dataset_file) as writer: - writer.write_table(dataset._data) + writer.write_table(dataset._data.table) writer.finalize() with fs.open( Path(dataset_path, config.DATASET_STATE_JSON_FILENAME).as_posix(), "w", encoding="utf-8" diff --git a/src/datasets/arrow_writer.py b/src/datasets/arrow_writer.py index 6ffdb0f715d..0230b840faf 100644 --- a/src/datasets/arrow_writer.py +++ b/src/datasets/arrow_writer.py @@ -18,8 +18,10 @@ import sys from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +import fsspec import numpy as np import pyarrow as pa +import pyarrow.parquet as pq from . import config from .features import Features, Image, Value @@ -33,6 +35,7 @@ numpy_to_pyarrow_listarray, to_pyarrow_listarray, ) +from .filesystems import is_remote_filesystem from .info import DatasetInfo from .keyhash import DuplicatedKeysError, KeyHasher from .table import array_cast, cast_array_to_feature, table_cast @@ -268,6 +271,8 @@ def __init__( class ArrowWriter: """Shuffles and writes Examples to Arrow files.""" + _WRITER_CLASS = pa.RecordBatchStreamWriter + def __init__( self, schema: Optional[pa.Schema] = None, @@ -282,6 +287,7 @@ def __init__( update_features: bool = False, with_metadata: bool = True, unit: str = "examples", + storage_options: Optional[dict] = None, ): if path is None and stream is None: raise ValueError("At least one of path and stream must be provided.") @@ -304,11 +310,19 @@ def __init__( self._check_duplicates = check_duplicates self._disable_nullable = disable_nullable - self._path = path if stream is None: - self.stream = pa.OSFile(self._path, "wb") + fs_token_paths = fsspec.get_fs_token_paths(path, storage_options=storage_options) + self._fs: fsspec.AbstractFileSystem = fs_token_paths[0] + self._path = ( + fs_token_paths[2][0] + if not is_remote_filesystem(self._fs) + else self._fs.unstrip_protocol(fs_token_paths[2][0]) + ) + self.stream = self._fs.open(fs_token_paths[2][0], "wb") self._closable_stream = True else: + self._fs = None + self._path = None self.stream = stream self._closable_stream = False @@ -367,7 +381,7 @@ def _build_writer(self, inferred_schema: pa.Schema): if self.with_metadata: schema = schema.with_metadata(self._build_metadata(DatasetInfo(features=self._features), self.fingerprint)) self._schema = schema - self.pa_writer = pa.RecordBatchStreamWriter(self.stream, schema) + self.pa_writer = self._WRITER_CLASS(self.stream, schema) @property def schema(self): @@ -522,11 +536,9 @@ def write_table(self, pa_table: pa.Table, writer_batch_size: Optional[int] = Non if self.pa_writer is None: self._build_writer(inferred_schema=pa_table.schema) pa_table = table_cast(pa_table, self._schema) - batches: List[pa.RecordBatch] = pa_table.to_batches(max_chunksize=writer_batch_size) - self._num_bytes += sum(batch.nbytes for batch in batches) + self._num_bytes += pa_table.nbytes self._num_examples += pa_table.num_rows - for batch in batches: - self.pa_writer.write_batch(batch) + self.pa_writer.write_table(pa_table, writer_batch_size) def finalize(self, close_stream=True): self.write_rows_on_file() @@ -542,6 +554,7 @@ def finalize(self, close_stream=True): else: raise ValueError("Please pass `features` or at least one example when writing data") self.pa_writer.close() + self.pa_writer = None if close_stream: self.stream.close() logger.debug( @@ -550,6 +563,10 @@ def finalize(self, close_stream=True): return self._num_examples, self._num_bytes +class ParquetWriter(ArrowWriter): + _WRITER_CLASS = pq.ParquetWriter + + class BeamWriter: """ Shuffles and writes Examples to Arrow files. @@ -616,35 +633,50 @@ def finalize(self, metrics_query_result: dict): from .utils import beam_utils # Convert to arrow - logger.info(f"Converting parquet file {self._parquet_path} to arrow {self._path}") - shards = [ - metadata.path - for metadata in beam.io.filesystems.FileSystems.match([self._parquet_path + "*.parquet"])[0].metadata_list - ] - try: # stream conversion - sources = [beam.io.filesystems.FileSystems.open(shard) for shard in shards] - with beam.io.filesystems.FileSystems.create(self._path) as dest: - parquet_to_arrow(sources, dest) - except OSError as e: # broken pipe can happen if the connection is unstable, do local conversion instead - if e.errno != errno.EPIPE: # not a broken pipe - raise - logger.warning("Broken Pipe during stream conversion from parquet to arrow. Using local convert instead") - local_convert_dir = os.path.join(self._cache_dir, "beam_convert") - os.makedirs(local_convert_dir, exist_ok=True) - local_arrow_path = os.path.join(local_convert_dir, hash_url_to_filename(self._parquet_path) + ".arrow") - local_shards = [] - for shard in shards: - local_parquet_path = os.path.join(local_convert_dir, hash_url_to_filename(shard) + ".parquet") - local_shards.append(local_parquet_path) - beam_utils.download_remote_to_local(shard, local_parquet_path) - parquet_to_arrow(local_shards, local_arrow_path) - beam_utils.upload_local_to_remote(local_arrow_path, self._path) + if self._path.endswith(".arrow"): + logger.info(f"Converting parquet file {self._parquet_path} to arrow {self._path}") + shards = [ + metadata.path + for metadata in beam.io.filesystems.FileSystems.match([self._parquet_path + "*.parquet"])[ + 0 + ].metadata_list + ] + try: # stream conversion + sources = [beam.io.filesystems.FileSystems.open(shard) for shard in shards] + with beam.io.filesystems.FileSystems.create(self._path) as dest: + parquet_to_arrow(sources, dest) + except OSError as e: # broken pipe can happen if the connection is unstable, do local conversion instead + if e.errno != errno.EPIPE: # not a broken pipe + raise + logger.warning( + "Broken Pipe during stream conversion from parquet to arrow. Using local convert instead" + ) + local_convert_dir = os.path.join(self._cache_dir, "beam_convert") + os.makedirs(local_convert_dir, exist_ok=True) + local_arrow_path = os.path.join(local_convert_dir, hash_url_to_filename(self._parquet_path) + ".arrow") + local_shards = [] + for shard in shards: + local_parquet_path = os.path.join(local_convert_dir, hash_url_to_filename(shard) + ".parquet") + local_shards.append(local_parquet_path) + beam_utils.download_remote_to_local(shard, local_parquet_path) + parquet_to_arrow(local_shards, local_arrow_path) + beam_utils.upload_local_to_remote(local_arrow_path, self._path) + output_file_metadata = beam.io.filesystems.FileSystems.match([self._path], limits=[1])[0].metadata_list[0] + num_bytes = output_file_metadata.size_in_bytes + else: + num_bytes = sum( + [ + metadata.size_in_bytes + for metadata in beam.io.filesystems.FileSystems.match([self._parquet_path + "*.parquet"])[ + 0 + ].metadata_list + ] + ) # Save metrics counters_dict = {metric.key.metric.name: metric.result for metric in metrics_query_result["counters"]} self._num_examples = counters_dict["num_examples"] - output_file_metadata = beam.io.filesystems.FileSystems.match([self._path], limits=[1])[0].metadata_list[0] - self._num_bytes = output_file_metadata.size_in_bytes + self._num_bytes = num_bytes return self._num_examples, self._num_bytes diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 380ed031bce..955078dff27 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -27,8 +27,11 @@ import warnings from dataclasses import dataclass from functools import partial +from pathlib import Path from typing import Dict, Mapping, Optional, Tuple, Union +import fsspec + from . import config, utils from .arrow_dataset import Dataset from .arrow_reader import ( @@ -38,7 +41,7 @@ MissingFilesOnHfGcsError, ReadInstruction, ) -from .arrow_writer import ArrowWriter, BeamWriter +from .arrow_writer import ArrowWriter, BeamWriter, ParquetWriter from .data_files import DataFilesDict, sanitize_patterns from .dataset_dict import DatasetDict, IterableDatasetDict from .download.download_config import DownloadConfig @@ -46,6 +49,7 @@ from .download.mock_download_manager import MockDownloadManager from .download.streaming_download_manager import StreamingDownloadManager from .features import Features +from .filesystems import is_remote_filesystem from .fingerprint import Hasher from .info import DatasetInfo, DatasetInfosDict, PostProcessedInfo from .iterable_dataset import ExamplesIterable, IterableDataset, _generate_examples_from_tables_wrapper @@ -345,6 +349,10 @@ def __init__( ) os.rmdir(self._cache_dir) + # Store in the cache by default unless the user specifies a custom output_dir to download_and_prepare + self._output_dir = self._cache_dir + self._fs: fsspec.AbstractFileSystem = fsspec.filesystem("file") + # Set download manager self.dl_manager = None @@ -536,7 +544,7 @@ def _other_versions_on_disk(): version_dirnames.sort(reverse=True) return version_dirnames - # Check and warn if other versions exist on disk + # Check and warn if other versions exist if not is_remote_url(builder_data_dir): version_dirs = _other_versions_on_disk() if version_dirs: @@ -570,6 +578,7 @@ def get_imported_module_dir(cls): def download_and_prepare( self, + output_dir: Optional[str] = None, download_config: Optional[DownloadConfig] = None, download_mode: Optional[DownloadMode] = None, ignore_verifications: bool = False, @@ -577,11 +586,17 @@ def download_and_prepare( dl_manager: Optional[DownloadManager] = None, base_path: Optional[str] = None, use_auth_token: Optional[Union[bool, str]] = None, + file_format: Optional[str] = None, + storage_options: Optional[dict] = None, **download_and_prepare_kwargs, ): """Downloads and prepares dataset for reading. Args: + output_dir (:obj:`str`, optional): output directory for the dataset. + Default to this builder's ``cache_dir``, which is inside ~/.cache/huggingface/datasets by default. + + download_config (:class:`DownloadConfig`, optional): specific download configuration parameters. download_mode (:class:`DownloadMode`, optional): select the download/generate mode - Default to ``REUSE_DATASET_IF_EXISTS`` ignore_verifications (:obj:`bool`): Ignore the verifications of the downloaded/processed dataset information (checksums/size/splits/...) @@ -591,6 +606,13 @@ def download_and_prepare( If not specified, the value of the `base_path` attribute (`self.base_path`) will be used instead. use_auth_token (:obj:`Union[str, bool]`, optional): Optional string or boolean to use as Bearer token for remote files on the Datasets Hub. If True, will get token from ~/.huggingface. + file_format (:obj:`str`, optional): format of the data files in which the dataset will be written. + Supported formats: "arrow", "parquet". Default to "arrow" format. + + + storage_options (:obj:`dict`, *optional*): Key/value pairs to be passed on to the caching file-system backend, if any. + + **download_and_prepare_kwargs (additional keyword arguments): Keyword arguments. Example: @@ -601,9 +623,18 @@ def download_and_prepare( >>> ds = builder.download_and_prepare() ``` """ + self._output_dir = output_dir if output_dir is not None else self._cache_dir + # output_dir can be a remote bucket on GCS or S3 (when using BeamBasedBuilder for distributed data processing) + fs_token_paths = fsspec.get_fs_token_paths(self._output_dir, storage_options=storage_options) + self._fs = fs_token_paths[0] + download_mode = DownloadMode(download_mode or DownloadMode.REUSE_DATASET_IF_EXISTS) verify_infos = not ignore_verifications base_path = base_path if base_path is not None else self.base_path + is_local = not is_remote_filesystem(self._fs) + + if file_format is not None and file_format not in ["arrow", "parquet"]: + raise ValueError(f"Unsupported file_format: {file_format}. Expected 'arrow' or 'parquet'") if dl_manager is None: if download_config is None: @@ -625,28 +656,35 @@ def download_and_prepare( else False, ) - elif isinstance(dl_manager, MockDownloadManager): + elif isinstance(dl_manager, MockDownloadManager) or not is_local: try_from_hf_gcs = False self.dl_manager = dl_manager - # Prevent parallel disk operations - is_local = not is_remote_url(self._cache_dir_root) + # Prevent parallel local disk operations if is_local: - lock_path = os.path.join(self._cache_dir_root, self._cache_dir.replace(os.sep, "_") + ".lock") + # Create parent directory of the output_dir to put the lock file in there + Path(self._output_dir).parent.mkdir(parents=True, exist_ok=True) + lock_path = self._output_dir + "_builder.lock" + # File locking only with local paths; no file locking on GCS or S3 with FileLock(lock_path) if is_local else contextlib.nullcontext(): - if is_local: - data_exists = os.path.exists(self._cache_dir) - if data_exists and download_mode == DownloadMode.REUSE_DATASET_IF_EXISTS: - logger.warning(f"Reusing dataset {self.name} ({self._cache_dir})") - # We need to update the info in case some splits were added in the meantime - # for example when calling load_dataset from multiple workers. - self.info = self._load_info() - self.download_post_processing_resources(dl_manager) - return - logger.info(f"Generating dataset {self.name} ({self._cache_dir})") + + # Check if the data already exists + path_join = os.path.join if is_local else posixpath.join + data_exists = self._fs.exists(path_join(self._output_dir, config.DATASET_INFO_FILENAME)) + if data_exists and download_mode == DownloadMode.REUSE_DATASET_IF_EXISTS: + logger.warning(f"Found cached dataset {self.name} ({self._output_dir})") + # We need to update the info in case some splits were added in the meantime + # for example when calling load_dataset from multiple workers. + self.info = self._load_info() + self.download_post_processing_resources(dl_manager) + return + + logger.info(f"Generating dataset {self.name} ({self._output_dir})") if is_local: # if cache dir is local, check for available space - if not has_sufficient_disk_space(self.info.size_in_bytes or 0, directory=self._cache_dir_root): + if not has_sufficient_disk_space( + self.info.size_in_bytes or 0, directory=Path(self._output_dir).parent + ): raise OSError( f"Not enough disk space. Needed: {size_str(self.info.size_in_bytes or 0)} (download: {size_str(self.info.download_size or 0)}, generated: {size_str(self.info.dataset_size or 0)}, post-processed: {size_str(self.info.post_processing_size or 0)})" ) @@ -654,7 +692,8 @@ def download_and_prepare( @contextlib.contextmanager def incomplete_dir(dirname): """Create temporary dir for dirname and rename on exit.""" - if is_remote_url(dirname): + if not is_local: + self._fs.makedirs(dirname, exist_ok=True) yield dirname else: tmp_dir = dirname + ".incomplete" @@ -663,6 +702,7 @@ def incomplete_dir(dirname): yield tmp_dir if os.path.isdir(dirname): shutil.rmtree(dirname) + # LocalFileSystem.mv does copy + rm, it is more efficient to simply rename a local directory shutil.move(tmp_dir, dirname) finally: if os.path.exists(tmp_dir): @@ -676,20 +716,22 @@ def incomplete_dir(dirname): f"Downloading and preparing dataset {self.info.builder_name}/{self.info.config_name} " f"(download: {size_str(self.info.download_size)}, generated: {size_str(self.info.dataset_size)}, " f"post-processed: {size_str(self.info.post_processing_size)}, " - f"total: {size_str(self.info.size_in_bytes)}) to {self._cache_dir}..." + f"total: {size_str(self.info.size_in_bytes)}) to {self._output_dir}..." ) else: + _dest = self._fs._strip_protocol(self._output_dir) if is_local else self._output_dir print( - f"Downloading and preparing dataset {self.info.builder_name}/{self.info.config_name} to {self._cache_dir}..." + f"Downloading and preparing dataset {self.info.builder_name}/{self.info.config_name} to {_dest}..." ) self._check_manual_download(dl_manager) - # Create a tmp dir and rename to self._cache_dir on successful exit. - with incomplete_dir(self._cache_dir) as tmp_data_dir: - # Temporarily assign _cache_dir to tmp_data_dir to avoid having to forward + # Create a tmp dir and rename to self._output_dir on successful exit. + with incomplete_dir(self._output_dir) as tmp_output_dir: + # Temporarily assign _output_dir to tmp_data_dir to avoid having to forward # it to every sub function. - with temporary_assignment(self, "_cache_dir", tmp_data_dir): + with temporary_assignment(self, "_output_dir", tmp_output_dir): + # Try to download the already prepared dataset files downloaded_from_gcs = False if try_from_hf_gcs: @@ -702,7 +744,10 @@ def incomplete_dir(dirname): logger.warning("HF google storage unreachable. Downloading and preparing it from source") if not downloaded_from_gcs: self._download_and_prepare( - dl_manager=dl_manager, verify_infos=verify_infos, **download_and_prepare_kwargs + dl_manager=dl_manager, + verify_infos=verify_infos, + file_format=file_format, + **download_and_prepare_kwargs, ) # Sync info self.info.dataset_size = sum(split.num_bytes for split in self.info.splits.values()) @@ -715,7 +760,7 @@ def incomplete_dir(dirname): self.download_post_processing_resources(dl_manager) print( - f"Dataset {self.name} downloaded and prepared to {self._cache_dir}. " + f"Dataset {self.name} downloaded and prepared to {self._output_dir}. " f"Subsequent calls will reuse this data." ) @@ -734,10 +779,10 @@ def _check_manual_download(self, dl_manager): def _download_prepared_from_hf_gcs(self, download_config: DownloadConfig): relative_data_dir = self._relative_data_dir(with_version=True, with_hash=False) - reader = ArrowReader(self._cache_dir, self.info) + reader = ArrowReader(self._output_dir, self.info) # use reader instructions to download the right files reader.download_from_hf_gcs(download_config, relative_data_dir) - downloaded_info = DatasetInfo.from_directory(self._cache_dir) + downloaded_info = DatasetInfo.from_directory(self._output_dir) self.info.update(downloaded_info) # download post processing resources remote_cache_dir = HF_GCP_BASE_URL + "/" + relative_data_dir.replace(os.sep, "/") @@ -747,12 +792,12 @@ def _download_prepared_from_hf_gcs(self, download_config: DownloadConfig): raise ValueError(f"Resources shouldn't be in a sub-directory: {resource_file_name}") try: resource_path = cached_path(remote_cache_dir + "/" + resource_file_name) - shutil.move(resource_path, os.path.join(self._cache_dir, resource_file_name)) + shutil.move(resource_path, os.path.join(self._output_dir, resource_file_name)) except ConnectionError: logger.info(f"Couldn't download resourse file {resource_file_name} from Hf google storage.") logger.info("Dataset downloaded from Hf google storage.") - def _download_and_prepare(self, dl_manager, verify_infos, **prepare_split_kwargs): + def _download_and_prepare(self, dl_manager, verify_infos, file_format=None, **prepare_split_kwargs): """Downloads and prepares dataset for reading. This is the internal implementation to overwrite called when user calls @@ -760,9 +805,10 @@ def _download_and_prepare(self, dl_manager, verify_infos, **prepare_split_kwargs the pre-processed datasets files. Args: - dl_manager: (DownloadManager) `DownloadManager` used to download and cache - data. - verify_infos: bool, if False, do not perform checksums and size tests. + dl_manager: (:obj:`DownloadManager`) `DownloadManager` used to download and cache data. + verify_infos (:obj:`bool`): if False, do not perform checksums and size tests. + file_format (:obj:`str`, optional): format of the data files in which the dataset will be written. + Supported formats: "arrow", "parquet". Default to "arrow" format. prepare_split_kwargs: Additional options. """ # Generating data for all splits @@ -790,7 +836,7 @@ def _download_and_prepare(self, dl_manager, verify_infos, **prepare_split_kwargs try: # Prepare split will record examples associated to the split - self._prepare_split(split_generator, **prepare_split_kwargs) + self._prepare_split(split_generator, file_format=file_format, **prepare_split_kwargs) except OSError as e: raise OSError( "Cannot find data file. " @@ -815,11 +861,13 @@ def _download_and_prepare(self, dl_manager, verify_infos, **prepare_split_kwargs self.info.download_size = dl_manager.downloaded_size def download_post_processing_resources(self, dl_manager): - for split in self.info.splits: + for split in self.info.splits or []: for resource_name, resource_file_name in self._post_processing_resources(split).items(): + if not not is_remote_filesystem(self._fs): + raise NotImplementedError(f"Post processing is not supported on filesystem {self._fs}") if os.sep in resource_file_name: raise ValueError(f"Resources shouldn't be in a sub-directory: {resource_file_name}") - resource_path = os.path.join(self._cache_dir, resource_file_name) + resource_path = os.path.join(self._output_dir, resource_file_name) if not os.path.exists(resource_path): downloaded_resource_path = self._download_post_processing_resources( split, resource_name, dl_manager @@ -829,16 +877,20 @@ def download_post_processing_resources(self, dl_manager): shutil.move(downloaded_resource_path, resource_path) def _load_info(self) -> DatasetInfo: - return DatasetInfo.from_directory(self._cache_dir) + return DatasetInfo.from_directory(self._output_dir, fs=self._fs) def _save_info(self): - lock_path = os.path.join(self._cache_dir_root, self._cache_dir.replace(os.sep, "_") + ".lock") - with FileLock(lock_path): - self.info.write_to_directory(self._cache_dir) + is_local = not is_remote_filesystem(self._fs) + if is_local: + lock_path = self._output_dir + "_info.lock" + with FileLock(lock_path) if is_local else contextlib.nullcontext(): + self.info.write_to_directory(self._output_dir, fs=self._fs) def _save_infos(self): - lock_path = os.path.join(self._cache_dir_root, self._cache_dir.replace(os.sep, "_") + ".lock") - with FileLock(lock_path): + is_local = not is_remote_filesystem(self._fs) + if is_local: + lock_path = self._output_dir + "_infos.lock" + with FileLock(lock_path) if is_local else contextlib.nullcontext(): DatasetInfosDict(**{self.config.name: self.info}).write_to_directory(self.get_imported_module_dir()) def _make_split_generators_kwargs(self, prepare_split_kwargs): @@ -876,14 +928,17 @@ def as_dataset( }) ``` """ - if not os.path.exists(self._cache_dir): + is_local = not is_remote_filesystem(self._fs) + if not is_local: + raise NotImplementedError(f"Loading a dataset cached in a {type(self._fs).__name__} is not supported.") + if not os.path.exists(self._output_dir): raise AssertionError( - f"Dataset {self.name}: could not find data in {self._cache_dir_root}. Please make sure to call " - "builder.download_and_prepare(), or pass download=True to " + f"Dataset {self.name}: could not find data in {self._output_dir}. Please make sure to call " + "builder.download_and_prepare(), or use " "datasets.load_dataset() before trying to access the Dataset object." ) - logger.debug(f'Constructing Dataset for split {split or ", ".join(self.info.splits)}, from {self._cache_dir}') + logger.debug(f'Constructing Dataset for split {split or ", ".join(self.info.splits)}, from {self._output_dir}') # By default, return all splits if split is None: @@ -930,7 +985,7 @@ def _build_single_dataset( if os.sep in resource_file_name: raise ValueError(f"Resources shouldn't be in a sub-directory: {resource_file_name}") resources_paths = { - resource_name: os.path.join(self._cache_dir, resource_file_name) + resource_name: os.path.join(self._output_dir, resource_file_name) for resource_name, resource_file_name in self._post_processing_resources(split).items() } post_processed = self._post_process(ds, resources_paths) @@ -989,8 +1044,8 @@ def _as_dataset(self, split: Union[ReadInstruction, Split] = Split.TRAIN, in_mem Returns: `Dataset` """ - - dataset_kwargs = ArrowReader(self._cache_dir, self.info).read( + cache_dir = self._fs._strip_protocol(self._output_dir) + dataset_kwargs = ArrowReader(cache_dir, self.info).read( name=self.name, instructions=split, split_infos=self.info.splits.values(), @@ -1015,6 +1070,12 @@ def as_streaming_dataset( if not isinstance(self, (GeneratorBasedBuilder, ArrowBasedBuilder)): raise ValueError(f"Builder {self.name} is not streamable.") + is_local = not is_remote_filesystem(self._fs) + if not is_local: + raise NotImplementedError( + f"Loading a streaming dataset cached in a {type(self._fs).__name__} is not supported yet." + ) + dl_manager = StreamingDownloadManager( base_path=base_path or self.base_path, download_config=DownloadConfig(use_auth_token=self.use_auth_token), @@ -1113,11 +1174,13 @@ def _split_generators(self, dl_manager: DownloadManager): raise NotImplementedError() @abc.abstractmethod - def _prepare_split(self, split_generator: SplitGenerator, **kwargs): + def _prepare_split(self, split_generator: SplitGenerator, file_format: Optional[str] = None, **kwargs): """Generate the examples and record them on disk. Args: split_generator: `SplitGenerator`, Split generator to process + file_format (:obj:`str`, optional): format of the data files in which the dataset will be written. + Supported formats: "arrow", "parquet". Default to "arrow" format. **kwargs: Additional kwargs forwarded from _download_and_prepare (ex: beam pipeline) """ @@ -1188,23 +1251,32 @@ def _generate_examples(self, **kwargs): """ raise NotImplementedError() - def _prepare_split(self, split_generator, check_duplicate_keys): + def _prepare_split(self, split_generator, check_duplicate_keys, file_format=None): + is_local = not is_remote_filesystem(self._fs) + path_join = os.path.join if is_local else posixpath.join + if self.info.splits is not None: split_info = self.info.splits[split_generator.name] else: split_info = split_generator.split_info - fname = f"{self.name}-{split_generator.name}.arrow" - fpath = os.path.join(self._cache_dir, fname) + file_format = file_format or "arrow" + suffix = "-00000-of-00001" if file_format == "parquet" else "" + fname = f"{self.name}-{split_generator.name}{suffix}.{file_format}" + fpath = path_join(self._output_dir, fname) generator = self._generate_examples(**split_generator.gen_kwargs) - with ArrowWriter( + writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter + + # TODO: embed the images/audio files inside parquet files. + with writer_class( features=self.info.features, path=fpath, writer_batch_size=self._writer_batch_size, hash_salt=split_info.name, check_duplicates=check_duplicate_keys, + storage_options=self._fs.storage_options, ) as writer: try: for key, record in logging.tqdm( @@ -1223,8 +1295,10 @@ def _prepare_split(self, split_generator, check_duplicate_keys): split_generator.split_info.num_examples = num_examples split_generator.split_info.num_bytes = num_bytes - def _download_and_prepare(self, dl_manager, verify_infos): - super()._download_and_prepare(dl_manager, verify_infos, check_duplicate_keys=verify_infos) + def _download_and_prepare(self, dl_manager, verify_infos, file_format=None): + super()._download_and_prepare( + dl_manager, verify_infos, file_format=file_format, check_duplicate_keys=verify_infos + ) def _get_examples_iterable_for_split(self, split_generator: SplitGenerator) -> ExamplesIterable: return ExamplesIterable(self._generate_examples, split_generator.gen_kwargs) @@ -1265,12 +1339,19 @@ def _generate_tables(self, **kwargs): """ raise NotImplementedError() - def _prepare_split(self, split_generator): - fname = f"{self.name}-{split_generator.name}.arrow" - fpath = os.path.join(self._cache_dir, fname) + def _prepare_split(self, split_generator, file_format=None): + is_local = not is_remote_filesystem(self._fs) + path_join = os.path.join if is_local else posixpath.join + + file_format = file_format or "arrow" + suffix = "-00000-of-00001" if file_format == "parquet" else "" + fname = f"{self.name}-{split_generator.name}{suffix}.{file_format}" + fpath = path_join(self._output_dir, fname) generator = self._generate_tables(**split_generator.gen_kwargs) - with ArrowWriter(features=self.info.features, path=fpath) as writer: + writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter + # TODO: embed the images/audio files inside parquet files. + with writer_class(features=self.info.features, path=fpath, storage_options=self._fs.storage_options) as writer: for key, table in logging.tqdm( generator, unit=" tables", leave=False, disable=(not logging.is_progress_bar_enabled()) ): @@ -1351,7 +1432,7 @@ def _build_pcollection(pipeline, extracted_dir=None): """ raise NotImplementedError() - def _download_and_prepare(self, dl_manager, verify_infos): + def _download_and_prepare(self, dl_manager, verify_infos, file_format=None): # Create the Beam pipeline and forward it to _prepare_split import apache_beam as beam @@ -1389,6 +1470,7 @@ def _download_and_prepare(self, dl_manager, verify_infos): dl_manager, verify_infos=False, pipeline=pipeline, + file_format=file_format, ) # TODO handle verify_infos in beam datasets # Run pipeline pipeline_results = pipeline.run() @@ -1404,27 +1486,27 @@ def _download_and_prepare(self, dl_manager, verify_infos): split_info.num_bytes = num_bytes def _save_info(self): - if os.path.exists(self._cache_dir): - super()._save_info() - else: - import apache_beam as beam + import apache_beam as beam - fs = beam.io.filesystems.FileSystems - with fs.create(os.path.join(self._cache_dir, config.DATASET_INFO_FILENAME)) as f: - self.info._dump_info(f) - if self.info.license: - with fs.create(os.path.join(self._cache_dir, config.LICENSE_FILENAME)) as f: - self.info._dump_license(f) + fs = beam.io.filesystems.FileSystems + path_join = os.path.join if not is_remote_filesystem(self._fs) else posixpath.join + with fs.create(path_join(self._output_dir, config.DATASET_INFO_FILENAME)) as f: + self.info._dump_info(f) + if self.info.license: + with fs.create(path_join(self._output_dir, config.LICENSE_FILENAME)) as f: + self.info._dump_license(f) - def _prepare_split(self, split_generator, pipeline): + def _prepare_split(self, split_generator, pipeline, file_format=None): import apache_beam as beam - # To write examples to disk: + # To write examples in filesystem: split_name = split_generator.split_info.name - fname = f"{self.name}-{split_name}.arrow" - fpath = os.path.join(self._cache_dir, fname) + file_format = file_format or "arrow" + fname = f"{self.name}-{split_name}.{file_format}" + path_join = os.path.join if not is_remote_filesystem(self._fs) else posixpath.join + fpath = path_join(self._output_dir, fname) beam_writer = BeamWriter( - features=self.info.features, path=fpath, namespace=split_name, cache_dir=self._cache_dir + features=self.info.features, path=fpath, namespace=split_name, cache_dir=self._output_dir ) self._beam_writers[split_name] = beam_writer diff --git a/src/datasets/download/streaming_download_manager.py b/src/datasets/download/streaming_download_manager.py index 697f5122b44..c61531c4915 100644 --- a/src/datasets/download/streaming_download_manager.py +++ b/src/datasets/download/streaming_download_manager.py @@ -526,7 +526,8 @@ def xglob(urlpath, *, recursive=False, use_auth_token: Optional[Union[str, bool] # - If there is "**" in the pattern, `fs.glob` must be called anyway. inner_path = main_hop.split("://")[1] globbed_paths = fs.glob(inner_path) - return ["::".join([f"{fs.protocol}://{globbed_path}"] + rest_hops) for globbed_path in globbed_paths] + protocol = fs.protocol if isinstance(fs.protocol, str) else fs.protocol[-1] + return ["::".join([f"{protocol}://{globbed_path}"] + rest_hops) for globbed_path in globbed_paths] def xwalk(urlpath, use_auth_token: Optional[Union[str, bool]] = None): @@ -558,8 +559,9 @@ def xwalk(urlpath, use_auth_token: Optional[Union[str, bool]] = None): inner_path = main_hop.split("://")[1] if inner_path.strip("/") and not fs.isdir(inner_path): return [] + protocol = fs.protocol if isinstance(fs.protocol, str) else fs.protocol[-1] for dirpath, dirnames, filenames in fs.walk(inner_path): - yield "::".join([f"{fs.protocol}://{dirpath}"] + rest_hops), dirnames, filenames + yield "::".join([f"{protocol}://{dirpath}"] + rest_hops), dirnames, filenames class xPath(type(Path())): diff --git a/src/datasets/info.py b/src/datasets/info.py index 565f893e5be..2b5acfdde92 100644 --- a/src/datasets/info.py +++ b/src/datasets/info.py @@ -32,11 +32,15 @@ import dataclasses import json import os +import posixpath from dataclasses import dataclass, field from typing import Dict, List, Optional, Union +from fsspec.implementations.local import LocalFileSystem + from . import config from .features import Features, Value +from .filesystems import is_remote_filesystem from .splits import SplitDict from .tasks import TaskTemplate, task_template_from_dict from .utils import Version @@ -176,15 +180,16 @@ def __post_init__(self): template.align_with_features(self.features) for template in (self.task_templates) ] - def _license_path(self, dataset_info_dir): - return os.path.join(dataset_info_dir, config.LICENSE_FILENAME) - - def write_to_directory(self, dataset_info_dir, pretty_print=False): + def write_to_directory(self, dataset_info_dir, pretty_print=False, fs=None): """Write `DatasetInfo` and license (if present) as JSON files to `dataset_info_dir`. Args: dataset_info_dir (str): Destination directory. pretty_print (bool, default ``False``): If True, the JSON will be pretty-printed with the indent level of 4. + fs (``fsspec.spec.AbstractFileSystem``, optional, defaults ``None``): + Instance of the remote filesystem used to download the files from. + + Example: @@ -194,10 +199,14 @@ def write_to_directory(self, dataset_info_dir, pretty_print=False): >>> ds.info.write_to_directory("/path/to/directory/") ``` """ - with open(os.path.join(dataset_info_dir, config.DATASET_INFO_FILENAME), "wb") as f: + fs = fs or LocalFileSystem() + is_local = not is_remote_filesystem(fs) + path_join = os.path.join if is_local else posixpath.join + + with fs.open(path_join(dataset_info_dir, config.DATASET_INFO_FILENAME), "wb") as f: self._dump_info(f, pretty_print=pretty_print) if self.license: - with open(os.path.join(dataset_info_dir, config.LICENSE_FILENAME), "wb") as f: + with fs.open(path_join(dataset_info_dir, config.LICENSE_FILENAME), "wb") as f: self._dump_license(f) def _dump_info(self, file, pretty_print=False): @@ -239,7 +248,7 @@ def from_merge(cls, dataset_infos: List["DatasetInfo"]): ) @classmethod - def from_directory(cls, dataset_info_dir: str) -> "DatasetInfo": + def from_directory(cls, dataset_info_dir: str, fs=None) -> "DatasetInfo": """Create DatasetInfo from the JSON file in `dataset_info_dir`. This function updates all the dynamically generated fields (num_examples, @@ -250,6 +259,10 @@ def from_directory(cls, dataset_info_dir: str) -> "DatasetInfo": Args: dataset_info_dir (`str`): The directory containing the metadata file. This should be the root directory of a specific dataset version. + fs (``fsspec.spec.AbstractFileSystem``, optional, defaults ``None``): + Instance of the remote filesystem used to download the files from. + + Example: @@ -258,11 +271,15 @@ def from_directory(cls, dataset_info_dir: str) -> "DatasetInfo": >>> ds_info = DatasetInfo.from_directory("/path/to/directory/") ``` """ + fs = fs or LocalFileSystem() logger.info(f"Loading Dataset info from {dataset_info_dir}") if not dataset_info_dir: raise ValueError("Calling DatasetInfo.from_directory() with undefined dataset_info_dir.") - with open(os.path.join(dataset_info_dir, config.DATASET_INFO_FILENAME), encoding="utf-8") as f: + is_local = not is_remote_filesystem(fs) + path_join = os.path.join if is_local else posixpath.join + + with fs.open(path_join(dataset_info_dir, config.DATASET_INFO_FILENAME), "r", encoding="utf-8") as f: dataset_info_dict = json.load(f) return cls.from_dict(dataset_info_dict) diff --git a/src/datasets/utils/py_utils.py b/src/datasets/utils/py_utils.py index 07436d016f3..f56b765a7c2 100644 --- a/src/datasets/utils/py_utils.py +++ b/src/datasets/utils/py_utils.py @@ -17,7 +17,6 @@ """ -import contextlib import copy import functools import itertools @@ -189,7 +188,7 @@ def _asdict_inner(obj): return _asdict_inner(obj) -@contextlib.contextmanager +@contextmanager def temporary_assignment(obj, attr, value): """Temporarily assign obj.attr to value.""" original = getattr(obj, attr, None) @@ -601,7 +600,7 @@ def dump(obj, file): return -@contextlib.contextmanager +@contextmanager def _no_cache_fields(obj): try: if ( diff --git a/tests/conftest.py b/tests/conftest.py index 04f37225581..c1a1af48e57 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,7 @@ # Import fixture modules as plugins -pytest_plugins = ["tests.fixtures.files", "tests.fixtures.hub", "tests.fixtures.s3"] +pytest_plugins = ["tests.fixtures.files", "tests.fixtures.hub", "tests.fixtures.s3", "tests.fixtures.fsspec"] def pytest_collection_modifyitems(config, items): diff --git a/tests/fixtures/fsspec.py b/tests/fixtures/fsspec.py new file mode 100644 index 00000000000..7a301116ea8 --- /dev/null +++ b/tests/fixtures/fsspec.py @@ -0,0 +1,84 @@ +import posixpath +from pathlib import Path + +import fsspec +import pytest +from fsspec.implementations.local import AbstractFileSystem, LocalFileSystem, stringify_path + + +class MockFileSystem(AbstractFileSystem): + protocol = "mock" + + def __init__(self, *args, local_root_dir, **kwargs): + super().__init__() + self._fs = LocalFileSystem(*args, **kwargs) + self.local_root_dir = Path(local_root_dir).resolve().as_posix() + "/" + + def mkdir(self, path, *args, **kwargs): + path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) + return self._fs.mkdir(path, *args, **kwargs) + + def makedirs(self, path, *args, **kwargs): + path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) + return self._fs.makedirs(path, *args, **kwargs) + + def rmdir(self, path): + path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) + return self._fs.rmdir(path) + + def ls(self, path, detail=True, *args, **kwargs): + path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) + out = self._fs.ls(path, detail=detail, *args, **kwargs) + if detail: + return [{**info, "name": info["name"][len(self.local_root_dir) :]} for info in out] + else: + return [name[len(self.local_root_dir) :] for name in out] + + def info(self, path, *args, **kwargs): + path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) + out = dict(self._fs.info(path, *args, **kwargs)) + out["name"] = out["name"][len(self.local_root_dir) :] + return out + + def cp_file(self, path1, path2, *args, **kwargs): + path1 = posixpath.join(self.local_root_dir, self._strip_protocol(path1)) + path2 = posixpath.join(self.local_root_dir, self._strip_protocol(path2)) + return self._fs.cp_file(path1, path2, *args, **kwargs) + + def rm_file(self, path, *args, **kwargs): + path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) + return self._fs.rm_file(path, *args, **kwargs) + + def rm(self, path, *args, **kwargs): + path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) + return self._fs.rm(path, *args, **kwargs) + + def _open(self, path, *args, **kwargs): + path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) + return self._fs._open(path, *args, **kwargs) + + def created(self, path): + path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) + return self._fs.created(path) + + def modified(self, path): + path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) + return self._fs.modified(path) + + @classmethod + def _strip_protocol(cls, path): + path = stringify_path(path) + if path.startswith("mock://"): + path = path[7:] + return path + + +@pytest.fixture +def mock_fsspec(monkeypatch): + monkeypatch.setitem(fsspec.registry.target, "mock", MockFileSystem) + + +@pytest.fixture +def mockfs(tmp_path_factory, mock_fsspec): + local_fs_dir = tmp_path_factory.mktemp("mockfs") + return MockFileSystem(local_root_dir=local_fs_dir) diff --git a/tests/test_arrow_writer.py b/tests/test_arrow_writer.py index b5542755db5..794b03be25b 100644 --- a/tests/test_arrow_writer.py +++ b/tests/test_arrow_writer.py @@ -6,9 +6,10 @@ import numpy as np import pyarrow as pa +import pyarrow.parquet as pq import pytest -from datasets.arrow_writer import ArrowWriter, OptimizedTypedSequence, TypedSequence +from datasets.arrow_writer import ArrowWriter, OptimizedTypedSequence, ParquetWriter, TypedSequence from datasets.features import Array2D, ClassLabel, Features, Image, Value from datasets.features.features import Array2DExtensionType, cast_to_python_objects from datasets.keyhash import DuplicatedKeysError, InvalidKeyError @@ -299,3 +300,29 @@ def test_arrow_writer_closes_stream(raise_exception, tmp_path): pass finally: assert writer.stream.closed + + +def test_arrow_writer_with_filesystem(mockfs): + path = "mock://dataset-train.arrow" + with ArrowWriter(path=path, storage_options=mockfs.storage_options) as writer: + assert isinstance(writer._fs, type(mockfs)) + assert writer._fs.storage_options == mockfs.storage_options + writer.write({"col_1": "foo", "col_2": 1}) + writer.write({"col_1": "bar", "col_2": 2}) + num_examples, num_bytes = writer.finalize() + assert num_examples == 2 + assert num_bytes > 0 + assert mockfs.exists(path) + + +def test_parquet_writer_write(): + output = pa.BufferOutputStream() + with ParquetWriter(stream=output) as writer: + writer.write({"col_1": "foo", "col_2": 1}) + writer.write({"col_1": "bar", "col_2": 2}) + num_examples, num_bytes = writer.finalize() + assert num_examples == 2 + assert num_bytes > 0 + stream = pa.BufferReader(output.getvalue()) + pa_table: pa.Table = pq.read_table(stream) + assert pa_table.to_pydict() == {"col_1": ["foo", "bar"], "col_2": [1, 2]} diff --git a/tests/test_builder.py b/tests/test_builder.py index 01e39e6cabd..04ea1c87ccc 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -8,12 +8,14 @@ from unittest.mock import patch import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq import pytest from multiprocess.pool import Pool from datasets.arrow_dataset import Dataset from datasets.arrow_writer import ArrowWriter -from datasets.builder import BuilderConfig, DatasetBuilder, GeneratorBasedBuilder +from datasets.builder import ArrowBasedBuilder, BeamBasedBuilder, BuilderConfig, DatasetBuilder, GeneratorBasedBuilder from datasets.dataset_dict import DatasetDict, IterableDatasetDict from datasets.download.download_manager import DownloadMode from datasets.features import Features, Value @@ -21,8 +23,9 @@ from datasets.iterable_dataset import IterableDataset from datasets.splits import Split, SplitDict, SplitGenerator, SplitInfo from datasets.streaming import xjoin +from datasets.utils.file_utils import is_local_path -from .utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases, require_faiss +from .utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases, require_beam, require_faiss class DummyBuilder(DatasetBuilder): @@ -34,7 +37,7 @@ def _split_generators(self, dl_manager): def _prepare_split(self, split_generator, **kwargs): fname = f"{self.name}-{split_generator.name}.arrow" - with ArrowWriter(features=self.info.features, path=os.path.join(self._cache_dir, fname)) as writer: + with ArrowWriter(features=self.info.features, path=os.path.join(self._output_dir, fname)) as writer: writer.write_batch({"text": ["foo"] * 100}) num_examples, num_bytes = writer.finalize() split_generator.split_info.num_examples = num_examples @@ -57,6 +60,35 @@ def _generate_examples(self): yield i, {"text": "foo"} +class DummyArrowBasedBuilder(ArrowBasedBuilder): + def _info(self): + return DatasetInfo(features=Features({"text": Value("string")})) + + def _split_generators(self, dl_manager): + return [SplitGenerator(name=Split.TRAIN)] + + def _generate_tables(self): + for i in range(10): + yield i, pa.table({"text": ["foo"] * 10}) + + +class DummyBeamBasedBuilder(BeamBasedBuilder): + def _info(self): + return DatasetInfo(features=Features({"text": Value("string")})) + + def _split_generators(self, dl_manager): + return [SplitGenerator(name=Split.TRAIN)] + + def _build_pcollection(self, pipeline): + import apache_beam as beam + + def _process(item): + for i in range(10): + yield f"{i}_{item}", {"text": "foo"} + + return pipeline | "Initialize" >> beam.Create(range(10)) | "Extract content" >> beam.FlatMap(_process) + + class DummyGeneratorBasedBuilderWithIntegers(GeneratorBasedBuilder): def _info(self): return DatasetInfo(features=Features({"id": Value("int8")})) @@ -690,6 +722,41 @@ def test_cache_dir_for_data_dir(self): self.assertNotEqual(builder.cache_dir, other_builder.cache_dir) +def test_arrow_based_download_and_prepare(tmp_path): + builder = DummyArrowBasedBuilder(cache_dir=tmp_path) + builder.download_and_prepare() + assert os.path.exists( + os.path.join( + tmp_path, + builder.name, + "default", + "0.0.0", + f"{builder.name}-train.arrow", + ) + ) + assert builder.info.features, Features({"text": Value("string")}) + assert builder.info.splits["train"].num_examples, 100 + assert os.path.exists(os.path.join(tmp_path, builder.name, "default", "0.0.0", "dataset_info.json")) + + +@require_beam +def test_beam_based_download_and_prepare(tmp_path): + builder = DummyBeamBasedBuilder(cache_dir=tmp_path, beam_runner="DirectRunner") + builder.download_and_prepare() + assert os.path.exists( + os.path.join( + tmp_path, + builder.name, + "default", + "0.0.0", + f"{builder.name}-train.arrow", + ) + ) + assert builder.info.features, Features({"text": Value("string")}) + assert builder.info.splits["train"].num_examples, 100 + assert os.path.exists(os.path.join(tmp_path, builder.name, "default", "0.0.0", "dataset_info.json")) + + @pytest.mark.parametrize( "split, expected_dataset_class, expected_dataset_length", [ @@ -846,3 +913,58 @@ def test_builder_config_version(builder_class, kwargs, tmp_path): cache_dir = str(tmp_path) builder = builder_class(cache_dir=cache_dir, **kwargs) assert builder.config.version == "2.0.0" + + +def test_builder_with_filesystem_download_and_prepare(tmp_path, mockfs): + builder = DummyGeneratorBasedBuilder(cache_dir=tmp_path) + builder.download_and_prepare("mock://my_dataset", storage_options=mockfs.storage_options) + assert builder._output_dir.startswith("mock://my_dataset") + assert is_local_path(builder._cache_downloaded_dir) + assert isinstance(builder._fs, type(mockfs)) + assert builder._fs.storage_options == mockfs.storage_options + assert mockfs.exists("my_dataset/dataset_info.json") + assert mockfs.exists(f"my_dataset/{builder.name}-train.arrow") + assert not mockfs.exists("my_dataset.incomplete") + + +def test_builder_with_filesystem_download_and_prepare_reload(tmp_path, mockfs, caplog): + builder = DummyGeneratorBasedBuilder(cache_dir=tmp_path) + mockfs.makedirs("my_dataset") + DatasetInfo().write_to_directory("my_dataset", fs=mockfs) + mockfs.touch(f"my_dataset/{builder.name}-train.arrow") + caplog.clear() + builder.download_and_prepare("mock://my_dataset", storage_options=mockfs.storage_options) + assert "Found cached dataset" in caplog.text + + +def test_generator_based_builder_download_and_prepare_as_parquet(tmp_path): + builder = DummyGeneratorBasedBuilder(cache_dir=tmp_path) + builder.download_and_prepare(file_format="parquet") + assert builder.info.splits["train"].num_examples, 100 + parquet_path = os.path.join( + tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-00001.parquet" + ) + assert os.path.exists(parquet_path) + assert pq.ParquetFile(parquet_path) is not None + + +def test_arrow_based_builder_download_and_prepare_as_parquet(tmp_path): + builder = DummyArrowBasedBuilder(cache_dir=tmp_path) + builder.download_and_prepare(file_format="parquet") + assert builder.info.splits["train"].num_examples, 100 + parquet_path = os.path.join( + tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-00001.parquet" + ) + assert os.path.exists(parquet_path) + assert pq.ParquetFile(parquet_path) is not None + + +def test_beam_based_builder_download_and_prepare_as_parquet(tmp_path): + builder = DummyBeamBasedBuilder(cache_dir=tmp_path, beam_runner="DirectRunner") + builder.download_and_prepare(file_format="parquet") + assert builder.info.splits["train"].num_examples, 100 + parquet_path = os.path.join( + tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-00001.parquet" + ) + assert os.path.exists(parquet_path) + assert pq.ParquetFile(parquet_path) is not None diff --git a/tests/test_load.py b/tests/test_load.py index c47ea967b00..4436f8bd88c 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -908,7 +908,7 @@ def test_load_dataset_then_move_then_reload(dataset_loading_script_dir, data_dir os.rename(cache_dir1, cache_dir2) caplog.clear() dataset = load_dataset(dataset_loading_script_dir, data_dir=data_dir, split="train", cache_dir=cache_dir2) - assert "Reusing dataset" in caplog.text + assert "Found cached dataset" in caplog.text assert dataset._fingerprint == fingerprint1, "for the caching mechanism to work, fingerprint should stay the same" dataset = load_dataset(dataset_loading_script_dir, data_dir=data_dir, split="test", cache_dir=cache_dir2) assert dataset._fingerprint != fingerprint1