From 6d89cfbde168848a9bb4107e86988e2325e0a1f3 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 20 Jul 2022 14:03:25 +0200 Subject: [PATCH 01/51] use fsspec for caching --- src/datasets/arrow_writer.py | 15 ++- src/datasets/builder.py | 195 ++++++++++++++++++++--------------- src/datasets/info.py | 23 +++-- src/datasets/load.py | 3 + 4 files changed, 142 insertions(+), 94 deletions(-) diff --git a/src/datasets/arrow_writer.py b/src/datasets/arrow_writer.py index 4f616110ac2..ac8ca10d0f1 100644 --- a/src/datasets/arrow_writer.py +++ b/src/datasets/arrow_writer.py @@ -19,8 +19,10 @@ from dataclasses import asdict from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +import fsspec import numpy as np import pyarrow as pa +from fsspec.implementations.local import LocalFileSystem from . import config from .features import Features, Image, Value @@ -283,6 +285,7 @@ def __init__( update_features: bool = False, with_metadata: bool = True, unit: str = "examples", + storage_options: Optional[dict] = None, ): if path is None and stream is None: raise ValueError("At least one of path and stream must be provided.") @@ -305,11 +308,19 @@ def __init__( self._check_duplicates = check_duplicates self._disable_nullable = disable_nullable - self._path = path if stream is None: - self.stream = pa.OSFile(self._path, "wb") + fs_token_paths = fsspec.get_fs_token_paths(path, storage_options=storage_options) + self._fs: fsspec.AbstractFileSystem = fs_token_paths[0] + self._path = ( + fs_token_paths[2][0] + if isinstance(self._fs, LocalFileSystem) + else self._fs.protocol + "://" + fs_token_paths[2][0] + ) + self.stream = self._fs.open(fs_token_paths[2][0], "wb") self._closable_stream = True else: + self._fs = None + self._path = None self.stream = stream self._closable_stream = False diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 1ca46e78eb7..13844647745 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -29,6 +29,9 @@ from functools import partial from typing import Dict, Mapping, Optional, Tuple, Union +import fsspec +from fsspec.implementations.local import LocalFileSystem + from . import config, utils from .arrow_dataset import Dataset from .arrow_reader import ( @@ -54,7 +57,7 @@ from .splits import Split, SplitDict, SplitGenerator from .streaming import extend_dataset_builder_for_streaming from .utils import logging -from .utils.file_utils import cached_path, is_remote_url +from .utils.file_utils import cached_path from .utils.filelock import FileLock from .utils.info_utils import get_size_checksum_dict, verify_checksums, verify_splits from .utils.py_utils import ( @@ -226,6 +229,7 @@ class DatasetBuilder: ``os.path.join(data_dir, "**")`` as `data_files`. For builders that require manual download, it must be the path to the local directory containing the manually downloaded data. + storage_options (:obj:`dict`, *optional*): Key/value pairs to be passed on to the caching file-system backend, if any. name (`str`): Configuration name for the dataset. @@ -266,6 +270,7 @@ def __init__( repo_id: Optional[str] = None, data_files: Optional[Union[str, list, dict, DataFilesDict]] = None, data_dir: Optional[str] = None, + storage_options: Optional[dict] = None, name="deprecated", **config_kwargs, ): @@ -315,35 +320,36 @@ def __init__( # Prepare data dirs: # cache_dir can be a remote bucket on GCS or S3 (when using BeamBasedBuilder for distributed data processing) - self._cache_dir_root = str(cache_dir or config.HF_DATASETS_CACHE) - self._cache_dir_root = ( - self._cache_dir_root if is_remote_url(self._cache_dir_root) else os.path.expanduser(self._cache_dir_root) + + fs_token_paths = fsspec.get_fs_token_paths( + cache_dir or os.path.expanduser(config.HF_DATASETS_CACHE), storage_options=storage_options ) - path_join = posixpath.join if is_remote_url(self._cache_dir_root) else os.path.join + self._fs: fsspec.AbstractFileSystem = fs_token_paths[0] + + is_local = isinstance(self._fs, LocalFileSystem) + path_join = os.path.join if is_local else os.path.join + + self._cache_dir_root = fs_token_paths[2][0] if is_local else self._fs.protocol + "://" + fs_token_paths[2][0] + self._cache_dir = self._build_cache_dir() self._cache_downloaded_dir = ( path_join(self._cache_dir_root, config.DOWNLOADED_DATASETS_DIR) if cache_dir - else str(config.DOWNLOADED_DATASETS_PATH) - ) - self._cache_downloaded_dir = ( - self._cache_downloaded_dir - if is_remote_url(self._cache_downloaded_dir) - else os.path.expanduser(self._cache_downloaded_dir) + else os.path.expanduser(config.DOWNLOADED_DATASETS_PATH) ) - self._cache_dir = self._build_cache_dir() - if not is_remote_url(self._cache_dir_root): + + if is_local: os.makedirs(self._cache_dir_root, exist_ok=True) lock_path = os.path.join(self._cache_dir_root, self._cache_dir.replace(os.sep, "_") + ".lock") - with FileLock(lock_path): - if os.path.exists(self._cache_dir): # check if data exist - if len(os.listdir(self._cache_dir)) > 0: - logger.info("Overwrite dataset info from restored data version.") - self.info = DatasetInfo.from_directory(self._cache_dir) - else: # dir exists but no data, remove the empty dir as data aren't available anymore - logger.warning( - f"Old caching folder {self._cache_dir} for dataset {self.name} exists but not data were found. Removing it. " - ) - os.rmdir(self._cache_dir) + with FileLock(lock_path) if is_local else contextlib.nullcontext(): + if self._fs.exists(self._cache_dir): # check if data exist + if len(self._fs.listdir(self._cache_dir)) > 0: + logger.info("Overwrite dataset info from restored data version.") + self.info = DatasetInfo.from_directory(self._cache_dir, fs=self._fs) + else: # dir exists but no data, remove the empty dir as data aren't available anymore + logger.warning( + f"Old caching folder {self._cache_dir} for dataset {self.name} exists but not data were found. Removing it. " + ) + self._fs.rmdir(self._cache_dir) # Set download manager self.dl_manager = None @@ -513,7 +519,7 @@ def _relative_data_dir(self, with_version=True, with_hash=True, is_local=True) - def _build_cache_dir(self): """Return the data directory for the current version.""" - is_local = not is_remote_url(self._cache_dir_root) + is_local = isinstance(self._fs, LocalFileSystem) path_join = os.path.join if is_local else posixpath.join builder_data_dir = path_join( self._cache_dir_root, self._relative_data_dir(with_version=False, is_local=is_local) @@ -522,13 +528,13 @@ def _build_cache_dir(self): self._cache_dir_root, self._relative_data_dir(with_version=True, is_local=is_local) ) - def _other_versions_on_disk(): + def _other_versions(): """Returns previous versions on disk.""" - if not os.path.exists(builder_data_dir): + if not self._fs.exists(builder_data_dir): return [] version_dirnames = [] - for dir_name in os.listdir(builder_data_dir): + for dir_name in self._fs.listdir(builder_data_dir, detail=False): try: version_dirnames.append((utils.Version(dir_name), dir_name)) except ValueError: # Invalid version (ex: incomplete data dir) @@ -536,18 +542,17 @@ def _other_versions_on_disk(): version_dirnames.sort(reverse=True) return version_dirnames - # Check and warn if other versions exist on disk - if not is_remote_url(builder_data_dir): - version_dirs = _other_versions_on_disk() - if version_dirs: - other_version = version_dirs[0][0] - if other_version != self.config.version: - warn_msg = ( - f"Found a different version {str(other_version)} of dataset {self.name} in " - f"cache_dir {self._cache_dir_root}. Using currently defined version " - f"{str(self.config.version)}." - ) - logger.warning(warn_msg) + # Check and warn if other versions exist + version_dirs = _other_versions() + if version_dirs: + other_version = version_dirs[0][0] + if other_version != self.config.version: + warn_msg = ( + f"Found a different version {str(other_version)} of dataset {self.name} in " + f"cache_dir {self._cache_dir_root}. Using currently defined version " + f"{str(self.config.version)}." + ) + logger.warning(warn_msg) return version_data_dir @@ -604,6 +609,7 @@ def download_and_prepare( download_mode = DownloadMode(download_mode or DownloadMode.REUSE_DATASET_IF_EXISTS) verify_infos = not ignore_verifications base_path = base_path if base_path is not None else self.base_path + is_local = isinstance(self._fs, LocalFileSystem) if dl_manager is None: if download_config is None: @@ -625,25 +631,23 @@ def download_and_prepare( else False, ) - elif isinstance(dl_manager, MockDownloadManager): + elif isinstance(dl_manager, MockDownloadManager) or not is_local: try_from_hf_gcs = False self.dl_manager = dl_manager # Prevent parallel disk operations - is_local = not is_remote_url(self._cache_dir_root) if is_local: lock_path = os.path.join(self._cache_dir_root, self._cache_dir.replace(os.sep, "_") + ".lock") # File locking only with local paths; no file locking on GCS or S3 with FileLock(lock_path) if is_local else contextlib.nullcontext(): - if is_local: - data_exists = os.path.exists(self._cache_dir) - if data_exists and download_mode == DownloadMode.REUSE_DATASET_IF_EXISTS: - logger.warning(f"Reusing dataset {self.name} ({self._cache_dir})") - # We need to update the info in case some splits were added in the meantime - # for example when calling load_dataset from multiple workers. - self.info = self._load_info() - self.download_post_processing_resources(dl_manager) - return + data_exists = self._fs.exists(self._cache_dir) + if data_exists and download_mode == DownloadMode.REUSE_DATASET_IF_EXISTS: + logger.warning(f"Found cached dataset {self.name} ({self._cache_dir})") + # We need to update the info in case some splits were added in the meantime + # for example when calling load_dataset from multiple workers. + self.info = self._load_info() + self.download_post_processing_resources(dl_manager) + return logger.info(f"Generating dataset {self.name} ({self._cache_dir})") if is_local: # if cache dir is local, check for available space if not has_sufficient_disk_space(self.info.size_in_bytes or 0, directory=self._cache_dir_root): @@ -654,19 +658,20 @@ def download_and_prepare( @contextlib.contextmanager def incomplete_dir(dirname): """Create temporary dir for dirname and rename on exit.""" - if is_remote_url(dirname): - yield dirname - else: - tmp_dir = dirname + ".incomplete" - os.makedirs(tmp_dir, exist_ok=True) - try: - yield tmp_dir - if os.path.isdir(dirname): - shutil.rmtree(dirname) + tmp_dir = dirname + ".incomplete" + self._fs.makedirs(tmp_dir, exist_ok=True) + try: + yield tmp_dir + if self._fs.isdir(dirname): + self._fs.rm(dirname, recursive=True) + if is_local: + # LocalFileSystem.mv does copy + rm, it is more efficient to simply rename a local directory os.rename(tmp_dir, dirname) - finally: - if os.path.exists(tmp_dir): - shutil.rmtree(tmp_dir) + else: + self._fs.mv(tmp_dir, dirname, recursive=True) + finally: + if self._fs.exists(tmp_dir): + self._fs.rm(tmp_dir, recursive=True) # Print is intentional: we want this to always go to stdout so user has # information needed to cancel download/preparation if needed. @@ -679,8 +684,9 @@ def incomplete_dir(dirname): f"total: {size_str(self.info.size_in_bytes)}) to {self._cache_dir}..." ) else: + _dest = self._cache_dir if is_local else self._fs.protocol + "://" + self._cache_dir print( - f"Downloading and preparing dataset {self.info.builder_name}/{self.info.config_name} to {self._cache_dir}..." + f"Downloading and preparing dataset {self.info.builder_name}/{self.info.config_name} to {_dest}..." ) self._check_manual_download(dl_manager) @@ -817,6 +823,8 @@ def _download_and_prepare(self, dl_manager, verify_infos, **prepare_split_kwargs def download_post_processing_resources(self, dl_manager): for split in self.info.splits: for resource_name, resource_file_name in self._post_processing_resources(split).items(): + if not isinstance(self._fs, LocalFileSystem): + raise NotImplementedError(f"Post processing is not supported on filesystem {self._fs}") if os.sep in resource_file_name: raise ValueError(f"Resources shouldn't be in a sub-directory: {resource_file_name}") resource_path = os.path.join(self._cache_dir, resource_file_name) @@ -829,16 +837,20 @@ def download_post_processing_resources(self, dl_manager): shutil.move(downloaded_resource_path, resource_path) def _load_info(self) -> DatasetInfo: - return DatasetInfo.from_directory(self._cache_dir) + return DatasetInfo.from_directory(self._cache_dir, fs=self._fs) def _save_info(self): - lock_path = os.path.join(self._cache_dir_root, self._cache_dir.replace(os.sep, "_") + ".lock") - with FileLock(lock_path): - self.info.write_to_directory(self._cache_dir) + is_local = isinstance(self._fs, LocalFileSystem) + if is_local: + lock_path = os.path.join(self._cache_dir_root, self._cache_dir.replace(os.sep, "_") + ".lock") + with FileLock(lock_path) if is_local else contextlib.nullcontext(): + self.info.write_to_directory(self._cache_dir, fs=self._fs) def _save_infos(self): - lock_path = os.path.join(self._cache_dir_root, self._cache_dir.replace(os.sep, "_") + ".lock") - with FileLock(lock_path): + is_local = isinstance(self._fs, LocalFileSystem) + if is_local: + lock_path = os.path.join(self._cache_dir_root, self._cache_dir.replace(os.sep, "_") + ".lock") + with FileLock(lock_path) if is_local else contextlib.nullcontext(): DatasetInfosDict(**{self.config.name: self.info}).write_to_directory(self.get_imported_module_dir()) def _make_split_generators_kwargs(self, prepare_split_kwargs): @@ -876,6 +888,9 @@ def as_dataset( }) ``` """ + is_local = isinstance(self._fs, LocalFileSystem) + if not is_local: + raise NotImplementedError(f"Loading a dataset cached in a {type(self._fs).__name__} is not supported.") if not os.path.exists(self._cache_dir): raise AssertionError( f"Dataset {self.name}: could not find data in {self._cache_dir_root}. Please make sure to call " @@ -1015,6 +1030,12 @@ def as_streaming_dataset( if not isinstance(self, (GeneratorBasedBuilder, ArrowBasedBuilder)): raise ValueError(f"Builder {self.name} is not streamable.") + is_local = isinstance(self._fs, LocalFileSystem) + if not is_local: + raise NotImplementedError( + f"Loading a streaming dataset cached in a {type(self._fs).__name__} is not supported yet." + ) + dl_manager = StreamingDownloadManager( base_path=base_path or self.base_path, download_config=DownloadConfig(use_auth_token=self.use_auth_token), @@ -1189,13 +1210,16 @@ def _generate_examples(self, **kwargs): raise NotImplementedError() def _prepare_split(self, split_generator, check_duplicate_keys): + is_local = isinstance(self._fs, LocalFileSystem) + path_join = os.path.join if is_local else posixpath.join + if self.info.splits is not None: split_info = self.info.splits[split_generator.name] else: split_info = split_generator.split_info fname = f"{self.name}-{split_generator.name}.arrow" - fpath = os.path.join(self._cache_dir, fname) + fpath = self._fs.protocol + "://" + path_join(self._cache_dir, fname) generator = self._generate_examples(**split_generator.gen_kwargs) @@ -1205,6 +1229,7 @@ def _prepare_split(self, split_generator, check_duplicate_keys): writer_batch_size=self._writer_batch_size, hash_salt=split_info.name, check_duplicates=check_duplicate_keys, + storage_options=self._fs.storage_options, ) as writer: try: for key, record in logging.tqdm( @@ -1266,11 +1291,14 @@ def _generate_tables(self, **kwargs): raise NotImplementedError() def _prepare_split(self, split_generator): + is_local = isinstance(self._fs, LocalFileSystem) + path_join = os.path.join if is_local else posixpath.join + fname = f"{self.name}-{split_generator.name}.arrow" - fpath = os.path.join(self._cache_dir, fname) + fpath = self._fs.protocol + "://" + path_join(self._cache_dir, fname) generator = self._generate_tables(**split_generator.gen_kwargs) - with ArrowWriter(features=self.info.features, path=fpath) as writer: + with ArrowWriter(features=self.info.features, path=fpath, storage_options=self._fs.storage_options) as writer: for key, table in logging.tqdm( generator, unit=" tables", leave=False, disable=(not logging.is_progress_bar_enabled()) ): @@ -1404,25 +1432,24 @@ def _download_and_prepare(self, dl_manager, verify_infos): split_info.num_bytes = num_bytes def _save_info(self): - if os.path.exists(self._cache_dir): - super()._save_info() - else: - import apache_beam as beam + import apache_beam as beam - fs = beam.io.filesystems.FileSystems - with fs.create(os.path.join(self._cache_dir, config.DATASET_INFO_FILENAME)) as f: - self.info._dump_info(f) - if self.info.license: - with fs.create(os.path.join(self._cache_dir, config.LICENSE_FILENAME)) as f: - self.info._dump_license(f) + fs = beam.io.filesystems.FileSystems + path_join = os.path.join if isinstance(self._fs, LocalFileSystem) else posixpath.join + with fs.create(path_join(self._cache_dir, config.DATASET_INFO_FILENAME)) as f: + self.info._dump_info(f) + if self.info.license: + with fs.create(path_join(self._cache_dir, config.LICENSE_FILENAME)) as f: + self.info._dump_license(f) def _prepare_split(self, split_generator, pipeline): import apache_beam as beam - # To write examples to disk: + # To write examples in filesystem: split_name = split_generator.split_info.name fname = f"{self.name}-{split_name}.arrow" - fpath = os.path.join(self._cache_dir, fname) + path_join = os.path.join if isinstance(self._fs, LocalFileSystem) else posixpath.join + fpath = path_join(self._cache_dir, fname) beam_writer = BeamWriter( features=self.info.features, path=fpath, namespace=split_name, cache_dir=self._cache_dir ) diff --git a/src/datasets/info.py b/src/datasets/info.py index a3a6fa3a8c4..721fd18285f 100644 --- a/src/datasets/info.py +++ b/src/datasets/info.py @@ -35,6 +35,8 @@ from dataclasses import asdict, dataclass, field from typing import Dict, List, Optional, Union +from fsspec.implementations.local import LocalFileSystem + from . import config from .features import Features, Value from .splits import SplitDict @@ -176,10 +178,7 @@ def __post_init__(self): template.align_with_features(self.features) for template in (self.task_templates) ] - def _license_path(self, dataset_info_dir): - return os.path.join(dataset_info_dir, config.LICENSE_FILENAME) - - def write_to_directory(self, dataset_info_dir, pretty_print=False): + def write_to_directory(self, dataset_info_dir, pretty_print=False, fs=None): """Write `DatasetInfo` and license (if present) as JSON files to `dataset_info_dir`. Args: @@ -194,10 +193,14 @@ def write_to_directory(self, dataset_info_dir, pretty_print=False): >>> ds.info.write_to_directory("/path/to/directory/") ``` """ - with open(os.path.join(dataset_info_dir, config.DATASET_INFO_FILENAME), "wb") as f: + fs = fs or LocalFileSystem() + is_local = isinstance(fs, LocalFileSystem) + path_join = os.path.join if is_local else os.path.join + + with fs.open(path_join(dataset_info_dir, config.DATASET_INFO_FILENAME), "wb") as f: self._dump_info(f, pretty_print=pretty_print) if self.license: - with open(os.path.join(dataset_info_dir, config.LICENSE_FILENAME), "wb") as f: + with fs.open(path_join(dataset_info_dir, config.LICENSE_FILENAME), "wb") as f: self._dump_license(f) def _dump_info(self, file, pretty_print=False): @@ -239,7 +242,7 @@ def from_merge(cls, dataset_infos: List["DatasetInfo"]): ) @classmethod - def from_directory(cls, dataset_info_dir: str) -> "DatasetInfo": + def from_directory(cls, dataset_info_dir: str, fs=None) -> "DatasetInfo": """Create DatasetInfo from the JSON file in `dataset_info_dir`. This function updates all the dynamically generated fields (num_examples, @@ -258,11 +261,15 @@ def from_directory(cls, dataset_info_dir: str) -> "DatasetInfo": >>> ds_info = DatasetInfo.from_directory("/path/to/directory/") ``` """ + fs = fs or LocalFileSystem() logger.info(f"Loading Dataset info from {dataset_info_dir}") if not dataset_info_dir: raise ValueError("Calling DatasetInfo.from_directory() with undefined dataset_info_dir.") - with open(os.path.join(dataset_info_dir, config.DATASET_INFO_FILENAME), encoding="utf-8") as f: + is_local = isinstance(fs, LocalFileSystem) + path_join = os.path.join if is_local else os.path.join + + with fs.open(path_join(dataset_info_dir, config.DATASET_INFO_FILENAME), "r", encoding="utf-8") as f: dataset_info_dict = json.load(f) return cls.from_dict(dataset_info_dict) diff --git a/src/datasets/load.py b/src/datasets/load.py index 567587583fd..7f785cf03e7 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -1417,6 +1417,7 @@ def load_dataset_builder( download_mode: Optional[DownloadMode] = None, revision: Optional[Union[str, Version]] = None, use_auth_token: Optional[Union[bool, str]] = None, + storage_options: Optional[dict] = None, **config_kwargs, ) -> DatasetBuilder: """Load a dataset builder from the Hugging Face Hub, or a local dataset. A dataset builder can be used to inspect general information that is required to build a dataset (cache directory, config, dataset info, etc.) @@ -1470,6 +1471,7 @@ def load_dataset_builder( You can specify a different version that the default "main" by using a commit sha or a git tag of the dataset repository. use_auth_token (``str`` or :obj:`bool`, optional): Optional string or boolean to use as Bearer token for remote files on the Datasets Hub. If True, will get token from `"~/.huggingface"`. + storage_options (:obj:`dict`, optional): Key/value pairs to be passed on to the caching file-system backend, if any. **config_kwargs (additional keyword arguments): Keyword arguments to be passed to the :class:`BuilderConfig` and used in the :class:`DatasetBuilder`. @@ -1531,6 +1533,7 @@ def load_dataset_builder( hash=hash, features=features, use_auth_token=use_auth_token, + storage_options=storage_options, **builder_kwargs, **config_kwargs, ) From 606a48f0c7876805e3e7f3dd0ce2e1242e20b107 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 20 Jul 2022 15:10:37 +0200 Subject: [PATCH 02/51] add parquet writer --- src/datasets/arrow_writer.py | 65 +++++++++++++++++++++--------------- 1 file changed, 39 insertions(+), 26 deletions(-) diff --git a/src/datasets/arrow_writer.py b/src/datasets/arrow_writer.py index ac8ca10d0f1..e855d81d401 100644 --- a/src/datasets/arrow_writer.py +++ b/src/datasets/arrow_writer.py @@ -22,6 +22,7 @@ import fsspec import numpy as np import pyarrow as pa +import pyarrow.parquet as pq from fsspec.implementations.local import LocalFileSystem from . import config @@ -270,6 +271,7 @@ def __init__( class ArrowWriter: """Shuffles and writes Examples to Arrow files.""" + _WRITER_CLASS = pa.RecordBatchStreamWriter def __init__( self, @@ -379,7 +381,7 @@ def _build_writer(self, inferred_schema: pa.Schema): if self.with_metadata: schema = schema.with_metadata(self._build_metadata(DatasetInfo(features=self._features), self.fingerprint)) self._schema = schema - self.pa_writer = pa.RecordBatchStreamWriter(self.stream, schema) + self.pa_writer = self._WRITER_CLASS(self.stream, schema) @property def schema(self): @@ -562,6 +564,10 @@ def finalize(self, close_stream=True): return self._num_examples, self._num_bytes +class ParquetWriter(ArrowWriter): + _WRITER_CLASS = pq.ParquetWriter + + class BeamWriter: """ Shuffles and writes Examples to Arrow files. @@ -628,35 +634,42 @@ def finalize(self, metrics_query_result: dict): from .utils import beam_utils # Convert to arrow - logger.info(f"Converting parquet file {self._parquet_path} to arrow {self._path}") - shards = [ - metadata.path - for metadata in beam.io.filesystems.FileSystems.match([self._parquet_path + "*.parquet"])[0].metadata_list - ] - try: # stream conversion - sources = [beam.io.filesystems.FileSystems.open(shard) for shard in shards] - with beam.io.filesystems.FileSystems.create(self._path) as dest: - parquet_to_arrow(sources, dest) - except OSError as e: # broken pipe can happen if the connection is unstable, do local conversion instead - if e.errno != errno.EPIPE: # not a broken pipe - raise - logger.warning("Broken Pipe during stream conversion from parquet to arrow. Using local convert instead") - local_convert_dir = os.path.join(self._cache_dir, "beam_convert") - os.makedirs(local_convert_dir, exist_ok=True) - local_arrow_path = os.path.join(local_convert_dir, hash_url_to_filename(self._parquet_path) + ".arrow") - local_shards = [] - for shard in shards: - local_parquet_path = os.path.join(local_convert_dir, hash_url_to_filename(shard) + ".parquet") - local_shards.append(local_parquet_path) - beam_utils.download_remote_to_local(shard, local_parquet_path) - parquet_to_arrow(local_shards, local_arrow_path) - beam_utils.upload_local_to_remote(local_arrow_path, self._path) + if self._path.endswith(".arrow"): + logger.info(f"Converting parquet file {self._parquet_path} to arrow {self._path}") + shards = [ + metadata.path + for metadata in beam.io.filesystems.FileSystems.match([self._parquet_path + "*.parquet"])[0].metadata_list + ] + try: # stream conversion + sources = [beam.io.filesystems.FileSystems.open(shard) for shard in shards] + with beam.io.filesystems.FileSystems.create(self._path) as dest: + parquet_to_arrow(sources, dest) + except OSError as e: # broken pipe can happen if the connection is unstable, do local conversion instead + if e.errno != errno.EPIPE: # not a broken pipe + raise + logger.warning("Broken Pipe during stream conversion from parquet to arrow. Using local convert instead") + local_convert_dir = os.path.join(self._cache_dir, "beam_convert") + os.makedirs(local_convert_dir, exist_ok=True) + local_arrow_path = os.path.join(local_convert_dir, hash_url_to_filename(self._parquet_path) + ".arrow") + local_shards = [] + for shard in shards: + local_parquet_path = os.path.join(local_convert_dir, hash_url_to_filename(shard) + ".parquet") + local_shards.append(local_parquet_path) + beam_utils.download_remote_to_local(shard, local_parquet_path) + parquet_to_arrow(local_shards, local_arrow_path) + beam_utils.upload_local_to_remote(local_arrow_path, self._path) + output_file_metadata = beam.io.filesystems.FileSystems.match([self._path], limits=[1])[0].metadata_list[0] + num_bytes = output_file_metadata.size_in_bytes + else: + num_bytes = sum([ + metadata.size_in_bytes + for metadata in beam.io.filesystems.FileSystems.match([self._parquet_path + "*.parquet"])[0].metadata_list + ]) # Save metrics counters_dict = {metric.key.metric.name: metric.result for metric in metrics_query_result["counters"]} self._num_examples = counters_dict["num_examples"] - output_file_metadata = beam.io.filesystems.FileSystems.match([self._path], limits=[1])[0].metadata_list[0] - self._num_bytes = output_file_metadata.size_in_bytes + self._num_bytes = num_bytes return self._num_examples, self._num_bytes From cdf8dcd962ddbaffe62749f0cb3ac0687ca5427e Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 20 Jul 2022 15:10:45 +0200 Subject: [PATCH 03/51] add file_format argument --- src/datasets/builder.py | 53 ++++++++++++++++++++++++++--------------- 1 file changed, 34 insertions(+), 19 deletions(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 13844647745..71cb9a7fafb 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -41,7 +41,7 @@ MissingFilesOnHfGcsError, ReadInstruction, ) -from .arrow_writer import ArrowWriter, BeamWriter +from .arrow_writer import ArrowWriter, BeamWriter, ParquetWriter from .data_files import DataFilesDict, sanitize_patterns from .dataset_dict import DatasetDict, IterableDatasetDict from .download.download_config import DownloadConfig @@ -582,6 +582,7 @@ def download_and_prepare( dl_manager: Optional[DownloadManager] = None, base_path: Optional[str] = None, use_auth_token: Optional[Union[bool, str]] = None, + file_format: Optional[str] = None, **download_and_prepare_kwargs, ): """Downloads and prepares dataset for reading. @@ -596,6 +597,8 @@ def download_and_prepare( If not specified, the value of the `base_path` attribute (`self.base_path`) will be used instead. use_auth_token (:obj:`Union[str, bool]`, optional): Optional string or boolean to use as Bearer token for remote files on the Datasets Hub. If True, will get token from ~/.huggingface. + file_format (:obj:`str`, optional): format of the data files in which the dataset will be written. + Supported formats: "arrow", "parquet". Default to "arrow" format. **download_and_prepare_kwargs (additional keyword arguments): Keyword arguments. Example: @@ -611,6 +614,9 @@ def download_and_prepare( base_path = base_path if base_path is not None else self.base_path is_local = isinstance(self._fs, LocalFileSystem) + if file_format is not None and file_format not in ["arrow", "parquet"]: + raise ValueError(f"Unsupported file_format: {file_format}. Expected 'arrow' or 'parquet'") + if dl_manager is None: if download_config is None: download_config = DownloadConfig( @@ -708,7 +714,7 @@ def incomplete_dir(dirname): logger.warning("HF google storage unreachable. Downloading and preparing it from source") if not downloaded_from_gcs: self._download_and_prepare( - dl_manager=dl_manager, verify_infos=verify_infos, **download_and_prepare_kwargs + dl_manager=dl_manager, verify_infos=verify_infos, file_format=file_format, **download_and_prepare_kwargs ) # Sync info self.info.dataset_size = sum(split.num_bytes for split in self.info.splits.values()) @@ -758,7 +764,7 @@ def _download_prepared_from_hf_gcs(self, download_config: DownloadConfig): logger.info(f"Couldn't download resourse file {resource_file_name} from Hf google storage.") logger.info("Dataset downloaded from Hf google storage.") - def _download_and_prepare(self, dl_manager, verify_infos, **prepare_split_kwargs): + def _download_and_prepare(self, dl_manager, verify_infos, file_format=None, **prepare_split_kwargs): """Downloads and prepares dataset for reading. This is the internal implementation to overwrite called when user calls @@ -766,9 +772,10 @@ def _download_and_prepare(self, dl_manager, verify_infos, **prepare_split_kwargs the pre-processed datasets files. Args: - dl_manager: (DownloadManager) `DownloadManager` used to download and cache - data. - verify_infos: bool, if False, do not perform checksums and size tests. + dl_manager: (:obj:`DownloadManager`) `DownloadManager` used to download and cache data. + verify_infos (:obj:`bool`): if False, do not perform checksums and size tests. + file_format (:obj:`str`, optional): format of the data files in which the dataset will be written. + Supported formats: "arrow", "parquet". Default to "arrow" format. prepare_split_kwargs: Additional options. """ # Generating data for all splits @@ -796,7 +803,7 @@ def _download_and_prepare(self, dl_manager, verify_infos, **prepare_split_kwargs try: # Prepare split will record examples associated to the split - self._prepare_split(split_generator, **prepare_split_kwargs) + self._prepare_split(split_generator, file_format=file_format, **prepare_split_kwargs) except OSError as e: raise OSError( "Cannot find data file. " @@ -1134,11 +1141,13 @@ def _split_generators(self, dl_manager: DownloadManager): raise NotImplementedError() @abc.abstractmethod - def _prepare_split(self, split_generator: SplitGenerator, **kwargs): + def _prepare_split(self, split_generator: SplitGenerator, file_format: Optional[str] = None, **kwargs): """Generate the examples and record them on disk. Args: split_generator: `SplitGenerator`, Split generator to process + file_format (:obj:`str`, optional): format of the data files in which the dataset will be written. + Supported formats: "arrow", "parquet". Default to "arrow" format. **kwargs: Additional kwargs forwarded from _download_and_prepare (ex: beam pipeline) """ @@ -1209,7 +1218,7 @@ def _generate_examples(self, **kwargs): """ raise NotImplementedError() - def _prepare_split(self, split_generator, check_duplicate_keys): + def _prepare_split(self, split_generator, check_duplicate_keys, file_format=None): is_local = isinstance(self._fs, LocalFileSystem) path_join = os.path.join if is_local else posixpath.join @@ -1218,12 +1227,14 @@ def _prepare_split(self, split_generator, check_duplicate_keys): else: split_info = split_generator.split_info - fname = f"{self.name}-{split_generator.name}.arrow" + file_format = file_format or "arroq" + fname = f"{self.name}-{split_generator.name}.{file_format}" fpath = self._fs.protocol + "://" + path_join(self._cache_dir, fname) generator = self._generate_examples(**split_generator.gen_kwargs) - with ArrowWriter( + writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter + with writer_class( features=self.info.features, path=fpath, writer_batch_size=self._writer_batch_size, @@ -1248,8 +1259,8 @@ def _prepare_split(self, split_generator, check_duplicate_keys): split_generator.split_info.num_examples = num_examples split_generator.split_info.num_bytes = num_bytes - def _download_and_prepare(self, dl_manager, verify_infos): - super()._download_and_prepare(dl_manager, verify_infos, check_duplicate_keys=verify_infos) + def _download_and_prepare(self, dl_manager, verify_infos, file_format=None): + super()._download_and_prepare(dl_manager, verify_infos, file_format=file_format, check_duplicate_keys=verify_infos) def _get_examples_iterable_for_split(self, split_generator: SplitGenerator) -> ExamplesIterable: return ExamplesIterable(self._generate_examples, split_generator.gen_kwargs) @@ -1290,15 +1301,17 @@ def _generate_tables(self, **kwargs): """ raise NotImplementedError() - def _prepare_split(self, split_generator): + def _prepare_split(self, split_generator, file_format=None): is_local = isinstance(self._fs, LocalFileSystem) path_join = os.path.join if is_local else posixpath.join - fname = f"{self.name}-{split_generator.name}.arrow" + file_format = file_format or "arrow" + fname = f"{self.name}-{split_generator.name}.{file_format}" fpath = self._fs.protocol + "://" + path_join(self._cache_dir, fname) generator = self._generate_tables(**split_generator.gen_kwargs) - with ArrowWriter(features=self.info.features, path=fpath, storage_options=self._fs.storage_options) as writer: + writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter + with writer_class(features=self.info.features, path=fpath, storage_options=self._fs.storage_options) as writer: for key, table in logging.tqdm( generator, unit=" tables", leave=False, disable=(not logging.is_progress_bar_enabled()) ): @@ -1379,7 +1392,7 @@ def _build_pcollection(pipeline, extracted_dir=None): """ raise NotImplementedError() - def _download_and_prepare(self, dl_manager, verify_infos): + def _download_and_prepare(self, dl_manager, verify_infos, file_format=None): # Create the Beam pipeline and forward it to _prepare_split import apache_beam as beam @@ -1417,6 +1430,7 @@ def _download_and_prepare(self, dl_manager, verify_infos): dl_manager, verify_infos=False, pipeline=pipeline, + file_format=file_format, ) # TODO handle verify_infos in beam datasets # Run pipeline pipeline_results = pipeline.run() @@ -1442,12 +1456,13 @@ def _save_info(self): with fs.create(path_join(self._cache_dir, config.LICENSE_FILENAME)) as f: self.info._dump_license(f) - def _prepare_split(self, split_generator, pipeline): + def _prepare_split(self, split_generator, pipeline, file_format=None): import apache_beam as beam # To write examples in filesystem: split_name = split_generator.split_info.name - fname = f"{self.name}-{split_name}.arrow" + file_format = file_format or "arrow" + fname = f"{self.name}-{split_name}.{file_format}" path_join = os.path.join if isinstance(self._fs, LocalFileSystem) else posixpath.join fpath = path_join(self._cache_dir, fname) beam_writer = BeamWriter( From ad9127067bc014ad7175b0660f8f87a819546d0b Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 20 Jul 2022 15:11:00 +0200 Subject: [PATCH 04/51] style --- src/datasets/arrow_writer.py | 21 +++++++++++++++------ src/datasets/builder.py | 9 +++++++-- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/src/datasets/arrow_writer.py b/src/datasets/arrow_writer.py index e855d81d401..cd71fc8ba93 100644 --- a/src/datasets/arrow_writer.py +++ b/src/datasets/arrow_writer.py @@ -271,6 +271,7 @@ def __init__( class ArrowWriter: """Shuffles and writes Examples to Arrow files.""" + _WRITER_CLASS = pa.RecordBatchStreamWriter def __init__( @@ -638,7 +639,9 @@ def finalize(self, metrics_query_result: dict): logger.info(f"Converting parquet file {self._parquet_path} to arrow {self._path}") shards = [ metadata.path - for metadata in beam.io.filesystems.FileSystems.match([self._parquet_path + "*.parquet"])[0].metadata_list + for metadata in beam.io.filesystems.FileSystems.match([self._parquet_path + "*.parquet"])[ + 0 + ].metadata_list ] try: # stream conversion sources = [beam.io.filesystems.FileSystems.open(shard) for shard in shards] @@ -647,7 +650,9 @@ def finalize(self, metrics_query_result: dict): except OSError as e: # broken pipe can happen if the connection is unstable, do local conversion instead if e.errno != errno.EPIPE: # not a broken pipe raise - logger.warning("Broken Pipe during stream conversion from parquet to arrow. Using local convert instead") + logger.warning( + "Broken Pipe during stream conversion from parquet to arrow. Using local convert instead" + ) local_convert_dir = os.path.join(self._cache_dir, "beam_convert") os.makedirs(local_convert_dir, exist_ok=True) local_arrow_path = os.path.join(local_convert_dir, hash_url_to_filename(self._parquet_path) + ".arrow") @@ -661,10 +666,14 @@ def finalize(self, metrics_query_result: dict): output_file_metadata = beam.io.filesystems.FileSystems.match([self._path], limits=[1])[0].metadata_list[0] num_bytes = output_file_metadata.size_in_bytes else: - num_bytes = sum([ - metadata.size_in_bytes - for metadata in beam.io.filesystems.FileSystems.match([self._parquet_path + "*.parquet"])[0].metadata_list - ]) + num_bytes = sum( + [ + metadata.size_in_bytes + for metadata in beam.io.filesystems.FileSystems.match([self._parquet_path + "*.parquet"])[ + 0 + ].metadata_list + ] + ) # Save metrics counters_dict = {metric.key.metric.name: metric.result for metric in metrics_query_result["counters"]} diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 71cb9a7fafb..765643b920d 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -714,7 +714,10 @@ def incomplete_dir(dirname): logger.warning("HF google storage unreachable. Downloading and preparing it from source") if not downloaded_from_gcs: self._download_and_prepare( - dl_manager=dl_manager, verify_infos=verify_infos, file_format=file_format, **download_and_prepare_kwargs + dl_manager=dl_manager, + verify_infos=verify_infos, + file_format=file_format, + **download_and_prepare_kwargs, ) # Sync info self.info.dataset_size = sum(split.num_bytes for split in self.info.splits.values()) @@ -1260,7 +1263,9 @@ def _prepare_split(self, split_generator, check_duplicate_keys, file_format=None split_generator.split_info.num_bytes = num_bytes def _download_and_prepare(self, dl_manager, verify_infos, file_format=None): - super()._download_and_prepare(dl_manager, verify_infos, file_format=file_format, check_duplicate_keys=verify_infos) + super()._download_and_prepare( + dl_manager, verify_infos, file_format=file_format, check_duplicate_keys=verify_infos + ) def _get_examples_iterable_for_split(self, split_generator: SplitGenerator) -> ExamplesIterable: return ExamplesIterable(self._generate_examples, split_generator.gen_kwargs) From 742d2a937c2fa11e906b052b98301cd48d0e3299 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 20 Jul 2022 17:04:48 +0200 Subject: [PATCH 05/51] use "gs" instead of "gcs" for apache beam + use is_remote_filesystem --- src/datasets/arrow_writer.py | 7 ++-- src/datasets/builder.py | 41 +++++++++++-------- .../download/streaming_download_manager.py | 6 ++- src/datasets/info.py | 5 ++- 4 files changed, 33 insertions(+), 26 deletions(-) diff --git a/src/datasets/arrow_writer.py b/src/datasets/arrow_writer.py index cd71fc8ba93..ad4c5097d8c 100644 --- a/src/datasets/arrow_writer.py +++ b/src/datasets/arrow_writer.py @@ -23,7 +23,6 @@ import numpy as np import pyarrow as pa import pyarrow.parquet as pq -from fsspec.implementations.local import LocalFileSystem from . import config from .features import Features, Image, Value @@ -37,6 +36,7 @@ numpy_to_pyarrow_listarray, to_pyarrow_listarray, ) +from .filesystems import is_remote_filesystem from .info import DatasetInfo from .keyhash import DuplicatedKeysError, KeyHasher from .table import array_cast, cast_array_to_feature, table_cast @@ -314,10 +314,9 @@ def __init__( if stream is None: fs_token_paths = fsspec.get_fs_token_paths(path, storage_options=storage_options) self._fs: fsspec.AbstractFileSystem = fs_token_paths[0] + protocol = self._fs.protocol if isinstance(self._fs.protocol, str) else self._fs.protocol[-1] self._path = ( - fs_token_paths[2][0] - if isinstance(self._fs, LocalFileSystem) - else self._fs.protocol + "://" + fs_token_paths[2][0] + fs_token_paths[2][0] if not is_remote_filesystem(self._fs) else protocol + "://" + fs_token_paths[2][0] ) self.stream = self._fs.open(fs_token_paths[2][0], "wb") self._closable_stream = True diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 765643b920d..e120faf75ca 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -30,7 +30,6 @@ from typing import Dict, Mapping, Optional, Tuple, Union import fsspec -from fsspec.implementations.local import LocalFileSystem from . import config, utils from .arrow_dataset import Dataset @@ -49,6 +48,7 @@ from .download.mock_download_manager import MockDownloadManager from .download.streaming_download_manager import StreamingDownloadManager from .features import Features +from .filesystems import is_remote_filesystem from .fingerprint import Hasher from .info import DatasetInfo, DatasetInfosDict, PostProcessedInfo from .iterable_dataset import ExamplesIterable, IterableDataset, _generate_examples_from_tables_wrapper @@ -326,14 +326,15 @@ def __init__( ) self._fs: fsspec.AbstractFileSystem = fs_token_paths[0] - is_local = isinstance(self._fs, LocalFileSystem) + is_local = not is_remote_filesystem(self._fs) path_join = os.path.join if is_local else os.path.join - self._cache_dir_root = fs_token_paths[2][0] if is_local else self._fs.protocol + "://" + fs_token_paths[2][0] + protocol = self._fs.protocol if isinstance(self._fs.protocol, str) else self._fs.protocol[-1] + self._cache_dir_root = fs_token_paths[2][0] if is_local else protocol + "://" + fs_token_paths[2][0] self._cache_dir = self._build_cache_dir() self._cache_downloaded_dir = ( path_join(self._cache_dir_root, config.DOWNLOADED_DATASETS_DIR) - if cache_dir + if cache_dir and is_local else os.path.expanduser(config.DOWNLOADED_DATASETS_PATH) ) @@ -519,7 +520,7 @@ def _relative_data_dir(self, with_version=True, with_hash=True, is_local=True) - def _build_cache_dir(self): """Return the data directory for the current version.""" - is_local = isinstance(self._fs, LocalFileSystem) + is_local = not is_remote_filesystem(self._fs) path_join = os.path.join if is_local else posixpath.join builder_data_dir = path_join( self._cache_dir_root, self._relative_data_dir(with_version=False, is_local=is_local) @@ -612,7 +613,7 @@ def download_and_prepare( download_mode = DownloadMode(download_mode or DownloadMode.REUSE_DATASET_IF_EXISTS) verify_infos = not ignore_verifications base_path = base_path if base_path is not None else self.base_path - is_local = isinstance(self._fs, LocalFileSystem) + is_local = not is_remote_filesystem(self._fs) if file_format is not None and file_format not in ["arrow", "parquet"]: raise ValueError(f"Unsupported file_format: {file_format}. Expected 'arrow' or 'parquet'") @@ -690,7 +691,8 @@ def incomplete_dir(dirname): f"total: {size_str(self.info.size_in_bytes)}) to {self._cache_dir}..." ) else: - _dest = self._cache_dir if is_local else self._fs.protocol + "://" + self._cache_dir + _protocol = self._fs.protocol if isinstance(self._fs.protocol, str) else self._fs.protocol[-1] + _dest = self._cache_dir if is_local else _protocol + "://" + self._cache_dir print( f"Downloading and preparing dataset {self.info.builder_name}/{self.info.config_name} to {_dest}..." ) @@ -808,6 +810,7 @@ def _download_and_prepare(self, dl_manager, verify_infos, file_format=None, **pr # Prepare split will record examples associated to the split self._prepare_split(split_generator, file_format=file_format, **prepare_split_kwargs) except OSError as e: + raise raise OSError( "Cannot find data file. " + (self.manual_download_instructions or "") @@ -833,7 +836,7 @@ def _download_and_prepare(self, dl_manager, verify_infos, file_format=None, **pr def download_post_processing_resources(self, dl_manager): for split in self.info.splits: for resource_name, resource_file_name in self._post_processing_resources(split).items(): - if not isinstance(self._fs, LocalFileSystem): + if not not is_remote_filesystem(self._fs): raise NotImplementedError(f"Post processing is not supported on filesystem {self._fs}") if os.sep in resource_file_name: raise ValueError(f"Resources shouldn't be in a sub-directory: {resource_file_name}") @@ -850,14 +853,14 @@ def _load_info(self) -> DatasetInfo: return DatasetInfo.from_directory(self._cache_dir, fs=self._fs) def _save_info(self): - is_local = isinstance(self._fs, LocalFileSystem) + is_local = not is_remote_filesystem(self._fs) if is_local: lock_path = os.path.join(self._cache_dir_root, self._cache_dir.replace(os.sep, "_") + ".lock") with FileLock(lock_path) if is_local else contextlib.nullcontext(): self.info.write_to_directory(self._cache_dir, fs=self._fs) def _save_infos(self): - is_local = isinstance(self._fs, LocalFileSystem) + is_local = not is_remote_filesystem(self._fs) if is_local: lock_path = os.path.join(self._cache_dir_root, self._cache_dir.replace(os.sep, "_") + ".lock") with FileLock(lock_path) if is_local else contextlib.nullcontext(): @@ -898,7 +901,7 @@ def as_dataset( }) ``` """ - is_local = isinstance(self._fs, LocalFileSystem) + is_local = not is_remote_filesystem(self._fs) if not is_local: raise NotImplementedError(f"Loading a dataset cached in a {type(self._fs).__name__} is not supported.") if not os.path.exists(self._cache_dir): @@ -1040,7 +1043,7 @@ def as_streaming_dataset( if not isinstance(self, (GeneratorBasedBuilder, ArrowBasedBuilder)): raise ValueError(f"Builder {self.name} is not streamable.") - is_local = isinstance(self._fs, LocalFileSystem) + is_local = not is_remote_filesystem(self._fs) if not is_local: raise NotImplementedError( f"Loading a streaming dataset cached in a {type(self._fs).__name__} is not supported yet." @@ -1222,7 +1225,7 @@ def _generate_examples(self, **kwargs): raise NotImplementedError() def _prepare_split(self, split_generator, check_duplicate_keys, file_format=None): - is_local = isinstance(self._fs, LocalFileSystem) + is_local = not is_remote_filesystem(self._fs) path_join = os.path.join if is_local else posixpath.join if self.info.splits is not None: @@ -1232,7 +1235,8 @@ def _prepare_split(self, split_generator, check_duplicate_keys, file_format=None file_format = file_format or "arroq" fname = f"{self.name}-{split_generator.name}.{file_format}" - fpath = self._fs.protocol + "://" + path_join(self._cache_dir, fname) + protocol = self._fs.protocol if isinstance(self._fs.protocol, str) else self._fs.protocol[-1] + fpath = protocol + "://" + path_join(self._cache_dir, fname) generator = self._generate_examples(**split_generator.gen_kwargs) @@ -1307,12 +1311,13 @@ def _generate_tables(self, **kwargs): raise NotImplementedError() def _prepare_split(self, split_generator, file_format=None): - is_local = isinstance(self._fs, LocalFileSystem) + is_local = not is_remote_filesystem(self._fs) path_join = os.path.join if is_local else posixpath.join file_format = file_format or "arrow" fname = f"{self.name}-{split_generator.name}.{file_format}" - fpath = self._fs.protocol + "://" + path_join(self._cache_dir, fname) + protocol = self._fs.protocol if isinstance(self._fs.protocol, str) else self._fs.protocol[-1] + fpath = protocol + "://" + path_join(self._cache_dir, fname) generator = self._generate_tables(**split_generator.gen_kwargs) writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter @@ -1454,7 +1459,7 @@ def _save_info(self): import apache_beam as beam fs = beam.io.filesystems.FileSystems - path_join = os.path.join if isinstance(self._fs, LocalFileSystem) else posixpath.join + path_join = os.path.join if not is_remote_filesystem(self._fs) else posixpath.join with fs.create(path_join(self._cache_dir, config.DATASET_INFO_FILENAME)) as f: self.info._dump_info(f) if self.info.license: @@ -1468,7 +1473,7 @@ def _prepare_split(self, split_generator, pipeline, file_format=None): split_name = split_generator.split_info.name file_format = file_format or "arrow" fname = f"{self.name}-{split_name}.{file_format}" - path_join = os.path.join if isinstance(self._fs, LocalFileSystem) else posixpath.join + path_join = os.path.join if not is_remote_filesystem(self._fs) else posixpath.join fpath = path_join(self._cache_dir, fname) beam_writer = BeamWriter( features=self.info.features, path=fpath, namespace=split_name, cache_dir=self._cache_dir diff --git a/src/datasets/download/streaming_download_manager.py b/src/datasets/download/streaming_download_manager.py index 697f5122b44..c61531c4915 100644 --- a/src/datasets/download/streaming_download_manager.py +++ b/src/datasets/download/streaming_download_manager.py @@ -526,7 +526,8 @@ def xglob(urlpath, *, recursive=False, use_auth_token: Optional[Union[str, bool] # - If there is "**" in the pattern, `fs.glob` must be called anyway. inner_path = main_hop.split("://")[1] globbed_paths = fs.glob(inner_path) - return ["::".join([f"{fs.protocol}://{globbed_path}"] + rest_hops) for globbed_path in globbed_paths] + protocol = fs.protocol if isinstance(fs.protocol, str) else fs.protocol[-1] + return ["::".join([f"{protocol}://{globbed_path}"] + rest_hops) for globbed_path in globbed_paths] def xwalk(urlpath, use_auth_token: Optional[Union[str, bool]] = None): @@ -558,8 +559,9 @@ def xwalk(urlpath, use_auth_token: Optional[Union[str, bool]] = None): inner_path = main_hop.split("://")[1] if inner_path.strip("/") and not fs.isdir(inner_path): return [] + protocol = fs.protocol if isinstance(fs.protocol, str) else fs.protocol[-1] for dirpath, dirnames, filenames in fs.walk(inner_path): - yield "::".join([f"{fs.protocol}://{dirpath}"] + rest_hops), dirnames, filenames + yield "::".join([f"{protocol}://{dirpath}"] + rest_hops), dirnames, filenames class xPath(type(Path())): diff --git a/src/datasets/info.py b/src/datasets/info.py index 721fd18285f..e4750eecc47 100644 --- a/src/datasets/info.py +++ b/src/datasets/info.py @@ -39,6 +39,7 @@ from . import config from .features import Features, Value +from .filesystems import is_remote_filesystem from .splits import SplitDict from .tasks import TaskTemplate, task_template_from_dict from .utils import Version @@ -194,7 +195,7 @@ def write_to_directory(self, dataset_info_dir, pretty_print=False, fs=None): ``` """ fs = fs or LocalFileSystem() - is_local = isinstance(fs, LocalFileSystem) + is_local = not is_remote_filesystem(fs) path_join = os.path.join if is_local else os.path.join with fs.open(path_join(dataset_info_dir, config.DATASET_INFO_FILENAME), "wb") as f: @@ -266,7 +267,7 @@ def from_directory(cls, dataset_info_dir: str, fs=None) -> "DatasetInfo": if not dataset_info_dir: raise ValueError("Calling DatasetInfo.from_directory() with undefined dataset_info_dir.") - is_local = isinstance(fs, LocalFileSystem) + is_local = not is_remote_filesystem(fs) path_join = os.path.join if is_local else os.path.join with fs.open(path_join(dataset_info_dir, config.DATASET_INFO_FILENAME), "r", encoding="utf-8") as f: From aed8ce69e8d74d6efcfa76a95fadaee17a8ff2f3 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 20 Jul 2022 17:12:28 +0200 Subject: [PATCH 06/51] typo --- src/datasets/builder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index e120faf75ca..4e2095425fb 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -1017,8 +1017,8 @@ def _as_dataset(self, split: Union[ReadInstruction, Split] = Split.TRAIN, in_mem Returns: `Dataset` """ - - dataset_kwargs = ArrowReader(self._cache_dir, self.info).read( + cache_dir = self._fs._strip_protocol(self._cache_dir) + dataset_kwargs = ArrowReader(cache_dir, self.info).read( name=self.name, instructions=split, split_infos=self.info.splits.values(), @@ -1233,7 +1233,7 @@ def _prepare_split(self, split_generator, check_duplicate_keys, file_format=None else: split_info = split_generator.split_info - file_format = file_format or "arroq" + file_format = file_format or "arrow" fname = f"{self.name}-{split_generator.name}.{file_format}" protocol = self._fs.protocol if isinstance(self._fs.protocol, str) else self._fs.protocol[-1] fpath = protocol + "://" + path_join(self._cache_dir, fname) From 93d5660b223decfcd7d9d89266526c11404f4537 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 20 Jul 2022 17:35:44 +0200 Subject: [PATCH 07/51] fix test --- tests/test_load.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_load.py b/tests/test_load.py index 712f9b55e93..0662ceff24c 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -895,7 +895,7 @@ def test_load_dataset_then_move_then_reload(dataset_loading_script_dir, data_dir os.rename(cache_dir1, cache_dir2) caplog.clear() dataset = load_dataset(dataset_loading_script_dir, data_dir=data_dir, split="train", cache_dir=cache_dir2) - assert "Reusing dataset" in caplog.text + assert "Found cached dataset" in caplog.text assert dataset._fingerprint == fingerprint1, "for the caching mechanism to work, fingerprint should stay the same" dataset = load_dataset(dataset_loading_script_dir, data_dir=data_dir, split="test", cache_dir=cache_dir2) assert dataset._fingerprint != fingerprint1 From 65c203733ac1f7109b205ab8c4aa401c9304a1ad Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 20 Jul 2022 19:22:22 +0200 Subject: [PATCH 08/51] test ArrowWriter with filesystem --- tests/conftest.py | 1 + tests/fsspec_fixtures.py | 117 +++++++++++++++++++++++++++++++++++++ tests/test_arrow_writer.py | 11 ++++ 3 files changed, 129 insertions(+) create mode 100644 tests/fsspec_fixtures.py diff --git a/tests/conftest.py b/tests/conftest.py index f8d5e97e9a1..fe2bce47178 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,6 +15,7 @@ from datasets.arrow_dataset import Dataset from datasets.features import ClassLabel, Features, Sequence, Value +from .fsspec_fixtures import * # noqa: load fsspec fixtures from .hub_fixtures import * # noqa: load hub fixtures from .s3_fixtures import * # noqa: load s3 fixtures diff --git a/tests/fsspec_fixtures.py b/tests/fsspec_fixtures.py new file mode 100644 index 00000000000..8e160ed83ba --- /dev/null +++ b/tests/fsspec_fixtures.py @@ -0,0 +1,117 @@ +import posixpath +from pathlib import Path + +import fsspec +import pytest +from fsspec.implementations.local import AbstractFileSystem, LocalFileSystem, stringify_path + + +class MockFileSystem(AbstractFileSystem): + protocol = "mock" + + def __init__(self, *args, local_root_dir, **kwargs): + super().__init__() + self._fs = LocalFileSystem(*args, **kwargs) + self.local_root_dir = Path(local_root_dir).as_posix() + + def mkdir(self, path, *args, **kwargs): + path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) + return self._fs.mkdir(path, *args, **kwargs) + + def makedirs(self, path, *args, **kwargs): + path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) + return self._fs.makedirs(path, *args, **kwargs) + + def rmdir(self, path): + path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) + return self._fs.rmdir(path) + + def ls(self, path, *args, **kwargs): + path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) + return self._fs.ls(path, *args, **kwargs) + + def glob(self, path, *args, **kwargs): + path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) + return self._fs.glob(path, *args, **kwargs) + + def info(self, path, *args, **kwargs): + path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) + return self._fs.info(path, *args, **kwargs) + + def lexists(self, path, *args, **kwargs): + path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) + return self._fs.lexists(path, *args, **kwargs) + + def cp_file(self, path1, path2, *args, **kwargs): + path1 = posixpath.join(self.local_root_dir, self._strip_protocol(path1)) + path2 = posixpath.join(self.local_root_dir, self._strip_protocol(path2)) + return self._fs.cp_file(path1, path2, *args, **kwargs) + + def get_file(self, path1, path2, *args, **kwargs): + path1 = posixpath.join(self.local_root_dir, self._strip_protocol(path1)) + path2 = posixpath.join(self.local_root_dir, self._strip_protocol(path2)) + return self._fs.get_file(path1, path2, *args, **kwargs) + + def put_file(self, path1, path2, *args, **kwargs): + path1 = posixpath.join(self.local_root_dir, self._strip_protocol(path1)) + path2 = posixpath.join(self.local_root_dir, self._strip_protocol(path2)) + return self._fs.put_file(path1, path2, *args, **kwargs) + + def mv_file(self, path1, path2, *args, **kwargs): + path1 = posixpath.join(self.local_root_dir, self._strip_protocol(path1)) + path2 = posixpath.join(self.local_root_dir, self._strip_protocol(path2)) + return self._fs.mv_file(path1, path2, *args, **kwargs) + + def rm_file(self, path): + path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) + return self._fs.rm_file(path) + + def rm(self, path, *args, **kwargs): + path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) + return self._fs.rm(path, *args, **kwargs) + + def _open(self, path, *args, **kwargs): + path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) + return self._fs._open(path, *args, **kwargs) + + def open(self, path, *args, **kwargs): + path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) + return self._fs.open(path, *args, **kwargs) + + def touch(self, path, *args, **kwargs): + path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) + return self._fs.touch(path, *args, **kwargs) + + def created(self, path): + path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) + return self._fs.created(path) + + def modified(self, path): + path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) + return self._fs.modified(path) + + @classmethod + def _parent(cls, path): + return LocalFileSystem._parent(path) + + @classmethod + def _strip_protocol(cls, path): + path = stringify_path(path) + if path.startswith("mock://"): + path = path[7:] + return path + + def chmod(self, path, *args, **kwargs): + path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) + return self._fs.mkdir(path, *args, **kwargs) + + +@pytest.fixture +def mock_fsspec(monkeypatch): + monkeypatch.setitem(fsspec.registry.target, "mock", MockFileSystem) + + +@pytest.fixture +def mockfs(tmp_path_factory, mock_fsspec): + local_fs_dir = tmp_path_factory.mktemp("mockfs") + return MockFileSystem(local_root_dir=local_fs_dir) diff --git a/tests/test_arrow_writer.py b/tests/test_arrow_writer.py index b5542755db5..01e0a4a028c 100644 --- a/tests/test_arrow_writer.py +++ b/tests/test_arrow_writer.py @@ -299,3 +299,14 @@ def test_arrow_writer_closes_stream(raise_exception, tmp_path): pass finally: assert writer.stream.closed + + +def test_arrow_writer_with_filesystem(mockfs): + path = "mock://dataset-train.arrow" + with ArrowWriter(path=path, storage_options=mockfs.storage_options) as writer: + writer.write({"col_1": "foo", "col_2": 1}) + writer.write({"col_1": "bar", "col_2": 2}) + num_examples, num_bytes = writer.finalize() + assert num_examples == 2 + assert num_bytes > 0 + assert mockfs.exists(path) From 84d839724b08ac70f8b5da75336930ebb26ae567 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 21 Jul 2022 15:42:16 +0200 Subject: [PATCH 09/51] test parquet writer --- tests/test_arrow_writer.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/tests/test_arrow_writer.py b/tests/test_arrow_writer.py index 01e0a4a028c..606f02d9742 100644 --- a/tests/test_arrow_writer.py +++ b/tests/test_arrow_writer.py @@ -7,8 +7,9 @@ import numpy as np import pyarrow as pa import pytest +import pyarrow.parquet as pq -from datasets.arrow_writer import ArrowWriter, OptimizedTypedSequence, TypedSequence +from datasets.arrow_writer import ArrowWriter, OptimizedTypedSequence, ParquetWriter, TypedSequence from datasets.features import Array2D, ClassLabel, Features, Image, Value from datasets.features.features import Array2DExtensionType, cast_to_python_objects from datasets.keyhash import DuplicatedKeysError, InvalidKeyError @@ -304,9 +305,24 @@ def test_arrow_writer_closes_stream(raise_exception, tmp_path): def test_arrow_writer_with_filesystem(mockfs): path = "mock://dataset-train.arrow" with ArrowWriter(path=path, storage_options=mockfs.storage_options) as writer: + assert isinstance(writer._fs, type(mockfs)) + assert writer._fs.storage_options == mockfs.storage_options writer.write({"col_1": "foo", "col_2": 1}) writer.write({"col_1": "bar", "col_2": 2}) num_examples, num_bytes = writer.finalize() assert num_examples == 2 assert num_bytes > 0 assert mockfs.exists(path) + + +def test_parquet_writer_write(): + output = pa.BufferOutputStream() + with ParquetWriter(stream=output) as writer: + writer.write({"col_1": "foo", "col_2": 1}) + writer.write({"col_1": "bar", "col_2": 2}) + num_examples, num_bytes = writer.finalize() + assert num_examples == 2 + assert num_bytes > 0 + stream = pa.BufferReader(output.getvalue()) + pa_table: pa.Table = pq.read_table(stream) + assert pa_table.to_pydict() == {"col_1": ["foo", "bar"], "col_2": [1, 2]} From 4c46349da67edbb735b9132f93f0b471bb0faeb8 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 21 Jul 2022 17:58:34 +0200 Subject: [PATCH 10/51] more tests --- src/datasets/builder.py | 19 +++--- tests/fsspec_fixtures.py | 57 ++++------------ tests/test_arrow_writer.py | 2 +- tests/test_builder.py | 130 ++++++++++++++++++++++++++++++++++++- 4 files changed, 150 insertions(+), 58 deletions(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 4e2095425fb..42440dc117b 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -673,7 +673,7 @@ def incomplete_dir(dirname): self._fs.rm(dirname, recursive=True) if is_local: # LocalFileSystem.mv does copy + rm, it is more efficient to simply rename a local directory - os.rename(tmp_dir, dirname) + os.rename(self._fs._strip_protocol(tmp_dir), self._fs._strip_protocol(dirname)) else: self._fs.mv(tmp_dir, dirname, recursive=True) finally: @@ -691,8 +691,7 @@ def incomplete_dir(dirname): f"total: {size_str(self.info.size_in_bytes)}) to {self._cache_dir}..." ) else: - _protocol = self._fs.protocol if isinstance(self._fs.protocol, str) else self._fs.protocol[-1] - _dest = self._cache_dir if is_local else _protocol + "://" + self._cache_dir + _dest = self._fs._strip_protocol(self._cache_dir) if is_local else self._cache_dir print( f"Downloading and preparing dataset {self.info.builder_name}/{self.info.config_name} to {_dest}..." ) @@ -834,7 +833,7 @@ def _download_and_prepare(self, dl_manager, verify_infos, file_format=None, **pr self.info.download_size = dl_manager.downloaded_size def download_post_processing_resources(self, dl_manager): - for split in self.info.splits: + for split in self.info.splits or []: for resource_name, resource_file_name in self._post_processing_resources(split).items(): if not not is_remote_filesystem(self._fs): raise NotImplementedError(f"Post processing is not supported on filesystem {self._fs}") @@ -1234,9 +1233,9 @@ def _prepare_split(self, split_generator, check_duplicate_keys, file_format=None split_info = split_generator.split_info file_format = file_format or "arrow" - fname = f"{self.name}-{split_generator.name}.{file_format}" - protocol = self._fs.protocol if isinstance(self._fs.protocol, str) else self._fs.protocol[-1] - fpath = protocol + "://" + path_join(self._cache_dir, fname) + suffix = "-00000-of-00001" if file_format == "parquet" else "" + fname = f"{self.name}-{split_generator.name}{suffix}.{file_format}" + fpath = path_join(self._cache_dir, fname) generator = self._generate_examples(**split_generator.gen_kwargs) @@ -1315,9 +1314,9 @@ def _prepare_split(self, split_generator, file_format=None): path_join = os.path.join if is_local else posixpath.join file_format = file_format or "arrow" - fname = f"{self.name}-{split_generator.name}.{file_format}" - protocol = self._fs.protocol if isinstance(self._fs.protocol, str) else self._fs.protocol[-1] - fpath = protocol + "://" + path_join(self._cache_dir, fname) + suffix = "-00000-of-00001" if file_format == "parquet" else "" + fname = f"{self.name}-{split_generator.name}{suffix}.{file_format}" + fpath = path_join(self._cache_dir, fname) generator = self._generate_tables(**split_generator.gen_kwargs) writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter diff --git a/tests/fsspec_fixtures.py b/tests/fsspec_fixtures.py index 8e160ed83ba..7a301116ea8 100644 --- a/tests/fsspec_fixtures.py +++ b/tests/fsspec_fixtures.py @@ -12,7 +12,7 @@ class MockFileSystem(AbstractFileSystem): def __init__(self, *args, local_root_dir, **kwargs): super().__init__() self._fs = LocalFileSystem(*args, **kwargs) - self.local_root_dir = Path(local_root_dir).as_posix() + self.local_root_dir = Path(local_root_dir).resolve().as_posix() + "/" def mkdir(self, path, *args, **kwargs): path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) @@ -26,45 +26,28 @@ def rmdir(self, path): path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) return self._fs.rmdir(path) - def ls(self, path, *args, **kwargs): + def ls(self, path, detail=True, *args, **kwargs): path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) - return self._fs.ls(path, *args, **kwargs) - - def glob(self, path, *args, **kwargs): - path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) - return self._fs.glob(path, *args, **kwargs) + out = self._fs.ls(path, detail=detail, *args, **kwargs) + if detail: + return [{**info, "name": info["name"][len(self.local_root_dir) :]} for info in out] + else: + return [name[len(self.local_root_dir) :] for name in out] def info(self, path, *args, **kwargs): path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) - return self._fs.info(path, *args, **kwargs) - - def lexists(self, path, *args, **kwargs): - path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) - return self._fs.lexists(path, *args, **kwargs) + out = dict(self._fs.info(path, *args, **kwargs)) + out["name"] = out["name"][len(self.local_root_dir) :] + return out def cp_file(self, path1, path2, *args, **kwargs): path1 = posixpath.join(self.local_root_dir, self._strip_protocol(path1)) path2 = posixpath.join(self.local_root_dir, self._strip_protocol(path2)) return self._fs.cp_file(path1, path2, *args, **kwargs) - def get_file(self, path1, path2, *args, **kwargs): - path1 = posixpath.join(self.local_root_dir, self._strip_protocol(path1)) - path2 = posixpath.join(self.local_root_dir, self._strip_protocol(path2)) - return self._fs.get_file(path1, path2, *args, **kwargs) - - def put_file(self, path1, path2, *args, **kwargs): - path1 = posixpath.join(self.local_root_dir, self._strip_protocol(path1)) - path2 = posixpath.join(self.local_root_dir, self._strip_protocol(path2)) - return self._fs.put_file(path1, path2, *args, **kwargs) - - def mv_file(self, path1, path2, *args, **kwargs): - path1 = posixpath.join(self.local_root_dir, self._strip_protocol(path1)) - path2 = posixpath.join(self.local_root_dir, self._strip_protocol(path2)) - return self._fs.mv_file(path1, path2, *args, **kwargs) - - def rm_file(self, path): + def rm_file(self, path, *args, **kwargs): path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) - return self._fs.rm_file(path) + return self._fs.rm_file(path, *args, **kwargs) def rm(self, path, *args, **kwargs): path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) @@ -74,14 +57,6 @@ def _open(self, path, *args, **kwargs): path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) return self._fs._open(path, *args, **kwargs) - def open(self, path, *args, **kwargs): - path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) - return self._fs.open(path, *args, **kwargs) - - def touch(self, path, *args, **kwargs): - path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) - return self._fs.touch(path, *args, **kwargs) - def created(self, path): path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) return self._fs.created(path) @@ -90,10 +65,6 @@ def modified(self, path): path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) return self._fs.modified(path) - @classmethod - def _parent(cls, path): - return LocalFileSystem._parent(path) - @classmethod def _strip_protocol(cls, path): path = stringify_path(path) @@ -101,10 +72,6 @@ def _strip_protocol(cls, path): path = path[7:] return path - def chmod(self, path, *args, **kwargs): - path = posixpath.join(self.local_root_dir, self._strip_protocol(path)) - return self._fs.mkdir(path, *args, **kwargs) - @pytest.fixture def mock_fsspec(monkeypatch): diff --git a/tests/test_arrow_writer.py b/tests/test_arrow_writer.py index 606f02d9742..794b03be25b 100644 --- a/tests/test_arrow_writer.py +++ b/tests/test_arrow_writer.py @@ -6,8 +6,8 @@ import numpy as np import pyarrow as pa -import pytest import pyarrow.parquet as pq +import pytest from datasets.arrow_writer import ArrowWriter, OptimizedTypedSequence, ParquetWriter, TypedSequence from datasets.features import Array2D, ClassLabel, Features, Image, Value diff --git a/tests/test_builder.py b/tests/test_builder.py index 01e39e6cabd..8e2a13b4ae4 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -8,12 +8,14 @@ from unittest.mock import patch import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq import pytest from multiprocess.pool import Pool from datasets.arrow_dataset import Dataset from datasets.arrow_writer import ArrowWriter -from datasets.builder import BuilderConfig, DatasetBuilder, GeneratorBasedBuilder +from datasets.builder import ArrowBasedBuilder, BeamBasedBuilder, BuilderConfig, DatasetBuilder, GeneratorBasedBuilder from datasets.dataset_dict import DatasetDict, IterableDatasetDict from datasets.download.download_manager import DownloadMode from datasets.features import Features, Value @@ -21,8 +23,9 @@ from datasets.iterable_dataset import IterableDataset from datasets.splits import Split, SplitDict, SplitGenerator, SplitInfo from datasets.streaming import xjoin +from datasets.utils.file_utils import is_local_path -from .utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases, require_faiss +from .utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases, require_beam, require_faiss class DummyBuilder(DatasetBuilder): @@ -57,6 +60,35 @@ def _generate_examples(self): yield i, {"text": "foo"} +class DummyArrowBasedBuilder(ArrowBasedBuilder): + def _info(self): + return DatasetInfo(features=Features({"text": Value("string")})) + + def _split_generators(self, dl_manager): + return [SplitGenerator(name=Split.TRAIN)] + + def _generate_tables(self): + for i in range(10): + yield i, pa.table({"text": ["foo"] * 10}) + + +class DummyBeamBasedBuilder(BeamBasedBuilder): + def _info(self): + return DatasetInfo(features=Features({"text": Value("string")})) + + def _split_generators(self, dl_manager): + return [SplitGenerator(name=Split.TRAIN)] + + def _build_pcollection(self, pipeline): + import apache_beam as beam + + def _process(item): + for i in range(10): + yield f"{i}_{item}", {"text": "foo"} + + return pipeline | "Initialize" >> beam.Create(range(10)) | "Extract content" >> beam.FlatMap(_process) + + class DummyGeneratorBasedBuilderWithIntegers(GeneratorBasedBuilder): def _info(self): return DatasetInfo(features=Features({"id": Value("int8")})) @@ -690,6 +722,41 @@ def test_cache_dir_for_data_dir(self): self.assertNotEqual(builder.cache_dir, other_builder.cache_dir) +def test_arrow_based_download_and_prepare(tmp_path): + builder = DummyArrowBasedBuilder(cache_dir=tmp_path) + builder.download_and_prepare() + assert os.path.exists( + os.path.join( + tmp_path, + builder.name, + "default", + "0.0.0", + f"{builder.name}-train.arrow", + ) + ) + assert builder.info.features, Features({"text": Value("string")}) + assert builder.info.splits["train"].num_examples, 100 + assert os.path.exists(os.path.join(tmp_path, builder.name, "default", "0.0.0", "dataset_info.json")) + + +@require_beam +def test_beam_based_download_and_prepare(tmp_path): + builder = DummyBeamBasedBuilder(cache_dir=tmp_path, beam_runner="DirectRunner") + builder.download_and_prepare() + assert os.path.exists( + os.path.join( + tmp_path, + builder.name, + "default", + "0.0.0", + f"{builder.name}-train.arrow", + ) + ) + assert builder.info.features, Features({"text": Value("string")}) + assert builder.info.splits["train"].num_examples, 100 + assert os.path.exists(os.path.join(tmp_path, builder.name, "default", "0.0.0", "dataset_info.json")) + + @pytest.mark.parametrize( "split, expected_dataset_class, expected_dataset_length", [ @@ -846,3 +913,62 @@ def test_builder_config_version(builder_class, kwargs, tmp_path): cache_dir = str(tmp_path) builder = builder_class(cache_dir=cache_dir, **kwargs) assert builder.config.version == "2.0.0" + + +def test_builder_with_filesystem(mockfs): + builder = DummyGeneratorBasedBuilder(cache_dir="mock://", storage_options=mockfs.storage_options) + assert builder.cache_dir.startswith("mock://") + assert is_local_path(builder._cache_downloaded_dir) + assert isinstance(builder._fs, type(mockfs)) + assert builder._fs.storage_options == mockfs.storage_options + + +def test_builder_with_filesystem_download_and_prepare(mockfs): + builder = DummyGeneratorBasedBuilder(cache_dir="mock://", storage_options=mockfs.storage_options) + builder.download_and_prepare() + assert mockfs.exists(f"{builder.name}/default/0.0.0/dataset_info.json") + assert mockfs.exists(f"{builder.name}/default/0.0.0/{builder.name}-train.arrow") + assert not mockfs.exists(f"{builder.name}/default/0.0.0.incomplete") + + +def test_builder_with_filesystem_download_and_prepare_reload(mockfs, caplog): + builder = DummyGeneratorBasedBuilder(cache_dir="mock://", storage_options=mockfs.storage_options) + mockfs.makedirs(f"{builder.name}/default/0.0.0") + DatasetInfo().write_to_directory(f"{builder.name}/default/0.0.0", fs=mockfs) + mockfs.touch(f"{builder.name}/default/0.0.0/{builder.name}-train.arrow") + caplog.clear() + builder.download_and_prepare() + assert "Found cached dataset" in caplog.text + + +def test_generator_based_builder_download_and_prepare_as_parquet(tmp_path): + builder = DummyGeneratorBasedBuilder(cache_dir=tmp_path) + builder.download_and_prepare(file_format="parquet") + assert builder.info.splits["train"].num_examples, 100 + parquet_path = os.path.join( + tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-00001.parquet" + ) + assert os.path.exists(parquet_path) + assert pq.ParquetFile(parquet_path) is not None + + +def test_arrow_based_builder_download_and_prepare_as_parquet(tmp_path): + builder = DummyArrowBasedBuilder(cache_dir=tmp_path) + builder.download_and_prepare(file_format="parquet") + assert builder.info.splits["train"].num_examples, 100 + parquet_path = os.path.join( + tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-00001.parquet" + ) + assert os.path.exists(parquet_path) + assert pq.ParquetFile(parquet_path) is not None + + +def test_beam_based_builder_download_and_prepare_as_parquet(tmp_path): + builder = DummyBeamBasedBuilder(cache_dir=tmp_path, beam_runner="DirectRunner") + builder.download_and_prepare(file_format="parquet") + assert builder.info.splits["train"].num_examples, 100 + parquet_path = os.path.join( + tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-00001.parquet" + ) + assert os.path.exists(parquet_path) + assert pq.ParquetFile(parquet_path) is not None From 033a3b83afd88182a4cf5093176efef79238fd64 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 21 Jul 2022 18:15:27 +0200 Subject: [PATCH 11/51] more tests --- tests/test_load.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/test_load.py b/tests/test_load.py index 0662ceff24c..945215dca55 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -684,6 +684,15 @@ def test_load_dataset_builder_fail(): datasets.load_dataset_builder("blabla") +def test_load_dataset_builder_with_filesystem(dataset_loading_script_dir, data_dir, mockfs): + builder = datasets.load_dataset_builder( + dataset_loading_script_dir, data_dir=data_dir, cache_dir="mock://", storage_options=mockfs.storage_options + ) + assert builder.cache_dir.startswith("mock://") + assert isinstance(builder._fs, type(mockfs)) + assert builder._fs.storage_options == mockfs.storage_options + + @pytest.mark.parametrize("keep_in_memory", [False, True]) def test_load_dataset_local(dataset_loading_script_dir, data_dir, keep_in_memory, caplog): with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase(): From ce8d7f94a8d347c6878177ce1eb6a61f847be5da Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 21 Jul 2022 19:05:04 +0200 Subject: [PATCH 12/51] fix nullcontext on 3.6 --- src/datasets/builder.py | 9 +++++---- src/datasets/utils/py_utils.py | 17 ++++++++++++++--- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 42440dc117b..0d233d30534 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -65,6 +65,7 @@ has_sufficient_disk_space, map_nested, memoize, + nullcontext, size_str, temporary_assignment, ) @@ -341,7 +342,7 @@ def __init__( if is_local: os.makedirs(self._cache_dir_root, exist_ok=True) lock_path = os.path.join(self._cache_dir_root, self._cache_dir.replace(os.sep, "_") + ".lock") - with FileLock(lock_path) if is_local else contextlib.nullcontext(): + with FileLock(lock_path) if is_local else nullcontext(): if self._fs.exists(self._cache_dir): # check if data exist if len(self._fs.listdir(self._cache_dir)) > 0: logger.info("Overwrite dataset info from restored data version.") @@ -646,7 +647,7 @@ def download_and_prepare( if is_local: lock_path = os.path.join(self._cache_dir_root, self._cache_dir.replace(os.sep, "_") + ".lock") # File locking only with local paths; no file locking on GCS or S3 - with FileLock(lock_path) if is_local else contextlib.nullcontext(): + with FileLock(lock_path) if is_local else nullcontext(): data_exists = self._fs.exists(self._cache_dir) if data_exists and download_mode == DownloadMode.REUSE_DATASET_IF_EXISTS: logger.warning(f"Found cached dataset {self.name} ({self._cache_dir})") @@ -855,14 +856,14 @@ def _save_info(self): is_local = not is_remote_filesystem(self._fs) if is_local: lock_path = os.path.join(self._cache_dir_root, self._cache_dir.replace(os.sep, "_") + ".lock") - with FileLock(lock_path) if is_local else contextlib.nullcontext(): + with FileLock(lock_path) if is_local else nullcontext(): self.info.write_to_directory(self._cache_dir, fs=self._fs) def _save_infos(self): is_local = not is_remote_filesystem(self._fs) if is_local: lock_path = os.path.join(self._cache_dir_root, self._cache_dir.replace(os.sep, "_") + ".lock") - with FileLock(lock_path) if is_local else contextlib.nullcontext(): + with FileLock(lock_path) if is_local else nullcontext(): DatasetInfosDict(**{self.config.name: self.info}).write_to_directory(self.get_imported_module_dir()) def _make_split_generators_kwargs(self, prepare_split_kwargs): diff --git a/src/datasets/utils/py_utils.py b/src/datasets/utils/py_utils.py index b6656b7f6bf..87482bb64ac 100644 --- a/src/datasets/utils/py_utils.py +++ b/src/datasets/utils/py_utils.py @@ -17,7 +17,6 @@ """ -import contextlib import functools import itertools import os @@ -151,7 +150,19 @@ def string_to_dict(string: str, pattern: str) -> Dict[str, str]: return _dict -@contextlib.contextmanager +@contextmanager +def nullcontext(): + """Context manager that does no additional processing. + Used as a stand-in for a normal context manager, when a particular + block of code is only sometimes used with a normal context manager: + cm = optional_cm if condition else nullcontext() + with cm: + # Perform operation, using optional_cm if condition is True + """ + yield + + +@contextmanager def temporary_assignment(obj, attr, value): """Temporarily assign obj.attr to value.""" original = getattr(obj, attr, None) @@ -540,7 +551,7 @@ def dump(obj, file): return -@contextlib.contextmanager +@contextmanager def _no_cache_fields(obj): try: if ( From 15dccf91972a258bc6cd24ec42bb997a34fd781c Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 21 Jul 2022 19:33:46 +0200 Subject: [PATCH 13/51] parquet_writer.write_batch is not available in pyarrow 6 --- src/datasets/arrow_writer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/datasets/arrow_writer.py b/src/datasets/arrow_writer.py index ad4c5097d8c..15cd2222a9b 100644 --- a/src/datasets/arrow_writer.py +++ b/src/datasets/arrow_writer.py @@ -536,11 +536,9 @@ def write_table(self, pa_table: pa.Table, writer_batch_size: Optional[int] = Non if self.pa_writer is None: self._build_writer(inferred_schema=pa_table.schema) pa_table = table_cast(pa_table, self._schema) - batches: List[pa.RecordBatch] = pa_table.to_batches(max_chunksize=writer_batch_size) - self._num_bytes += sum(batch.nbytes for batch in batches) + self._num_bytes += pa_table.nbytes self._num_examples += pa_table.num_rows - for batch in batches: - self.pa_writer.write_batch(batch) + self.pa_writer.write_table(pa_table, writer_batch_size) def finalize(self, close_stream=True): self.write_rows_on_file() From 3a3d784e6052e97003dad2dd13ef1821f87fb2f9 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 21 Jul 2022 19:45:23 +0200 Subject: [PATCH 14/51] remove reference to open file --- src/datasets/arrow_writer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/datasets/arrow_writer.py b/src/datasets/arrow_writer.py index 15cd2222a9b..47c226ad3a7 100644 --- a/src/datasets/arrow_writer.py +++ b/src/datasets/arrow_writer.py @@ -554,6 +554,7 @@ def finalize(self, close_stream=True): else: raise ValueError("Please pass `features` or at least one example when writing data") self.pa_writer.close() + self.pa_writer = None if close_stream: self.stream.close() logger.debug( From 3eef46dabc18d718870f32bab26068fea2498e42 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Fri, 22 Jul 2022 11:17:23 +0200 Subject: [PATCH 15/51] fix test --- src/datasets/arrow_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 929766662e2..2ca75c7f208 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1130,7 +1130,7 @@ def save_to_disk(self, dataset_path: str, fs=None): fs.makedirs(dataset_path, exist_ok=True) with fs.open(Path(dataset_path, config.DATASET_ARROW_FILENAME).as_posix(), "wb") as dataset_file: with ArrowWriter(stream=dataset_file) as writer: - writer.write_table(dataset._data) + writer.write_table(dataset._data.table) writer.finalize() with fs.open( Path(dataset_path, config.DATASET_STATE_JSON_FILENAME).as_posix(), "w", encoding="utf-8" From b480549e0e30d8651101b23e77b55040af0ed159 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Fri, 22 Jul 2022 17:30:49 +0200 Subject: [PATCH 16/51] docs --- docs/source/filesystems.mdx | 209 +++++++++++++++++------------------- 1 file changed, 101 insertions(+), 108 deletions(-) diff --git a/docs/source/filesystems.mdx b/docs/source/filesystems.mdx index 048dcafcd36..90bac715605 100644 --- a/docs/source/filesystems.mdx +++ b/docs/source/filesystems.mdx @@ -1,6 +1,8 @@ # Cloud storage -🤗 Datasets supports access to cloud storage providers through a S3 filesystem implementation: [`filesystems.S3FileSystem`]. You can save and load datasets from your Amazon S3 bucket in a Pythonic way. Take a look at the following table for other supported cloud storage providers: +🤗 Datasets supports access to cloud storage providers through a `fsspec` FileSystem implementations. +You can save and load datasets from any cloud storage in a Pythonic way. +Take a look at the following table for some example of supported cloud storage providers: | Storage provider | Filesystem implementation | |----------------------|---------------------------------------------------------------| @@ -10,11 +12,12 @@ | Dropbox | [dropboxdrivefs](https://github.com/MarineChap/dropboxdrivefs)| | Google Drive | [gdrivefs](https://github.com/intake/gdrivefs) | -This guide will show you how to save and load datasets with **s3fs** to a S3 bucket, but other filesystem implementations can be used similarly. An example is shown also for Google Cloud Storage and Azure Blob Storage. +This guide will show you how to save and load datasets with any cloud storage. +Here are examples for S3, Google Cloud Storage and Azure Blob Storage. -## Amazon S3 +## Set up your cloud storage FileSystem -### Listing datasets +### Amazon S3 1. Install the S3 dependency with 🤗 Datasets: @@ -22,163 +25,153 @@ This guide will show you how to save and load datasets with **s3fs** to a S3 buc >>> pip install datasets[s3] ``` -2. List files from a public S3 bucket with `s3.ls`: +2. Define your credentials -```py ->>> import datasets ->>> s3 = datasets.filesystems.S3FileSystem(anon=True) ->>> s3.ls('public-datasets/imdb/train') -['dataset_info.json.json','dataset.arrow','state.json'] -``` - -Access a private S3 bucket by entering your `aws_access_key_id` and `aws_secret_access_key`: +To use an anonymous connection, use `anon=True`. +Otherwise, include your `aws_access_key_id` and `aws_secret_access_key` whenever you are interacting with a private S3 bucket. ```py ->>> import datasets ->>> s3 = datasets.filesystems.S3FileSystem(key=aws_access_key_id, secret=aws_secret_access_key) ->>> s3.ls('my-private-datasets/imdb/train') -['dataset_info.json.json','dataset.arrow','state.json'] +>>> storage_options = {"anon": True} # for anynonous connection +# or use your credentials +>>> storage_options = {"key": aws_access_key_id, "secret": aws_secret_access_key} # for private buckets +# or use a botocore session +>>> import botocore +>>> s3_session = botocore.session.Session(profile="my_profile_name") +>>> storage_options = {"session": s3_session} ``` -### Saving datasets - -After you have processed your dataset, you can save it to S3 with [`Dataset.save_to_disk`]: +3. Load your FileSystem instance ```py ->>> from datasets.filesystems import S3FileSystem - -# create S3FileSystem instance ->>> s3 = S3FileSystem(anon=True) - -# saves encoded_dataset to your s3 bucket ->>> encoded_dataset.save_to_disk('s3://my-private-datasets/imdb/train', fs=s3) +>>> import s3fs +>>> fs = s3fs.S3FileSystem(**storage_options) ``` - +### Google Cloud Storage -Remember to include your `aws_access_key_id` and `aws_secret_access_key` whenever you are interacting with a private S3 bucket. +1. Install the Google Cloud Storage implementation: - +``` +>>> conda install -c conda-forge gcsfs +# or install with pip +>>> pip install gcsfs +``` -Save your dataset with `botocore.session.Session` and a custom AWS profile: +2. Define your credentials ```py ->>> import botocore ->>> from datasets.filesystems import S3FileSystem - -# creates a botocore session with the provided AWS profile ->>> s3_session = botocore.session.Session(profile='my_profile_name') - -# create S3FileSystem instance with s3_session ->>> s3 = S3FileSystem(session=s3_session) - -# saves encoded_dataset to your s3 bucket ->>> encoded_dataset.save_to_disk('s3://my-private-datasets/imdb/train',fs=s3) +>>> storage_options={"token": "anon"} # for anonymous connection +# or use your credentials of your default gcloud credentials or from the google metadata service +>>> storage_options={"project": "my-google-project"} +# or use your credentials from elsewhere, see the documentation at https://gcsfs.readthedocs.io/ +>>> storage_options={"project": "my-google-project", "token": TOKEN} ``` -### Loading datasets - -When you are ready to use your dataset again, reload it with [`Dataset.load_from_disk`]: +3. Load your FileSystem instance ```py ->>> from datasets import load_from_disk ->>> from datasets.filesystems import S3FileSystem +>>> import gcsfs +>>> fs = gcsfs.GCSFileSystem(**storage_options) +``` -# create S3FileSystem without credentials ->>> s3 = S3FileSystem(anon=True) +### Azure Blob Storage -# load encoded_dataset to from s3 bucket ->>> dataset = load_from_disk('s3://a-public-datasets/imdb/train',fs=s3) +1. Install the Azure Blob Storage implementation: ->>> print(len(dataset)) ->>> # 25000 +``` +>>> conda install -c conda-forge adlfs +# or install with pip +>>> pip install adlfs ``` -Load with `botocore.session.Session` and custom AWS profile: +2. Define your credentials ```py ->>> import botocore ->>> from datasets.filesystems import S3FileSystem +>>> storage_options = {"anon": True} # for anonymous connection +# or use your credentials +>>> storage_options = {"account_name": ACCOUNT_NAME, "account_key": ACCOUNT_KEY) # gen 2 filesystem +# or use your credentials with the gen 1 filesystem +>>> storage_options={"tenant_id": TENANT_ID, "client_id": CLIENT_ID, "client_secret": CLIENT_SECRET} +``` -# create S3FileSystem instance with aws_access_key_id and aws_secret_access_key ->>> s3_session = botocore.session.Session(profile='my_profile_name') +3. Load your FileSystem instance -# create S3FileSystem instance with s3_session ->>> s3 = S3FileSystem(session=s3_session) +```py +>>> import adlfs +>>> fs = adlfs.AzureBlobFileSystem(**storage_options) +``` -# load encoded_dataset to from s3 bucket ->>> dataset = load_from_disk('s3://my-private-datasets/imdb/train',fs=s3) +## Load and Save your datasets using your cloud storage FileSystem ->>> print(len(dataset)) ->>> # 25000 -``` +### Load datasets into a cloud storage -## Google Cloud Storage +You can load and cache a dataset into your cloud storage by specifying a remote `cache_dir` in `load_dataset`. +Don't forget to use the previously defined `storage_options` containing your credentials to write into a private cloud storage. -1. Install the Google Cloud Storage implementation: +Load a dataset from the Hugging Face Hub (see [how to load from the Hugging Face Hub](./loading#hugging-face-hub)): -``` ->>> conda install -c conda-forge gcsfs -# or install with pip ->>> pip install gcsfs +```py +>>> cache_dir = "s3://my-bucket/datasets-cache" +>>> builder = load_dataset_builder("imdb", cache_dir=cache_dir, storage_options=storage_options) +>>> builder.download_and_prepare(file_format="parquet") ``` -2. Save your dataset: +Load a dataset using a loading script (see [how to load a local loading script](./loading#local-loading-script)): ```py ->>> import gcsfs +>>> cache_dir = "s3://my-bucket/datasets-cache" +>>> builder = load_dataset_builder("path/to/local/loading_script/loading_script.py", cache_dir=cache_dir, storage_options=storage_options) +>>> builder.download_and_prepare(file_format="parquet") +``` -# create GCSFileSystem instance using default gcloud credentials with project ->>> gcs = gcsfs.GCSFileSystem(project='my-google-project') +Load your own data files (see [how to load local and remote files](./loading#local-and-remote-files)): -# saves encoded_dataset to your gcs bucket ->>> encoded_dataset.save_to_disk('gcs://my-private-datasets/imdb/train', fs=gcs) +```py +>>> data_files = {"train": ["path/to/train.csv"]} +>>> cache_dir "s3://my-bucket/datasets-cache" +>>> builder = load_dataset_builder("csv", data_files=data_files, cache_dir=cache_dir, storage_options=storage_options) +>>> builder.download_and_prepare(file_format="parquet") ``` -3. Load your dataset: +It is highly recommended to save the files as compressed Parquet files to optimize I/O by specifying `file_format="parquet"`. +Otherwize the dataset is saved as an uncompressed Arrow file. -```py ->>> import gcsfs ->>> from datasets import load_from_disk +## Saving serialized datasets -# create GCSFileSystem instance using default gcloud credentials with project ->>> gcs = gcsfs.GCSFileSystem(project='my-google-project') +After you have processed your dataset, you can save it to your cloud storage with [`Dataset.save_to_disk`]: -# loads encoded_dataset from your gcs bucket ->>> dataset = load_from_disk('gcs://my-private-datasets/imdb/train', fs=gcs) +```py +# saves encoded_dataset to amazon s3 +>>> encoded_dataset.save_to_disk("s3://my-private-datasets/imdb/train", fs=fs) +# saves encoded_dataset to google cloud storage +>>> encoded_dataset.save_to_disk("gcs://my-private-datasets/imdb/train", fs=fs) +# saves encoded_dataset to microsoft azure blob/datalake +>>> encoded_dataset.save_to_disk("adl://my-private-datasets/imdb/train", fs=fs) ``` -## Azure Blob Storage - -1. Install the Azure Blob Storage implementation: + -``` ->>> conda install -c conda-forge adlfs -# or install with pip ->>> pip install adlfs -``` +Remember to define your credentials in your [FileSystem instance](#set-up-your-cloud-storage-filesystem) `fs` whenever you are interacting with a private cloud storage. -2. Save your dataset: + -```py ->>> import adlfs +## Listing serialized datasets -# create AzureBlobFileSystem instance with account_name and account_key ->>> abfs = adlfs.AzureBlobFileSystem(account_name="XXXX", account_key="XXXX") +List files from a cloud storage with your FileSystem instance `fs`, using `fs.ls`: -# saves encoded_dataset to your azure container ->>> encoded_dataset.save_to_disk('abfs://my-private-datasets/imdb/train', fs=abfs) +```py +>>> fs.ls("my-private-datasets/imdb/train") +["dataset_info.json.json","dataset.arrow","state.json"] ``` -3. Load your dataset: +### Load serialized datasets + +When you are ready to use your dataset again, reload it with [`Dataset.load_from_disk`]: ```py ->>> import adlfs >>> from datasets import load_from_disk - -# create AzureBlobFileSystem instance with account_name and account_key ->>> abfs = adlfs.AzureBlobFileSystem(account_name="XXXX", account_key="XXXX") - -# loads encoded_dataset from your azure container ->>> dataset = load_from_disk('abfs://my-private-datasets/imdb/train', fs=abfs) +# load encoded_dataset from cloud storage +>>> dataset = load_from_disk("s3://a-public-datasets/imdb/train", fs=fs) +>>> print(len(dataset)) +25000 ``` From 4757930d2a8a2d3187aa38ca5960de1fc7ffdf01 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Tue, 26 Jul 2022 19:59:26 +0200 Subject: [PATCH 17/51] shard parquet in download_and_prepare --- src/datasets/builder.py | 161 +++++++++++++++++++++++++++++++--------- tests/test_builder.py | 35 +++++++++ 2 files changed, 160 insertions(+), 36 deletions(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 0d233d30534..5f8b87836d3 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -30,6 +30,7 @@ from typing import Dict, Mapping, Optional, Tuple, Union import fsspec +from tqdm.contrib.concurrent import thread_map from . import config, utils from .arrow_dataset import Dataset @@ -62,6 +63,7 @@ from .utils.info_utils import get_size_checksum_dict, verify_checksums, verify_splits from .utils.py_utils import ( classproperty, + convert_file_size_to_int, has_sufficient_disk_space, map_nested, memoize, @@ -575,6 +577,14 @@ def get_imported_module_dir(cls): """Return the path of the module of this class or subclass.""" return os.path.dirname(inspect.getfile(inspect.getmodule(cls))) + def _rename(self, src: str, dst: str): + is_local = not is_remote_filesystem(self._fs) + if is_local: + # LocalFileSystem.mv does copy + rm, it is more efficient to simply rename a local directory + os.rename(self._fs._strip_protocol(src), self._fs._strip_protocol(dst)) + else: + self._fs.mv(src, dst, recursive=True) + def download_and_prepare( self, download_config: Optional[DownloadConfig] = None, @@ -672,11 +682,7 @@ def incomplete_dir(dirname): yield tmp_dir if self._fs.isdir(dirname): self._fs.rm(dirname, recursive=True) - if is_local: - # LocalFileSystem.mv does copy + rm, it is more efficient to simply rename a local directory - os.rename(self._fs._strip_protocol(tmp_dir), self._fs._strip_protocol(dirname)) - else: - self._fs.mv(tmp_dir, dirname, recursive=True) + self._rename(tmp_dir, dirname) finally: if self._fs.exists(tmp_dir): self._fs.rm(tmp_dir, recursive=True) @@ -1224,51 +1230,90 @@ def _generate_examples(self, **kwargs): """ raise NotImplementedError() - def _prepare_split(self, split_generator, check_duplicate_keys, file_format=None): + def _prepare_split(self, split_generator, check_duplicate_keys, file_format=None, max_shard_size=None): is_local = not is_remote_filesystem(self._fs) path_join = os.path.join if is_local else posixpath.join + file_format = file_format or "arrow" + + if max_shard_size is not None: + max_shard_size = convert_file_size_to_int(max_shard_size) + if file_format == "arrow": + raise NotImplementedError( + "Writing sharded arrow files is not supported. Please don't use max_shard_size or use parquet." + ) if self.info.splits is not None: split_info = self.info.splits[split_generator.name] else: split_info = split_generator.split_info - file_format = file_format or "arrow" - suffix = "-00000-of-00001" if file_format == "parquet" else "" + suffix = "-SSSSS-of-NNNNN" if file_format == "parquet" else "" fname = f"{self.name}-{split_generator.name}{suffix}.{file_format}" fpath = path_join(self._cache_dir, fname) generator = self._generate_examples(**split_generator.gen_kwargs) writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter - with writer_class( + + shard_id = 0 + writer = writer_class( features=self.info.features, - path=fpath, + path=fpath.replace("SSSSS", f"{shard_id:05d}"), writer_batch_size=self._writer_batch_size, hash_salt=split_info.name, check_duplicates=check_duplicate_keys, storage_options=self._fs.storage_options, - ) as writer: - try: - for key, record in logging.tqdm( - generator, - unit=" examples", - total=split_info.num_examples, - leave=False, - disable=not logging.is_progress_bar_enabled(), - desc=f"Generating {split_info.name} split", - ): - example = self.info.features.encode_example(record) - writer.write(example, key) - finally: - num_examples, num_bytes = writer.finalize() - - split_generator.split_info.num_examples = num_examples - split_generator.split_info.num_bytes = num_bytes + ) + total_num_examples, total_num_bytes = 0, 0 + try: + for key, record in logging.tqdm( + generator, + unit=" examples", + total=split_info.num_examples, + leave=False, + disable=not logging.is_progress_bar_enabled(), + desc=f"Generating {split_info.name} split", + ): + if max_shard_size is not None and writer._num_bytes > max_shard_size: + num_examples, num_bytes = writer.finalize() + total_num_examples += num_examples + total_num_bytes += num_bytes + shard_id += 1 + writer = writer_class( + features=writer._features, + path=fpath.replace("SSSSS", f"{shard_id:05d}"), + writer_batch_size=self._writer_batch_size, + hash_salt=split_info.name, + check_duplicates=check_duplicate_keys, + storage_options=self._fs.storage_options, + ) + example = self.info.features.encode_example(record) + writer.write(example, key) + finally: + num_shards = shard_id + 1 + num_examples, num_bytes = writer.finalize() + total_num_examples += num_examples + total_num_bytes += num_bytes - def _download_and_prepare(self, dl_manager, verify_infos, file_format=None): + if file_format == "parquet": + + def _rename_shard(shard_id: int): + self._rename( + fpath.replace("SSSSS", f"{shard_id:05d}"), + fpath.replace("SSSSS", f"{shard_id:05d}").replace("NNNNN", f"{num_shards:05d}"), + ) + + logger.debug(f"Renaming {num_shards} shards.") + thread_map(_rename_shard, range(num_shards), disable=True, max_workers=64) + + split_generator.split_info.num_examples = total_num_examples + split_generator.split_info.num_bytes = total_num_bytes + if self.info.features is None: + self.info.features = writer._features + + def _download_and_prepare(self, dl_manager, verify_infos, file_format=None, **kwargs): super()._download_and_prepare( - dl_manager, verify_infos, file_format=file_format, check_duplicate_keys=verify_infos + dl_manager, verify_infos, file_format=file_format, check_duplicate_keys=verify_infos, **kwargs ) def _get_examples_iterable_for_split(self, split_generator: SplitGenerator) -> ExamplesIterable: @@ -1310,26 +1355,70 @@ def _generate_tables(self, **kwargs): """ raise NotImplementedError() - def _prepare_split(self, split_generator, file_format=None): + def _prepare_split(self, split_generator, file_format=None, max_shard_size=None): is_local = not is_remote_filesystem(self._fs) path_join = os.path.join if is_local else posixpath.join - file_format = file_format or "arrow" - suffix = "-00000-of-00001" if file_format == "parquet" else "" + + if max_shard_size is not None: + if file_format == "arrow": + raise NotImplementedError( + "Writing sharded arrow files is not supported. Please don't use max_shard_size or use parquet." + ) + max_shard_size = convert_file_size_to_int(max_shard_size or "500MB") + + suffix = "-SSSSS-of-NNNNN" if file_format == "parquet" else "" fname = f"{self.name}-{split_generator.name}{suffix}.{file_format}" fpath = path_join(self._cache_dir, fname) generator = self._generate_tables(**split_generator.gen_kwargs) + writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter - with writer_class(features=self.info.features, path=fpath, storage_options=self._fs.storage_options) as writer: + + shard_id = 0 + writer = writer_class( + features=self.info.features, + path=fpath.replace("SSSSS", f"{shard_id:05d}"), + storage_options=self._fs.storage_options, + ) + total_num_examples, total_num_bytes = 0, 0 + try: for key, table in logging.tqdm( - generator, unit=" tables", leave=False, disable=(not logging.is_progress_bar_enabled()) + generator, + unit=" tables", + leave=False, + disable=not logging.is_progress_bar_enabled(), ): + if max_shard_size is not None and writer._num_bytes > max_shard_size: + num_examples, num_bytes = writer.finalize() + total_num_examples += num_examples + total_num_bytes += num_bytes + shard_id += 1 + writer = writer_class( + features=writer._features, + path=fpath.replace("SSSSS", f"{shard_id:05d}"), + storage_options=self._fs.storage_options, + ) writer.write_table(table) + finally: + num_shards = shard_id + 1 num_examples, num_bytes = writer.finalize() + total_num_examples += num_examples + total_num_bytes += num_bytes + + if file_format == "parquet": + + def _rename_shard(shard_id: int): + self._rename( + fpath.replace("SSSSS", f"{shard_id:05d}"), + fpath.replace("SSSSS", f"{shard_id:05d}").replace("NNNNN", f"{num_shards:05d}"), + ) + + logger.debug(f"Renaming {num_shards} shards.") + thread_map(_rename_shard, range(num_shards), disable=True, max_workers=64) - split_generator.split_info.num_examples = num_examples - split_generator.split_info.num_bytes = num_bytes + split_generator.split_info.num_examples = total_num_examples + split_generator.split_info.num_bytes = total_num_bytes if self.info.features is None: self.info.features = writer._features diff --git a/tests/test_builder.py b/tests/test_builder.py index 8e2a13b4ae4..dac177c1678 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -952,6 +952,24 @@ def test_generator_based_builder_download_and_prepare_as_parquet(tmp_path): assert pq.ParquetFile(parquet_path) is not None +def test_generator_based_builder_download_and_prepare_as_sharded_parquet(tmp_path): + writer_batch_size = 25 + builder = DummyGeneratorBasedBuilder(cache_dir=tmp_path, writer_batch_size=writer_batch_size) + builder.download_and_prepare(file_format="parquet", max_shard_size=1) # one batch per shard + expected_num_shards = 100 // writer_batch_size + assert builder.info.splits["train"].num_examples, 100 + parquet_path = os.path.join( + tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-{expected_num_shards:05d}.parquet" + ) + assert os.path.exists(parquet_path) + parquet_files = [ + pq.ParquetFile(parquet_path) + for parquet_path in Path(tmp_path).rglob(f"{builder.name}-train-*-of-{expected_num_shards:05d}.parquet") + ] + assert len(parquet_files) == expected_num_shards + assert sum(parquet_file.metadata.num_rows for parquet_file in parquet_files) == 100 + + def test_arrow_based_builder_download_and_prepare_as_parquet(tmp_path): builder = DummyArrowBasedBuilder(cache_dir=tmp_path) builder.download_and_prepare(file_format="parquet") @@ -963,6 +981,23 @@ def test_arrow_based_builder_download_and_prepare_as_parquet(tmp_path): assert pq.ParquetFile(parquet_path) is not None +def test_arrow_based_builder_download_and_prepare_as_sharded_parquet(tmp_path): + builder = DummyArrowBasedBuilder(cache_dir=tmp_path) + builder.download_and_prepare(file_format="parquet", max_shard_size=1) # one table per shard + expected_num_shards = 10 + assert builder.info.splits["train"].num_examples, 100 + parquet_path = os.path.join( + tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-{expected_num_shards:05d}.parquet" + ) + assert os.path.exists(parquet_path) + parquet_files = [ + pq.ParquetFile(parquet_path) + for parquet_path in Path(tmp_path).rglob(f"{builder.name}-train-*-of-{expected_num_shards:05d}.parquet") + ] + assert len(parquet_files) == expected_num_shards + assert sum(parquet_file.metadata.num_rows for parquet_file in parquet_files) == 100 + + def test_beam_based_builder_download_and_prepare_as_parquet(tmp_path): builder = DummyBeamBasedBuilder(cache_dir=tmp_path, beam_runner="DirectRunner") builder.download_and_prepare(file_format="parquet") From 32f5bf8b375a63f84dd9fc78c06ffd1dfe8e818b Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 27 Jul 2022 11:55:03 +0200 Subject: [PATCH 18/51] typing, docs, docstrings --- docs/source/filesystems.mdx | 8 ++- src/datasets/builder.py | 98 +++++++++++++++++++++++++++++-------- 2 files changed, 85 insertions(+), 21 deletions(-) diff --git a/docs/source/filesystems.mdx b/docs/source/filesystems.mdx index 90bac715605..3bf04de35a7 100644 --- a/docs/source/filesystems.mdx +++ b/docs/source/filesystems.mdx @@ -133,9 +133,15 @@ Load your own data files (see [how to load local and remote files](./loading#loc >>> builder.download_and_prepare(file_format="parquet") ``` -It is highly recommended to save the files as compressed Parquet files to optimize I/O by specifying `file_format="parquet"`. +It is highly recommended to save the files as compressed sharded Parquet files to optimize I/O by specifying `file_format="parquet"`. Otherwize the dataset is saved as an uncompressed Arrow file. +You can also specify the size of the Parquet shard using `max_shard_size` (default is 500MB): + +```py +>>> builder.download_and_prepare(file_format="parquet", max_shard_size="1GB") +``` + ## Saving serialized datasets After you have processed your dataset, you can save it to your cloud storage with [`Dataset.save_to_disk`]: diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 5f8b87836d3..61fa1d648f3 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -594,7 +594,8 @@ def download_and_prepare( dl_manager: Optional[DownloadManager] = None, base_path: Optional[str] = None, use_auth_token: Optional[Union[bool, str]] = None, - file_format: Optional[str] = None, + file_format: str = "arrow", + max_shard_size: Optional[int] = None, **download_and_prepare_kwargs, ): """Downloads and prepares dataset for reading. @@ -611,15 +612,37 @@ def download_and_prepare( If True, will get token from ~/.huggingface. file_format (:obj:`str`, optional): format of the data files in which the dataset will be written. Supported formats: "arrow", "parquet". Default to "arrow" format. + max_shard_size (:obj:`Union[str, int]`, optional): Maximum number of bytes written per shard. + Supports only the "parquet" format with a default of "500MB". The size is based on uncompressed data size, + so in practice your shard files may be smaller than `max_shard_size` thanks to Parquet compression. **download_and_prepare_kwargs (additional keyword arguments): Keyword arguments. Example: + Downdload and prepare the dataset as Arrow files that can be loaded as a Dataset using `builder.as_dataset()` + ```py >>> from datasets import load_dataset_builder - >>> builder = load_dataset_builder('rotten_tomatoes') + >>> builder = load_dataset_builder("rotten_tomatoes") >>> ds = builder.download_and_prepare() ``` + + Downdload and prepare the dataset as sharded Parquet files locally + + ```py + >>> from datasets import load_dataset_builder + >>> builder = load_dataset_builder("rotten_tomatoes", cache_dir="path/to/local/datasets-cache") + >>> ds = builder.download_and_prepare(file_format="parquet") + ``` + + Downdload and prepare the dataset as sharded Parquet files in a cloud storage + + ```py + >>> from datasets import load_dataset_builder + >>> storage_options = {"key": aws_access_key_id, "secret": aws_secret_access_key} + >>> builder = load_dataset_builder("rotten_tomatoes", cache_dir="s3://my-bucket/datasets-cache", storage_options=storage_options) + >>> ds = builder.download_and_prepare(file_format="parquet") + ``` """ download_mode = DownloadMode(download_mode or DownloadMode.REUSE_DATASET_IF_EXISTS) verify_infos = not ignore_verifications @@ -629,6 +652,11 @@ def download_and_prepare( if file_format is not None and file_format not in ["arrow", "parquet"]: raise ValueError(f"Unsupported file_format: {file_format}. Expected 'arrow' or 'parquet'") + if file_format == "arrow" and max_shard_size is not None: + raise NotImplementedError( + "Writing sharded arrow files is not supported. Please don't use max_shard_size or use parquet." + ) + if dl_manager is None: if download_config is None: download_config = DownloadConfig( @@ -649,7 +677,12 @@ def download_and_prepare( else False, ) - elif isinstance(dl_manager, MockDownloadManager) or not is_local: + if ( + isinstance(dl_manager, MockDownloadManager) + or not is_local + or file_format != "arrow" + or max_shard_size is not None + ): try_from_hf_gcs = False self.dl_manager = dl_manager @@ -721,11 +754,15 @@ def incomplete_dir(dirname): except ConnectionError: logger.warning("HF google storage unreachable. Downloading and preparing it from source") if not downloaded_from_gcs: + prepare_split_kwargs = { + "file_format": file_format, + "max_shard_size": max_shard_size, + **download_and_prepare_kwargs, + } self._download_and_prepare( dl_manager=dl_manager, verify_infos=verify_infos, - file_format=file_format, - **download_and_prepare_kwargs, + **prepare_split_kwargs, ) # Sync info self.info.dataset_size = sum(split.num_bytes for split in self.info.splits.values()) @@ -775,7 +812,7 @@ def _download_prepared_from_hf_gcs(self, download_config: DownloadConfig): logger.info(f"Couldn't download resourse file {resource_file_name} from Hf google storage.") logger.info("Dataset downloaded from Hf google storage.") - def _download_and_prepare(self, dl_manager, verify_infos, file_format=None, **prepare_split_kwargs): + def _download_and_prepare(self, dl_manager, verify_infos, **prepare_split_kwargs): """Downloads and prepares dataset for reading. This is the internal implementation to overwrite called when user calls @@ -785,9 +822,7 @@ def _download_and_prepare(self, dl_manager, verify_infos, file_format=None, **pr Args: dl_manager: (:obj:`DownloadManager`) `DownloadManager` used to download and cache data. verify_infos (:obj:`bool`): if False, do not perform checksums and size tests. - file_format (:obj:`str`, optional): format of the data files in which the dataset will be written. - Supported formats: "arrow", "parquet". Default to "arrow" format. - prepare_split_kwargs: Additional options. + prepare_split_kwargs: Additional options, such as file_format, max_shard_size. """ # Generating data for all splits split_dict = SplitDict(dataset_name=self.name) @@ -814,7 +849,7 @@ def _download_and_prepare(self, dl_manager, verify_infos, file_format=None, **pr try: # Prepare split will record examples associated to the split - self._prepare_split(split_generator, file_format=file_format, **prepare_split_kwargs) + self._prepare_split(split_generator, **prepare_split_kwargs) except OSError as e: raise raise OSError( @@ -1153,13 +1188,22 @@ def _split_generators(self, dl_manager: DownloadManager): raise NotImplementedError() @abc.abstractmethod - def _prepare_split(self, split_generator: SplitGenerator, file_format: Optional[str] = None, **kwargs): + def _prepare_split( + self, + split_generator: SplitGenerator, + file_format: str = "arrow", + max_shard_size: Union[None, str, int] = None, + **kwargs, + ): """Generate the examples and record them on disk. Args: split_generator: `SplitGenerator`, Split generator to process file_format (:obj:`str`, optional): format of the data files in which the dataset will be written. Supported formats: "arrow", "parquet". Default to "arrow" format. + max_shard_size (:obj:`Union[str, int]`, optional): Approximate maximum number of bytes written per shard. + Supports only the "parquet" format with a default of "500MB". The size is computed using the uncompressed data, + so in practice your shard files may be smaller than `max_shard_size` thanks to compression. **kwargs: Additional kwargs forwarded from _download_and_prepare (ex: beam pipeline) """ @@ -1230,10 +1274,15 @@ def _generate_examples(self, **kwargs): """ raise NotImplementedError() - def _prepare_split(self, split_generator, check_duplicate_keys, file_format=None, max_shard_size=None): + def _prepare_split( + self, + split_generator: SplitGenerator, + check_duplicate_keys: bool, + file_format="arrow", + max_shard_size: Union[None, int, str] = None, + ): is_local = not is_remote_filesystem(self._fs) path_join = os.path.join if is_local else posixpath.join - file_format = file_format or "arrow" if max_shard_size is not None: max_shard_size = convert_file_size_to_int(max_shard_size) @@ -1311,9 +1360,9 @@ def _rename_shard(shard_id: int): if self.info.features is None: self.info.features = writer._features - def _download_and_prepare(self, dl_manager, verify_infos, file_format=None, **kwargs): + def _download_and_prepare(self, dl_manager, verify_infos, **prepare_splits_kwargs): super()._download_and_prepare( - dl_manager, verify_infos, file_format=file_format, check_duplicate_keys=verify_infos, **kwargs + dl_manager, verify_infos, check_duplicate_keys=verify_infos, **prepare_splits_kwargs ) def _get_examples_iterable_for_split(self, split_generator: SplitGenerator) -> ExamplesIterable: @@ -1355,10 +1404,11 @@ def _generate_tables(self, **kwargs): """ raise NotImplementedError() - def _prepare_split(self, split_generator, file_format=None, max_shard_size=None): + def _prepare_split( + self, split_generator: SplitGenerator, file_format: str = "arrow", max_shard_size: Union[None, str, int] = None + ): is_local = not is_remote_filesystem(self._fs) path_join = os.path.join if is_local else posixpath.join - file_format = file_format or "arrow" if max_shard_size is not None: if file_format == "arrow": @@ -1491,7 +1541,7 @@ def _build_pcollection(pipeline, extracted_dir=None): """ raise NotImplementedError() - def _download_and_prepare(self, dl_manager, verify_infos, file_format=None): + def _download_and_prepare(self, dl_manager, verify_infos, **prepare_splits_kwargs): # Create the Beam pipeline and forward it to _prepare_split import apache_beam as beam @@ -1529,7 +1579,7 @@ def _download_and_prepare(self, dl_manager, verify_infos, file_format=None): dl_manager, verify_infos=False, pipeline=pipeline, - file_format=file_format, + **prepare_splits_kwargs, ) # TODO handle verify_infos in beam datasets # Run pipeline pipeline_results = pipeline.run() @@ -1555,9 +1605,17 @@ def _save_info(self): with fs.create(path_join(self._cache_dir, config.LICENSE_FILENAME)) as f: self.info._dump_license(f) - def _prepare_split(self, split_generator, pipeline, file_format=None): + def _prepare_split( + self, split_generator, pipeline, file_format="arrow", max_shard_size: Union[None, str, int] = None + ): import apache_beam as beam + if max_shard_size is not None: + raise NotImplementedError( + "max_shard_size is not supported for Beam datasets, please." + "Set it to None to use the default Apache Beam sharding and get the best performance." + ) + # To write examples in filesystem: split_name = split_generator.split_info.name file_format = file_format or "arrow" From 1db12b93258716c000039f5802eb87f9ef01fa47 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 27 Jul 2022 16:12:22 +0200 Subject: [PATCH 19/51] docs: dask from parquet files --- docs/source/filesystems.mdx | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/docs/source/filesystems.mdx b/docs/source/filesystems.mdx index 90bac715605..2a41d7bc7be 100644 --- a/docs/source/filesystems.mdx +++ b/docs/source/filesystems.mdx @@ -128,7 +128,7 @@ Load your own data files (see [how to load local and remote files](./loading#loc ```py >>> data_files = {"train": ["path/to/train.csv"]} ->>> cache_dir "s3://my-bucket/datasets-cache" +>>> cache_dir = "s3://my-bucket/datasets-cache" >>> builder = load_dataset_builder("csv", data_files=data_files, cache_dir=cache_dir, storage_options=storage_options) >>> builder.download_and_prepare(file_format="parquet") ``` @@ -136,6 +136,27 @@ Load your own data files (see [how to load local and remote files](./loading#loc It is highly recommended to save the files as compressed Parquet files to optimize I/O by specifying `file_format="parquet"`. Otherwize the dataset is saved as an uncompressed Arrow file. +#### Dask + +Dask is a parallel computing library and it has a pandas-like API for working with larger than memory Parquet datasets in parallel. +Dask can use multiple threads or processes on a single machine, or a cluster of machines to process data in parallel. +Dask supports local data but also data from a cloud storage. + +Therefore you can load a dataset saved as sharded Parquet files in Dask with + +```py +import dask.dataframe as dd + +df = dd.read_parquet(builder.cache_dir, storage_options=storage_options) + +# or if your dataset is split into train/valid/test +df_train = dd.read_parquet(builder.cache_dir + f"/{builder.name}-train-*.parquet", storage_options=storage_options) +df_valid = dd.read_parquet(builder.cache_dir + f"/{builder.name}-validation-*.parquet", storage_options=storage_options) +df_test = dd.read_parquet(builder.cache_dir + f"/{builder.name}-test-*.parquet", storage_options=storage_options) +``` + +You can find more about dask dataframes in their [documentation](https://docs.dask.org/en/stable/dataframe.html). + ## Saving serialized datasets After you have processed your dataset, you can save it to your cloud storage with [`Dataset.save_to_disk`]: From 874b2a08f2d0409db640fa246904d3ce5d17c590 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> Date: Wed, 27 Jul 2022 17:43:28 +0200 Subject: [PATCH 20/51] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Mario Šaško --- src/datasets/builder.py | 2 +- src/datasets/info.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 0d233d30534..317999b4ad3 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -328,7 +328,7 @@ def __init__( self._fs: fsspec.AbstractFileSystem = fs_token_paths[0] is_local = not is_remote_filesystem(self._fs) - path_join = os.path.join if is_local else os.path.join + path_join = os.path.join if is_local else posixpath.join protocol = self._fs.protocol if isinstance(self._fs.protocol, str) else self._fs.protocol[-1] self._cache_dir_root = fs_token_paths[2][0] if is_local else protocol + "://" + fs_token_paths[2][0] diff --git a/src/datasets/info.py b/src/datasets/info.py index 8ee8f8c5b7f..90a66b341bd 100644 --- a/src/datasets/info.py +++ b/src/datasets/info.py @@ -196,7 +196,7 @@ def write_to_directory(self, dataset_info_dir, pretty_print=False, fs=None): """ fs = fs or LocalFileSystem() is_local = not is_remote_filesystem(fs) - path_join = os.path.join if is_local else os.path.join + path_join = os.path.join if is_local else posixpath.join with fs.open(path_join(dataset_info_dir, config.DATASET_INFO_FILENAME), "wb") as f: self._dump_info(f, pretty_print=pretty_print) @@ -268,7 +268,7 @@ def from_directory(cls, dataset_info_dir: str, fs=None) -> "DatasetInfo": raise ValueError("Calling DatasetInfo.from_directory() with undefined dataset_info_dir.") is_local = not is_remote_filesystem(fs) - path_join = os.path.join if is_local else os.path.join + path_join = os.path.join if is_local else posixpath.join with fs.open(path_join(dataset_info_dir, config.DATASET_INFO_FILENAME), "r", encoding="utf-8") as f: dataset_info_dict = json.load(f) From f6ecb64a86cbeecc9066a81d3632efa89d470bb8 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 27 Jul 2022 17:46:02 +0200 Subject: [PATCH 21/51] use contextlib.nullcontext --- src/datasets/builder.py | 9 ++++----- src/datasets/utils/py_utils.py | 12 ------------ 2 files changed, 4 insertions(+), 17 deletions(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 317999b4ad3..1e8bab9f9d6 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -65,7 +65,6 @@ has_sufficient_disk_space, map_nested, memoize, - nullcontext, size_str, temporary_assignment, ) @@ -342,7 +341,7 @@ def __init__( if is_local: os.makedirs(self._cache_dir_root, exist_ok=True) lock_path = os.path.join(self._cache_dir_root, self._cache_dir.replace(os.sep, "_") + ".lock") - with FileLock(lock_path) if is_local else nullcontext(): + with FileLock(lock_path) if is_local else contextlib.nullcontext(): if self._fs.exists(self._cache_dir): # check if data exist if len(self._fs.listdir(self._cache_dir)) > 0: logger.info("Overwrite dataset info from restored data version.") @@ -647,7 +646,7 @@ def download_and_prepare( if is_local: lock_path = os.path.join(self._cache_dir_root, self._cache_dir.replace(os.sep, "_") + ".lock") # File locking only with local paths; no file locking on GCS or S3 - with FileLock(lock_path) if is_local else nullcontext(): + with FileLock(lock_path) if is_local else contextlib.nullcontext(): data_exists = self._fs.exists(self._cache_dir) if data_exists and download_mode == DownloadMode.REUSE_DATASET_IF_EXISTS: logger.warning(f"Found cached dataset {self.name} ({self._cache_dir})") @@ -856,14 +855,14 @@ def _save_info(self): is_local = not is_remote_filesystem(self._fs) if is_local: lock_path = os.path.join(self._cache_dir_root, self._cache_dir.replace(os.sep, "_") + ".lock") - with FileLock(lock_path) if is_local else nullcontext(): + with FileLock(lock_path) if is_local else contextlib.nullcontext(): self.info.write_to_directory(self._cache_dir, fs=self._fs) def _save_infos(self): is_local = not is_remote_filesystem(self._fs) if is_local: lock_path = os.path.join(self._cache_dir_root, self._cache_dir.replace(os.sep, "_") + ".lock") - with FileLock(lock_path) if is_local else nullcontext(): + with FileLock(lock_path) if is_local else contextlib.nullcontext(): DatasetInfosDict(**{self.config.name: self.info}).write_to_directory(self.get_imported_module_dir()) def _make_split_generators_kwargs(self, prepare_split_kwargs): diff --git a/src/datasets/utils/py_utils.py b/src/datasets/utils/py_utils.py index f6ea3262927..c2d413021df 100644 --- a/src/datasets/utils/py_utils.py +++ b/src/datasets/utils/py_utils.py @@ -187,18 +187,6 @@ def _asdict_inner(obj): return _asdict_inner(obj) -@contextmanager -def nullcontext(): - """Context manager that does no additional processing. - Used as a stand-in for a normal context manager, when a particular - block of code is only sometimes used with a normal context manager: - cm = optional_cm if condition else nullcontext() - with cm: - # Perform operation, using optional_cm if condition is True - """ - yield - - @contextmanager def temporary_assignment(obj, attr, value): """Temporarily assign obj.attr to value.""" From e7f3ac4babe90c4d2714aab9700d0b4e4246c333 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 27 Jul 2022 17:57:23 +0200 Subject: [PATCH 22/51] fix missing import --- src/datasets/info.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/datasets/info.py b/src/datasets/info.py index 90a66b341bd..6969e1afc79 100644 --- a/src/datasets/info.py +++ b/src/datasets/info.py @@ -32,6 +32,7 @@ import dataclasses import json import os +import posixpath from dataclasses import dataclass, field from typing import Dict, List, Optional, Union From df0343a9ca6bb583cfa74d29e16b4ab3e5fd7ac5 Mon Sep 17 00:00:00 2001 From: mariosasko Date: Fri, 29 Jul 2022 13:45:20 +0200 Subject: [PATCH 23/51] Use unstrip_protocol to merge protocol and path --- src/datasets/arrow_writer.py | 5 +++-- src/datasets/builder.py | 3 +-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/datasets/arrow_writer.py b/src/datasets/arrow_writer.py index 67f6b88eb29..0230b840faf 100644 --- a/src/datasets/arrow_writer.py +++ b/src/datasets/arrow_writer.py @@ -313,9 +313,10 @@ def __init__( if stream is None: fs_token_paths = fsspec.get_fs_token_paths(path, storage_options=storage_options) self._fs: fsspec.AbstractFileSystem = fs_token_paths[0] - protocol = self._fs.protocol if isinstance(self._fs.protocol, str) else self._fs.protocol[-1] self._path = ( - fs_token_paths[2][0] if not is_remote_filesystem(self._fs) else protocol + "://" + fs_token_paths[2][0] + fs_token_paths[2][0] + if not is_remote_filesystem(self._fs) + else self._fs.unstrip_protocol(fs_token_paths[2][0]) ) self.stream = self._fs.open(fs_token_paths[2][0], "wb") self._closable_stream = True diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 1e8bab9f9d6..7a0590a9e28 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -329,8 +329,7 @@ def __init__( is_local = not is_remote_filesystem(self._fs) path_join = os.path.join if is_local else posixpath.join - protocol = self._fs.protocol if isinstance(self._fs.protocol, str) else self._fs.protocol[-1] - self._cache_dir_root = fs_token_paths[2][0] if is_local else protocol + "://" + fs_token_paths[2][0] + self._cache_dir_root = fs_token_paths[2][0] if is_local else self._fs.unstrip_protocol(fs_token_paths[2][0]) self._cache_dir = self._build_cache_dir() self._cache_downloaded_dir = ( path_join(self._cache_dir_root, config.DOWNLOADED_DATASETS_DIR) From a0f84f4b74ee3dd2e51fecaf634a2cf99cc6fb24 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Fri, 29 Jul 2022 16:16:57 +0200 Subject: [PATCH 24/51] remove bad "raise" and add TODOs --- src/datasets/builder.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 7a0590a9e28..0d9c8d8f266 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -808,7 +808,6 @@ def _download_and_prepare(self, dl_manager, verify_infos, file_format=None, **pr # Prepare split will record examples associated to the split self._prepare_split(split_generator, file_format=file_format, **prepare_split_kwargs) except OSError as e: - raise raise OSError( "Cannot find data file. " + (self.manual_download_instructions or "") @@ -1239,6 +1238,8 @@ def _prepare_split(self, split_generator, check_duplicate_keys, file_format=None generator = self._generate_examples(**split_generator.gen_kwargs) writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter + + # TODO: embed the images/audio files inside parquet files. with writer_class( features=self.info.features, path=fpath, @@ -1319,6 +1320,7 @@ def _prepare_split(self, split_generator, file_format=None): generator = self._generate_tables(**split_generator.gen_kwargs) writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter + # TODO: embed the images/audio files inside parquet files. with writer_class(features=self.info.features, path=fpath, storage_options=self._fs.storage_options) as writer: for key, table in logging.tqdm( generator, unit=" tables", leave=False, disable=(not logging.is_progress_bar_enabled()) From 1b02b6606a729510a5f87044451f76bb47777083 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 25 Aug 2022 18:41:51 +0200 Subject: [PATCH 25/51] add output_dir arg to download_and_prepare --- src/datasets/builder.py | 153 ++++++++++++++++++++++------------------ 1 file changed, 83 insertions(+), 70 deletions(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 4cfebe61ffa..a5a200be5e0 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -27,6 +27,7 @@ import warnings from dataclasses import dataclass from functools import partial +from pathlib import Path from typing import Dict, Mapping, Optional, Tuple, Union import fsspec @@ -229,7 +230,6 @@ class DatasetBuilder: ``os.path.join(data_dir, "**")`` as `data_files`. For builders that require manual download, it must be the path to the local directory containing the manually downloaded data. - storage_options (:obj:`dict`, *optional*): Key/value pairs to be passed on to the caching file-system backend, if any. name (`str`): Configuration name for the dataset. @@ -270,7 +270,6 @@ def __init__( repo_id: Optional[str] = None, data_files: Optional[Union[str, list, dict, DataFilesDict]] = None, data_dir: Optional[str] = None, - storage_options: Optional[dict] = None, name="deprecated", **config_kwargs, ): @@ -319,37 +318,30 @@ def __init__( self.info.features = features # Prepare data dirs: - # cache_dir can be a remote bucket on GCS or S3 (when using BeamBasedBuilder for distributed data processing) - - fs_token_paths = fsspec.get_fs_token_paths( - cache_dir or os.path.expanduser(config.HF_DATASETS_CACHE), storage_options=storage_options - ) - self._fs: fsspec.AbstractFileSystem = fs_token_paths[0] - - is_local = not is_remote_filesystem(self._fs) - path_join = os.path.join if is_local else posixpath.join - - self._cache_dir_root = fs_token_paths[2][0] if is_local else self._fs.unstrip_protocol(fs_token_paths[2][0]) + self._cache_dir_root = str(cache_dir) or os.path.expanduser(config.HF_DATASETS_CACHE) self._cache_dir = self._build_cache_dir() self._cache_downloaded_dir = ( - path_join(self._cache_dir_root, config.DOWNLOADED_DATASETS_DIR) - if cache_dir and is_local + os.path.join(self._cache_dir_root, config.DOWNLOADED_DATASETS_DIR) + if cache_dir else os.path.expanduser(config.DOWNLOADED_DATASETS_PATH) ) - if is_local: - os.makedirs(self._cache_dir_root, exist_ok=True) - lock_path = os.path.join(self._cache_dir_root, self._cache_dir.replace(os.sep, "_") + ".lock") - with FileLock(lock_path) if is_local else contextlib.nullcontext(): - if self._fs.exists(self._cache_dir): # check if data exist - if len(self._fs.listdir(self._cache_dir)) > 0: + os.makedirs(self._cache_dir_root, exist_ok=True) + lock_path = os.path.join(self._cache_dir_root, self._cache_dir.replace(os.sep, "_") + ".lock") + with FileLock(lock_path): + if os.path.exists(self._cache_dir): # check if data exist + if len(os.listdir(self._cache_dir)) > 0: logger.info("Overwrite dataset info from restored data version.") - self.info = DatasetInfo.from_directory(self._cache_dir, fs=self._fs) + self.info = DatasetInfo.from_directory(self._cache_dir) else: # dir exists but no data, remove the empty dir as data aren't available anymore logger.warning( f"Old caching folder {self._cache_dir} for dataset {self.name} exists but not data were found. Removing it. " ) - self._fs.rmdir(self._cache_dir) + os.rmdir(self._cache_dir) + + # Store in the cache by default unless the user specifies a custom output_dir to download_and_prepare + self._output_dir = self._cache_dir + self._fs: fsspec.AbstractFileSystem = fsspec.filesystem("file") # Set download manager self.dl_manager = None @@ -495,7 +487,7 @@ def builder_configs(cls): def cache_dir(self): return self._cache_dir - def _relative_data_dir(self, with_version=True, with_hash=True, is_local=True) -> str: + def _relative_data_dir(self, with_version=True, with_hash=True) -> str: """Relative path of this dataset in cache_dir: Will be: self.name/self.config.version/self.hash/ @@ -507,34 +499,27 @@ def _relative_data_dir(self, with_version=True, with_hash=True, is_local=True) - builder_data_dir = self.name if namespace is None else f"{namespace}___{self.name}" builder_config = self.config hash = self.hash - path_join = os.path.join if is_local else posixpath.join if builder_config: # use the enriched name instead of the name to make it unique - builder_data_dir = path_join(builder_data_dir, self.config_id) + builder_data_dir = os.path.join(builder_data_dir, self.config_id) if with_version: - builder_data_dir = path_join(builder_data_dir, str(self.config.version)) + builder_data_dir = os.path.join(builder_data_dir, str(self.config.version)) if with_hash and hash and isinstance(hash, str): - builder_data_dir = path_join(builder_data_dir, hash) + builder_data_dir = os.path.join(builder_data_dir, hash) return builder_data_dir def _build_cache_dir(self): """Return the data directory for the current version.""" - is_local = not is_remote_filesystem(self._fs) - path_join = os.path.join if is_local else posixpath.join - builder_data_dir = path_join( - self._cache_dir_root, self._relative_data_dir(with_version=False, is_local=is_local) - ) - version_data_dir = path_join( - self._cache_dir_root, self._relative_data_dir(with_version=True, is_local=is_local) - ) + builder_data_dir = os.path.join(self._cache_dir_root, self._relative_data_dir(with_version=False)) + version_data_dir = os.path.join(self._cache_dir_root, self._relative_data_dir(with_version=True)) def _other_versions(): """Returns previous versions on disk.""" - if not self._fs.exists(builder_data_dir): + if not os.path.exists(builder_data_dir): return [] version_dirnames = [] - for dir_name in self._fs.listdir(builder_data_dir, detail=False): + for dir_name in os.listdir(builder_data_dir): try: version_dirnames.append((utils.Version(dir_name), dir_name)) except ValueError: # Invalid version (ex: incomplete data dir) @@ -575,6 +560,7 @@ def get_imported_module_dir(cls): def download_and_prepare( self, + output_dir: Optional[str] = None, download_config: Optional[DownloadConfig] = None, download_mode: Optional[DownloadMode] = None, ignore_verifications: bool = False, @@ -583,11 +569,14 @@ def download_and_prepare( base_path: Optional[str] = None, use_auth_token: Optional[Union[bool, str]] = None, file_format: Optional[str] = None, + storage_options: Optional[dict] = None, **download_and_prepare_kwargs, ): """Downloads and prepares dataset for reading. Args: + output_dir (:obj:`str`, optional): output directory for the dataset. + Default to this builder's ``cache_dir``, which is inside ~/.cache/huggingface/datasets by default. download_config (:class:`DownloadConfig`, optional): specific download configuration parameters. download_mode (:class:`DownloadMode`, optional): select the download/generate mode - Default to ``REUSE_DATASET_IF_EXISTS`` ignore_verifications (:obj:`bool`): Ignore the verifications of the downloaded/processed dataset information (checksums/size/splits/...) @@ -599,6 +588,7 @@ def download_and_prepare( If True, will get token from ~/.huggingface. file_format (:obj:`str`, optional): format of the data files in which the dataset will be written. Supported formats: "arrow", "parquet". Default to "arrow" format. + storage_options (:obj:`dict`, *optional*): Key/value pairs to be passed on to the caching file-system backend, if any. **download_and_prepare_kwargs (additional keyword arguments): Keyword arguments. Example: @@ -609,6 +599,11 @@ def download_and_prepare( >>> ds = builder.download_and_prepare() ``` """ + self._output_dir = output_dir if output_dir is not None else self._cache_dir + # output_dir can be a remote bucket on GCS or S3 (when using BeamBasedBuilder for distributed data processing) + fs_token_paths = fsspec.get_fs_token_paths(self._output_dir, storage_options=storage_options) + self._fs = fs_token_paths[0] + download_mode = DownloadMode(download_mode or DownloadMode.REUSE_DATASET_IF_EXISTS) verify_infos = not ignore_verifications base_path = base_path if base_path is not None else self.base_path @@ -617,6 +612,15 @@ def download_and_prepare( if file_format is not None and file_format not in ["arrow", "parquet"]: raise ValueError(f"Unsupported file_format: {file_format}. Expected 'arrow' or 'parquet'") + if self._fs._strip_protocol(self._output_dir) == "": + # We don't support the root directory, because it has no dirname, + # and we need a dirname to use a .incomplete directory + # when the dataset is being written + raise RuntimeError( + f"Unable to download and prepare the dataset at the root {self._output_dir}. " + f"Please specify a subdirectory, e.g. '{self._output_dir + self.name}'" + ) + if dl_manager is None: if download_config is None: download_config = DownloadConfig( @@ -643,20 +647,28 @@ def download_and_prepare( # Prevent parallel disk operations if is_local: - lock_path = os.path.join(self._cache_dir_root, self._cache_dir.replace(os.sep, "_") + ".lock") + lock_path = self._output_dir + "_builder.lock" + # File locking only with local paths; no file locking on GCS or S3 + self._fs.makedirs(os.path.dirname(self._output_dir), exist_ok=True) with FileLock(lock_path) if is_local else contextlib.nullcontext(): - data_exists = self._fs.exists(self._cache_dir) + + # Check if the data already exists + path_join = os.path.join if is_local else posixpath.join + data_exists = self._fs.exists(path_join(self._output_dir, config.DATASET_INFO_FILENAME)) if data_exists and download_mode == DownloadMode.REUSE_DATASET_IF_EXISTS: - logger.warning(f"Found cached dataset {self.name} ({self._cache_dir})") + logger.warning(f"Found cached dataset {self.name} ({self._output_dir})") # We need to update the info in case some splits were added in the meantime # for example when calling load_dataset from multiple workers. self.info = self._load_info() self.download_post_processing_resources(dl_manager) return - logger.info(f"Generating dataset {self.name} ({self._cache_dir})") + + logger.info(f"Generating dataset {self.name} ({self._output_dir})") if is_local: # if cache dir is local, check for available space - if not has_sufficient_disk_space(self.info.size_in_bytes or 0, directory=self._cache_dir_root): + if not has_sufficient_disk_space( + self.info.size_in_bytes or 0, directory=Path(self._output_dir).parent + ): raise OSError( f"Not enough disk space. Needed: {size_str(self.info.size_in_bytes or 0)} (download: {size_str(self.info.download_size or 0)}, generated: {size_str(self.info.dataset_size or 0)}, post-processed: {size_str(self.info.post_processing_size or 0)})" ) @@ -687,21 +699,22 @@ def incomplete_dir(dirname): f"Downloading and preparing dataset {self.info.builder_name}/{self.info.config_name} " f"(download: {size_str(self.info.download_size)}, generated: {size_str(self.info.dataset_size)}, " f"post-processed: {size_str(self.info.post_processing_size)}, " - f"total: {size_str(self.info.size_in_bytes)}) to {self._cache_dir}..." + f"total: {size_str(self.info.size_in_bytes)}) to {self._output_dir}..." ) else: - _dest = self._fs._strip_protocol(self._cache_dir) if is_local else self._cache_dir + _dest = self._fs._strip_protocol(self._output_dir) if is_local else self._output_dir print( f"Downloading and preparing dataset {self.info.builder_name}/{self.info.config_name} to {_dest}..." ) self._check_manual_download(dl_manager) - # Create a tmp dir and rename to self._cache_dir on successful exit. - with incomplete_dir(self._cache_dir) as tmp_data_dir: - # Temporarily assign _cache_dir to tmp_data_dir to avoid having to forward + # Create a tmp dir and rename to self._output_dir on successful exit. + with incomplete_dir(self._output_dir) as tmp_output_dir: + # Temporarily assign _output_dir to tmp_data_dir to avoid having to forward # it to every sub function. - with temporary_assignment(self, "_cache_dir", tmp_data_dir): + with temporary_assignment(self, "_output_dir", tmp_output_dir): + # Try to download the already prepared dataset files downloaded_from_gcs = False if try_from_hf_gcs: @@ -730,7 +743,7 @@ def incomplete_dir(dirname): self.download_post_processing_resources(dl_manager) print( - f"Dataset {self.name} downloaded and prepared to {self._cache_dir}. " + f"Dataset {self.name} downloaded and prepared to {self._output_dir}. " f"Subsequent calls will reuse this data." ) @@ -749,10 +762,10 @@ def _check_manual_download(self, dl_manager): def _download_prepared_from_hf_gcs(self, download_config: DownloadConfig): relative_data_dir = self._relative_data_dir(with_version=True, with_hash=False) - reader = ArrowReader(self._cache_dir, self.info) + reader = ArrowReader(self._output_dir, self.info) # use reader instructions to download the right files reader.download_from_hf_gcs(download_config, relative_data_dir) - downloaded_info = DatasetInfo.from_directory(self._cache_dir) + downloaded_info = DatasetInfo.from_directory(self._output_dir) self.info.update(downloaded_info) # download post processing resources remote_cache_dir = HF_GCP_BASE_URL + "/" + relative_data_dir.replace(os.sep, "/") @@ -762,7 +775,7 @@ def _download_prepared_from_hf_gcs(self, download_config: DownloadConfig): raise ValueError(f"Resources shouldn't be in a sub-directory: {resource_file_name}") try: resource_path = cached_path(remote_cache_dir + "/" + resource_file_name) - shutil.move(resource_path, os.path.join(self._cache_dir, resource_file_name)) + shutil.move(resource_path, os.path.join(self._output_dir, resource_file_name)) except ConnectionError: logger.info(f"Couldn't download resourse file {resource_file_name} from Hf google storage.") logger.info("Dataset downloaded from Hf google storage.") @@ -837,7 +850,7 @@ def download_post_processing_resources(self, dl_manager): raise NotImplementedError(f"Post processing is not supported on filesystem {self._fs}") if os.sep in resource_file_name: raise ValueError(f"Resources shouldn't be in a sub-directory: {resource_file_name}") - resource_path = os.path.join(self._cache_dir, resource_file_name) + resource_path = os.path.join(self._output_dir, resource_file_name) if not os.path.exists(resource_path): downloaded_resource_path = self._download_post_processing_resources( split, resource_name, dl_manager @@ -847,19 +860,19 @@ def download_post_processing_resources(self, dl_manager): shutil.move(downloaded_resource_path, resource_path) def _load_info(self) -> DatasetInfo: - return DatasetInfo.from_directory(self._cache_dir, fs=self._fs) + return DatasetInfo.from_directory(self._output_dir, fs=self._fs) def _save_info(self): is_local = not is_remote_filesystem(self._fs) if is_local: - lock_path = os.path.join(self._cache_dir_root, self._cache_dir.replace(os.sep, "_") + ".lock") + lock_path = self._output_dir + "_info.lock" with FileLock(lock_path) if is_local else contextlib.nullcontext(): - self.info.write_to_directory(self._cache_dir, fs=self._fs) + self.info.write_to_directory(self._output_dir, fs=self._fs) def _save_infos(self): is_local = not is_remote_filesystem(self._fs) if is_local: - lock_path = os.path.join(self._cache_dir_root, self._cache_dir.replace(os.sep, "_") + ".lock") + lock_path = self._output_dir + "_infos.lock" with FileLock(lock_path) if is_local else contextlib.nullcontext(): DatasetInfosDict(**{self.config.name: self.info}).write_to_directory(self.get_imported_module_dir()) @@ -901,14 +914,14 @@ def as_dataset( is_local = not is_remote_filesystem(self._fs) if not is_local: raise NotImplementedError(f"Loading a dataset cached in a {type(self._fs).__name__} is not supported.") - if not os.path.exists(self._cache_dir): + if not os.path.exists(self._output_dir): raise AssertionError( - f"Dataset {self.name}: could not find data in {self._cache_dir_root}. Please make sure to call " - "builder.download_and_prepare(), or pass download=True to " + f"Dataset {self.name}: could not find data in {self._output_dir}. Please make sure to call " + "builder.download_and_prepare(), or use " "datasets.load_dataset() before trying to access the Dataset object." ) - logger.debug(f'Constructing Dataset for split {split or ", ".join(self.info.splits)}, from {self._cache_dir}') + logger.debug(f'Constructing Dataset for split {split or ", ".join(self.info.splits)}, from {self._output_dir}') # By default, return all splits if split is None: @@ -955,7 +968,7 @@ def _build_single_dataset( if os.sep in resource_file_name: raise ValueError(f"Resources shouldn't be in a sub-directory: {resource_file_name}") resources_paths = { - resource_name: os.path.join(self._cache_dir, resource_file_name) + resource_name: os.path.join(self._output_dir, resource_file_name) for resource_name, resource_file_name in self._post_processing_resources(split).items() } post_processed = self._post_process(ds, resources_paths) @@ -1014,7 +1027,7 @@ def _as_dataset(self, split: Union[ReadInstruction, Split] = Split.TRAIN, in_mem Returns: `Dataset` """ - cache_dir = self._fs._strip_protocol(self._cache_dir) + cache_dir = self._fs._strip_protocol(self._output_dir) dataset_kwargs = ArrowReader(cache_dir, self.info).read( name=self.name, instructions=split, @@ -1233,7 +1246,7 @@ def _prepare_split(self, split_generator, check_duplicate_keys, file_format=None file_format = file_format or "arrow" suffix = "-00000-of-00001" if file_format == "parquet" else "" fname = f"{self.name}-{split_generator.name}{suffix}.{file_format}" - fpath = path_join(self._cache_dir, fname) + fpath = path_join(self._output_dir, fname) generator = self._generate_examples(**split_generator.gen_kwargs) @@ -1316,7 +1329,7 @@ def _prepare_split(self, split_generator, file_format=None): file_format = file_format or "arrow" suffix = "-00000-of-00001" if file_format == "parquet" else "" fname = f"{self.name}-{split_generator.name}{suffix}.{file_format}" - fpath = path_join(self._cache_dir, fname) + fpath = path_join(self._output_dir, fname) generator = self._generate_tables(**split_generator.gen_kwargs) writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter @@ -1460,10 +1473,10 @@ def _save_info(self): fs = beam.io.filesystems.FileSystems path_join = os.path.join if not is_remote_filesystem(self._fs) else posixpath.join - with fs.create(path_join(self._cache_dir, config.DATASET_INFO_FILENAME)) as f: + with fs.create(path_join(self._output_dir, config.DATASET_INFO_FILENAME)) as f: self.info._dump_info(f) if self.info.license: - with fs.create(path_join(self._cache_dir, config.LICENSE_FILENAME)) as f: + with fs.create(path_join(self._output_dir, config.LICENSE_FILENAME)) as f: self.info._dump_license(f) def _prepare_split(self, split_generator, pipeline, file_format=None): @@ -1474,9 +1487,9 @@ def _prepare_split(self, split_generator, pipeline, file_format=None): file_format = file_format or "arrow" fname = f"{self.name}-{split_name}.{file_format}" path_join = os.path.join if not is_remote_filesystem(self._fs) else posixpath.join - fpath = path_join(self._cache_dir, fname) + fpath = path_join(self._output_dir, fname) beam_writer = BeamWriter( - features=self.info.features, path=fpath, namespace=split_name, cache_dir=self._cache_dir + features=self.info.features, path=fpath, namespace=split_name, cache_dir=self._output_dir ) self._beam_writers[split_name] = beam_writer From 2e8521650dd419e2f1513fed861fb85352729062 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 25 Aug 2022 18:41:57 +0200 Subject: [PATCH 26/51] update tests --- tests/test_builder.py | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/tests/test_builder.py b/tests/test_builder.py index 8e2a13b4ae4..04ea1c87ccc 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -37,7 +37,7 @@ def _split_generators(self, dl_manager): def _prepare_split(self, split_generator, **kwargs): fname = f"{self.name}-{split_generator.name}.arrow" - with ArrowWriter(features=self.info.features, path=os.path.join(self._cache_dir, fname)) as writer: + with ArrowWriter(features=self.info.features, path=os.path.join(self._output_dir, fname)) as writer: writer.write_batch({"text": ["foo"] * 100}) num_examples, num_bytes = writer.finalize() split_generator.split_info.num_examples = num_examples @@ -915,29 +915,25 @@ def test_builder_config_version(builder_class, kwargs, tmp_path): assert builder.config.version == "2.0.0" -def test_builder_with_filesystem(mockfs): - builder = DummyGeneratorBasedBuilder(cache_dir="mock://", storage_options=mockfs.storage_options) - assert builder.cache_dir.startswith("mock://") +def test_builder_with_filesystem_download_and_prepare(tmp_path, mockfs): + builder = DummyGeneratorBasedBuilder(cache_dir=tmp_path) + builder.download_and_prepare("mock://my_dataset", storage_options=mockfs.storage_options) + assert builder._output_dir.startswith("mock://my_dataset") assert is_local_path(builder._cache_downloaded_dir) assert isinstance(builder._fs, type(mockfs)) assert builder._fs.storage_options == mockfs.storage_options + assert mockfs.exists("my_dataset/dataset_info.json") + assert mockfs.exists(f"my_dataset/{builder.name}-train.arrow") + assert not mockfs.exists("my_dataset.incomplete") -def test_builder_with_filesystem_download_and_prepare(mockfs): - builder = DummyGeneratorBasedBuilder(cache_dir="mock://", storage_options=mockfs.storage_options) - builder.download_and_prepare() - assert mockfs.exists(f"{builder.name}/default/0.0.0/dataset_info.json") - assert mockfs.exists(f"{builder.name}/default/0.0.0/{builder.name}-train.arrow") - assert not mockfs.exists(f"{builder.name}/default/0.0.0.incomplete") - - -def test_builder_with_filesystem_download_and_prepare_reload(mockfs, caplog): - builder = DummyGeneratorBasedBuilder(cache_dir="mock://", storage_options=mockfs.storage_options) - mockfs.makedirs(f"{builder.name}/default/0.0.0") - DatasetInfo().write_to_directory(f"{builder.name}/default/0.0.0", fs=mockfs) - mockfs.touch(f"{builder.name}/default/0.0.0/{builder.name}-train.arrow") +def test_builder_with_filesystem_download_and_prepare_reload(tmp_path, mockfs, caplog): + builder = DummyGeneratorBasedBuilder(cache_dir=tmp_path) + mockfs.makedirs("my_dataset") + DatasetInfo().write_to_directory("my_dataset", fs=mockfs) + mockfs.touch(f"my_dataset/{builder.name}-train.arrow") caplog.clear() - builder.download_and_prepare() + builder.download_and_prepare("mock://my_dataset", storage_options=mockfs.storage_options) assert "Found cached dataset" in caplog.text From ba167dbd494a4df27138edb68e780553c36bff58 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 25 Aug 2022 18:42:01 +0200 Subject: [PATCH 27/51] update docs --- docs/source/filesystems.mdx | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/docs/source/filesystems.mdx b/docs/source/filesystems.mdx index 2a41d7bc7be..583353f0bc4 100644 --- a/docs/source/filesystems.mdx +++ b/docs/source/filesystems.mdx @@ -105,32 +105,36 @@ Otherwise, include your `aws_access_key_id` and `aws_secret_access_key` whenever ### Load datasets into a cloud storage -You can load and cache a dataset into your cloud storage by specifying a remote `cache_dir` in `load_dataset`. +You can load a dataset into your cloud storage by specifying a remote `output_dir` in `download_and_prepare`. Don't forget to use the previously defined `storage_options` containing your credentials to write into a private cloud storage. +The `download_and_prepare` method works in two steps: +1. it first downloads the raw data files (if any) in your local cache. You can set your cache directory by passing `cache_dir` to [`load_dataset_builder`] +2. then it generates the dataset in Arrow or Parquet format in your cloud storage by iterating over the raw data files. + Load a dataset from the Hugging Face Hub (see [how to load from the Hugging Face Hub](./loading#hugging-face-hub)): ```py ->>> cache_dir = "s3://my-bucket/datasets-cache" ->>> builder = load_dataset_builder("imdb", cache_dir=cache_dir, storage_options=storage_options) ->>> builder.download_and_prepare(file_format="parquet") +>>> output_dir = "s3://my-bucket/imdb" +>>> builder = load_dataset_builder("imdb") +>>> builder.download_and_prepare(output_dir, storage_options=storage_options, file_format="parquet") ``` Load a dataset using a loading script (see [how to load a local loading script](./loading#local-loading-script)): ```py ->>> cache_dir = "s3://my-bucket/datasets-cache" ->>> builder = load_dataset_builder("path/to/local/loading_script/loading_script.py", cache_dir=cache_dir, storage_options=storage_options) ->>> builder.download_and_prepare(file_format="parquet") +>>> output_dir = "s3://my-bucket/imdb" +>>> builder = load_dataset_builder("path/to/local/loading_script/loading_script.py") +>>> builder.download_and_prepare(output_dir, storage_options=storage_options, file_format="parquet") ``` Load your own data files (see [how to load local and remote files](./loading#local-and-remote-files)): ```py >>> data_files = {"train": ["path/to/train.csv"]} ->>> cache_dir = "s3://my-bucket/datasets-cache" ->>> builder = load_dataset_builder("csv", data_files=data_files, cache_dir=cache_dir, storage_options=storage_options) ->>> builder.download_and_prepare(file_format="parquet") +>>> output_dir = "s3://my-bucket/imdb" +>>> builder = load_dataset_builder("csv", data_files=data_files) +>>> builder.download_and_prepare(output_dir, storage_options=storage_options, file_format="parquet") ``` It is highly recommended to save the files as compressed Parquet files to optimize I/O by specifying `file_format="parquet"`. @@ -147,12 +151,12 @@ Therefore you can load a dataset saved as sharded Parquet files in Dask with ```py import dask.dataframe as dd -df = dd.read_parquet(builder.cache_dir, storage_options=storage_options) +df = dd.read_parquet(output_dir, storage_options=storage_options) # or if your dataset is split into train/valid/test -df_train = dd.read_parquet(builder.cache_dir + f"/{builder.name}-train-*.parquet", storage_options=storage_options) -df_valid = dd.read_parquet(builder.cache_dir + f"/{builder.name}-validation-*.parquet", storage_options=storage_options) -df_test = dd.read_parquet(builder.cache_dir + f"/{builder.name}-test-*.parquet", storage_options=storage_options) +df_train = dd.read_parquet(output_dir + f"/{builder.name}-train-*.parquet", storage_options=storage_options) +df_valid = dd.read_parquet(output_dir + f"/{builder.name}-validation-*.parquet", storage_options=storage_options) +df_test = dd.read_parquet(output_dir + f"/{builder.name}-test-*.parquet", storage_options=storage_options) ``` You can find more about dask dataframes in their [documentation](https://docs.dask.org/en/stable/dataframe.html). From c9b1ca9435536a22059caf931d59ab517b503d75 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 25 Aug 2022 18:54:00 +0200 Subject: [PATCH 28/51] docs --- docs/source/filesystems.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/filesystems.mdx b/docs/source/filesystems.mdx index cdd3dd7226b..6614cf55cfa 100644 --- a/docs/source/filesystems.mdx +++ b/docs/source/filesystems.mdx @@ -143,7 +143,7 @@ Otherwize the dataset is saved as an uncompressed Arrow file. You can also specify the size of the Parquet shard using `max_shard_size` (default is 500MB): ```py ->>> builder.download_and_prepare(file_format="parquet", max_shard_size="1GB") +>>> builder.download_and_prepare(output_dir, storage_options=storage_options, file_format="parquet", max_shard_size="1GB") ``` #### Dask From a9379f8ab19ac59d85efc6f97ece435f20f88726 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 25 Aug 2022 18:57:05 +0200 Subject: [PATCH 29/51] fix tests --- src/datasets/load.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/datasets/load.py b/src/datasets/load.py index 5c5dae0da08..a7cb698b727 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -1464,7 +1464,6 @@ def load_dataset_builder( download_mode: Optional[DownloadMode] = None, revision: Optional[Union[str, Version]] = None, use_auth_token: Optional[Union[bool, str]] = None, - storage_options: Optional[dict] = None, **config_kwargs, ) -> DatasetBuilder: """Load a dataset builder from the Hugging Face Hub, or a local dataset. A dataset builder can be used to inspect general information that is required to build a dataset (cache directory, config, dataset info, etc.) @@ -1518,7 +1517,6 @@ def load_dataset_builder( You can specify a different version that the default "main" by using a commit sha or a git tag of the dataset repository. use_auth_token (``str`` or :obj:`bool`, optional): Optional string or boolean to use as Bearer token for remote files on the Datasets Hub. If True, will get token from `"~/.huggingface"`. - storage_options (:obj:`dict`, optional): Key/value pairs to be passed on to the caching file-system backend, if any. **config_kwargs (additional keyword arguments): Keyword arguments to be passed to the :class:`BuilderConfig` and used in the :class:`DatasetBuilder`. @@ -1580,7 +1578,6 @@ def load_dataset_builder( hash=hash, features=features, use_auth_token=use_auth_token, - storage_options=storage_options, **builder_kwargs, **config_kwargs, ) From ec94a4b9f24aa2082062c8bca0f4898bf586c7ae Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 25 Aug 2022 19:08:06 +0200 Subject: [PATCH 30/51] fix tests --- tests/test_load.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/tests/test_load.py b/tests/test_load.py index 7c49a2ef084..4436f8bd88c 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -691,15 +691,6 @@ def test_load_dataset_builder_fail(): datasets.load_dataset_builder("blabla") -def test_load_dataset_builder_with_filesystem(dataset_loading_script_dir, data_dir, mockfs): - builder = datasets.load_dataset_builder( - dataset_loading_script_dir, data_dir=data_dir, cache_dir="mock://", storage_options=mockfs.storage_options - ) - assert builder.cache_dir.startswith("mock://") - assert isinstance(builder._fs, type(mockfs)) - assert builder._fs.storage_options == mockfs.storage_options - - @pytest.mark.parametrize("keep_in_memory", [False, True]) def test_load_dataset_local(dataset_loading_script_dir, data_dir, keep_in_memory, caplog): with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase(): From f47871aca0344d0ce66cee0c67b287486604260b Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Fri, 26 Aug 2022 12:16:37 +0200 Subject: [PATCH 31/51] fix output parent dir creattion --- src/datasets/builder.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index a5a200be5e0..c12ff19cb50 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -645,12 +645,13 @@ def download_and_prepare( try_from_hf_gcs = False self.dl_manager = dl_manager - # Prevent parallel disk operations + # Prevent parallel local disk operations if is_local: + # Create parent directory of the output_dir to put the lock file in there + Path(self._output_dir).parent.mkdir(parents=True, exist_ok=True) lock_path = self._output_dir + "_builder.lock" # File locking only with local paths; no file locking on GCS or S3 - self._fs.makedirs(os.path.dirname(self._output_dir), exist_ok=True) with FileLock(lock_path) if is_local else contextlib.nullcontext(): # Check if the data already exists From 460e1a6fb2f108f55f54ff703b961fb086379a7c Mon Sep 17 00:00:00 2001 From: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> Date: Fri, 26 Aug 2022 18:10:51 +0200 Subject: [PATCH 32/51] Apply suggestions from code review Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> --- docs/source/filesystems.mdx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/filesystems.mdx b/docs/source/filesystems.mdx index 583353f0bc4..333c167ca16 100644 --- a/docs/source/filesystems.mdx +++ b/docs/source/filesystems.mdx @@ -31,7 +31,7 @@ To use an anonymous connection, use `anon=True`. Otherwise, include your `aws_access_key_id` and `aws_secret_access_key` whenever you are interacting with a private S3 bucket. ```py ->>> storage_options = {"anon": True} # for anynonous connection +>>> storage_options = {"anon": True} # for anonymous connection # or use your credentials >>> storage_options = {"key": aws_access_key_id, "secret": aws_secret_access_key} # for private buckets # or use a botocore session @@ -138,7 +138,7 @@ Load your own data files (see [how to load local and remote files](./loading#loc ``` It is highly recommended to save the files as compressed Parquet files to optimize I/O by specifying `file_format="parquet"`. -Otherwize the dataset is saved as an uncompressed Arrow file. +Otherwise the dataset is saved as an uncompressed Arrow file. #### Dask From 88daa8a8e9d96a65e81a2587571cc44e06f80f8f Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Fri, 26 Aug 2022 18:26:27 +0200 Subject: [PATCH 33/51] revert changes for remote cache_dir --- src/datasets/builder.py | 129 ++++++++++++++++++++++------------------ 1 file changed, 71 insertions(+), 58 deletions(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index c12ff19cb50..8d517cf5206 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -31,6 +31,7 @@ from typing import Dict, Mapping, Optional, Tuple, Union import fsspec +from hamcrest import is_ from . import config, utils from .arrow_dataset import Dataset @@ -58,7 +59,7 @@ from .splits import Split, SplitDict, SplitGenerator from .streaming import extend_dataset_builder_for_streaming from .utils import logging -from .utils.file_utils import cached_path +from .utils.file_utils import cached_path, is_remote_url from .utils.filelock import FileLock from .utils.info_utils import get_size_checksum_dict, verify_checksums, verify_splits from .utils.py_utils import ( @@ -318,26 +319,36 @@ def __init__( self.info.features = features # Prepare data dirs: - self._cache_dir_root = str(cache_dir) or os.path.expanduser(config.HF_DATASETS_CACHE) - self._cache_dir = self._build_cache_dir() + # cache_dir can be a remote bucket on GCS or S3 (when using BeamBasedBuilder for distributed data processing) + self._cache_dir_root = str(cache_dir or config.HF_DATASETS_CACHE) + self._cache_dir_root = ( + self._cache_dir_root if is_remote_url(self._cache_dir_root) else os.path.expanduser(self._cache_dir_root) + ) + path_join = posixpath.join if is_remote_url(self._cache_dir_root) else os.path.join self._cache_downloaded_dir = ( - os.path.join(self._cache_dir_root, config.DOWNLOADED_DATASETS_DIR) + path_join(self._cache_dir_root, config.DOWNLOADED_DATASETS_DIR) if cache_dir - else os.path.expanduser(config.DOWNLOADED_DATASETS_PATH) + else str(config.DOWNLOADED_DATASETS_PATH) ) - - os.makedirs(self._cache_dir_root, exist_ok=True) - lock_path = os.path.join(self._cache_dir_root, self._cache_dir.replace(os.sep, "_") + ".lock") - with FileLock(lock_path): - if os.path.exists(self._cache_dir): # check if data exist - if len(os.listdir(self._cache_dir)) > 0: - logger.info("Overwrite dataset info from restored data version.") - self.info = DatasetInfo.from_directory(self._cache_dir) - else: # dir exists but no data, remove the empty dir as data aren't available anymore - logger.warning( - f"Old caching folder {self._cache_dir} for dataset {self.name} exists but not data were found. Removing it. " - ) - os.rmdir(self._cache_dir) + self._cache_downloaded_dir = ( + self._cache_downloaded_dir + if is_remote_url(self._cache_downloaded_dir) + else os.path.expanduser(self._cache_downloaded_dir) + ) + self._cache_dir = self._build_cache_dir() + if not is_remote_url(self._cache_dir_root): + os.makedirs(self._cache_dir_root, exist_ok=True) + lock_path = os.path.join(self._cache_dir_root, self._cache_dir.replace(os.sep, "_") + ".lock") + with FileLock(lock_path): + if os.path.exists(self._cache_dir): # check if data exist + if len(os.listdir(self._cache_dir)) > 0: + logger.info("Overwrite dataset info from restored data version.") + self.info = DatasetInfo.from_directory(self._cache_dir) + else: # dir exists but no data, remove the empty dir as data aren't available anymore + logger.warning( + f"Old caching folder {self._cache_dir} for dataset {self.name} exists but not data were found. Removing it. " + ) + os.rmdir(self._cache_dir) # Store in the cache by default unless the user specifies a custom output_dir to download_and_prepare self._output_dir = self._cache_dir @@ -487,7 +498,7 @@ def builder_configs(cls): def cache_dir(self): return self._cache_dir - def _relative_data_dir(self, with_version=True, with_hash=True) -> str: + def _relative_data_dir(self, with_version=True, with_hash=True, is_local=True) -> str: """Relative path of this dataset in cache_dir: Will be: self.name/self.config.version/self.hash/ @@ -499,21 +510,28 @@ def _relative_data_dir(self, with_version=True, with_hash=True) -> str: builder_data_dir = self.name if namespace is None else f"{namespace}___{self.name}" builder_config = self.config hash = self.hash + path_join = os.path.join if is_local else posixpath.join if builder_config: # use the enriched name instead of the name to make it unique - builder_data_dir = os.path.join(builder_data_dir, self.config_id) + builder_data_dir = path_join(builder_data_dir, self.config_id) if with_version: - builder_data_dir = os.path.join(builder_data_dir, str(self.config.version)) + builder_data_dir = path_join(builder_data_dir, str(self.config.version)) if with_hash and hash and isinstance(hash, str): - builder_data_dir = os.path.join(builder_data_dir, hash) + builder_data_dir = path_join(builder_data_dir, hash) return builder_data_dir def _build_cache_dir(self): """Return the data directory for the current version.""" - builder_data_dir = os.path.join(self._cache_dir_root, self._relative_data_dir(with_version=False)) - version_data_dir = os.path.join(self._cache_dir_root, self._relative_data_dir(with_version=True)) + is_local = not is_remote_url(self._cache_dir_root) + path_join = os.path.join if is_local else posixpath.join + builder_data_dir = path_join( + self._cache_dir_root, self._relative_data_dir(with_version=False, is_local=is_local) + ) + version_data_dir = path_join( + self._cache_dir_root, self._relative_data_dir(with_version=True, is_local=is_local) + ) - def _other_versions(): + def _other_versions_on_disk(): """Returns previous versions on disk.""" if not os.path.exists(builder_data_dir): return [] @@ -528,16 +546,17 @@ def _other_versions(): return version_dirnames # Check and warn if other versions exist - version_dirs = _other_versions() - if version_dirs: - other_version = version_dirs[0][0] - if other_version != self.config.version: - warn_msg = ( - f"Found a different version {str(other_version)} of dataset {self.name} in " - f"cache_dir {self._cache_dir_root}. Using currently defined version " - f"{str(self.config.version)}." - ) - logger.warning(warn_msg) + if not is_remote_url(builder_data_dir): + version_dirs = _other_versions_on_disk() + if version_dirs: + other_version = version_dirs[0][0] + if other_version != self.config.version: + warn_msg = ( + f"Found a different version {str(other_version)} of dataset {self.name} in " + f"cache_dir {self._cache_dir_root}. Using currently defined version " + f"{str(self.config.version)}." + ) + logger.warning(warn_msg) return version_data_dir @@ -612,15 +631,6 @@ def download_and_prepare( if file_format is not None and file_format not in ["arrow", "parquet"]: raise ValueError(f"Unsupported file_format: {file_format}. Expected 'arrow' or 'parquet'") - if self._fs._strip_protocol(self._output_dir) == "": - # We don't support the root directory, because it has no dirname, - # and we need a dirname to use a .incomplete directory - # when the dataset is being written - raise RuntimeError( - f"Unable to download and prepare the dataset at the root {self._output_dir}. " - f"Please specify a subdirectory, e.g. '{self._output_dir + self.name}'" - ) - if dl_manager is None: if download_config is None: download_config = DownloadConfig( @@ -677,20 +687,23 @@ def download_and_prepare( @contextlib.contextmanager def incomplete_dir(dirname): """Create temporary dir for dirname and rename on exit.""" - tmp_dir = dirname + ".incomplete" - self._fs.makedirs(tmp_dir, exist_ok=True) - try: - yield tmp_dir - if self._fs.isdir(dirname): - self._fs.rm(dirname, recursive=True) - if is_local: - # LocalFileSystem.mv does copy + rm, it is more efficient to simply rename a local directory - shutil.move(self._fs._strip_protocol(tmp_dir), self._fs._strip_protocol(dirname)) - else: - self._fs.mv(tmp_dir, dirname, recursive=True) - finally: - if self._fs.exists(tmp_dir): - self._fs.rm(tmp_dir, recursive=True) + if not is_local: + yield dirname + else: + tmp_dir = dirname + ".incomplete" + self._fs.makedirs(tmp_dir, exist_ok=True) + try: + yield tmp_dir + if self._fs.isdir(dirname): + self._fs.rm(dirname, recursive=True) + if is_local: + # LocalFileSystem.mv does copy + rm, it is more efficient to simply rename a local directory + shutil.move(self._fs._strip_protocol(tmp_dir), self._fs._strip_protocol(dirname)) + else: + self._fs.mv(tmp_dir, dirname, recursive=True) + finally: + if self._fs.exists(tmp_dir): + self._fs.rm(tmp_dir, recursive=True) # Print is intentional: we want this to always go to stdout so user has # information needed to cancel download/preparation if needed. From fdf7252c3f26aa98409588f971b435dd1892178e Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Fri, 26 Aug 2022 18:30:13 +0200 Subject: [PATCH 34/51] fix wording in the docs: load -> download and prepare --- docs/source/filesystems.mdx | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/source/filesystems.mdx b/docs/source/filesystems.mdx index 333c167ca16..f2da377568f 100644 --- a/docs/source/filesystems.mdx +++ b/docs/source/filesystems.mdx @@ -40,7 +40,7 @@ Otherwise, include your `aws_access_key_id` and `aws_secret_access_key` whenever >>> storage_options = {"session": s3_session} ``` -3. Load your FileSystem instance +3. Create your FileSystem instance ```py >>> import s3fs @@ -67,7 +67,7 @@ Otherwise, include your `aws_access_key_id` and `aws_secret_access_key` whenever >>> storage_options={"project": "my-google-project", "token": TOKEN} ``` -3. Load your FileSystem instance +3. Create your FileSystem instance ```py >>> import gcsfs @@ -94,7 +94,7 @@ Otherwise, include your `aws_access_key_id` and `aws_secret_access_key` whenever >>> storage_options={"tenant_id": TENANT_ID, "client_id": CLIENT_ID, "client_secret": CLIENT_SECRET} ``` -3. Load your FileSystem instance +3. Create your FileSystem instance ```py >>> import adlfs @@ -103,16 +103,16 @@ Otherwise, include your `aws_access_key_id` and `aws_secret_access_key` whenever ## Load and Save your datasets using your cloud storage FileSystem -### Load datasets into a cloud storage +### Download and prepare a dataset into a cloud storage -You can load a dataset into your cloud storage by specifying a remote `output_dir` in `download_and_prepare`. +You can download and prepare a dataset into your cloud storage by specifying a remote `output_dir` in `download_and_prepare`. Don't forget to use the previously defined `storage_options` containing your credentials to write into a private cloud storage. The `download_and_prepare` method works in two steps: 1. it first downloads the raw data files (if any) in your local cache. You can set your cache directory by passing `cache_dir` to [`load_dataset_builder`] 2. then it generates the dataset in Arrow or Parquet format in your cloud storage by iterating over the raw data files. -Load a dataset from the Hugging Face Hub (see [how to load from the Hugging Face Hub](./loading#hugging-face-hub)): +Load a dataset builder from the Hugging Face Hub (see [how to load from the Hugging Face Hub](./loading#hugging-face-hub)): ```py >>> output_dir = "s3://my-bucket/imdb" @@ -120,7 +120,7 @@ Load a dataset from the Hugging Face Hub (see [how to load from the Hugging Face >>> builder.download_and_prepare(output_dir, storage_options=storage_options, file_format="parquet") ``` -Load a dataset using a loading script (see [how to load a local loading script](./loading#local-loading-script)): +Load a dataset builder using a loading script (see [how to load a local loading script](./loading#local-loading-script)): ```py >>> output_dir = "s3://my-bucket/imdb" @@ -128,7 +128,7 @@ Load a dataset using a loading script (see [how to load a local loading script]( >>> builder.download_and_prepare(output_dir, storage_options=storage_options, file_format="parquet") ``` -Load your own data files (see [how to load local and remote files](./loading#local-and-remote-files)): +Use your own data files (see [how to load local and remote files](./loading#local-and-remote-files)): ```py >>> data_files = {"train": ["path/to/train.csv"]} From 22aaf7b613127fbc89e527f6419813707b6e724e Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Fri, 26 Aug 2022 18:32:46 +0200 Subject: [PATCH 35/51] style --- src/datasets/builder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 8d517cf5206..afb0613c50b 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -31,7 +31,6 @@ from typing import Dict, Mapping, Optional, Tuple, Union import fsspec -from hamcrest import is_ from . import config, utils from .arrow_dataset import Dataset From c051b31ed3b5237f1ceb060e8a4c0b5bed1b8235 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Fri, 26 Aug 2022 18:48:15 +0200 Subject: [PATCH 36/51] fix --- src/datasets/builder.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index afb0613c50b..b10ea9785fe 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -687,6 +687,7 @@ def download_and_prepare( def incomplete_dir(dirname): """Create temporary dir for dirname and rename on exit.""" if not is_local: + self._fs.makedirs(dirname, exist_ok=True) yield dirname else: tmp_dir = dirname + ".incomplete" From e0a7742569491ecabd6e695ccd3ce6704950dc13 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Fri, 26 Aug 2022 19:31:11 +0200 Subject: [PATCH 37/51] simplify incomplete_dir --- src/datasets/builder.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index b10ea9785fe..2f0290d4254 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -691,19 +691,16 @@ def incomplete_dir(dirname): yield dirname else: tmp_dir = dirname + ".incomplete" - self._fs.makedirs(tmp_dir, exist_ok=True) + os.makedirs(tmp_dir, exist_ok=True) try: yield tmp_dir - if self._fs.isdir(dirname): - self._fs.rm(dirname, recursive=True) - if is_local: - # LocalFileSystem.mv does copy + rm, it is more efficient to simply rename a local directory - shutil.move(self._fs._strip_protocol(tmp_dir), self._fs._strip_protocol(dirname)) - else: - self._fs.mv(tmp_dir, dirname, recursive=True) + if os.path.isdir(dirname): + shutil.rmtree(dirname, recursive=True) + # LocalFileSystem.mv does copy + rm, it is more efficient to simply rename a local directory + shutil.move(tmp_dir, dirname) finally: - if self._fs.exists(tmp_dir): - self._fs.rm(tmp_dir, recursive=True) + if os.path.exists(tmp_dir): + shutil.rmtree(tmp_dir, recursive=True) # Print is intentional: we want this to always go to stdout so user has # information needed to cancel download/preparation if needed. From 53d46cc87366505b157bb73c4a842f069135a67c Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Mon, 29 Aug 2022 11:43:09 +0200 Subject: [PATCH 38/51] fix tests --- src/datasets/builder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 2f0290d4254..078f6771488 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -695,12 +695,12 @@ def incomplete_dir(dirname): try: yield tmp_dir if os.path.isdir(dirname): - shutil.rmtree(dirname, recursive=True) + shutil.rmtree(dirname) # LocalFileSystem.mv does copy + rm, it is more efficient to simply rename a local directory shutil.move(tmp_dir, dirname) finally: if os.path.exists(tmp_dir): - shutil.rmtree(tmp_dir, recursive=True) + shutil.rmtree(tmp_dir) # Print is intentional: we want this to always go to stdout so user has # information needed to cancel download/preparation if needed. From 606951f4755720877db42fd67f9c2e98c00bae84 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Mon, 5 Sep 2022 18:58:28 +0200 Subject: [PATCH 39/51] albert's comments --- src/datasets/builder.py | 6 ++++++ src/datasets/info.py | 8 ++++++++ 2 files changed, 14 insertions(+) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 078f6771488..955078dff27 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -595,6 +595,8 @@ def download_and_prepare( Args: output_dir (:obj:`str`, optional): output directory for the dataset. Default to this builder's ``cache_dir``, which is inside ~/.cache/huggingface/datasets by default. + + download_config (:class:`DownloadConfig`, optional): specific download configuration parameters. download_mode (:class:`DownloadMode`, optional): select the download/generate mode - Default to ``REUSE_DATASET_IF_EXISTS`` ignore_verifications (:obj:`bool`): Ignore the verifications of the downloaded/processed dataset information (checksums/size/splits/...) @@ -606,7 +608,11 @@ def download_and_prepare( If True, will get token from ~/.huggingface. file_format (:obj:`str`, optional): format of the data files in which the dataset will be written. Supported formats: "arrow", "parquet". Default to "arrow" format. + + storage_options (:obj:`dict`, *optional*): Key/value pairs to be passed on to the caching file-system backend, if any. + + **download_and_prepare_kwargs (additional keyword arguments): Keyword arguments. Example: diff --git a/src/datasets/info.py b/src/datasets/info.py index 6969e1afc79..2b5acfdde92 100644 --- a/src/datasets/info.py +++ b/src/datasets/info.py @@ -186,6 +186,10 @@ def write_to_directory(self, dataset_info_dir, pretty_print=False, fs=None): Args: dataset_info_dir (str): Destination directory. pretty_print (bool, default ``False``): If True, the JSON will be pretty-printed with the indent level of 4. + fs (``fsspec.spec.AbstractFileSystem``, optional, defaults ``None``): + Instance of the remote filesystem used to download the files from. + + Example: @@ -255,6 +259,10 @@ def from_directory(cls, dataset_info_dir: str, fs=None) -> "DatasetInfo": Args: dataset_info_dir (`str`): The directory containing the metadata file. This should be the root directory of a specific dataset version. + fs (``fsspec.spec.AbstractFileSystem``, optional, defaults ``None``): + Instance of the remote filesystem used to download the files from. + + Example: From bd94afbd8f32dc267741212f3bb173b5448ab805 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Mon, 5 Sep 2022 19:11:45 +0200 Subject: [PATCH 40/51] set arrow to default --- src/datasets/builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 87a0fd04a1a..9e0aaaf6472 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -596,7 +596,7 @@ def download_and_prepare( dl_manager: Optional[DownloadManager] = None, base_path: Optional[str] = None, use_auth_token: Optional[Union[bool, str]] = None, - file_format: Optional[str] = None, + file_format: str = "arrow", max_shard_size: Optional[int] = None, storage_options: Optional[dict] = None, **download_and_prepare_kwargs, From 0ea79c1d379b59b5d2f3e1c5b78b7e57b1688166 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Mon, 5 Sep 2022 19:59:10 +0200 Subject: [PATCH 41/51] style --- src/datasets/builder.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 80ba4b7dbe2..3f455835c1a 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -1621,10 +1621,7 @@ def _download_and_prepare(self, dl_manager, verify_infos, **prepare_splits_kwarg options=beam_options, ) super()._download_and_prepare( - dl_manager, - verify_infos=False, - pipeline=pipeline, - **prepare_splits_kwargs + dl_manager, verify_infos=False, pipeline=pipeline, **prepare_splits_kwargs ) # TODO handle verify_infos in beam datasets # Run pipeline pipeline_results = pipeline.run() From 0a950fdce79aae892a462a5f5338f2a8cdf3c07a Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Tue, 6 Sep 2022 15:31:44 +0200 Subject: [PATCH 42/51] add config.MAX_SHARD_SIZE --- src/datasets/arrow_dataset.py | 6 +++--- src/datasets/builder.py | 20 +++++++++--------- src/datasets/config.py | 3 +++ src/datasets/dataset_dict.py | 2 +- src/datasets/utils/py_utils.py | 2 +- tests/test_builder.py | 37 ++++++++++++++++++++++++++++++++++ tests/test_upstream_hub.py | 29 ++++++++++++++++++++++++++ 7 files changed, 85 insertions(+), 14 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index bf916558c36..fab285f54d8 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -4063,7 +4063,7 @@ def _push_parquet_shards_to_hub( private: Optional[bool] = False, token: Optional[str] = None, branch: Optional[str] = None, - max_shard_size: Union[int, str] = "500MB", + max_shard_size: Optional[Union[int, str]] = None, embed_external_files: bool = True, ) -> Tuple[str, str, int, int]: """Pushes the dataset to the hub. @@ -4109,7 +4109,7 @@ def _push_parquet_shards_to_hub( >>> dataset.push_to_hub("/", split="evaluation") ``` """ - max_shard_size = convert_file_size_to_int(max_shard_size) + max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE) api = HfApi(endpoint=config.HF_ENDPOINT) token = token if token is not None else HfFolder.get_token() @@ -4278,7 +4278,7 @@ def push_to_hub( private: Optional[bool] = False, token: Optional[str] = None, branch: Optional[str] = None, - max_shard_size: Union[int, str] = "500MB", + max_shard_size: Optional[Union[int, str]] = None, shard_size: Optional[int] = "deprecated", embed_external_files: bool = True, ): diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 3f455835c1a..39f0fe28ec1 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -622,7 +622,7 @@ def download_and_prepare( max_shard_size (:obj:`Union[str, int]`, optional): Maximum number of bytes written per shard. - Supports only the "parquet" format with a default of "500MB". The size is based on uncompressed data size, + Only available for the "parquet" format with a default of "500MB". The size is based on uncompressed data size, so in practice your shard files may be smaller than `max_shard_size` thanks to Parquet compression. @@ -654,8 +654,8 @@ def download_and_prepare( ```py >>> from datasets import load_dataset_builder >>> storage_options = {"key": aws_access_key_id, "secret": aws_secret_access_key} - >>> builder = load_dataset_builder("rotten_tomatoes", cache_dir="s3://my-bucket/datasets-cache", storage_options=storage_options) - >>> ds = builder.download_and_prepare(file_format="parquet") + >>> builder = load_dataset_builder("rotten_tomatoes") + >>> ds = builder.download_and_prepare("s3://my-bucket/my_rotten_tomatoes", storage_options=storage_options, file_format="parquet") ``` """ self._output_dir = output_dir if output_dir is not None else self._cache_dir @@ -1327,12 +1327,13 @@ def _prepare_split( is_local = not is_remote_filesystem(self._fs) path_join = os.path.join if is_local else posixpath.join - if max_shard_size is not None: - max_shard_size = convert_file_size_to_int(max_shard_size) - if file_format == "arrow": + if file_format == "arrow": + if max_shard_size is not None: raise NotImplementedError( "Writing sharded arrow files is not supported. Please don't use max_shard_size or use parquet." ) + else: + max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE) if self.info.splits is not None: split_info = self.info.splits[split_generator.name] @@ -1454,12 +1455,13 @@ def _prepare_split( is_local = not is_remote_filesystem(self._fs) path_join = os.path.join if is_local else posixpath.join - if max_shard_size is not None: - if file_format == "arrow": + if file_format == "arrow": + if max_shard_size is not None: raise NotImplementedError( "Writing sharded arrow files is not supported. Please don't use max_shard_size or use parquet." ) - max_shard_size = convert_file_size_to_int(max_shard_size or "500MB") + else: + max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE) suffix = "-SSSSS-of-NNNNN" if file_format == "parquet" else "" fname = f"{self.name}-{split_generator.name}{suffix}.{file_format}" diff --git a/src/datasets/config.py b/src/datasets/config.py index 87b321441cc..2bd5419cbe3 100644 --- a/src/datasets/config.py +++ b/src/datasets/config.py @@ -172,6 +172,9 @@ # For big tables, we write them on disk instead MAX_TABLE_NBYTES_FOR_PICKLING = 4 << 30 +# Max shard size in bytes (e.g. to shard parquet datasets in push_to_hub or download_and_prepare) +MAX_SHARD_SIZE = "500MB" + # Offline mode HF_DATASETS_OFFLINE = os.environ.get("HF_DATASETS_OFFLINE", "AUTO").upper() in ENV_VARS_TRUE_VALUES diff --git a/src/datasets/dataset_dict.py b/src/datasets/dataset_dict.py index 16ca8dfe208..6ab8a2163c1 100644 --- a/src/datasets/dataset_dict.py +++ b/src/datasets/dataset_dict.py @@ -1286,7 +1286,7 @@ def push_to_hub( private: Optional[bool] = False, token: Optional[str] = None, branch: Optional[None] = None, - max_shard_size: Union[int, str] = "500MB", + max_shard_size: Optional[Union[int, str]] = None, shard_size: Optional[int] = "deprecated", embed_external_files: bool = True, ): diff --git a/src/datasets/utils/py_utils.py b/src/datasets/utils/py_utils.py index f56b765a7c2..01d0673a2f6 100644 --- a/src/datasets/utils/py_utils.py +++ b/src/datasets/utils/py_utils.py @@ -114,7 +114,7 @@ def convert_file_size_to_int(size: Union[int, str]) -> int: if size.upper().endswith("KB"): int_size = int(size[:-2]) * (10**3) return int_size // 8 if size.endswith("b") else int_size - raise ValueError("`size` is not in a valid format. Use an integer followed by the unit, e.g., '5GB'.") + raise ValueError(f"`size={size}` is not in a valid format. Use an integer followed by the unit, e.g., '5GB'.") def string_to_dict(string: str, pattern: str) -> Dict[str, str]: diff --git a/tests/test_builder.py b/tests/test_builder.py index d7dccfd13e6..c535feb9c11 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -959,6 +959,25 @@ def test_generator_based_builder_download_and_prepare_as_parquet(tmp_path): def test_generator_based_builder_download_and_prepare_as_sharded_parquet(tmp_path): + writer_batch_size = 25 + builder = DummyGeneratorBasedBuilder(cache_dir=tmp_path, writer_batch_size=writer_batch_size) + with patch("datasets.config.MAX_SHARD_SIZE", 1): # one batch per shard + builder.download_and_prepare(file_format="parquet") + expected_num_shards = 100 // writer_batch_size + assert builder.info.splits["train"].num_examples, 100 + parquet_path = os.path.join( + tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-{expected_num_shards:05d}.parquet" + ) + assert os.path.exists(parquet_path) + parquet_files = [ + pq.ParquetFile(parquet_path) + for parquet_path in Path(tmp_path).rglob(f"{builder.name}-train-*-of-{expected_num_shards:05d}.parquet") + ] + assert len(parquet_files) == expected_num_shards + assert sum(parquet_file.metadata.num_rows for parquet_file in parquet_files) == 100 + + +def test_generator_based_builder_download_and_prepare_as_sharded_parquet_with_max_shard_size(tmp_path): writer_batch_size = 25 builder = DummyGeneratorBasedBuilder(cache_dir=tmp_path, writer_batch_size=writer_batch_size) builder.download_and_prepare(file_format="parquet", max_shard_size=1) # one batch per shard @@ -988,6 +1007,24 @@ def test_arrow_based_builder_download_and_prepare_as_parquet(tmp_path): def test_arrow_based_builder_download_and_prepare_as_sharded_parquet(tmp_path): + builder = DummyArrowBasedBuilder(cache_dir=tmp_path) + with patch("datasets.config.MAX_SHARD_SIZE", 1): # one batch per shard + builder.download_and_prepare(file_format="parquet") + expected_num_shards = 10 + assert builder.info.splits["train"].num_examples, 100 + parquet_path = os.path.join( + tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-{expected_num_shards:05d}.parquet" + ) + assert os.path.exists(parquet_path) + parquet_files = [ + pq.ParquetFile(parquet_path) + for parquet_path in Path(tmp_path).rglob(f"{builder.name}-train-*-of-{expected_num_shards:05d}.parquet") + ] + assert len(parquet_files) == expected_num_shards + assert sum(parquet_file.metadata.num_rows for parquet_file in parquet_files) == 100 + + +def test_arrow_based_builder_download_and_prepare_as_sharded_parquet_with_max_shard_size(tmp_path): builder = DummyArrowBasedBuilder(cache_dir=tmp_path) builder.download_and_prepare(file_format="parquet", max_shard_size=1) # one table per shard expected_num_shards = 10 diff --git a/tests/test_upstream_hub.py b/tests/test_upstream_hub.py index 0503a87b08e..5833b367af7 100644 --- a/tests/test_upstream_hub.py +++ b/tests/test_upstream_hub.py @@ -131,6 +131,35 @@ def test_push_dataset_dict_to_hub_multiple_files(self, temporary_repo): local_ds = DatasetDict({"train": ds}) + with temporary_repo(f"{CI_HUB_USER}/test-{int(time.time() * 10e3)}") as ds_name: + with patch("datasets.config.MAX_SHARD_SIZE", "16KB"): + local_ds.push_to_hub(ds_name, token=self._token) + hub_ds = load_dataset(ds_name, download_mode="force_redownload") + + assert local_ds.column_names == hub_ds.column_names + assert list(local_ds["train"].features.keys()) == list(hub_ds["train"].features.keys()) + assert local_ds["train"].features == hub_ds["train"].features + + # Ensure that there are two files on the repository that have the correct name + files = sorted(self._api.list_repo_files(ds_name, repo_type="dataset", token=self._token)) + assert all( + fnmatch.fnmatch(file, expected_file) + for file, expected_file in zip( + files, + [ + ".gitattributes", + "data/train-00000-of-00002-*.parquet", + "data/train-00001-of-00002-*.parquet", + "dataset_infos.json", + ], + ) + ) + + def test_push_dataset_dict_to_hub_multiple_files_with_max_shard_size(self, temporary_repo): + ds = Dataset.from_dict({"x": list(range(1000)), "y": list(range(1000))}) + + local_ds = DatasetDict({"train": ds}) + with temporary_repo(f"{CI_HUB_USER}/test-{int(time.time() * 10e3)}") as ds_name: local_ds.push_to_hub(ds_name, token=self._token, max_shard_size="16KB") hub_ds = load_dataset(ds_name, download_mode="force_redownload") From fbc8fe19df5ecf98f11592a481041e40b1388c4d Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Tue, 6 Sep 2022 15:36:06 +0200 Subject: [PATCH 43/51] nit --- src/datasets/builder.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 39f0fe28ec1..2d9a29984ae 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -645,8 +645,8 @@ def download_and_prepare( ```py >>> from datasets import load_dataset_builder - >>> builder = load_dataset_builder("rotten_tomatoes", cache_dir="path/to/local/datasets-cache") - >>> ds = builder.download_and_prepare(file_format="parquet") + >>> builder = load_dataset_builder("rotten_tomatoes") + >>> ds = builder.download_and_prepare("./output_dir", file_format="parquet") ``` Downdload and prepare the dataset as sharded Parquet files in a cloud storage @@ -1235,7 +1235,7 @@ def _prepare_split( self, split_generator: SplitGenerator, file_format: str = "arrow", - max_shard_size: Union[None, str, int] = None, + max_shard_size: Optional[Union[str, int]] = None, **kwargs, ): """Generate the examples and record them on disk. @@ -1245,8 +1245,8 @@ def _prepare_split( file_format (:obj:`str`, optional): format of the data files in which the dataset will be written. Supported formats: "arrow", "parquet". Default to "arrow" format. max_shard_size (:obj:`Union[str, int]`, optional): Approximate maximum number of bytes written per shard. - Supports only the "parquet" format with a default of "500MB". The size is computed using the uncompressed data, - so in practice your shard files may be smaller than `max_shard_size` thanks to compression. + Only available for the "parquet" format with a default of "500MB". The size is based on uncompressed data size, + so in practice your shard files may be smaller than `max_shard_size` thanks to Parquet compression. **kwargs: Additional kwargs forwarded from _download_and_prepare (ex: beam pipeline) """ @@ -1322,7 +1322,7 @@ def _prepare_split( split_generator: SplitGenerator, check_duplicate_keys: bool, file_format="arrow", - max_shard_size: Union[None, int, str] = None, + max_shard_size: Optional[Union[int, str]] = None, ): is_local = not is_remote_filesystem(self._fs) path_join = os.path.join if is_local else posixpath.join @@ -1450,7 +1450,7 @@ def _generate_tables(self, **kwargs): raise NotImplementedError() def _prepare_split( - self, split_generator: SplitGenerator, file_format: str = "arrow", max_shard_size: Union[None, str, int] = None + self, split_generator: SplitGenerator, file_format: str = "arrow", max_shard_size: Optional[Union[str, int]] = None ): is_local = not is_remote_filesystem(self._fs) path_join = os.path.join if is_local else posixpath.join @@ -1650,7 +1650,7 @@ def _save_info(self): self.info._dump_license(f) def _prepare_split( - self, split_generator, pipeline, file_format="arrow", max_shard_size: Union[None, str, int] = None + self, split_generator, pipeline, file_format="arrow", max_shard_size: Optional[Union[str, int]] = None ): import apache_beam as beam From c88a797b0a8b0dceefc07d991e6d13b7886b1985 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Tue, 6 Sep 2022 15:36:30 +0200 Subject: [PATCH 44/51] style --- src/datasets/builder.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 2d9a29984ae..2c2175765f6 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -1450,7 +1450,10 @@ def _generate_tables(self, **kwargs): raise NotImplementedError() def _prepare_split( - self, split_generator: SplitGenerator, file_format: str = "arrow", max_shard_size: Optional[Union[str, int]] = None + self, + split_generator: SplitGenerator, + file_format: str = "arrow", + max_shard_size: Optional[Union[str, int]] = None, ): is_local = not is_remote_filesystem(self._fs) path_join = os.path.join if is_local else posixpath.join From d59b68f27da94df845bda35cb4b9884e3bb170f5 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Tue, 6 Sep 2022 16:17:36 +0200 Subject: [PATCH 45/51] fix for relative output_dir --- src/datasets/builder.py | 15 ++++++++------- tests/test_builder.py | 29 ++++++++++++++++++++++++++++- 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 2c2175765f6..da09eaba156 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -658,22 +658,23 @@ def download_and_prepare( >>> ds = builder.download_and_prepare("s3://my-bucket/my_rotten_tomatoes", storage_options=storage_options, file_format="parquet") ``` """ - self._output_dir = output_dir if output_dir is not None else self._cache_dir + output_dir = output_dir if output_dir is not None else self._cache_dir # output_dir can be a remote bucket on GCS or S3 (when using BeamBasedBuilder for distributed data processing) - fs_token_paths = fsspec.get_fs_token_paths(self._output_dir, storage_options=storage_options) - self._fs = fs_token_paths[0] + fs_token_paths = fsspec.get_fs_token_paths(output_dir, storage_options=storage_options) + self._fs: fsspec.AbstractFileSystem = fs_token_paths[0] + is_local = not is_remote_filesystem(self._fs) + self._output_dir = self._fs._strip_protocol(fs_token_paths[2][0]) if is_local else fs_token_paths[2][0] download_mode = DownloadMode(download_mode or DownloadMode.REUSE_DATASET_IF_EXISTS) verify_infos = not ignore_verifications base_path = base_path if base_path is not None else self.base_path - is_local = not is_remote_filesystem(self._fs) if file_format is not None and file_format not in ["arrow", "parquet"]: raise ValueError(f"Unsupported file_format: {file_format}. Expected 'arrow' or 'parquet'") if file_format == "arrow" and max_shard_size is not None: raise NotImplementedError( - "Writing sharded arrow files is not supported. Please don't use max_shard_size or use parquet." + "Writing sharded arrow files is not supported. Please don't use max_shard_size or use file_format='paquet'." ) if self._fs._strip_protocol(self._output_dir) == "": @@ -1330,7 +1331,7 @@ def _prepare_split( if file_format == "arrow": if max_shard_size is not None: raise NotImplementedError( - "Writing sharded arrow files is not supported. Please don't use max_shard_size or use parquet." + "Writing sharded arrow files is not supported. Please don't use max_shard_size or use file_format='paquet'." ) else: max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE) @@ -1461,7 +1462,7 @@ def _prepare_split( if file_format == "arrow": if max_shard_size is not None: raise NotImplementedError( - "Writing sharded arrow files is not supported. Please don't use max_shard_size or use parquet." + "Writing sharded arrow files is not supported. Please don't use max_shard_size or use file_format='paquet'." ) else: max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE) diff --git a/tests/test_builder.py b/tests/test_builder.py index c535feb9c11..a3144b3b0ec 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -25,7 +25,13 @@ from datasets.streaming import xjoin from datasets.utils.file_utils import is_local_path -from .utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases, require_beam, require_faiss +from .utils import ( + assert_arrow_memory_doesnt_increase, + assert_arrow_memory_increases, + require_beam, + require_faiss, + set_current_working_directory_to_temp_dir, +) class DummyBuilder(DatasetBuilder): @@ -925,6 +931,27 @@ def test_builder_config_version(builder_class, kwargs, tmp_path): assert builder.config.version == "2.0.0" +def test_builder_download_and_prepare_with_absolute_output_dir(tmp_path): + builder = DummyGeneratorBasedBuilder() + output_dir = str(tmp_path) + builder.download_and_prepare(output_dir) + assert builder._output_dir.startswith(output_dir) + assert os.path.exists(os.path.join(output_dir, "dataset_info.json")) + assert os.path.exists(os.path.join(output_dir, f"{builder.name}-train.arrow")) + assert not os.path.exists(os.path.join(output_dir + ".incomplete")) + + +def test_builder_download_and_prepare_with_relative_output_dir(): + with set_current_working_directory_to_temp_dir(): + builder = DummyGeneratorBasedBuilder() + output_dir = "test-out" + builder.download_and_prepare(output_dir) + assert builder._output_dir.startswith(str(Path(output_dir).resolve())) + assert os.path.exists(os.path.join(output_dir, "dataset_info.json")) + assert os.path.exists(os.path.join(output_dir, f"{builder.name}-train.arrow")) + assert not os.path.exists(os.path.join(output_dir + ".incomplete")) + + def test_builder_with_filesystem_download_and_prepare(tmp_path, mockfs): builder = DummyGeneratorBasedBuilder(cache_dir=tmp_path) builder.download_and_prepare("mock://my_dataset", storage_options=mockfs.storage_options) From 211b38b55875812ef70041f194d1c8e77dd6fc8f Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Tue, 6 Sep 2022 16:18:26 +0200 Subject: [PATCH 46/51] typo --- src/datasets/builder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index da09eaba156..3f873614985 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -674,7 +674,7 @@ def download_and_prepare( if file_format == "arrow" and max_shard_size is not None: raise NotImplementedError( - "Writing sharded arrow files is not supported. Please don't use max_shard_size or use file_format='paquet'." + "Writing sharded arrow files is not supported. Please don't use max_shard_size or use file_format='parquet'." ) if self._fs._strip_protocol(self._output_dir) == "": @@ -1331,7 +1331,7 @@ def _prepare_split( if file_format == "arrow": if max_shard_size is not None: raise NotImplementedError( - "Writing sharded arrow files is not supported. Please don't use max_shard_size or use file_format='paquet'." + "Writing sharded arrow files is not supported. Please don't use max_shard_size or use file_format='parquet'." ) else: max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE) @@ -1462,7 +1462,7 @@ def _prepare_split( if file_format == "arrow": if max_shard_size is not None: raise NotImplementedError( - "Writing sharded arrow files is not supported. Please don't use max_shard_size or use file_format='paquet'." + "Writing sharded arrow files is not supported. Please don't use max_shard_size or use file_format='parquet'." ) else: max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE) From 361e32a7b0c736896a53b581f6775c92271be066 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Tue, 6 Sep 2022 17:18:17 +0200 Subject: [PATCH 47/51] fix test --- src/datasets/builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 3f873614985..19eb6cb4e4c 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -663,7 +663,7 @@ def download_and_prepare( fs_token_paths = fsspec.get_fs_token_paths(output_dir, storage_options=storage_options) self._fs: fsspec.AbstractFileSystem = fs_token_paths[0] is_local = not is_remote_filesystem(self._fs) - self._output_dir = self._fs._strip_protocol(fs_token_paths[2][0]) if is_local else fs_token_paths[2][0] + self._output_dir = fs_token_paths[2][0] if is_local else self._fs.unstrip_protocol(fs_token_paths[2][0]) download_mode = DownloadMode(download_mode or DownloadMode.REUSE_DATASET_IF_EXISTS) verify_infos = not ignore_verifications From aa146eadbf81fd667be64a9d73d9cdb138456ee8 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Fri, 9 Sep 2022 18:21:40 +0200 Subject: [PATCH 48/51] fix test --- tests/test_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_builder.py b/tests/test_builder.py index a3144b3b0ec..01a9a219997 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -946,7 +946,7 @@ def test_builder_download_and_prepare_with_relative_output_dir(): builder = DummyGeneratorBasedBuilder() output_dir = "test-out" builder.download_and_prepare(output_dir) - assert builder._output_dir.startswith(str(Path(output_dir).resolve())) + assert builder._output_dir.startswith(str(Path(output_dir).resolve().as_posix())) assert os.path.exists(os.path.join(output_dir, "dataset_info.json")) assert os.path.exists(os.path.join(output_dir, f"{builder.name}-train.arrow")) assert not os.path.exists(os.path.join(output_dir + ".incomplete")) From 11bd133424071d908b5b8b3ca57fc53c27a272f5 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Mon, 12 Sep 2022 18:59:41 +0200 Subject: [PATCH 49/51] fix win tests --- tests/test_builder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_builder.py b/tests/test_builder.py index 01a9a219997..5cd3c40ad6a 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -935,7 +935,7 @@ def test_builder_download_and_prepare_with_absolute_output_dir(tmp_path): builder = DummyGeneratorBasedBuilder() output_dir = str(tmp_path) builder.download_and_prepare(output_dir) - assert builder._output_dir.startswith(output_dir) + assert builder._output_dir.startswith(tmp_path.resolve().as_posix()) assert os.path.exists(os.path.join(output_dir, "dataset_info.json")) assert os.path.exists(os.path.join(output_dir, f"{builder.name}-train.arrow")) assert not os.path.exists(os.path.join(output_dir + ".incomplete")) @@ -946,7 +946,7 @@ def test_builder_download_and_prepare_with_relative_output_dir(): builder = DummyGeneratorBasedBuilder() output_dir = "test-out" builder.download_and_prepare(output_dir) - assert builder._output_dir.startswith(str(Path(output_dir).resolve().as_posix())) + assert Path(builder._output_dir).resolve().as_posix().startswith(Path(output_dir).resolve().as_posix()) assert os.path.exists(os.path.join(output_dir, "dataset_info.json")) assert os.path.exists(os.path.join(output_dir, f"{builder.name}-train.arrow")) assert not os.path.exists(os.path.join(output_dir + ".incomplete")) From 6e480e0dbea3b37ef5f3fc2097db9127417b6bc9 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> Date: Tue, 13 Sep 2022 18:39:36 +0200 Subject: [PATCH 50/51] Update src/datasets/builder.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Mario Šaško --- src/datasets/builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 19eb6cb4e4c..5278288d25a 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -597,7 +597,7 @@ def download_and_prepare( base_path: Optional[str] = None, use_auth_token: Optional[Union[bool, str]] = None, file_format: str = "arrow", - max_shard_size: Optional[int] = None, + max_shard_size: Optional[Union[int, str]] = None, storage_options: Optional[dict] = None, **download_and_prepare_kwargs, ): From d3357425dfefdd01bfd1d21058e47cd7c0ea0859 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> Date: Thu, 15 Sep 2022 14:52:39 +0200 Subject: [PATCH 51/51] Update src/datasets/builder.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Mario Šaško --- src/datasets/builder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 5278288d25a..016c30edd0e 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -1660,8 +1660,8 @@ def _prepare_split( if max_shard_size is not None: raise NotImplementedError( - "max_shard_size is not supported for Beam datasets, please." - "Set it to None to use the default Apache Beam sharding and get the best performance." + "max_shard_size is not supported for Beam datasets." + "Please set it to None to use the default Apache Beam sharding and get the best performance." ) # To write examples in filesystem: