diff --git a/docs/source/filesystems.mdx b/docs/source/filesystems.mdx index f2da377568f..962966bd65f 100644 --- a/docs/source/filesystems.mdx +++ b/docs/source/filesystems.mdx @@ -140,6 +140,12 @@ Use your own data files (see [how to load local and remote files](./loading#loca It is highly recommended to save the files as compressed Parquet files to optimize I/O by specifying `file_format="parquet"`. Otherwise the dataset is saved as an uncompressed Arrow file. +You can also specify the size of the Parquet shard using `max_shard_size` (default is 500MB): + +```py +>>> builder.download_and_prepare(output_dir, storage_options=storage_options, file_format="parquet", max_shard_size="1GB") +``` + #### Dask Dask is a parallel computing library and it has a pandas-like API for working with larger than memory Parquet datasets in parallel. diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 24454bdd66c..d7b0532306e 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -4052,7 +4052,7 @@ def _push_parquet_shards_to_hub( private: Optional[bool] = False, token: Optional[str] = None, branch: Optional[str] = None, - max_shard_size: Union[int, str] = "500MB", + max_shard_size: Optional[Union[int, str]] = None, embed_external_files: bool = True, ) -> Tuple[str, str, int, int]: """Pushes the dataset to the hub. @@ -4098,8 +4098,7 @@ def _push_parquet_shards_to_hub( >>> dataset.push_to_hub("/", split="evaluation") ``` """ - - max_shard_size = convert_file_size_to_int(max_shard_size) + max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE) api = HfApi(endpoint=config.HF_ENDPOINT) token = token if token is not None else HfFolder.get_token() @@ -4270,7 +4269,7 @@ def push_to_hub( private: Optional[bool] = False, token: Optional[str] = None, branch: Optional[str] = None, - max_shard_size: Union[int, str] = "500MB", + max_shard_size: Optional[Union[int, str]] = None, shard_size: Optional[int] = "deprecated", embed_external_files: bool = True, ): diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 955078dff27..016c30edd0e 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -31,6 +31,7 @@ from typing import Dict, Mapping, Optional, Tuple, Union import fsspec +from tqdm.contrib.concurrent import thread_map from . import config, utils from .arrow_dataset import Dataset @@ -63,6 +64,7 @@ from .utils.info_utils import get_size_checksum_dict, verify_checksums, verify_splits from .utils.py_utils import ( classproperty, + convert_file_size_to_int, has_sufficient_disk_space, map_nested, memoize, @@ -576,6 +578,14 @@ def get_imported_module_dir(cls): """Return the path of the module of this class or subclass.""" return os.path.dirname(inspect.getfile(inspect.getmodule(cls))) + def _rename(self, src: str, dst: str): + is_local = not is_remote_filesystem(self._fs) + if is_local: + # LocalFileSystem.mv does copy + rm, it is more efficient to simply move a local directory + shutil.move(self._fs._strip_protocol(src), self._fs._strip_protocol(dst)) + else: + self._fs.mv(src, dst, recursive=True) + def download_and_prepare( self, output_dir: Optional[str] = None, @@ -586,7 +596,8 @@ def download_and_prepare( dl_manager: Optional[DownloadManager] = None, base_path: Optional[str] = None, use_auth_token: Optional[Union[bool, str]] = None, - file_format: Optional[str] = None, + file_format: str = "arrow", + max_shard_size: Optional[Union[int, str]] = None, storage_options: Optional[dict] = None, **download_and_prepare_kwargs, ): @@ -609,6 +620,11 @@ def download_and_prepare( file_format (:obj:`str`, optional): format of the data files in which the dataset will be written. Supported formats: "arrow", "parquet". Default to "arrow" format. + + max_shard_size (:obj:`Union[str, int]`, optional): Maximum number of bytes written per shard. + Only available for the "parquet" format with a default of "500MB". The size is based on uncompressed data size, + so in practice your shard files may be smaller than `max_shard_size` thanks to Parquet compression. + storage_options (:obj:`dict`, *optional*): Key/value pairs to be passed on to the caching file-system backend, if any. @@ -617,25 +633,59 @@ def download_and_prepare( Example: + Downdload and prepare the dataset as Arrow files that can be loaded as a Dataset using `builder.as_dataset()` + ```py >>> from datasets import load_dataset_builder - >>> builder = load_dataset_builder('rotten_tomatoes') + >>> builder = load_dataset_builder("rotten_tomatoes") >>> ds = builder.download_and_prepare() ``` + + Downdload and prepare the dataset as sharded Parquet files locally + + ```py + >>> from datasets import load_dataset_builder + >>> builder = load_dataset_builder("rotten_tomatoes") + >>> ds = builder.download_and_prepare("./output_dir", file_format="parquet") + ``` + + Downdload and prepare the dataset as sharded Parquet files in a cloud storage + + ```py + >>> from datasets import load_dataset_builder + >>> storage_options = {"key": aws_access_key_id, "secret": aws_secret_access_key} + >>> builder = load_dataset_builder("rotten_tomatoes") + >>> ds = builder.download_and_prepare("s3://my-bucket/my_rotten_tomatoes", storage_options=storage_options, file_format="parquet") + ``` """ - self._output_dir = output_dir if output_dir is not None else self._cache_dir + output_dir = output_dir if output_dir is not None else self._cache_dir # output_dir can be a remote bucket on GCS or S3 (when using BeamBasedBuilder for distributed data processing) - fs_token_paths = fsspec.get_fs_token_paths(self._output_dir, storage_options=storage_options) - self._fs = fs_token_paths[0] + fs_token_paths = fsspec.get_fs_token_paths(output_dir, storage_options=storage_options) + self._fs: fsspec.AbstractFileSystem = fs_token_paths[0] + is_local = not is_remote_filesystem(self._fs) + self._output_dir = fs_token_paths[2][0] if is_local else self._fs.unstrip_protocol(fs_token_paths[2][0]) download_mode = DownloadMode(download_mode or DownloadMode.REUSE_DATASET_IF_EXISTS) verify_infos = not ignore_verifications base_path = base_path if base_path is not None else self.base_path - is_local = not is_remote_filesystem(self._fs) if file_format is not None and file_format not in ["arrow", "parquet"]: raise ValueError(f"Unsupported file_format: {file_format}. Expected 'arrow' or 'parquet'") + if file_format == "arrow" and max_shard_size is not None: + raise NotImplementedError( + "Writing sharded arrow files is not supported. Please don't use max_shard_size or use file_format='parquet'." + ) + + if self._fs._strip_protocol(self._output_dir) == "": + # We don't support the root directory, because it has no dirname, + # and we need a dirname to use a .incomplete directory + # when the dataset is being written + raise RuntimeError( + f"Unable to download and prepare the dataset at the root {self._output_dir}. " + f"Please specify a subdirectory, e.g. '{self._output_dir + self.name}'" + ) + if dl_manager is None: if download_config is None: download_config = DownloadConfig( @@ -656,7 +706,12 @@ def download_and_prepare( else False, ) - elif isinstance(dl_manager, MockDownloadManager) or not is_local: + if ( + isinstance(dl_manager, MockDownloadManager) + or not is_local + or file_format != "arrow" + or max_shard_size is not None + ): try_from_hf_gcs = False self.dl_manager = dl_manager @@ -743,10 +798,15 @@ def incomplete_dir(dirname): except ConnectionError: logger.warning("HF google storage unreachable. Downloading and preparing it from source") if not downloaded_from_gcs: + prepare_split_kwargs = { + "file_format": file_format, + "max_shard_size": max_shard_size, + **download_and_prepare_kwargs, + } self._download_and_prepare( dl_manager=dl_manager, verify_infos=verify_infos, - file_format=file_format, + **prepare_split_kwargs, **download_and_prepare_kwargs, ) # Sync info @@ -797,7 +857,7 @@ def _download_prepared_from_hf_gcs(self, download_config: DownloadConfig): logger.info(f"Couldn't download resourse file {resource_file_name} from Hf google storage.") logger.info("Dataset downloaded from Hf google storage.") - def _download_and_prepare(self, dl_manager, verify_infos, file_format=None, **prepare_split_kwargs): + def _download_and_prepare(self, dl_manager, verify_infos, **prepare_split_kwargs): """Downloads and prepares dataset for reading. This is the internal implementation to overwrite called when user calls @@ -807,9 +867,7 @@ def _download_and_prepare(self, dl_manager, verify_infos, file_format=None, **pr Args: dl_manager: (:obj:`DownloadManager`) `DownloadManager` used to download and cache data. verify_infos (:obj:`bool`): if False, do not perform checksums and size tests. - file_format (:obj:`str`, optional): format of the data files in which the dataset will be written. - Supported formats: "arrow", "parquet". Default to "arrow" format. - prepare_split_kwargs: Additional options. + prepare_split_kwargs: Additional options, such as file_format, max_shard_size """ # Generating data for all splits split_dict = SplitDict(dataset_name=self.name) @@ -836,7 +894,7 @@ def _download_and_prepare(self, dl_manager, verify_infos, file_format=None, **pr try: # Prepare split will record examples associated to the split - self._prepare_split(split_generator, file_format=file_format, **prepare_split_kwargs) + self._prepare_split(split_generator, **prepare_split_kwargs) except OSError as e: raise OSError( "Cannot find data file. " @@ -1174,13 +1232,22 @@ def _split_generators(self, dl_manager: DownloadManager): raise NotImplementedError() @abc.abstractmethod - def _prepare_split(self, split_generator: SplitGenerator, file_format: Optional[str] = None, **kwargs): + def _prepare_split( + self, + split_generator: SplitGenerator, + file_format: str = "arrow", + max_shard_size: Optional[Union[str, int]] = None, + **kwargs, + ): """Generate the examples and record them on disk. Args: split_generator: `SplitGenerator`, Split generator to process file_format (:obj:`str`, optional): format of the data files in which the dataset will be written. Supported formats: "arrow", "parquet". Default to "arrow" format. + max_shard_size (:obj:`Union[str, int]`, optional): Approximate maximum number of bytes written per shard. + Only available for the "parquet" format with a default of "500MB". The size is based on uncompressed data size, + so in practice your shard files may be smaller than `max_shard_size` thanks to Parquet compression. **kwargs: Additional kwargs forwarded from _download_and_prepare (ex: beam pipeline) """ @@ -1251,17 +1318,30 @@ def _generate_examples(self, **kwargs): """ raise NotImplementedError() - def _prepare_split(self, split_generator, check_duplicate_keys, file_format=None): + def _prepare_split( + self, + split_generator: SplitGenerator, + check_duplicate_keys: bool, + file_format="arrow", + max_shard_size: Optional[Union[int, str]] = None, + ): is_local = not is_remote_filesystem(self._fs) path_join = os.path.join if is_local else posixpath.join + if file_format == "arrow": + if max_shard_size is not None: + raise NotImplementedError( + "Writing sharded arrow files is not supported. Please don't use max_shard_size or use file_format='parquet'." + ) + else: + max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE) + if self.info.splits is not None: split_info = self.info.splits[split_generator.name] else: split_info = split_generator.split_info - file_format = file_format or "arrow" - suffix = "-00000-of-00001" if file_format == "parquet" else "" + suffix = "-SSSSS-of-NNNNN" if file_format == "parquet" else "" fname = f"{self.name}-{split_generator.name}{suffix}.{file_format}" fpath = path_join(self._output_dir, fname) @@ -1269,35 +1349,66 @@ def _prepare_split(self, split_generator, check_duplicate_keys, file_format=None writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter + shard_id = 0 # TODO: embed the images/audio files inside parquet files. - with writer_class( + writer = writer_class( features=self.info.features, - path=fpath, + path=fpath.replace("SSSSS", f"{shard_id:05d}"), writer_batch_size=self._writer_batch_size, hash_salt=split_info.name, check_duplicates=check_duplicate_keys, storage_options=self._fs.storage_options, - ) as writer: - try: - for key, record in logging.tqdm( - generator, - unit=" examples", - total=split_info.num_examples, - leave=False, - disable=not logging.is_progress_bar_enabled(), - desc=f"Generating {split_info.name} split", - ): - example = self.info.features.encode_example(record) - writer.write(example, key) - finally: - num_examples, num_bytes = writer.finalize() + ) + total_num_examples, total_num_bytes = 0, 0 + try: + for key, record in logging.tqdm( + generator, + unit=" examples", + total=split_info.num_examples, + leave=False, + disable=not logging.is_progress_bar_enabled(), + desc=f"Generating {split_info.name} split", + ): + if max_shard_size is not None and writer._num_bytes > max_shard_size: + num_examples, num_bytes = writer.finalize() + total_num_examples += num_examples + total_num_bytes += num_bytes + shard_id += 1 + writer = writer_class( + features=writer._features, + path=fpath.replace("SSSSS", f"{shard_id:05d}"), + writer_batch_size=self._writer_batch_size, + hash_salt=split_info.name, + check_duplicates=check_duplicate_keys, + storage_options=self._fs.storage_options, + ) + example = self.info.features.encode_example(record) + writer.write(example, key) + finally: + num_shards = shard_id + 1 + num_examples, num_bytes = writer.finalize() + total_num_examples += num_examples + total_num_bytes += num_bytes - split_generator.split_info.num_examples = num_examples - split_generator.split_info.num_bytes = num_bytes + if file_format == "parquet": - def _download_and_prepare(self, dl_manager, verify_infos, file_format=None): + def _rename_shard(shard_id: int): + self._rename( + fpath.replace("SSSSS", f"{shard_id:05d}"), + fpath.replace("SSSSS", f"{shard_id:05d}").replace("NNNNN", f"{num_shards:05d}"), + ) + + logger.debug(f"Renaming {num_shards} shards.") + thread_map(_rename_shard, range(num_shards), disable=True, max_workers=64) + + split_generator.split_info.num_examples = total_num_examples + split_generator.split_info.num_bytes = total_num_bytes + if self.info.features is None: + self.info.features = writer._features + + def _download_and_prepare(self, dl_manager, verify_infos, **prepare_splits_kwargs): super()._download_and_prepare( - dl_manager, verify_infos, file_format=file_format, check_duplicate_keys=verify_infos + dl_manager, verify_infos, check_duplicate_keys=verify_infos, **prepare_splits_kwargs ) def _get_examples_iterable_for_split(self, split_generator: SplitGenerator) -> ExamplesIterable: @@ -1339,27 +1450,76 @@ def _generate_tables(self, **kwargs): """ raise NotImplementedError() - def _prepare_split(self, split_generator, file_format=None): + def _prepare_split( + self, + split_generator: SplitGenerator, + file_format: str = "arrow", + max_shard_size: Optional[Union[str, int]] = None, + ): is_local = not is_remote_filesystem(self._fs) path_join = os.path.join if is_local else posixpath.join - file_format = file_format or "arrow" - suffix = "-00000-of-00001" if file_format == "parquet" else "" + if file_format == "arrow": + if max_shard_size is not None: + raise NotImplementedError( + "Writing sharded arrow files is not supported. Please don't use max_shard_size or use file_format='parquet'." + ) + else: + max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE) + + suffix = "-SSSSS-of-NNNNN" if file_format == "parquet" else "" fname = f"{self.name}-{split_generator.name}{suffix}.{file_format}" fpath = path_join(self._output_dir, fname) generator = self._generate_tables(**split_generator.gen_kwargs) + writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter + + shard_id = 0 # TODO: embed the images/audio files inside parquet files. - with writer_class(features=self.info.features, path=fpath, storage_options=self._fs.storage_options) as writer: + writer = writer_class( + features=self.info.features, + path=fpath.replace("SSSSS", f"{shard_id:05d}"), + storage_options=self._fs.storage_options, + ) + total_num_examples, total_num_bytes = 0, 0 + try: for key, table in logging.tqdm( - generator, unit=" tables", leave=False, disable=(not logging.is_progress_bar_enabled()) + generator, + unit=" tables", + leave=False, + disable=not logging.is_progress_bar_enabled(), ): + if max_shard_size is not None and writer._num_bytes > max_shard_size: + num_examples, num_bytes = writer.finalize() + total_num_examples += num_examples + total_num_bytes += num_bytes + shard_id += 1 + writer = writer_class( + features=writer._features, + path=fpath.replace("SSSSS", f"{shard_id:05d}"), + storage_options=self._fs.storage_options, + ) writer.write_table(table) + finally: + num_shards = shard_id + 1 num_examples, num_bytes = writer.finalize() + total_num_examples += num_examples + total_num_bytes += num_bytes - split_generator.split_info.num_examples = num_examples - split_generator.split_info.num_bytes = num_bytes + if file_format == "parquet": + + def _rename_shard(shard_id: int): + self._rename( + fpath.replace("SSSSS", f"{shard_id:05d}"), + fpath.replace("SSSSS", f"{shard_id:05d}").replace("NNNNN", f"{num_shards:05d}"), + ) + + logger.debug(f"Renaming {num_shards} shards.") + thread_map(_rename_shard, range(num_shards), disable=True, max_workers=64) + + split_generator.split_info.num_examples = total_num_examples + split_generator.split_info.num_bytes = total_num_bytes if self.info.features is None: self.info.features = writer._features @@ -1432,7 +1592,7 @@ def _build_pcollection(pipeline, extracted_dir=None): """ raise NotImplementedError() - def _download_and_prepare(self, dl_manager, verify_infos, file_format=None): + def _download_and_prepare(self, dl_manager, verify_infos, **prepare_splits_kwargs): # Create the Beam pipeline and forward it to _prepare_split import apache_beam as beam @@ -1467,10 +1627,7 @@ def _download_and_prepare(self, dl_manager, verify_infos, file_format=None): options=beam_options, ) super()._download_and_prepare( - dl_manager, - verify_infos=False, - pipeline=pipeline, - file_format=file_format, + dl_manager, verify_infos=False, pipeline=pipeline, **prepare_splits_kwargs ) # TODO handle verify_infos in beam datasets # Run pipeline pipeline_results = pipeline.run() @@ -1496,12 +1653,19 @@ def _save_info(self): with fs.create(path_join(self._output_dir, config.LICENSE_FILENAME)) as f: self.info._dump_license(f) - def _prepare_split(self, split_generator, pipeline, file_format=None): + def _prepare_split( + self, split_generator, pipeline, file_format="arrow", max_shard_size: Optional[Union[str, int]] = None + ): import apache_beam as beam + if max_shard_size is not None: + raise NotImplementedError( + "max_shard_size is not supported for Beam datasets." + "Please set it to None to use the default Apache Beam sharding and get the best performance." + ) + # To write examples in filesystem: split_name = split_generator.split_info.name - file_format = file_format or "arrow" fname = f"{self.name}-{split_name}.{file_format}" path_join = os.path.join if not is_remote_filesystem(self._fs) else posixpath.join fpath = path_join(self._output_dir, fname) diff --git a/src/datasets/config.py b/src/datasets/config.py index 87b321441cc..2bd5419cbe3 100644 --- a/src/datasets/config.py +++ b/src/datasets/config.py @@ -172,6 +172,9 @@ # For big tables, we write them on disk instead MAX_TABLE_NBYTES_FOR_PICKLING = 4 << 30 +# Max shard size in bytes (e.g. to shard parquet datasets in push_to_hub or download_and_prepare) +MAX_SHARD_SIZE = "500MB" + # Offline mode HF_DATASETS_OFFLINE = os.environ.get("HF_DATASETS_OFFLINE", "AUTO").upper() in ENV_VARS_TRUE_VALUES diff --git a/src/datasets/dataset_dict.py b/src/datasets/dataset_dict.py index 70ea04d7ea8..14142651028 100644 --- a/src/datasets/dataset_dict.py +++ b/src/datasets/dataset_dict.py @@ -1287,7 +1287,7 @@ def push_to_hub( private: Optional[bool] = False, token: Optional[str] = None, branch: Optional[None] = None, - max_shard_size: Union[int, str] = "500MB", + max_shard_size: Optional[Union[int, str]] = None, shard_size: Optional[int] = "deprecated", embed_external_files: bool = True, ): diff --git a/src/datasets/utils/py_utils.py b/src/datasets/utils/py_utils.py index f56b765a7c2..01d0673a2f6 100644 --- a/src/datasets/utils/py_utils.py +++ b/src/datasets/utils/py_utils.py @@ -114,7 +114,7 @@ def convert_file_size_to_int(size: Union[int, str]) -> int: if size.upper().endswith("KB"): int_size = int(size[:-2]) * (10**3) return int_size // 8 if size.endswith("b") else int_size - raise ValueError("`size` is not in a valid format. Use an integer followed by the unit, e.g., '5GB'.") + raise ValueError(f"`size={size}` is not in a valid format. Use an integer followed by the unit, e.g., '5GB'.") def string_to_dict(string: str, pattern: str) -> Dict[str, str]: diff --git a/tests/test_builder.py b/tests/test_builder.py index 04ea1c87ccc..5cd3c40ad6a 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -25,7 +25,13 @@ from datasets.streaming import xjoin from datasets.utils.file_utils import is_local_path -from .utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases, require_beam, require_faiss +from .utils import ( + assert_arrow_memory_doesnt_increase, + assert_arrow_memory_increases, + require_beam, + require_faiss, + set_current_working_directory_to_temp_dir, +) class DummyBuilder(DatasetBuilder): @@ -757,6 +763,16 @@ def test_beam_based_download_and_prepare(tmp_path): assert os.path.exists(os.path.join(tmp_path, builder.name, "default", "0.0.0", "dataset_info.json")) +@require_beam +def test_beam_based_as_dataset(tmp_path): + builder = DummyBeamBasedBuilder(cache_dir=tmp_path, beam_runner="DirectRunner") + builder.download_and_prepare() + dataset = builder.as_dataset() + assert dataset + assert isinstance(dataset["train"], Dataset) + assert len(dataset["train"]) > 0 + + @pytest.mark.parametrize( "split, expected_dataset_class, expected_dataset_length", [ @@ -915,6 +931,27 @@ def test_builder_config_version(builder_class, kwargs, tmp_path): assert builder.config.version == "2.0.0" +def test_builder_download_and_prepare_with_absolute_output_dir(tmp_path): + builder = DummyGeneratorBasedBuilder() + output_dir = str(tmp_path) + builder.download_and_prepare(output_dir) + assert builder._output_dir.startswith(tmp_path.resolve().as_posix()) + assert os.path.exists(os.path.join(output_dir, "dataset_info.json")) + assert os.path.exists(os.path.join(output_dir, f"{builder.name}-train.arrow")) + assert not os.path.exists(os.path.join(output_dir + ".incomplete")) + + +def test_builder_download_and_prepare_with_relative_output_dir(): + with set_current_working_directory_to_temp_dir(): + builder = DummyGeneratorBasedBuilder() + output_dir = "test-out" + builder.download_and_prepare(output_dir) + assert Path(builder._output_dir).resolve().as_posix().startswith(Path(output_dir).resolve().as_posix()) + assert os.path.exists(os.path.join(output_dir, "dataset_info.json")) + assert os.path.exists(os.path.join(output_dir, f"{builder.name}-train.arrow")) + assert not os.path.exists(os.path.join(output_dir + ".incomplete")) + + def test_builder_with_filesystem_download_and_prepare(tmp_path, mockfs): builder = DummyGeneratorBasedBuilder(cache_dir=tmp_path) builder.download_and_prepare("mock://my_dataset", storage_options=mockfs.storage_options) @@ -948,6 +985,43 @@ def test_generator_based_builder_download_and_prepare_as_parquet(tmp_path): assert pq.ParquetFile(parquet_path) is not None +def test_generator_based_builder_download_and_prepare_as_sharded_parquet(tmp_path): + writer_batch_size = 25 + builder = DummyGeneratorBasedBuilder(cache_dir=tmp_path, writer_batch_size=writer_batch_size) + with patch("datasets.config.MAX_SHARD_SIZE", 1): # one batch per shard + builder.download_and_prepare(file_format="parquet") + expected_num_shards = 100 // writer_batch_size + assert builder.info.splits["train"].num_examples, 100 + parquet_path = os.path.join( + tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-{expected_num_shards:05d}.parquet" + ) + assert os.path.exists(parquet_path) + parquet_files = [ + pq.ParquetFile(parquet_path) + for parquet_path in Path(tmp_path).rglob(f"{builder.name}-train-*-of-{expected_num_shards:05d}.parquet") + ] + assert len(parquet_files) == expected_num_shards + assert sum(parquet_file.metadata.num_rows for parquet_file in parquet_files) == 100 + + +def test_generator_based_builder_download_and_prepare_as_sharded_parquet_with_max_shard_size(tmp_path): + writer_batch_size = 25 + builder = DummyGeneratorBasedBuilder(cache_dir=tmp_path, writer_batch_size=writer_batch_size) + builder.download_and_prepare(file_format="parquet", max_shard_size=1) # one batch per shard + expected_num_shards = 100 // writer_batch_size + assert builder.info.splits["train"].num_examples, 100 + parquet_path = os.path.join( + tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-{expected_num_shards:05d}.parquet" + ) + assert os.path.exists(parquet_path) + parquet_files = [ + pq.ParquetFile(parquet_path) + for parquet_path in Path(tmp_path).rglob(f"{builder.name}-train-*-of-{expected_num_shards:05d}.parquet") + ] + assert len(parquet_files) == expected_num_shards + assert sum(parquet_file.metadata.num_rows for parquet_file in parquet_files) == 100 + + def test_arrow_based_builder_download_and_prepare_as_parquet(tmp_path): builder = DummyArrowBasedBuilder(cache_dir=tmp_path) builder.download_and_prepare(file_format="parquet") @@ -959,6 +1033,41 @@ def test_arrow_based_builder_download_and_prepare_as_parquet(tmp_path): assert pq.ParquetFile(parquet_path) is not None +def test_arrow_based_builder_download_and_prepare_as_sharded_parquet(tmp_path): + builder = DummyArrowBasedBuilder(cache_dir=tmp_path) + with patch("datasets.config.MAX_SHARD_SIZE", 1): # one batch per shard + builder.download_and_prepare(file_format="parquet") + expected_num_shards = 10 + assert builder.info.splits["train"].num_examples, 100 + parquet_path = os.path.join( + tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-{expected_num_shards:05d}.parquet" + ) + assert os.path.exists(parquet_path) + parquet_files = [ + pq.ParquetFile(parquet_path) + for parquet_path in Path(tmp_path).rglob(f"{builder.name}-train-*-of-{expected_num_shards:05d}.parquet") + ] + assert len(parquet_files) == expected_num_shards + assert sum(parquet_file.metadata.num_rows for parquet_file in parquet_files) == 100 + + +def test_arrow_based_builder_download_and_prepare_as_sharded_parquet_with_max_shard_size(tmp_path): + builder = DummyArrowBasedBuilder(cache_dir=tmp_path) + builder.download_and_prepare(file_format="parquet", max_shard_size=1) # one table per shard + expected_num_shards = 10 + assert builder.info.splits["train"].num_examples, 100 + parquet_path = os.path.join( + tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-{expected_num_shards:05d}.parquet" + ) + assert os.path.exists(parquet_path) + parquet_files = [ + pq.ParquetFile(parquet_path) + for parquet_path in Path(tmp_path).rglob(f"{builder.name}-train-*-of-{expected_num_shards:05d}.parquet") + ] + assert len(parquet_files) == expected_num_shards + assert sum(parquet_file.metadata.num_rows for parquet_file in parquet_files) == 100 + + def test_beam_based_builder_download_and_prepare_as_parquet(tmp_path): builder = DummyBeamBasedBuilder(cache_dir=tmp_path, beam_runner="DirectRunner") builder.download_and_prepare(file_format="parquet") diff --git a/tests/test_upstream_hub.py b/tests/test_upstream_hub.py index 0503a87b08e..5833b367af7 100644 --- a/tests/test_upstream_hub.py +++ b/tests/test_upstream_hub.py @@ -131,6 +131,35 @@ def test_push_dataset_dict_to_hub_multiple_files(self, temporary_repo): local_ds = DatasetDict({"train": ds}) + with temporary_repo(f"{CI_HUB_USER}/test-{int(time.time() * 10e3)}") as ds_name: + with patch("datasets.config.MAX_SHARD_SIZE", "16KB"): + local_ds.push_to_hub(ds_name, token=self._token) + hub_ds = load_dataset(ds_name, download_mode="force_redownload") + + assert local_ds.column_names == hub_ds.column_names + assert list(local_ds["train"].features.keys()) == list(hub_ds["train"].features.keys()) + assert local_ds["train"].features == hub_ds["train"].features + + # Ensure that there are two files on the repository that have the correct name + files = sorted(self._api.list_repo_files(ds_name, repo_type="dataset", token=self._token)) + assert all( + fnmatch.fnmatch(file, expected_file) + for file, expected_file in zip( + files, + [ + ".gitattributes", + "data/train-00000-of-00002-*.parquet", + "data/train-00001-of-00002-*.parquet", + "dataset_infos.json", + ], + ) + ) + + def test_push_dataset_dict_to_hub_multiple_files_with_max_shard_size(self, temporary_repo): + ds = Dataset.from_dict({"x": list(range(1000)), "y": list(range(1000))}) + + local_ds = DatasetDict({"train": ds}) + with temporary_repo(f"{CI_HUB_USER}/test-{int(time.time() * 10e3)}") as ds_name: local_ds.push_to_hub(ds_name, token=self._token, max_shard_size="16KB") hub_ds = load_dataset(ds_name, download_mode="force_redownload")