diff --git a/docs/source/filesystems.mdx b/docs/source/filesystems.mdx
index 048dcafcd36..f2da377568f 100644
--- a/docs/source/filesystems.mdx
+++ b/docs/source/filesystems.mdx
@@ -1,6 +1,8 @@
# Cloud storage
-🤗 Datasets supports access to cloud storage providers through a S3 filesystem implementation: [`filesystems.S3FileSystem`]. You can save and load datasets from your Amazon S3 bucket in a Pythonic way. Take a look at the following table for other supported cloud storage providers:
+🤗 Datasets supports access to cloud storage providers through a `fsspec` FileSystem implementations.
+You can save and load datasets from any cloud storage in a Pythonic way.
+Take a look at the following table for some example of supported cloud storage providers:
| Storage provider | Filesystem implementation |
|----------------------|---------------------------------------------------------------|
@@ -10,11 +12,12 @@
| Dropbox | [dropboxdrivefs](https://github.com/MarineChap/dropboxdrivefs)|
| Google Drive | [gdrivefs](https://github.com/intake/gdrivefs) |
-This guide will show you how to save and load datasets with **s3fs** to a S3 bucket, but other filesystem implementations can be used similarly. An example is shown also for Google Cloud Storage and Azure Blob Storage.
+This guide will show you how to save and load datasets with any cloud storage.
+Here are examples for S3, Google Cloud Storage and Azure Blob Storage.
-## Amazon S3
+## Set up your cloud storage FileSystem
-### Listing datasets
+### Amazon S3
1. Install the S3 dependency with 🤗 Datasets:
@@ -22,163 +25,178 @@ This guide will show you how to save and load datasets with **s3fs** to a S3 buc
>>> pip install datasets[s3]
```
-2. List files from a public S3 bucket with `s3.ls`:
+2. Define your credentials
+
+To use an anonymous connection, use `anon=True`.
+Otherwise, include your `aws_access_key_id` and `aws_secret_access_key` whenever you are interacting with a private S3 bucket.
```py
->>> import datasets
->>> s3 = datasets.filesystems.S3FileSystem(anon=True)
->>> s3.ls('public-datasets/imdb/train')
-['dataset_info.json.json','dataset.arrow','state.json']
+>>> storage_options = {"anon": True} # for anonymous connection
+# or use your credentials
+>>> storage_options = {"key": aws_access_key_id, "secret": aws_secret_access_key} # for private buckets
+# or use a botocore session
+>>> import botocore
+>>> s3_session = botocore.session.Session(profile="my_profile_name")
+>>> storage_options = {"session": s3_session}
```
-Access a private S3 bucket by entering your `aws_access_key_id` and `aws_secret_access_key`:
+3. Create your FileSystem instance
```py
->>> import datasets
->>> s3 = datasets.filesystems.S3FileSystem(key=aws_access_key_id, secret=aws_secret_access_key)
->>> s3.ls('my-private-datasets/imdb/train')
-['dataset_info.json.json','dataset.arrow','state.json']
+>>> import s3fs
+>>> fs = s3fs.S3FileSystem(**storage_options)
```
-### Saving datasets
+### Google Cloud Storage
-After you have processed your dataset, you can save it to S3 with [`Dataset.save_to_disk`]:
+1. Install the Google Cloud Storage implementation:
+
+```
+>>> conda install -c conda-forge gcsfs
+# or install with pip
+>>> pip install gcsfs
+```
+
+2. Define your credentials
```py
->>> from datasets.filesystems import S3FileSystem
+>>> storage_options={"token": "anon"} # for anonymous connection
+# or use your credentials of your default gcloud credentials or from the google metadata service
+>>> storage_options={"project": "my-google-project"}
+# or use your credentials from elsewhere, see the documentation at https://gcsfs.readthedocs.io/
+>>> storage_options={"project": "my-google-project", "token": TOKEN}
+```
-# create S3FileSystem instance
->>> s3 = S3FileSystem(anon=True)
+3. Create your FileSystem instance
-# saves encoded_dataset to your s3 bucket
->>> encoded_dataset.save_to_disk('s3://my-private-datasets/imdb/train', fs=s3)
+```py
+>>> import gcsfs
+>>> fs = gcsfs.GCSFileSystem(**storage_options)
```
-
+### Azure Blob Storage
-Remember to include your `aws_access_key_id` and `aws_secret_access_key` whenever you are interacting with a private S3 bucket.
+1. Install the Azure Blob Storage implementation:
-
+```
+>>> conda install -c conda-forge adlfs
+# or install with pip
+>>> pip install adlfs
+```
-Save your dataset with `botocore.session.Session` and a custom AWS profile:
+2. Define your credentials
```py
->>> import botocore
->>> from datasets.filesystems import S3FileSystem
-
-# creates a botocore session with the provided AWS profile
->>> s3_session = botocore.session.Session(profile='my_profile_name')
+>>> storage_options = {"anon": True} # for anonymous connection
+# or use your credentials
+>>> storage_options = {"account_name": ACCOUNT_NAME, "account_key": ACCOUNT_KEY) # gen 2 filesystem
+# or use your credentials with the gen 1 filesystem
+>>> storage_options={"tenant_id": TENANT_ID, "client_id": CLIENT_ID, "client_secret": CLIENT_SECRET}
+```
-# create S3FileSystem instance with s3_session
->>> s3 = S3FileSystem(session=s3_session)
+3. Create your FileSystem instance
-# saves encoded_dataset to your s3 bucket
->>> encoded_dataset.save_to_disk('s3://my-private-datasets/imdb/train',fs=s3)
+```py
+>>> import adlfs
+>>> fs = adlfs.AzureBlobFileSystem(**storage_options)
```
-### Loading datasets
+## Load and Save your datasets using your cloud storage FileSystem
-When you are ready to use your dataset again, reload it with [`Dataset.load_from_disk`]:
+### Download and prepare a dataset into a cloud storage
-```py
->>> from datasets import load_from_disk
->>> from datasets.filesystems import S3FileSystem
+You can download and prepare a dataset into your cloud storage by specifying a remote `output_dir` in `download_and_prepare`.
+Don't forget to use the previously defined `storage_options` containing your credentials to write into a private cloud storage.
-# create S3FileSystem without credentials
->>> s3 = S3FileSystem(anon=True)
+The `download_and_prepare` method works in two steps:
+1. it first downloads the raw data files (if any) in your local cache. You can set your cache directory by passing `cache_dir` to [`load_dataset_builder`]
+2. then it generates the dataset in Arrow or Parquet format in your cloud storage by iterating over the raw data files.
-# load encoded_dataset to from s3 bucket
->>> dataset = load_from_disk('s3://a-public-datasets/imdb/train',fs=s3)
+Load a dataset builder from the Hugging Face Hub (see [how to load from the Hugging Face Hub](./loading#hugging-face-hub)):
->>> print(len(dataset))
->>> # 25000
+```py
+>>> output_dir = "s3://my-bucket/imdb"
+>>> builder = load_dataset_builder("imdb")
+>>> builder.download_and_prepare(output_dir, storage_options=storage_options, file_format="parquet")
```
-Load with `botocore.session.Session` and custom AWS profile:
+Load a dataset builder using a loading script (see [how to load a local loading script](./loading#local-loading-script)):
```py
->>> import botocore
->>> from datasets.filesystems import S3FileSystem
-
-# create S3FileSystem instance with aws_access_key_id and aws_secret_access_key
->>> s3_session = botocore.session.Session(profile='my_profile_name')
-
-# create S3FileSystem instance with s3_session
->>> s3 = S3FileSystem(session=s3_session)
+>>> output_dir = "s3://my-bucket/imdb"
+>>> builder = load_dataset_builder("path/to/local/loading_script/loading_script.py")
+>>> builder.download_and_prepare(output_dir, storage_options=storage_options, file_format="parquet")
+```
-# load encoded_dataset to from s3 bucket
->>> dataset = load_from_disk('s3://my-private-datasets/imdb/train',fs=s3)
+Use your own data files (see [how to load local and remote files](./loading#local-and-remote-files)):
->>> print(len(dataset))
->>> # 25000
+```py
+>>> data_files = {"train": ["path/to/train.csv"]}
+>>> output_dir = "s3://my-bucket/imdb"
+>>> builder = load_dataset_builder("csv", data_files=data_files)
+>>> builder.download_and_prepare(output_dir, storage_options=storage_options, file_format="parquet")
```
-## Google Cloud Storage
+It is highly recommended to save the files as compressed Parquet files to optimize I/O by specifying `file_format="parquet"`.
+Otherwise the dataset is saved as an uncompressed Arrow file.
-1. Install the Google Cloud Storage implementation:
+#### Dask
-```
->>> conda install -c conda-forge gcsfs
-# or install with pip
->>> pip install gcsfs
-```
+Dask is a parallel computing library and it has a pandas-like API for working with larger than memory Parquet datasets in parallel.
+Dask can use multiple threads or processes on a single machine, or a cluster of machines to process data in parallel.
+Dask supports local data but also data from a cloud storage.
-2. Save your dataset:
+Therefore you can load a dataset saved as sharded Parquet files in Dask with
```py
->>> import gcsfs
+import dask.dataframe as dd
-# create GCSFileSystem instance using default gcloud credentials with project
->>> gcs = gcsfs.GCSFileSystem(project='my-google-project')
+df = dd.read_parquet(output_dir, storage_options=storage_options)
-# saves encoded_dataset to your gcs bucket
->>> encoded_dataset.save_to_disk('gcs://my-private-datasets/imdb/train', fs=gcs)
+# or if your dataset is split into train/valid/test
+df_train = dd.read_parquet(output_dir + f"/{builder.name}-train-*.parquet", storage_options=storage_options)
+df_valid = dd.read_parquet(output_dir + f"/{builder.name}-validation-*.parquet", storage_options=storage_options)
+df_test = dd.read_parquet(output_dir + f"/{builder.name}-test-*.parquet", storage_options=storage_options)
```
-3. Load your dataset:
+You can find more about dask dataframes in their [documentation](https://docs.dask.org/en/stable/dataframe.html).
-```py
->>> import gcsfs
->>> from datasets import load_from_disk
+## Saving serialized datasets
-# create GCSFileSystem instance using default gcloud credentials with project
->>> gcs = gcsfs.GCSFileSystem(project='my-google-project')
+After you have processed your dataset, you can save it to your cloud storage with [`Dataset.save_to_disk`]:
-# loads encoded_dataset from your gcs bucket
->>> dataset = load_from_disk('gcs://my-private-datasets/imdb/train', fs=gcs)
+```py
+# saves encoded_dataset to amazon s3
+>>> encoded_dataset.save_to_disk("s3://my-private-datasets/imdb/train", fs=fs)
+# saves encoded_dataset to google cloud storage
+>>> encoded_dataset.save_to_disk("gcs://my-private-datasets/imdb/train", fs=fs)
+# saves encoded_dataset to microsoft azure blob/datalake
+>>> encoded_dataset.save_to_disk("adl://my-private-datasets/imdb/train", fs=fs)
```
-## Azure Blob Storage
-
-1. Install the Azure Blob Storage implementation:
+
-```
->>> conda install -c conda-forge adlfs
-# or install with pip
->>> pip install adlfs
-```
+Remember to define your credentials in your [FileSystem instance](#set-up-your-cloud-storage-filesystem) `fs` whenever you are interacting with a private cloud storage.
-2. Save your dataset:
+
-```py
->>> import adlfs
+## Listing serialized datasets
-# create AzureBlobFileSystem instance with account_name and account_key
->>> abfs = adlfs.AzureBlobFileSystem(account_name="XXXX", account_key="XXXX")
+List files from a cloud storage with your FileSystem instance `fs`, using `fs.ls`:
-# saves encoded_dataset to your azure container
->>> encoded_dataset.save_to_disk('abfs://my-private-datasets/imdb/train', fs=abfs)
+```py
+>>> fs.ls("my-private-datasets/imdb/train")
+["dataset_info.json.json","dataset.arrow","state.json"]
```
-3. Load your dataset:
+### Load serialized datasets
+
+When you are ready to use your dataset again, reload it with [`Dataset.load_from_disk`]:
```py
->>> import adlfs
>>> from datasets import load_from_disk
-
-# create AzureBlobFileSystem instance with account_name and account_key
->>> abfs = adlfs.AzureBlobFileSystem(account_name="XXXX", account_key="XXXX")
-
-# loads encoded_dataset from your azure container
->>> dataset = load_from_disk('abfs://my-private-datasets/imdb/train', fs=abfs)
+# load encoded_dataset from cloud storage
+>>> dataset = load_from_disk("s3://a-public-datasets/imdb/train", fs=fs)
+>>> print(len(dataset))
+25000
```
diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py
index bc2c5ab37ba..54b5b7726d7 100644
--- a/src/datasets/arrow_dataset.py
+++ b/src/datasets/arrow_dataset.py
@@ -1134,7 +1134,7 @@ def save_to_disk(self, dataset_path: str, fs=None):
fs.makedirs(dataset_path, exist_ok=True)
with fs.open(Path(dataset_path, config.DATASET_ARROW_FILENAME).as_posix(), "wb") as dataset_file:
with ArrowWriter(stream=dataset_file) as writer:
- writer.write_table(dataset._data)
+ writer.write_table(dataset._data.table)
writer.finalize()
with fs.open(
Path(dataset_path, config.DATASET_STATE_JSON_FILENAME).as_posix(), "w", encoding="utf-8"
diff --git a/src/datasets/arrow_writer.py b/src/datasets/arrow_writer.py
index 6ffdb0f715d..0230b840faf 100644
--- a/src/datasets/arrow_writer.py
+++ b/src/datasets/arrow_writer.py
@@ -18,8 +18,10 @@
import sys
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+import fsspec
import numpy as np
import pyarrow as pa
+import pyarrow.parquet as pq
from . import config
from .features import Features, Image, Value
@@ -33,6 +35,7 @@
numpy_to_pyarrow_listarray,
to_pyarrow_listarray,
)
+from .filesystems import is_remote_filesystem
from .info import DatasetInfo
from .keyhash import DuplicatedKeysError, KeyHasher
from .table import array_cast, cast_array_to_feature, table_cast
@@ -268,6 +271,8 @@ def __init__(
class ArrowWriter:
"""Shuffles and writes Examples to Arrow files."""
+ _WRITER_CLASS = pa.RecordBatchStreamWriter
+
def __init__(
self,
schema: Optional[pa.Schema] = None,
@@ -282,6 +287,7 @@ def __init__(
update_features: bool = False,
with_metadata: bool = True,
unit: str = "examples",
+ storage_options: Optional[dict] = None,
):
if path is None and stream is None:
raise ValueError("At least one of path and stream must be provided.")
@@ -304,11 +310,19 @@ def __init__(
self._check_duplicates = check_duplicates
self._disable_nullable = disable_nullable
- self._path = path
if stream is None:
- self.stream = pa.OSFile(self._path, "wb")
+ fs_token_paths = fsspec.get_fs_token_paths(path, storage_options=storage_options)
+ self._fs: fsspec.AbstractFileSystem = fs_token_paths[0]
+ self._path = (
+ fs_token_paths[2][0]
+ if not is_remote_filesystem(self._fs)
+ else self._fs.unstrip_protocol(fs_token_paths[2][0])
+ )
+ self.stream = self._fs.open(fs_token_paths[2][0], "wb")
self._closable_stream = True
else:
+ self._fs = None
+ self._path = None
self.stream = stream
self._closable_stream = False
@@ -367,7 +381,7 @@ def _build_writer(self, inferred_schema: pa.Schema):
if self.with_metadata:
schema = schema.with_metadata(self._build_metadata(DatasetInfo(features=self._features), self.fingerprint))
self._schema = schema
- self.pa_writer = pa.RecordBatchStreamWriter(self.stream, schema)
+ self.pa_writer = self._WRITER_CLASS(self.stream, schema)
@property
def schema(self):
@@ -522,11 +536,9 @@ def write_table(self, pa_table: pa.Table, writer_batch_size: Optional[int] = Non
if self.pa_writer is None:
self._build_writer(inferred_schema=pa_table.schema)
pa_table = table_cast(pa_table, self._schema)
- batches: List[pa.RecordBatch] = pa_table.to_batches(max_chunksize=writer_batch_size)
- self._num_bytes += sum(batch.nbytes for batch in batches)
+ self._num_bytes += pa_table.nbytes
self._num_examples += pa_table.num_rows
- for batch in batches:
- self.pa_writer.write_batch(batch)
+ self.pa_writer.write_table(pa_table, writer_batch_size)
def finalize(self, close_stream=True):
self.write_rows_on_file()
@@ -542,6 +554,7 @@ def finalize(self, close_stream=True):
else:
raise ValueError("Please pass `features` or at least one example when writing data")
self.pa_writer.close()
+ self.pa_writer = None
if close_stream:
self.stream.close()
logger.debug(
@@ -550,6 +563,10 @@ def finalize(self, close_stream=True):
return self._num_examples, self._num_bytes
+class ParquetWriter(ArrowWriter):
+ _WRITER_CLASS = pq.ParquetWriter
+
+
class BeamWriter:
"""
Shuffles and writes Examples to Arrow files.
@@ -616,35 +633,50 @@ def finalize(self, metrics_query_result: dict):
from .utils import beam_utils
# Convert to arrow
- logger.info(f"Converting parquet file {self._parquet_path} to arrow {self._path}")
- shards = [
- metadata.path
- for metadata in beam.io.filesystems.FileSystems.match([self._parquet_path + "*.parquet"])[0].metadata_list
- ]
- try: # stream conversion
- sources = [beam.io.filesystems.FileSystems.open(shard) for shard in shards]
- with beam.io.filesystems.FileSystems.create(self._path) as dest:
- parquet_to_arrow(sources, dest)
- except OSError as e: # broken pipe can happen if the connection is unstable, do local conversion instead
- if e.errno != errno.EPIPE: # not a broken pipe
- raise
- logger.warning("Broken Pipe during stream conversion from parquet to arrow. Using local convert instead")
- local_convert_dir = os.path.join(self._cache_dir, "beam_convert")
- os.makedirs(local_convert_dir, exist_ok=True)
- local_arrow_path = os.path.join(local_convert_dir, hash_url_to_filename(self._parquet_path) + ".arrow")
- local_shards = []
- for shard in shards:
- local_parquet_path = os.path.join(local_convert_dir, hash_url_to_filename(shard) + ".parquet")
- local_shards.append(local_parquet_path)
- beam_utils.download_remote_to_local(shard, local_parquet_path)
- parquet_to_arrow(local_shards, local_arrow_path)
- beam_utils.upload_local_to_remote(local_arrow_path, self._path)
+ if self._path.endswith(".arrow"):
+ logger.info(f"Converting parquet file {self._parquet_path} to arrow {self._path}")
+ shards = [
+ metadata.path
+ for metadata in beam.io.filesystems.FileSystems.match([self._parquet_path + "*.parquet"])[
+ 0
+ ].metadata_list
+ ]
+ try: # stream conversion
+ sources = [beam.io.filesystems.FileSystems.open(shard) for shard in shards]
+ with beam.io.filesystems.FileSystems.create(self._path) as dest:
+ parquet_to_arrow(sources, dest)
+ except OSError as e: # broken pipe can happen if the connection is unstable, do local conversion instead
+ if e.errno != errno.EPIPE: # not a broken pipe
+ raise
+ logger.warning(
+ "Broken Pipe during stream conversion from parquet to arrow. Using local convert instead"
+ )
+ local_convert_dir = os.path.join(self._cache_dir, "beam_convert")
+ os.makedirs(local_convert_dir, exist_ok=True)
+ local_arrow_path = os.path.join(local_convert_dir, hash_url_to_filename(self._parquet_path) + ".arrow")
+ local_shards = []
+ for shard in shards:
+ local_parquet_path = os.path.join(local_convert_dir, hash_url_to_filename(shard) + ".parquet")
+ local_shards.append(local_parquet_path)
+ beam_utils.download_remote_to_local(shard, local_parquet_path)
+ parquet_to_arrow(local_shards, local_arrow_path)
+ beam_utils.upload_local_to_remote(local_arrow_path, self._path)
+ output_file_metadata = beam.io.filesystems.FileSystems.match([self._path], limits=[1])[0].metadata_list[0]
+ num_bytes = output_file_metadata.size_in_bytes
+ else:
+ num_bytes = sum(
+ [
+ metadata.size_in_bytes
+ for metadata in beam.io.filesystems.FileSystems.match([self._parquet_path + "*.parquet"])[
+ 0
+ ].metadata_list
+ ]
+ )
# Save metrics
counters_dict = {metric.key.metric.name: metric.result for metric in metrics_query_result["counters"]}
self._num_examples = counters_dict["num_examples"]
- output_file_metadata = beam.io.filesystems.FileSystems.match([self._path], limits=[1])[0].metadata_list[0]
- self._num_bytes = output_file_metadata.size_in_bytes
+ self._num_bytes = num_bytes
return self._num_examples, self._num_bytes
diff --git a/src/datasets/builder.py b/src/datasets/builder.py
index 380ed031bce..955078dff27 100644
--- a/src/datasets/builder.py
+++ b/src/datasets/builder.py
@@ -27,8 +27,11 @@
import warnings
from dataclasses import dataclass
from functools import partial
+from pathlib import Path
from typing import Dict, Mapping, Optional, Tuple, Union
+import fsspec
+
from . import config, utils
from .arrow_dataset import Dataset
from .arrow_reader import (
@@ -38,7 +41,7 @@
MissingFilesOnHfGcsError,
ReadInstruction,
)
-from .arrow_writer import ArrowWriter, BeamWriter
+from .arrow_writer import ArrowWriter, BeamWriter, ParquetWriter
from .data_files import DataFilesDict, sanitize_patterns
from .dataset_dict import DatasetDict, IterableDatasetDict
from .download.download_config import DownloadConfig
@@ -46,6 +49,7 @@
from .download.mock_download_manager import MockDownloadManager
from .download.streaming_download_manager import StreamingDownloadManager
from .features import Features
+from .filesystems import is_remote_filesystem
from .fingerprint import Hasher
from .info import DatasetInfo, DatasetInfosDict, PostProcessedInfo
from .iterable_dataset import ExamplesIterable, IterableDataset, _generate_examples_from_tables_wrapper
@@ -345,6 +349,10 @@ def __init__(
)
os.rmdir(self._cache_dir)
+ # Store in the cache by default unless the user specifies a custom output_dir to download_and_prepare
+ self._output_dir = self._cache_dir
+ self._fs: fsspec.AbstractFileSystem = fsspec.filesystem("file")
+
# Set download manager
self.dl_manager = None
@@ -536,7 +544,7 @@ def _other_versions_on_disk():
version_dirnames.sort(reverse=True)
return version_dirnames
- # Check and warn if other versions exist on disk
+ # Check and warn if other versions exist
if not is_remote_url(builder_data_dir):
version_dirs = _other_versions_on_disk()
if version_dirs:
@@ -570,6 +578,7 @@ def get_imported_module_dir(cls):
def download_and_prepare(
self,
+ output_dir: Optional[str] = None,
download_config: Optional[DownloadConfig] = None,
download_mode: Optional[DownloadMode] = None,
ignore_verifications: bool = False,
@@ -577,11 +586,17 @@ def download_and_prepare(
dl_manager: Optional[DownloadManager] = None,
base_path: Optional[str] = None,
use_auth_token: Optional[Union[bool, str]] = None,
+ file_format: Optional[str] = None,
+ storage_options: Optional[dict] = None,
**download_and_prepare_kwargs,
):
"""Downloads and prepares dataset for reading.
Args:
+ output_dir (:obj:`str`, optional): output directory for the dataset.
+ Default to this builder's ``cache_dir``, which is inside ~/.cache/huggingface/datasets by default.
+
+
download_config (:class:`DownloadConfig`, optional): specific download configuration parameters.
download_mode (:class:`DownloadMode`, optional): select the download/generate mode - Default to ``REUSE_DATASET_IF_EXISTS``
ignore_verifications (:obj:`bool`): Ignore the verifications of the downloaded/processed dataset information (checksums/size/splits/...)
@@ -591,6 +606,13 @@ def download_and_prepare(
If not specified, the value of the `base_path` attribute (`self.base_path`) will be used instead.
use_auth_token (:obj:`Union[str, bool]`, optional): Optional string or boolean to use as Bearer token for remote files on the Datasets Hub.
If True, will get token from ~/.huggingface.
+ file_format (:obj:`str`, optional): format of the data files in which the dataset will be written.
+ Supported formats: "arrow", "parquet". Default to "arrow" format.
+
+
+ storage_options (:obj:`dict`, *optional*): Key/value pairs to be passed on to the caching file-system backend, if any.
+
+
**download_and_prepare_kwargs (additional keyword arguments): Keyword arguments.
Example:
@@ -601,9 +623,18 @@ def download_and_prepare(
>>> ds = builder.download_and_prepare()
```
"""
+ self._output_dir = output_dir if output_dir is not None else self._cache_dir
+ # output_dir can be a remote bucket on GCS or S3 (when using BeamBasedBuilder for distributed data processing)
+ fs_token_paths = fsspec.get_fs_token_paths(self._output_dir, storage_options=storage_options)
+ self._fs = fs_token_paths[0]
+
download_mode = DownloadMode(download_mode or DownloadMode.REUSE_DATASET_IF_EXISTS)
verify_infos = not ignore_verifications
base_path = base_path if base_path is not None else self.base_path
+ is_local = not is_remote_filesystem(self._fs)
+
+ if file_format is not None and file_format not in ["arrow", "parquet"]:
+ raise ValueError(f"Unsupported file_format: {file_format}. Expected 'arrow' or 'parquet'")
if dl_manager is None:
if download_config is None:
@@ -625,28 +656,35 @@ def download_and_prepare(
else False,
)
- elif isinstance(dl_manager, MockDownloadManager):
+ elif isinstance(dl_manager, MockDownloadManager) or not is_local:
try_from_hf_gcs = False
self.dl_manager = dl_manager
- # Prevent parallel disk operations
- is_local = not is_remote_url(self._cache_dir_root)
+ # Prevent parallel local disk operations
if is_local:
- lock_path = os.path.join(self._cache_dir_root, self._cache_dir.replace(os.sep, "_") + ".lock")
+ # Create parent directory of the output_dir to put the lock file in there
+ Path(self._output_dir).parent.mkdir(parents=True, exist_ok=True)
+ lock_path = self._output_dir + "_builder.lock"
+
# File locking only with local paths; no file locking on GCS or S3
with FileLock(lock_path) if is_local else contextlib.nullcontext():
- if is_local:
- data_exists = os.path.exists(self._cache_dir)
- if data_exists and download_mode == DownloadMode.REUSE_DATASET_IF_EXISTS:
- logger.warning(f"Reusing dataset {self.name} ({self._cache_dir})")
- # We need to update the info in case some splits were added in the meantime
- # for example when calling load_dataset from multiple workers.
- self.info = self._load_info()
- self.download_post_processing_resources(dl_manager)
- return
- logger.info(f"Generating dataset {self.name} ({self._cache_dir})")
+
+ # Check if the data already exists
+ path_join = os.path.join if is_local else posixpath.join
+ data_exists = self._fs.exists(path_join(self._output_dir, config.DATASET_INFO_FILENAME))
+ if data_exists and download_mode == DownloadMode.REUSE_DATASET_IF_EXISTS:
+ logger.warning(f"Found cached dataset {self.name} ({self._output_dir})")
+ # We need to update the info in case some splits were added in the meantime
+ # for example when calling load_dataset from multiple workers.
+ self.info = self._load_info()
+ self.download_post_processing_resources(dl_manager)
+ return
+
+ logger.info(f"Generating dataset {self.name} ({self._output_dir})")
if is_local: # if cache dir is local, check for available space
- if not has_sufficient_disk_space(self.info.size_in_bytes or 0, directory=self._cache_dir_root):
+ if not has_sufficient_disk_space(
+ self.info.size_in_bytes or 0, directory=Path(self._output_dir).parent
+ ):
raise OSError(
f"Not enough disk space. Needed: {size_str(self.info.size_in_bytes or 0)} (download: {size_str(self.info.download_size or 0)}, generated: {size_str(self.info.dataset_size or 0)}, post-processed: {size_str(self.info.post_processing_size or 0)})"
)
@@ -654,7 +692,8 @@ def download_and_prepare(
@contextlib.contextmanager
def incomplete_dir(dirname):
"""Create temporary dir for dirname and rename on exit."""
- if is_remote_url(dirname):
+ if not is_local:
+ self._fs.makedirs(dirname, exist_ok=True)
yield dirname
else:
tmp_dir = dirname + ".incomplete"
@@ -663,6 +702,7 @@ def incomplete_dir(dirname):
yield tmp_dir
if os.path.isdir(dirname):
shutil.rmtree(dirname)
+ # LocalFileSystem.mv does copy + rm, it is more efficient to simply rename a local directory
shutil.move(tmp_dir, dirname)
finally:
if os.path.exists(tmp_dir):
@@ -676,20 +716,22 @@ def incomplete_dir(dirname):
f"Downloading and preparing dataset {self.info.builder_name}/{self.info.config_name} "
f"(download: {size_str(self.info.download_size)}, generated: {size_str(self.info.dataset_size)}, "
f"post-processed: {size_str(self.info.post_processing_size)}, "
- f"total: {size_str(self.info.size_in_bytes)}) to {self._cache_dir}..."
+ f"total: {size_str(self.info.size_in_bytes)}) to {self._output_dir}..."
)
else:
+ _dest = self._fs._strip_protocol(self._output_dir) if is_local else self._output_dir
print(
- f"Downloading and preparing dataset {self.info.builder_name}/{self.info.config_name} to {self._cache_dir}..."
+ f"Downloading and preparing dataset {self.info.builder_name}/{self.info.config_name} to {_dest}..."
)
self._check_manual_download(dl_manager)
- # Create a tmp dir and rename to self._cache_dir on successful exit.
- with incomplete_dir(self._cache_dir) as tmp_data_dir:
- # Temporarily assign _cache_dir to tmp_data_dir to avoid having to forward
+ # Create a tmp dir and rename to self._output_dir on successful exit.
+ with incomplete_dir(self._output_dir) as tmp_output_dir:
+ # Temporarily assign _output_dir to tmp_data_dir to avoid having to forward
# it to every sub function.
- with temporary_assignment(self, "_cache_dir", tmp_data_dir):
+ with temporary_assignment(self, "_output_dir", tmp_output_dir):
+
# Try to download the already prepared dataset files
downloaded_from_gcs = False
if try_from_hf_gcs:
@@ -702,7 +744,10 @@ def incomplete_dir(dirname):
logger.warning("HF google storage unreachable. Downloading and preparing it from source")
if not downloaded_from_gcs:
self._download_and_prepare(
- dl_manager=dl_manager, verify_infos=verify_infos, **download_and_prepare_kwargs
+ dl_manager=dl_manager,
+ verify_infos=verify_infos,
+ file_format=file_format,
+ **download_and_prepare_kwargs,
)
# Sync info
self.info.dataset_size = sum(split.num_bytes for split in self.info.splits.values())
@@ -715,7 +760,7 @@ def incomplete_dir(dirname):
self.download_post_processing_resources(dl_manager)
print(
- f"Dataset {self.name} downloaded and prepared to {self._cache_dir}. "
+ f"Dataset {self.name} downloaded and prepared to {self._output_dir}. "
f"Subsequent calls will reuse this data."
)
@@ -734,10 +779,10 @@ def _check_manual_download(self, dl_manager):
def _download_prepared_from_hf_gcs(self, download_config: DownloadConfig):
relative_data_dir = self._relative_data_dir(with_version=True, with_hash=False)
- reader = ArrowReader(self._cache_dir, self.info)
+ reader = ArrowReader(self._output_dir, self.info)
# use reader instructions to download the right files
reader.download_from_hf_gcs(download_config, relative_data_dir)
- downloaded_info = DatasetInfo.from_directory(self._cache_dir)
+ downloaded_info = DatasetInfo.from_directory(self._output_dir)
self.info.update(downloaded_info)
# download post processing resources
remote_cache_dir = HF_GCP_BASE_URL + "/" + relative_data_dir.replace(os.sep, "/")
@@ -747,12 +792,12 @@ def _download_prepared_from_hf_gcs(self, download_config: DownloadConfig):
raise ValueError(f"Resources shouldn't be in a sub-directory: {resource_file_name}")
try:
resource_path = cached_path(remote_cache_dir + "/" + resource_file_name)
- shutil.move(resource_path, os.path.join(self._cache_dir, resource_file_name))
+ shutil.move(resource_path, os.path.join(self._output_dir, resource_file_name))
except ConnectionError:
logger.info(f"Couldn't download resourse file {resource_file_name} from Hf google storage.")
logger.info("Dataset downloaded from Hf google storage.")
- def _download_and_prepare(self, dl_manager, verify_infos, **prepare_split_kwargs):
+ def _download_and_prepare(self, dl_manager, verify_infos, file_format=None, **prepare_split_kwargs):
"""Downloads and prepares dataset for reading.
This is the internal implementation to overwrite called when user calls
@@ -760,9 +805,10 @@ def _download_and_prepare(self, dl_manager, verify_infos, **prepare_split_kwargs
the pre-processed datasets files.
Args:
- dl_manager: (DownloadManager) `DownloadManager` used to download and cache
- data.
- verify_infos: bool, if False, do not perform checksums and size tests.
+ dl_manager: (:obj:`DownloadManager`) `DownloadManager` used to download and cache data.
+ verify_infos (:obj:`bool`): if False, do not perform checksums and size tests.
+ file_format (:obj:`str`, optional): format of the data files in which the dataset will be written.
+ Supported formats: "arrow", "parquet". Default to "arrow" format.
prepare_split_kwargs: Additional options.
"""
# Generating data for all splits
@@ -790,7 +836,7 @@ def _download_and_prepare(self, dl_manager, verify_infos, **prepare_split_kwargs
try:
# Prepare split will record examples associated to the split
- self._prepare_split(split_generator, **prepare_split_kwargs)
+ self._prepare_split(split_generator, file_format=file_format, **prepare_split_kwargs)
except OSError as e:
raise OSError(
"Cannot find data file. "
@@ -815,11 +861,13 @@ def _download_and_prepare(self, dl_manager, verify_infos, **prepare_split_kwargs
self.info.download_size = dl_manager.downloaded_size
def download_post_processing_resources(self, dl_manager):
- for split in self.info.splits:
+ for split in self.info.splits or []:
for resource_name, resource_file_name in self._post_processing_resources(split).items():
+ if not not is_remote_filesystem(self._fs):
+ raise NotImplementedError(f"Post processing is not supported on filesystem {self._fs}")
if os.sep in resource_file_name:
raise ValueError(f"Resources shouldn't be in a sub-directory: {resource_file_name}")
- resource_path = os.path.join(self._cache_dir, resource_file_name)
+ resource_path = os.path.join(self._output_dir, resource_file_name)
if not os.path.exists(resource_path):
downloaded_resource_path = self._download_post_processing_resources(
split, resource_name, dl_manager
@@ -829,16 +877,20 @@ def download_post_processing_resources(self, dl_manager):
shutil.move(downloaded_resource_path, resource_path)
def _load_info(self) -> DatasetInfo:
- return DatasetInfo.from_directory(self._cache_dir)
+ return DatasetInfo.from_directory(self._output_dir, fs=self._fs)
def _save_info(self):
- lock_path = os.path.join(self._cache_dir_root, self._cache_dir.replace(os.sep, "_") + ".lock")
- with FileLock(lock_path):
- self.info.write_to_directory(self._cache_dir)
+ is_local = not is_remote_filesystem(self._fs)
+ if is_local:
+ lock_path = self._output_dir + "_info.lock"
+ with FileLock(lock_path) if is_local else contextlib.nullcontext():
+ self.info.write_to_directory(self._output_dir, fs=self._fs)
def _save_infos(self):
- lock_path = os.path.join(self._cache_dir_root, self._cache_dir.replace(os.sep, "_") + ".lock")
- with FileLock(lock_path):
+ is_local = not is_remote_filesystem(self._fs)
+ if is_local:
+ lock_path = self._output_dir + "_infos.lock"
+ with FileLock(lock_path) if is_local else contextlib.nullcontext():
DatasetInfosDict(**{self.config.name: self.info}).write_to_directory(self.get_imported_module_dir())
def _make_split_generators_kwargs(self, prepare_split_kwargs):
@@ -876,14 +928,17 @@ def as_dataset(
})
```
"""
- if not os.path.exists(self._cache_dir):
+ is_local = not is_remote_filesystem(self._fs)
+ if not is_local:
+ raise NotImplementedError(f"Loading a dataset cached in a {type(self._fs).__name__} is not supported.")
+ if not os.path.exists(self._output_dir):
raise AssertionError(
- f"Dataset {self.name}: could not find data in {self._cache_dir_root}. Please make sure to call "
- "builder.download_and_prepare(), or pass download=True to "
+ f"Dataset {self.name}: could not find data in {self._output_dir}. Please make sure to call "
+ "builder.download_and_prepare(), or use "
"datasets.load_dataset() before trying to access the Dataset object."
)
- logger.debug(f'Constructing Dataset for split {split or ", ".join(self.info.splits)}, from {self._cache_dir}')
+ logger.debug(f'Constructing Dataset for split {split or ", ".join(self.info.splits)}, from {self._output_dir}')
# By default, return all splits
if split is None:
@@ -930,7 +985,7 @@ def _build_single_dataset(
if os.sep in resource_file_name:
raise ValueError(f"Resources shouldn't be in a sub-directory: {resource_file_name}")
resources_paths = {
- resource_name: os.path.join(self._cache_dir, resource_file_name)
+ resource_name: os.path.join(self._output_dir, resource_file_name)
for resource_name, resource_file_name in self._post_processing_resources(split).items()
}
post_processed = self._post_process(ds, resources_paths)
@@ -989,8 +1044,8 @@ def _as_dataset(self, split: Union[ReadInstruction, Split] = Split.TRAIN, in_mem
Returns:
`Dataset`
"""
-
- dataset_kwargs = ArrowReader(self._cache_dir, self.info).read(
+ cache_dir = self._fs._strip_protocol(self._output_dir)
+ dataset_kwargs = ArrowReader(cache_dir, self.info).read(
name=self.name,
instructions=split,
split_infos=self.info.splits.values(),
@@ -1015,6 +1070,12 @@ def as_streaming_dataset(
if not isinstance(self, (GeneratorBasedBuilder, ArrowBasedBuilder)):
raise ValueError(f"Builder {self.name} is not streamable.")
+ is_local = not is_remote_filesystem(self._fs)
+ if not is_local:
+ raise NotImplementedError(
+ f"Loading a streaming dataset cached in a {type(self._fs).__name__} is not supported yet."
+ )
+
dl_manager = StreamingDownloadManager(
base_path=base_path or self.base_path,
download_config=DownloadConfig(use_auth_token=self.use_auth_token),
@@ -1113,11 +1174,13 @@ def _split_generators(self, dl_manager: DownloadManager):
raise NotImplementedError()
@abc.abstractmethod
- def _prepare_split(self, split_generator: SplitGenerator, **kwargs):
+ def _prepare_split(self, split_generator: SplitGenerator, file_format: Optional[str] = None, **kwargs):
"""Generate the examples and record them on disk.
Args:
split_generator: `SplitGenerator`, Split generator to process
+ file_format (:obj:`str`, optional): format of the data files in which the dataset will be written.
+ Supported formats: "arrow", "parquet". Default to "arrow" format.
**kwargs: Additional kwargs forwarded from _download_and_prepare (ex:
beam pipeline)
"""
@@ -1188,23 +1251,32 @@ def _generate_examples(self, **kwargs):
"""
raise NotImplementedError()
- def _prepare_split(self, split_generator, check_duplicate_keys):
+ def _prepare_split(self, split_generator, check_duplicate_keys, file_format=None):
+ is_local = not is_remote_filesystem(self._fs)
+ path_join = os.path.join if is_local else posixpath.join
+
if self.info.splits is not None:
split_info = self.info.splits[split_generator.name]
else:
split_info = split_generator.split_info
- fname = f"{self.name}-{split_generator.name}.arrow"
- fpath = os.path.join(self._cache_dir, fname)
+ file_format = file_format or "arrow"
+ suffix = "-00000-of-00001" if file_format == "parquet" else ""
+ fname = f"{self.name}-{split_generator.name}{suffix}.{file_format}"
+ fpath = path_join(self._output_dir, fname)
generator = self._generate_examples(**split_generator.gen_kwargs)
- with ArrowWriter(
+ writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter
+
+ # TODO: embed the images/audio files inside parquet files.
+ with writer_class(
features=self.info.features,
path=fpath,
writer_batch_size=self._writer_batch_size,
hash_salt=split_info.name,
check_duplicates=check_duplicate_keys,
+ storage_options=self._fs.storage_options,
) as writer:
try:
for key, record in logging.tqdm(
@@ -1223,8 +1295,10 @@ def _prepare_split(self, split_generator, check_duplicate_keys):
split_generator.split_info.num_examples = num_examples
split_generator.split_info.num_bytes = num_bytes
- def _download_and_prepare(self, dl_manager, verify_infos):
- super()._download_and_prepare(dl_manager, verify_infos, check_duplicate_keys=verify_infos)
+ def _download_and_prepare(self, dl_manager, verify_infos, file_format=None):
+ super()._download_and_prepare(
+ dl_manager, verify_infos, file_format=file_format, check_duplicate_keys=verify_infos
+ )
def _get_examples_iterable_for_split(self, split_generator: SplitGenerator) -> ExamplesIterable:
return ExamplesIterable(self._generate_examples, split_generator.gen_kwargs)
@@ -1265,12 +1339,19 @@ def _generate_tables(self, **kwargs):
"""
raise NotImplementedError()
- def _prepare_split(self, split_generator):
- fname = f"{self.name}-{split_generator.name}.arrow"
- fpath = os.path.join(self._cache_dir, fname)
+ def _prepare_split(self, split_generator, file_format=None):
+ is_local = not is_remote_filesystem(self._fs)
+ path_join = os.path.join if is_local else posixpath.join
+
+ file_format = file_format or "arrow"
+ suffix = "-00000-of-00001" if file_format == "parquet" else ""
+ fname = f"{self.name}-{split_generator.name}{suffix}.{file_format}"
+ fpath = path_join(self._output_dir, fname)
generator = self._generate_tables(**split_generator.gen_kwargs)
- with ArrowWriter(features=self.info.features, path=fpath) as writer:
+ writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter
+ # TODO: embed the images/audio files inside parquet files.
+ with writer_class(features=self.info.features, path=fpath, storage_options=self._fs.storage_options) as writer:
for key, table in logging.tqdm(
generator, unit=" tables", leave=False, disable=(not logging.is_progress_bar_enabled())
):
@@ -1351,7 +1432,7 @@ def _build_pcollection(pipeline, extracted_dir=None):
"""
raise NotImplementedError()
- def _download_and_prepare(self, dl_manager, verify_infos):
+ def _download_and_prepare(self, dl_manager, verify_infos, file_format=None):
# Create the Beam pipeline and forward it to _prepare_split
import apache_beam as beam
@@ -1389,6 +1470,7 @@ def _download_and_prepare(self, dl_manager, verify_infos):
dl_manager,
verify_infos=False,
pipeline=pipeline,
+ file_format=file_format,
) # TODO handle verify_infos in beam datasets
# Run pipeline
pipeline_results = pipeline.run()
@@ -1404,27 +1486,27 @@ def _download_and_prepare(self, dl_manager, verify_infos):
split_info.num_bytes = num_bytes
def _save_info(self):
- if os.path.exists(self._cache_dir):
- super()._save_info()
- else:
- import apache_beam as beam
+ import apache_beam as beam
- fs = beam.io.filesystems.FileSystems
- with fs.create(os.path.join(self._cache_dir, config.DATASET_INFO_FILENAME)) as f:
- self.info._dump_info(f)
- if self.info.license:
- with fs.create(os.path.join(self._cache_dir, config.LICENSE_FILENAME)) as f:
- self.info._dump_license(f)
+ fs = beam.io.filesystems.FileSystems
+ path_join = os.path.join if not is_remote_filesystem(self._fs) else posixpath.join
+ with fs.create(path_join(self._output_dir, config.DATASET_INFO_FILENAME)) as f:
+ self.info._dump_info(f)
+ if self.info.license:
+ with fs.create(path_join(self._output_dir, config.LICENSE_FILENAME)) as f:
+ self.info._dump_license(f)
- def _prepare_split(self, split_generator, pipeline):
+ def _prepare_split(self, split_generator, pipeline, file_format=None):
import apache_beam as beam
- # To write examples to disk:
+ # To write examples in filesystem:
split_name = split_generator.split_info.name
- fname = f"{self.name}-{split_name}.arrow"
- fpath = os.path.join(self._cache_dir, fname)
+ file_format = file_format or "arrow"
+ fname = f"{self.name}-{split_name}.{file_format}"
+ path_join = os.path.join if not is_remote_filesystem(self._fs) else posixpath.join
+ fpath = path_join(self._output_dir, fname)
beam_writer = BeamWriter(
- features=self.info.features, path=fpath, namespace=split_name, cache_dir=self._cache_dir
+ features=self.info.features, path=fpath, namespace=split_name, cache_dir=self._output_dir
)
self._beam_writers[split_name] = beam_writer
diff --git a/src/datasets/download/streaming_download_manager.py b/src/datasets/download/streaming_download_manager.py
index 697f5122b44..c61531c4915 100644
--- a/src/datasets/download/streaming_download_manager.py
+++ b/src/datasets/download/streaming_download_manager.py
@@ -526,7 +526,8 @@ def xglob(urlpath, *, recursive=False, use_auth_token: Optional[Union[str, bool]
# - If there is "**" in the pattern, `fs.glob` must be called anyway.
inner_path = main_hop.split("://")[1]
globbed_paths = fs.glob(inner_path)
- return ["::".join([f"{fs.protocol}://{globbed_path}"] + rest_hops) for globbed_path in globbed_paths]
+ protocol = fs.protocol if isinstance(fs.protocol, str) else fs.protocol[-1]
+ return ["::".join([f"{protocol}://{globbed_path}"] + rest_hops) for globbed_path in globbed_paths]
def xwalk(urlpath, use_auth_token: Optional[Union[str, bool]] = None):
@@ -558,8 +559,9 @@ def xwalk(urlpath, use_auth_token: Optional[Union[str, bool]] = None):
inner_path = main_hop.split("://")[1]
if inner_path.strip("/") and not fs.isdir(inner_path):
return []
+ protocol = fs.protocol if isinstance(fs.protocol, str) else fs.protocol[-1]
for dirpath, dirnames, filenames in fs.walk(inner_path):
- yield "::".join([f"{fs.protocol}://{dirpath}"] + rest_hops), dirnames, filenames
+ yield "::".join([f"{protocol}://{dirpath}"] + rest_hops), dirnames, filenames
class xPath(type(Path())):
diff --git a/src/datasets/info.py b/src/datasets/info.py
index 565f893e5be..2b5acfdde92 100644
--- a/src/datasets/info.py
+++ b/src/datasets/info.py
@@ -32,11 +32,15 @@
import dataclasses
import json
import os
+import posixpath
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union
+from fsspec.implementations.local import LocalFileSystem
+
from . import config
from .features import Features, Value
+from .filesystems import is_remote_filesystem
from .splits import SplitDict
from .tasks import TaskTemplate, task_template_from_dict
from .utils import Version
@@ -176,15 +180,16 @@ def __post_init__(self):
template.align_with_features(self.features) for template in (self.task_templates)
]
- def _license_path(self, dataset_info_dir):
- return os.path.join(dataset_info_dir, config.LICENSE_FILENAME)
-
- def write_to_directory(self, dataset_info_dir, pretty_print=False):
+ def write_to_directory(self, dataset_info_dir, pretty_print=False, fs=None):
"""Write `DatasetInfo` and license (if present) as JSON files to `dataset_info_dir`.
Args:
dataset_info_dir (str): Destination directory.
pretty_print (bool, default ``False``): If True, the JSON will be pretty-printed with the indent level of 4.
+ fs (``fsspec.spec.AbstractFileSystem``, optional, defaults ``None``):
+ Instance of the remote filesystem used to download the files from.
+
+
Example:
@@ -194,10 +199,14 @@ def write_to_directory(self, dataset_info_dir, pretty_print=False):
>>> ds.info.write_to_directory("/path/to/directory/")
```
"""
- with open(os.path.join(dataset_info_dir, config.DATASET_INFO_FILENAME), "wb") as f:
+ fs = fs or LocalFileSystem()
+ is_local = not is_remote_filesystem(fs)
+ path_join = os.path.join if is_local else posixpath.join
+
+ with fs.open(path_join(dataset_info_dir, config.DATASET_INFO_FILENAME), "wb") as f:
self._dump_info(f, pretty_print=pretty_print)
if self.license:
- with open(os.path.join(dataset_info_dir, config.LICENSE_FILENAME), "wb") as f:
+ with fs.open(path_join(dataset_info_dir, config.LICENSE_FILENAME), "wb") as f:
self._dump_license(f)
def _dump_info(self, file, pretty_print=False):
@@ -239,7 +248,7 @@ def from_merge(cls, dataset_infos: List["DatasetInfo"]):
)
@classmethod
- def from_directory(cls, dataset_info_dir: str) -> "DatasetInfo":
+ def from_directory(cls, dataset_info_dir: str, fs=None) -> "DatasetInfo":
"""Create DatasetInfo from the JSON file in `dataset_info_dir`.
This function updates all the dynamically generated fields (num_examples,
@@ -250,6 +259,10 @@ def from_directory(cls, dataset_info_dir: str) -> "DatasetInfo":
Args:
dataset_info_dir (`str`): The directory containing the metadata file. This
should be the root directory of a specific dataset version.
+ fs (``fsspec.spec.AbstractFileSystem``, optional, defaults ``None``):
+ Instance of the remote filesystem used to download the files from.
+
+
Example:
@@ -258,11 +271,15 @@ def from_directory(cls, dataset_info_dir: str) -> "DatasetInfo":
>>> ds_info = DatasetInfo.from_directory("/path/to/directory/")
```
"""
+ fs = fs or LocalFileSystem()
logger.info(f"Loading Dataset info from {dataset_info_dir}")
if not dataset_info_dir:
raise ValueError("Calling DatasetInfo.from_directory() with undefined dataset_info_dir.")
- with open(os.path.join(dataset_info_dir, config.DATASET_INFO_FILENAME), encoding="utf-8") as f:
+ is_local = not is_remote_filesystem(fs)
+ path_join = os.path.join if is_local else posixpath.join
+
+ with fs.open(path_join(dataset_info_dir, config.DATASET_INFO_FILENAME), "r", encoding="utf-8") as f:
dataset_info_dict = json.load(f)
return cls.from_dict(dataset_info_dict)
diff --git a/src/datasets/utils/py_utils.py b/src/datasets/utils/py_utils.py
index 07436d016f3..f56b765a7c2 100644
--- a/src/datasets/utils/py_utils.py
+++ b/src/datasets/utils/py_utils.py
@@ -17,7 +17,6 @@
"""
-import contextlib
import copy
import functools
import itertools
@@ -189,7 +188,7 @@ def _asdict_inner(obj):
return _asdict_inner(obj)
-@contextlib.contextmanager
+@contextmanager
def temporary_assignment(obj, attr, value):
"""Temporarily assign obj.attr to value."""
original = getattr(obj, attr, None)
@@ -601,7 +600,7 @@ def dump(obj, file):
return
-@contextlib.contextmanager
+@contextmanager
def _no_cache_fields(obj):
try:
if (
diff --git a/tests/conftest.py b/tests/conftest.py
index 04f37225581..c1a1af48e57 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -4,7 +4,7 @@
# Import fixture modules as plugins
-pytest_plugins = ["tests.fixtures.files", "tests.fixtures.hub", "tests.fixtures.s3"]
+pytest_plugins = ["tests.fixtures.files", "tests.fixtures.hub", "tests.fixtures.s3", "tests.fixtures.fsspec"]
def pytest_collection_modifyitems(config, items):
diff --git a/tests/fixtures/fsspec.py b/tests/fixtures/fsspec.py
new file mode 100644
index 00000000000..7a301116ea8
--- /dev/null
+++ b/tests/fixtures/fsspec.py
@@ -0,0 +1,84 @@
+import posixpath
+from pathlib import Path
+
+import fsspec
+import pytest
+from fsspec.implementations.local import AbstractFileSystem, LocalFileSystem, stringify_path
+
+
+class MockFileSystem(AbstractFileSystem):
+ protocol = "mock"
+
+ def __init__(self, *args, local_root_dir, **kwargs):
+ super().__init__()
+ self._fs = LocalFileSystem(*args, **kwargs)
+ self.local_root_dir = Path(local_root_dir).resolve().as_posix() + "/"
+
+ def mkdir(self, path, *args, **kwargs):
+ path = posixpath.join(self.local_root_dir, self._strip_protocol(path))
+ return self._fs.mkdir(path, *args, **kwargs)
+
+ def makedirs(self, path, *args, **kwargs):
+ path = posixpath.join(self.local_root_dir, self._strip_protocol(path))
+ return self._fs.makedirs(path, *args, **kwargs)
+
+ def rmdir(self, path):
+ path = posixpath.join(self.local_root_dir, self._strip_protocol(path))
+ return self._fs.rmdir(path)
+
+ def ls(self, path, detail=True, *args, **kwargs):
+ path = posixpath.join(self.local_root_dir, self._strip_protocol(path))
+ out = self._fs.ls(path, detail=detail, *args, **kwargs)
+ if detail:
+ return [{**info, "name": info["name"][len(self.local_root_dir) :]} for info in out]
+ else:
+ return [name[len(self.local_root_dir) :] for name in out]
+
+ def info(self, path, *args, **kwargs):
+ path = posixpath.join(self.local_root_dir, self._strip_protocol(path))
+ out = dict(self._fs.info(path, *args, **kwargs))
+ out["name"] = out["name"][len(self.local_root_dir) :]
+ return out
+
+ def cp_file(self, path1, path2, *args, **kwargs):
+ path1 = posixpath.join(self.local_root_dir, self._strip_protocol(path1))
+ path2 = posixpath.join(self.local_root_dir, self._strip_protocol(path2))
+ return self._fs.cp_file(path1, path2, *args, **kwargs)
+
+ def rm_file(self, path, *args, **kwargs):
+ path = posixpath.join(self.local_root_dir, self._strip_protocol(path))
+ return self._fs.rm_file(path, *args, **kwargs)
+
+ def rm(self, path, *args, **kwargs):
+ path = posixpath.join(self.local_root_dir, self._strip_protocol(path))
+ return self._fs.rm(path, *args, **kwargs)
+
+ def _open(self, path, *args, **kwargs):
+ path = posixpath.join(self.local_root_dir, self._strip_protocol(path))
+ return self._fs._open(path, *args, **kwargs)
+
+ def created(self, path):
+ path = posixpath.join(self.local_root_dir, self._strip_protocol(path))
+ return self._fs.created(path)
+
+ def modified(self, path):
+ path = posixpath.join(self.local_root_dir, self._strip_protocol(path))
+ return self._fs.modified(path)
+
+ @classmethod
+ def _strip_protocol(cls, path):
+ path = stringify_path(path)
+ if path.startswith("mock://"):
+ path = path[7:]
+ return path
+
+
+@pytest.fixture
+def mock_fsspec(monkeypatch):
+ monkeypatch.setitem(fsspec.registry.target, "mock", MockFileSystem)
+
+
+@pytest.fixture
+def mockfs(tmp_path_factory, mock_fsspec):
+ local_fs_dir = tmp_path_factory.mktemp("mockfs")
+ return MockFileSystem(local_root_dir=local_fs_dir)
diff --git a/tests/test_arrow_writer.py b/tests/test_arrow_writer.py
index b5542755db5..794b03be25b 100644
--- a/tests/test_arrow_writer.py
+++ b/tests/test_arrow_writer.py
@@ -6,9 +6,10 @@
import numpy as np
import pyarrow as pa
+import pyarrow.parquet as pq
import pytest
-from datasets.arrow_writer import ArrowWriter, OptimizedTypedSequence, TypedSequence
+from datasets.arrow_writer import ArrowWriter, OptimizedTypedSequence, ParquetWriter, TypedSequence
from datasets.features import Array2D, ClassLabel, Features, Image, Value
from datasets.features.features import Array2DExtensionType, cast_to_python_objects
from datasets.keyhash import DuplicatedKeysError, InvalidKeyError
@@ -299,3 +300,29 @@ def test_arrow_writer_closes_stream(raise_exception, tmp_path):
pass
finally:
assert writer.stream.closed
+
+
+def test_arrow_writer_with_filesystem(mockfs):
+ path = "mock://dataset-train.arrow"
+ with ArrowWriter(path=path, storage_options=mockfs.storage_options) as writer:
+ assert isinstance(writer._fs, type(mockfs))
+ assert writer._fs.storage_options == mockfs.storage_options
+ writer.write({"col_1": "foo", "col_2": 1})
+ writer.write({"col_1": "bar", "col_2": 2})
+ num_examples, num_bytes = writer.finalize()
+ assert num_examples == 2
+ assert num_bytes > 0
+ assert mockfs.exists(path)
+
+
+def test_parquet_writer_write():
+ output = pa.BufferOutputStream()
+ with ParquetWriter(stream=output) as writer:
+ writer.write({"col_1": "foo", "col_2": 1})
+ writer.write({"col_1": "bar", "col_2": 2})
+ num_examples, num_bytes = writer.finalize()
+ assert num_examples == 2
+ assert num_bytes > 0
+ stream = pa.BufferReader(output.getvalue())
+ pa_table: pa.Table = pq.read_table(stream)
+ assert pa_table.to_pydict() == {"col_1": ["foo", "bar"], "col_2": [1, 2]}
diff --git a/tests/test_builder.py b/tests/test_builder.py
index 01e39e6cabd..04ea1c87ccc 100644
--- a/tests/test_builder.py
+++ b/tests/test_builder.py
@@ -8,12 +8,14 @@
from unittest.mock import patch
import numpy as np
+import pyarrow as pa
+import pyarrow.parquet as pq
import pytest
from multiprocess.pool import Pool
from datasets.arrow_dataset import Dataset
from datasets.arrow_writer import ArrowWriter
-from datasets.builder import BuilderConfig, DatasetBuilder, GeneratorBasedBuilder
+from datasets.builder import ArrowBasedBuilder, BeamBasedBuilder, BuilderConfig, DatasetBuilder, GeneratorBasedBuilder
from datasets.dataset_dict import DatasetDict, IterableDatasetDict
from datasets.download.download_manager import DownloadMode
from datasets.features import Features, Value
@@ -21,8 +23,9 @@
from datasets.iterable_dataset import IterableDataset
from datasets.splits import Split, SplitDict, SplitGenerator, SplitInfo
from datasets.streaming import xjoin
+from datasets.utils.file_utils import is_local_path
-from .utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases, require_faiss
+from .utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases, require_beam, require_faiss
class DummyBuilder(DatasetBuilder):
@@ -34,7 +37,7 @@ def _split_generators(self, dl_manager):
def _prepare_split(self, split_generator, **kwargs):
fname = f"{self.name}-{split_generator.name}.arrow"
- with ArrowWriter(features=self.info.features, path=os.path.join(self._cache_dir, fname)) as writer:
+ with ArrowWriter(features=self.info.features, path=os.path.join(self._output_dir, fname)) as writer:
writer.write_batch({"text": ["foo"] * 100})
num_examples, num_bytes = writer.finalize()
split_generator.split_info.num_examples = num_examples
@@ -57,6 +60,35 @@ def _generate_examples(self):
yield i, {"text": "foo"}
+class DummyArrowBasedBuilder(ArrowBasedBuilder):
+ def _info(self):
+ return DatasetInfo(features=Features({"text": Value("string")}))
+
+ def _split_generators(self, dl_manager):
+ return [SplitGenerator(name=Split.TRAIN)]
+
+ def _generate_tables(self):
+ for i in range(10):
+ yield i, pa.table({"text": ["foo"] * 10})
+
+
+class DummyBeamBasedBuilder(BeamBasedBuilder):
+ def _info(self):
+ return DatasetInfo(features=Features({"text": Value("string")}))
+
+ def _split_generators(self, dl_manager):
+ return [SplitGenerator(name=Split.TRAIN)]
+
+ def _build_pcollection(self, pipeline):
+ import apache_beam as beam
+
+ def _process(item):
+ for i in range(10):
+ yield f"{i}_{item}", {"text": "foo"}
+
+ return pipeline | "Initialize" >> beam.Create(range(10)) | "Extract content" >> beam.FlatMap(_process)
+
+
class DummyGeneratorBasedBuilderWithIntegers(GeneratorBasedBuilder):
def _info(self):
return DatasetInfo(features=Features({"id": Value("int8")}))
@@ -690,6 +722,41 @@ def test_cache_dir_for_data_dir(self):
self.assertNotEqual(builder.cache_dir, other_builder.cache_dir)
+def test_arrow_based_download_and_prepare(tmp_path):
+ builder = DummyArrowBasedBuilder(cache_dir=tmp_path)
+ builder.download_and_prepare()
+ assert os.path.exists(
+ os.path.join(
+ tmp_path,
+ builder.name,
+ "default",
+ "0.0.0",
+ f"{builder.name}-train.arrow",
+ )
+ )
+ assert builder.info.features, Features({"text": Value("string")})
+ assert builder.info.splits["train"].num_examples, 100
+ assert os.path.exists(os.path.join(tmp_path, builder.name, "default", "0.0.0", "dataset_info.json"))
+
+
+@require_beam
+def test_beam_based_download_and_prepare(tmp_path):
+ builder = DummyBeamBasedBuilder(cache_dir=tmp_path, beam_runner="DirectRunner")
+ builder.download_and_prepare()
+ assert os.path.exists(
+ os.path.join(
+ tmp_path,
+ builder.name,
+ "default",
+ "0.0.0",
+ f"{builder.name}-train.arrow",
+ )
+ )
+ assert builder.info.features, Features({"text": Value("string")})
+ assert builder.info.splits["train"].num_examples, 100
+ assert os.path.exists(os.path.join(tmp_path, builder.name, "default", "0.0.0", "dataset_info.json"))
+
+
@pytest.mark.parametrize(
"split, expected_dataset_class, expected_dataset_length",
[
@@ -846,3 +913,58 @@ def test_builder_config_version(builder_class, kwargs, tmp_path):
cache_dir = str(tmp_path)
builder = builder_class(cache_dir=cache_dir, **kwargs)
assert builder.config.version == "2.0.0"
+
+
+def test_builder_with_filesystem_download_and_prepare(tmp_path, mockfs):
+ builder = DummyGeneratorBasedBuilder(cache_dir=tmp_path)
+ builder.download_and_prepare("mock://my_dataset", storage_options=mockfs.storage_options)
+ assert builder._output_dir.startswith("mock://my_dataset")
+ assert is_local_path(builder._cache_downloaded_dir)
+ assert isinstance(builder._fs, type(mockfs))
+ assert builder._fs.storage_options == mockfs.storage_options
+ assert mockfs.exists("my_dataset/dataset_info.json")
+ assert mockfs.exists(f"my_dataset/{builder.name}-train.arrow")
+ assert not mockfs.exists("my_dataset.incomplete")
+
+
+def test_builder_with_filesystem_download_and_prepare_reload(tmp_path, mockfs, caplog):
+ builder = DummyGeneratorBasedBuilder(cache_dir=tmp_path)
+ mockfs.makedirs("my_dataset")
+ DatasetInfo().write_to_directory("my_dataset", fs=mockfs)
+ mockfs.touch(f"my_dataset/{builder.name}-train.arrow")
+ caplog.clear()
+ builder.download_and_prepare("mock://my_dataset", storage_options=mockfs.storage_options)
+ assert "Found cached dataset" in caplog.text
+
+
+def test_generator_based_builder_download_and_prepare_as_parquet(tmp_path):
+ builder = DummyGeneratorBasedBuilder(cache_dir=tmp_path)
+ builder.download_and_prepare(file_format="parquet")
+ assert builder.info.splits["train"].num_examples, 100
+ parquet_path = os.path.join(
+ tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-00001.parquet"
+ )
+ assert os.path.exists(parquet_path)
+ assert pq.ParquetFile(parquet_path) is not None
+
+
+def test_arrow_based_builder_download_and_prepare_as_parquet(tmp_path):
+ builder = DummyArrowBasedBuilder(cache_dir=tmp_path)
+ builder.download_and_prepare(file_format="parquet")
+ assert builder.info.splits["train"].num_examples, 100
+ parquet_path = os.path.join(
+ tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-00001.parquet"
+ )
+ assert os.path.exists(parquet_path)
+ assert pq.ParquetFile(parquet_path) is not None
+
+
+def test_beam_based_builder_download_and_prepare_as_parquet(tmp_path):
+ builder = DummyBeamBasedBuilder(cache_dir=tmp_path, beam_runner="DirectRunner")
+ builder.download_and_prepare(file_format="parquet")
+ assert builder.info.splits["train"].num_examples, 100
+ parquet_path = os.path.join(
+ tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-00001.parquet"
+ )
+ assert os.path.exists(parquet_path)
+ assert pq.ParquetFile(parquet_path) is not None
diff --git a/tests/test_load.py b/tests/test_load.py
index c47ea967b00..4436f8bd88c 100644
--- a/tests/test_load.py
+++ b/tests/test_load.py
@@ -908,7 +908,7 @@ def test_load_dataset_then_move_then_reload(dataset_loading_script_dir, data_dir
os.rename(cache_dir1, cache_dir2)
caplog.clear()
dataset = load_dataset(dataset_loading_script_dir, data_dir=data_dir, split="train", cache_dir=cache_dir2)
- assert "Reusing dataset" in caplog.text
+ assert "Found cached dataset" in caplog.text
assert dataset._fingerprint == fingerprint1, "for the caching mechanism to work, fingerprint should stay the same"
dataset = load_dataset(dataset_loading_script_dir, data_dir=data_dir, split="test", cache_dir=cache_dir2)
assert dataset._fingerprint != fingerprint1