diff --git a/docs/source/dataset_script.mdx b/docs/source/dataset_script.mdx index 65bba37ccd8..e46d013ff29 100644 --- a/docs/source/dataset_script.mdx +++ b/docs/source/dataset_script.mdx @@ -296,3 +296,59 @@ Congratulations, you can now load your dataset from the Hub! 🥳 >>> from datasets import load_dataset >>> load_dataset("/my_dataset") ``` + +## Advanced features + +### Sharding + +If your dataset is made of many big files, 🤗 Datasets automatically runs your script in parallel to make it super fast! +It can help if you have hundreds or thousands of TAR archives, or JSONL files like [oscar](https://huggingface.co/datasets/oscar/blob/main/oscar.py) for example. + +To make it work, we consider lists of files in `gen_kwargs` to be shards. +Therefore 🤗 Datasets can automatically spawn several workers to run `_generate_examples` in parallel, and each worker is given a subset of shards to process. + + +```python + +class MyShardedDataset(datasets.GeneratorBasedBuilder): + + def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]: + downloaded_files = dl_manager.download([f"data/shard_{i}.jsonl" for i in range(1024)]) + return [ + datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": downloaded_files}), + ] + + def _generate_examples(self, filepaths): + # Each worker can be given a slice of the original `filepaths` list defined in the `gen_kwargs` + # so that this code can run in parallel on several shards at the same time + for filepath in filepaths: + ... +``` + +Users can also specify `num_proc=` in `load_dataset()` to specify the number of processes to use as workers. + +### ArrowBasedBuilder + +For some datasets it can be much faster to yield batches of data rather than examples one by one. +You can speed up the dataset generation by yielding Arrow tables directly, instead of examples. +This is especially useful if your data comes from Pandas DataFrames for example, since the conversion from Pandas to Arrow is as simple as: + +```python +import pyarrow as pa +pa_table = pa.Table.from_pandas(df) +``` + +To yield Arrow tables instead of single examples, make your dataset builder inherit from [`ArrowBasedBuilder`] instead of [`GeneratorBasedBuilder`], and use `_generate_tables` instead of `_generate_examples`: + +```python +class MySuperFastDataset(datasets.ArrowBasedBuilder): + + def _generate_tables(self, filepaths): + idx = 0 + for filepath in filepaths: + ... + yield idx, pa_table + idx += 1 +``` + +Don't forget to keep your script memory efficient, in case users run them on machines with a low amount of RAM. diff --git a/docs/source/loading.mdx b/docs/source/loading.mdx index 3dd51e98639..ba3b5ded659 100644 --- a/docs/source/loading.mdx +++ b/docs/source/loading.mdx @@ -223,6 +223,21 @@ You can specify [`Dataset.from_sql#con`] as a [URI string](https://docs.sqlalche +## Multiprocessing + +When a dataset is made of several files (that we call "shards"), it is possible to significantly speed up the dataset downloading and preparation step. + +You can choose how many processes you'd like to use to prepare a dataset in parallel using `num_proc`. +In this case, each process is given a subset of shards to prepare: + +```python +from datasets import load_dataset + +oscar_afrikaans = load_dataset("oscar-corpus/OSCAR-2201", "af", num_proc=8) +imagenet = load_dataset("imagenet-1k", num_proc=8) +ml_librispeech_spanish = load_dataset("facebook/multilingual_librispeech", "spanish", num_proc=8) +``` + ## In-memory data 🤗 Datasets will also allow you to create a [`Dataset`] directly from in-memory data structures like Python dictionaries and Pandas DataFrames. diff --git a/src/datasets/arrow_reader.py b/src/datasets/arrow_reader.py index fb53b35f42e..1f1c384549c 100644 --- a/src/datasets/arrow_reader.py +++ b/src/datasets/arrow_reader.py @@ -27,7 +27,7 @@ import pyarrow.parquet as pq from .download.download_config import DownloadConfig -from .naming import _split_re, filename_for_dataset_split +from .naming import _split_re, filenames_for_dataset_split from .table import InMemoryTable, MemoryMappedTable, Table, concat_tables from .utils import logging from .utils.file_utils import cached_path @@ -35,7 +35,7 @@ if TYPE_CHECKING: from .info import DatasetInfo # noqa: F401 - from .splits import Split # noqa: F401 + from .splits import Split, SplitInfo # noqa: F401 logger = logging.get_logger(__name__) @@ -88,7 +88,13 @@ class FileInstructions: file_instructions: List[dict] -def make_file_instructions(name, split_infos, instruction, filetype_suffix=None): +def make_file_instructions( + name: str, + split_infos: List["SplitInfo"], + instruction: Union[str, "ReadInstruction"], + filetype_suffix: Optional[str] = None, + prefix_path: Optional[str] = None, +): """Returns instructions of the split dict. Args: @@ -101,31 +107,48 @@ def make_file_instructions(name, split_infos, instruction, filetype_suffix=None) file_intructions: FileInstructions instance """ name2len = {info.name: info.num_examples for info in split_infos} + name2shard_lengths = {info.name: info.shard_lengths for info in split_infos} + name2filenames = { + info.name: filenames_for_dataset_split( + path=prefix_path, + dataset_name=name, + split=info.name, + filetype_suffix=filetype_suffix, + shard_lengths=name2shard_lengths[info.name], + ) + for info in split_infos + } if not isinstance(instruction, ReadInstruction): instruction = ReadInstruction.from_spec(instruction) # Create the absolute instruction (per split) absolute_instructions = instruction.to_absolute(name2len) - return _make_file_instructions_from_absolutes( - name=name, name2len=name2len, absolute_instructions=absolute_instructions, filetype_suffix=filetype_suffix - ) - - -def _make_file_instructions_from_absolutes(name, name2len, absolute_instructions, filetype_suffix=None): - """Returns the files instructions from the absolute instructions list.""" # For each split, return the files instruction (skip/take) file_instructions = [] num_examples = 0 for abs_instr in absolute_instructions: - length = name2len[abs_instr.splitname] - filename = filename_for_dataset_split( - dataset_name=name, split=abs_instr.splitname, filetype_suffix=filetype_suffix - ) + split_length = name2len[abs_instr.splitname] + filenames = name2filenames[abs_instr.splitname] + shard_lengths = name2shard_lengths[abs_instr.splitname] from_ = 0 if abs_instr.from_ is None else abs_instr.from_ - to = length if abs_instr.to is None else abs_instr.to - num_examples += to - from_ - single_file_instructions = [{"filename": filename, "skip": from_, "take": to - from_}] - file_instructions.extend(single_file_instructions) + to = split_length if abs_instr.to is None else abs_instr.to + if shard_lengths is None: # not sharded + for filename in filenames: + num_examples += to - from_ + file_instructions.append({"filename": filename, "skip": from_, "take": to - from_}) + else: # sharded + index_start = 0 # Beginning (included) of moving window. + index_end = 0 # End (excluded) of moving window. + for filename, shard_length in zip(filenames, shard_lengths): + index_end += shard_length + if from_ < index_end and to > index_start: # There is something to take. + skip = from_ - index_start if from_ > index_start else 0 + take = to - index_start - skip if to < index_end else -1 + if take == 0: + continue + file_instructions.append({"filename": filename, "skip": skip, "take": take}) + num_examples += shard_length - skip if take == -1 else take + index_start += shard_length return FileInstructions( num_examples=num_examples, file_instructions=file_instructions, @@ -182,7 +205,7 @@ def _read_files(self, files, in_memory=False) -> Table: def get_file_instructions(self, name, instruction, split_infos): """Return list of dict {'filename': str, 'skip': int, 'take': int}""" file_instructions = make_file_instructions( - name, split_infos, instruction, filetype_suffix=self._filetype_suffix + name, split_infos, instruction, filetype_suffix=self._filetype_suffix, prefix_path=self._path ) files = file_instructions.file_instructions return files @@ -304,6 +327,8 @@ def _get_table_from_filename(self, filename_skip_take, in_memory=False) -> Table filename_skip_take["take"] if "take" in filename_skip_take else None, ) table = ArrowReader.read_table(filename, in_memory=in_memory) + if take == -1: + take = len(table) - skip # here we don't want to slice an empty table, or it may segfault if skip is not None and take is not None and not (skip == 0 and take == len(table)): table = table.slice(skip, take) diff --git a/src/datasets/arrow_writer.py b/src/datasets/arrow_writer.py index f254f877105..ab0057c599f 100644 --- a/src/datasets/arrow_writer.py +++ b/src/datasets/arrow_writer.py @@ -636,9 +636,17 @@ def finalize(self, metrics_query_result: dict): from .utils import beam_utils + shards_metadata = [ + metadata + for metadata in beam.io.filesystems.FileSystems.match([self._parquet_path + "*.parquet"])[0].metadata_list + ] + shards = [metadata.path for metadata in shards_metadata] + num_bytes = sum([metadata.size_in_bytes for metadata in shards_metadata]) + shard_lengths = get_parquet_lengths(shards) + # Convert to arrow if self._path.endswith(".arrow"): - logger.info(f"Converting parquet file {self._parquet_path} to arrow {self._path}") + logger.info(f"Converting parquet files {self._parquet_path} to arrow {self._path}") shards = [ metadata.path for metadata in beam.io.filesystems.FileSystems.match([self._parquet_path + "*.parquet"])[ @@ -646,9 +654,15 @@ def finalize(self, metrics_query_result: dict): ].metadata_list ] try: # stream conversion - sources = [beam.io.filesystems.FileSystems.open(shard) for shard in shards] - with beam.io.filesystems.FileSystems.create(self._path) as dest: - parquet_to_arrow(sources, dest) + disable = not logging.is_progress_bar_enabled() + num_bytes = 0 + for shard in logging.tqdm(shards, unit="shards", disable=disable): + with beam.io.filesystems.FileSystems.open(shard) as source: + with beam.io.filesystems.FileSystems.create( + shard.replace(".parquet", ".arrow") + ) as destination: + shard_num_bytes, _ = parquet_to_arrow(source, destination) + num_bytes += shard_num_bytes except OSError as e: # broken pipe can happen if the connection is unstable, do local conversion instead if e.errno != errno.EPIPE: # not a broken pipe raise @@ -657,41 +671,41 @@ def finalize(self, metrics_query_result: dict): ) local_convert_dir = os.path.join(self._cache_dir, "beam_convert") os.makedirs(local_convert_dir, exist_ok=True) - local_arrow_path = os.path.join(local_convert_dir, hash_url_to_filename(self._parquet_path) + ".arrow") - local_shards = [] - for shard in shards: + disable = not logging.is_progress_bar_enabled() + num_bytes = 0 + for shard in logging.tqdm(shards, unit="shards", disable=disable): local_parquet_path = os.path.join(local_convert_dir, hash_url_to_filename(shard) + ".parquet") - local_shards.append(local_parquet_path) beam_utils.download_remote_to_local(shard, local_parquet_path) - parquet_to_arrow(local_shards, local_arrow_path) - beam_utils.upload_local_to_remote(local_arrow_path, self._path) - output_file_metadata = beam.io.filesystems.FileSystems.match([self._path], limits=[1])[0].metadata_list[0] - num_bytes = output_file_metadata.size_in_bytes - else: - num_bytes = sum( - [ - metadata.size_in_bytes - for metadata in beam.io.filesystems.FileSystems.match([self._parquet_path + "*.parquet"])[ - 0 - ].metadata_list - ] - ) + local_arrow_path = local_parquet_path.replace(".parquet", ".arrow") + shard_num_bytes, _ = parquet_to_arrow(local_parquet_path, local_arrow_path) + num_bytes += shard_num_bytes + remote_arrow_path = shard.replace(".parquet", ".arrow") + beam_utils.upload_local_to_remote(local_arrow_path, remote_arrow_path) # Save metrics counters_dict = {metric.key.metric.name: metric.result for metric in metrics_query_result["counters"]} self._num_examples = counters_dict["num_examples"] self._num_bytes = num_bytes + self._shard_lengths = shard_lengths return self._num_examples, self._num_bytes -def parquet_to_arrow(sources, destination): - """Convert parquet files to arrow file. Inputs can be str paths or file-like objects""" - stream = None if isinstance(destination, str) else destination +def get_parquet_lengths(sources) -> List[int]: + shard_lengths = [] disable = not logging.is_progress_bar_enabled() + for source in logging.tqdm(sources, unit="parquet files", disable=disable): + parquet_file = pa.parquet.ParquetFile(source) + shard_lengths.append(parquet_file.metadata.num_rows) + return shard_lengths + + +def parquet_to_arrow(source, destination) -> List[int]: + """Convert parquet file to arrow file. Inputs can be str paths or file-like objects""" + stream = None if isinstance(destination, str) else destination with ArrowWriter(path=destination, stream=stream) as writer: - for source in logging.tqdm(sources, unit="sources", disable=disable): - parquet_file = pa.parquet.ParquetFile(source) - for record_batch in parquet_file.iter_batches(): - pa_table = pa.Table.from_batches([record_batch]) - writer.write_table(pa_table) - return destination + parquet_file = pa.parquet.ParquetFile(source) + for record_batch in parquet_file.iter_batches(): + pa_table = pa.Table.from_batches([record_batch]) + writer.write_table(pa_table) + num_bytes, num_examples = writer.finalize() + return num_bytes, num_examples diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 65a698c9fbd..6ac6ad75570 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -23,14 +23,16 @@ import posixpath import shutil import textwrap +import time import urllib import warnings from dataclasses import dataclass from functools import partial from pathlib import Path -from typing import Dict, Mapping, Optional, Tuple, Union +from typing import Dict, Iterable, Mapping, Optional, Tuple, Union import fsspec +from multiprocess import Pool from tqdm.contrib.concurrent import thread_map from . import config, utils @@ -56,7 +58,7 @@ from .iterable_dataset import ExamplesIterable, IterableDataset, _generate_examples_from_tables_wrapper from .keyhash import DuplicatedKeysError from .naming import INVALID_WINDOWS_CHARACTERS_IN_PATH, camelcase_to_snakecase -from .splits import Split, SplitDict, SplitGenerator +from .splits import Split, SplitDict, SplitGenerator, SplitInfo from .streaming import extend_dataset_builder_for_streaming from .utils import logging from .utils.file_utils import cached_path, is_remote_url @@ -66,11 +68,13 @@ classproperty, convert_file_size_to_int, has_sufficient_disk_space, + iflatmap_unordered, map_nested, memoize, size_str, temporary_assignment, ) +from .utils.sharding import _number_of_shards_in_gen_kwargs, _split_gen_kwargs logger = logging.get_logger(__name__) @@ -605,6 +609,7 @@ def download_and_prepare( use_auth_token: Optional[Union[bool, str]] = None, file_format: str = "arrow", max_shard_size: Optional[Union[int, str]] = None, + num_proc: Optional[int] = None, storage_options: Optional[dict] = None, **download_and_prepare_kwargs, ): @@ -634,6 +639,10 @@ def download_and_prepare( so in practice your shard files may be smaller than `max_shard_size` thanks to Parquet compression. + num_proc (:obj:`int`, optional, default `None`): Number of processes when downloading and generating the dataset locally. + Multiprocessing is disabled by default. + + storage_options (:obj:`dict`, *optional*): Key/value pairs to be passed on to the caching file-system backend, if any. @@ -680,11 +689,6 @@ def download_and_prepare( if file_format is not None and file_format not in ["arrow", "parquet"]: raise ValueError(f"Unsupported file_format: {file_format}. Expected 'arrow' or 'parquet'") - if file_format == "arrow" and max_shard_size is not None: - raise NotImplementedError( - "Writing sharded arrow files is not supported. Please don't use max_shard_size or use file_format='parquet'." - ) - if self._fs._strip_protocol(self._output_dir) == "": # We don't support the root directory, because it has no dirname, # and we need a dirname to use a .incomplete directory @@ -806,11 +810,11 @@ def incomplete_dir(dirname): except ConnectionError: logger.warning("HF google storage unreachable. Downloading and preparing it from source") if not downloaded_from_gcs: - prepare_split_kwargs = { - "file_format": file_format, - "max_shard_size": max_shard_size, - **download_and_prepare_kwargs, - } + prepare_split_kwargs = {"file_format": file_format} + if max_shard_size is not None: + prepare_split_kwargs["max_shard_size"] = max_shard_size + if num_proc is not None: + prepare_split_kwargs["num_proc"] = num_proc self._download_and_prepare( dl_manager=dl_manager, verify_infos=verify_infos, @@ -1245,6 +1249,7 @@ def _prepare_split( split_generator: SplitGenerator, file_format: str = "arrow", max_shard_size: Optional[Union[str, int]] = None, + num_proc: Optional[int] = None, **kwargs, ): """Generate the examples and record them on disk. @@ -1256,6 +1261,10 @@ def _prepare_split( max_shard_size (:obj:`Union[str, int]`, optional): Approximate maximum number of bytes written per shard. Only available for the "parquet" format with a default of "500MB". The size is based on uncompressed data size, so in practice your shard files may be smaller than `max_shard_size` thanks to Parquet compression. + num_proc (:obj:`int`, optional, default `None`): Number of processes when downloading and generating the dataset locally. + Multiprocessing is disabled by default. + + **kwargs: Additional kwargs forwarded from _download_and_prepare (ex: beam pipeline) """ @@ -1331,63 +1340,183 @@ def _prepare_split( split_generator: SplitGenerator, check_duplicate_keys: bool, file_format="arrow", + num_proc: Optional[int] = None, max_shard_size: Optional[Union[int, str]] = None, ): + + max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE) is_local = not is_remote_filesystem(self._fs) path_join = os.path.join if is_local else posixpath.join - if file_format == "arrow": - if max_shard_size is not None: - raise NotImplementedError( - "Writing sharded arrow files is not supported. Please don't use max_shard_size or use file_format='parquet'." - ) - else: - max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE) - if self.info.splits is not None: split_info = self.info.splits[split_generator.name] else: split_info = split_generator.split_info - suffix = "-SSSSS-of-NNNNN" if file_format == "parquet" else "" - fname = f"{self.name}-{split_generator.name}{suffix}.{file_format}" + SUFFIX = "-JJJJJ-SSSSS-of-NNNNN" + fname = f"{self.name}-{split_generator.name}{SUFFIX}.{file_format}" fpath = path_join(self._output_dir, fname) - generator = self._generate_examples(**split_generator.gen_kwargs) + num_input_shards = _number_of_shards_in_gen_kwargs(split_generator.gen_kwargs) + if num_input_shards <= 1 and num_proc is not None: + logger.warning( + f"Setting num_proc from {num_proc} back to 1 for the {split_info.name} split to disable multiprocessing as it only contains one shard." + ) + num_proc = 1 + elif num_proc is not None and num_input_shards < num_proc: + logger.info( + f"Setting num_proc from {num_proc} to {num_input_shards} for the {split_info.name} split as it only contains {num_input_shards} shards." + ) + num_proc = num_input_shards + + pbar = logging.tqdm( + disable=not logging.is_progress_bar_enabled(), + unit=" examples", + total=split_info.num_examples, + leave=False, + desc=f"Generating {split_info.name} split", + ) + _prepare_split_args = { + "fpath": fpath, + "file_format": file_format, + "max_shard_size": max_shard_size, + "split_info": split_info, + "check_duplicate_keys": check_duplicate_keys, + } + + if num_proc is None or num_proc == 1: + result = None + gen_kwargs = split_generator.gen_kwargs + job_id = 0 + for job_id, done, content in self._prepare_split_single( + {"gen_kwargs": gen_kwargs, "job_id": job_id, **_prepare_split_args} + ): + if done: + result = content + else: + pbar.update(content) + # wrapping everything into lists for consistency with the multiprocessed code path + assert result is not None, "Failed to retrieve results from prepare_split" + examples_per_job, bytes_per_job, features_per_job, shards_per_job, shard_lengths_per_job = [ + [item] for item in result + ] + else: + args_per_job = [ + {"gen_kwargs": gen_kwargs, "job_id": job_id, **_prepare_split_args} + for job_id, gen_kwargs in enumerate( + _split_gen_kwargs(split_generator.gen_kwargs, max_num_jobs=num_proc) + ) + ] + num_jobs = len(args_per_job) + + examples_per_job = [None] * num_jobs + bytes_per_job = [None] * num_jobs + features_per_job = [None] * num_jobs + shards_per_job = [None] * num_jobs + shard_lengths_per_job = [None] * num_jobs + + with Pool(num_proc) as pool: + for job_id, done, content in iflatmap_unordered(pool, self._prepare_split_single, args_per_job): + if done: + # the content is the result of the job + ( + examples_per_job[job_id], + bytes_per_job[job_id], + features_per_job[job_id], + shards_per_job[job_id], + shard_lengths_per_job[job_id], + ) = content + else: + # the content is the number of examples progress update + pbar.update(content) + + assert ( + None not in examples_per_job + ), f"Failed to retrieve results from prepare_split: result list {examples_per_job} still contains None - at least one worker failed to return its results" + + total_shards = sum(shards_per_job) + total_num_examples = sum(examples_per_job) + total_num_bytes = sum(bytes_per_job) + features = features_per_job[0] + + split_generator.split_info.num_examples = total_num_examples + split_generator.split_info.num_bytes = total_num_bytes + + # should rename everything at the end + logger.debug(f"Renaming {total_shards} shards.") + if total_shards > 1: + # use the -SSSSS-of-NNNNN pattern + + def _rename_shard(shard_and_job: Tuple[int]): + shard_id, job_id = shard_and_job + global_shard_id = sum(shards_per_job[:job_id]) + shard_id + self._rename( + fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"), + fpath.replace("JJJJJ-SSSSS", f"{global_shard_id:05d}").replace("NNNNN", f"{total_shards:05d}"), + ) + + shards_and_jobs = [ + (shard_id, job_id) + for job_id, num_shards in enumerate(shards_per_job) + for shard_id in range(num_shards) + ] + thread_map(_rename_shard, shards_and_jobs, disable=True, max_workers=64) + + split_generator.split_info.shard_lengths = [ + shard_length for shard_lengths in shard_lengths_per_job for shard_length in shard_lengths + ] + else: + # don't use any pattern + shard_id, job_id = 0, 0 + self._rename( + fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"), + fpath.replace(SUFFIX, ""), + ) + + if self.info.features is None: + self.info.features = features + + def _prepare_split_single(self, arg: dict) -> Iterable[Tuple[int, bool, Union[int, tuple]]]: + gen_kwargs: dict = arg["gen_kwargs"] + fpath: str = arg["fpath"] + file_format: str = arg["file_format"] + max_shard_size: int = arg["max_shard_size"] + split_info: SplitInfo = arg["split_info"] + check_duplicate_keys: bool = arg["check_duplicate_keys"] + job_id: int = arg["job_id"] + refresh_rate = 0.05 # 20 progress updates per sec + + generator = self._generate_examples(**gen_kwargs) writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter embed_local_files = file_format == "parquet" + shard_lengths = [] + total_num_examples, total_num_bytes = 0, 0 shard_id = 0 - # TODO: embed the images/audio files inside parquet files. + num_examples_progress_update = 0 writer = writer_class( features=self.info.features, - path=fpath.replace("SSSSS", f"{shard_id:05d}"), + path=fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"), writer_batch_size=self._writer_batch_size, hash_salt=split_info.name, check_duplicates=check_duplicate_keys, storage_options=self._fs.storage_options, embed_local_files=embed_local_files, ) - total_num_examples, total_num_bytes = 0, 0 try: - for key, record in logging.tqdm( - generator, - unit=" examples", - total=split_info.num_examples, - leave=False, - disable=not logging.is_progress_bar_enabled(), - desc=f"Generating {split_info.name} split", - ): + _time = time.time() + for key, record in generator: if max_shard_size is not None and writer._num_bytes > max_shard_size: num_examples, num_bytes = writer.finalize() writer.close() + shard_lengths.append(num_examples) total_num_examples += num_examples total_num_bytes += num_bytes shard_id += 1 writer = writer_class( features=writer._features, - path=fpath.replace("SSSSS", f"{shard_id:05d}"), + path=fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"), writer_batch_size=self._writer_batch_size, hash_salt=split_info.name, check_duplicates=check_duplicate_keys, @@ -1396,28 +1525,21 @@ def _prepare_split( ) example = self.info.features.encode_example(record) if self.info.features is not None else record writer.write(example, key) + num_examples_progress_update += 1 + if time.time() > _time + refresh_rate: + _time = time.time() + yield job_id, False, num_examples_progress_update + num_examples_progress_update = 0 finally: + yield job_id, False, num_examples_progress_update num_shards = shard_id + 1 num_examples, num_bytes = writer.finalize() writer.close() + shard_lengths.append(num_examples) total_num_examples += num_examples total_num_bytes += num_bytes - if file_format == "parquet": - - def _rename_shard(shard_id: int): - self._rename( - fpath.replace("SSSSS", f"{shard_id:05d}"), - fpath.replace("SSSSS", f"{shard_id:05d}").replace("NNNNN", f"{num_shards:05d}"), - ) - - logger.debug(f"Renaming {num_shards} shards.") - thread_map(_rename_shard, range(num_shards), disable=True, max_workers=64) - - split_generator.split_info.num_examples = total_num_examples - split_generator.split_info.num_bytes = total_num_bytes - if self.info.features is None: - self.info.features = writer._features + yield job_id, True, (total_num_examples, total_num_bytes, writer._features, num_shards, shard_lengths) def _download_and_prepare(self, dl_manager, verify_infos, **prepare_splits_kwargs): super()._download_and_prepare( @@ -1467,79 +1589,196 @@ def _prepare_split( self, split_generator: SplitGenerator, file_format: str = "arrow", + num_proc: Optional[int] = None, max_shard_size: Optional[Union[str, int]] = None, ): + + max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE) is_local = not is_remote_filesystem(self._fs) path_join = os.path.join if is_local else posixpath.join - if file_format == "arrow": - if max_shard_size is not None: - raise NotImplementedError( - "Writing sharded arrow files is not supported. Please don't use max_shard_size or use file_format='parquet'." - ) + if self.info.splits is not None: + split_info = self.info.splits[split_generator.name] else: - max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE) + split_info = split_generator.split_info - suffix = "-SSSSS-of-NNNNN" if file_format == "parquet" else "" - fname = f"{self.name}-{split_generator.name}{suffix}.{file_format}" + SUFFIX = "-JJJJJ-SSSSS-of-NNNNN" + fname = f"{self.name}-{split_generator.name}{SUFFIX}.{file_format}" fpath = path_join(self._output_dir, fname) - generator = self._generate_tables(**split_generator.gen_kwargs) + num_input_shards = _number_of_shards_in_gen_kwargs(split_generator.gen_kwargs) + if num_input_shards <= 1 and num_proc is not None: + logger.warning( + f"Setting num_proc from {num_proc} back to 1 for the {split_info.name} split to disable multiprocessing as it only contains one shard." + ) + num_proc = 1 + elif num_proc is not None and num_input_shards < num_proc: + logger.info( + f"Setting num_proc from {num_proc} to {num_input_shards} for the {split_info.name} split as it only contains {num_input_shards} shards." + ) + num_proc = num_input_shards + + pbar = logging.tqdm( + disable=not logging.is_progress_bar_enabled(), + unit=" examples", + total=split_info.num_examples, + leave=False, + desc=f"Generating {split_info.name} split", + ) + + _prepare_split_args = { + "fpath": fpath, + "file_format": file_format, + "max_shard_size": max_shard_size, + "split_info": split_info, + } + + if num_proc is None or num_proc == 1: + result = None + gen_kwargs = split_generator.gen_kwargs + job_id = 0 + for job_id, done, content in self._prepare_split_single( + {"gen_kwargs": gen_kwargs, "job_id": job_id, **_prepare_split_args} + ): + if done: + result = content + else: + pbar.update(content) + # wrapping everything into lists for consistency with the multiprocessed code path + assert result is not None, "Failed to retrieve results from prepare_split" + examples_per_job, bytes_per_job, features_per_job, shards_per_job, shard_lengths_per_job = [ + [item] for item in result + ] + else: + args_per_job = [ + {"gen_kwargs": gen_kwargs, "job_id": job_id, **_prepare_split_args} + for job_id, gen_kwargs in enumerate( + _split_gen_kwargs(split_generator.gen_kwargs, max_num_jobs=num_proc) + ) + ] + num_jobs = len(args_per_job) + + examples_per_job = [None] * num_jobs + bytes_per_job = [None] * num_jobs + features_per_job = [None] * num_jobs + shards_per_job = [None] * num_jobs + shard_lengths_per_job = [None] * num_jobs + + with Pool(num_proc) as pool: + for job_id, done, content in iflatmap_unordered(pool, self._prepare_split_single, args_per_job): + if done: + # the content is the result of the job + ( + examples_per_job[job_id], + bytes_per_job[job_id], + features_per_job[job_id], + shards_per_job[job_id], + shard_lengths_per_job[job_id], + ) = content + else: + # the content is the number of examples progress update + pbar.update(content) + + assert ( + None not in examples_per_job + ), f"Failed to retrieve results from prepare_split: result list {examples_per_job} still contains None - at least one worker failed to return its results" + + total_shards = sum(shards_per_job) + total_num_examples = sum(examples_per_job) + total_num_bytes = sum(bytes_per_job) + features = features_per_job[0] + + split_generator.split_info.num_examples = total_num_examples + split_generator.split_info.num_bytes = total_num_bytes + + # should rename everything at the end + logger.debug(f"Renaming {total_shards} shards.") + if total_shards > 1: + # use the -SSSSS-of-NNNNN pattern + + def _rename_shard(shard_id_and_job: Tuple[int]): + shard_id, job_id = shard_id_and_job + global_shard_id = sum(shards_per_job[:job_id]) + shard_id + self._rename( + fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"), + fpath.replace("JJJJJ-SSSSS", f"{global_shard_id:05d}").replace("NNNNN", f"{total_shards:05d}"), + ) + + shard_ids_and_jobs = [ + (shard_id, job_id) + for job_id, num_shards in enumerate(shards_per_job) + for shard_id in range(num_shards) + ] + thread_map(_rename_shard, shard_ids_and_jobs, disable=True, max_workers=64) + + split_generator.split_info.shard_lengths = [ + shard_length for shard_lengths in shard_lengths_per_job for shard_length in shard_lengths + ] + else: + # don't use any pattern + shard_id, job_id = 0, 0 + self._rename( + fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"), + fpath.replace(SUFFIX, ""), + ) + + if self.info.features is None: + self.info.features = features + + def _prepare_split_single(self, arg: dict) -> Iterable[Tuple[int, bool, Union[int, tuple]]]: + gen_kwargs: dict = arg["gen_kwargs"] + fpath: str = arg["fpath"] + file_format: str = arg["file_format"] + max_shard_size: int = arg["max_shard_size"] + job_id: int = arg["job_id"] + refresh_rate = 0.05 # 20 progress updates per sec + generator = self._generate_tables(**gen_kwargs) writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter embed_local_files = file_format == "parquet" + shard_lengths = [] + total_num_examples, total_num_bytes = 0, 0 shard_id = 0 - # TODO: embed the images/audio files inside parquet files. + num_examples_progress_update = 0 writer = writer_class( features=self.info.features, - path=fpath.replace("SSSSS", f"{shard_id:05d}"), + path=fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"), storage_options=self._fs.storage_options, embed_local_files=embed_local_files, ) - total_num_examples, total_num_bytes = 0, 0 try: - for key, table in logging.tqdm( - generator, - unit=" tables", - leave=False, - disable=not logging.is_progress_bar_enabled(), - ): + _time = time.time() + for _, table in generator: if max_shard_size is not None and writer._num_bytes > max_shard_size: num_examples, num_bytes = writer.finalize() writer.close() + shard_lengths.append(num_examples) total_num_examples += num_examples total_num_bytes += num_bytes shard_id += 1 writer = writer_class( features=writer._features, - path=fpath.replace("SSSSS", f"{shard_id:05d}"), + path=fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"), storage_options=self._fs.storage_options, embed_local_files=embed_local_files, ) writer.write_table(table) + num_examples_progress_update += len(table) + if time.time() > _time + refresh_rate: + _time = time.time() + yield job_id, False, num_examples_progress_update + num_examples_progress_update = 0 finally: + yield job_id, False, num_examples_progress_update num_shards = shard_id + 1 num_examples, num_bytes = writer.finalize() writer.close() + shard_lengths.append(num_examples) total_num_examples += num_examples total_num_bytes += num_bytes - if file_format == "parquet": - - def _rename_shard(shard_id: int): - self._rename( - fpath.replace("SSSSS", f"{shard_id:05d}"), - fpath.replace("SSSSS", f"{shard_id:05d}").replace("NNNNN", f"{num_shards:05d}"), - ) - - logger.debug(f"Renaming {num_shards} shards.") - thread_map(_rename_shard, range(num_shards), disable=True, max_workers=64) - - split_generator.split_info.num_examples = total_num_examples - split_generator.split_info.num_bytes = total_num_bytes - if self.info.features is None: - self.info.features = writer._features + yield job_id, True, (total_num_examples, total_num_bytes, writer._features, num_shards, shard_lengths) def _get_examples_iterable_for_split(self, split_generator: SplitGenerator) -> ExamplesIterable: return ExamplesIterable( @@ -1634,11 +1873,18 @@ def _download_and_prepare(self, dl_manager, verify_infos, **prepare_splits_kwarg f"\n\t`{usage_example}`" ) - beam_options = beam_options or beam.options.pipeline_options.PipelineOptions() # Beam type checking assumes transforms multiple outputs are of same type, # which is not our case. Plus it doesn't handle correctly all types, so we # are better without it. - beam_options.view_as(beam.options.pipeline_options.TypeOptions).pipeline_type_check = False + pipeline_options = {"pipeline_type_check": False} + if "num_proc" in prepare_splits_kwargs: + num_workers = prepare_splits_kwargs.pop("num_proc") + pipeline_options["direct_num_workers"] = num_workers + pipeline_options["num_workers"] = num_workers + pipeline_options["direct_running_mode"] = "multi_processing" + # TODO: Fix ModuleNotFoundError: No module named 'datasets_modules' when running multiprocessed DirectRunner + raise NotImplementedError("Using a DirectRunner with `num_proc` for multiprocessing it not supported yet.") + beam_options = beam_options or beam.options.pipeline_options.PipelineOptions.from_dictionary(pipeline_options) # Use a single pipeline for all splits pipeline = beam_utils.BeamPipeline( runner=beam_runner, @@ -1659,6 +1905,18 @@ def _download_and_prepare(self, dl_manager, verify_infos, **prepare_splits_kwarg split_info = split_dict[split_name] split_info.num_examples = num_examples split_info.num_bytes = num_bytes + if hasattr(beam_writer, "_shard_lengths") and len(beam_writer._shard_lengths) > 1: + # keep the -SSSSS-of-NNNNN pattern + split_info.shard_lengths = beam_writer._shard_lengths + else: + # don't use any pattern + file_format = prepare_splits_kwargs.get("file_format", "arrow") + src_fname = f"{self.name}-{split_name}-00000-of-00001.{file_format}" + dst_fname = f"{self.name}-{split_name}.{file_format}" + path_join = os.path.join if not is_remote_filesystem(self._fs) else posixpath.join + src_fpath = path_join(self._output_dir, src_fname) + dst_fpath = path_join(self._output_dir, dst_fname) + self._rename(src_fpath, dst_fpath) def _save_info(self): import apache_beam as beam @@ -1705,4 +1963,4 @@ def _build_pcollection(pipeline): return beam_writer.write_from_pcollection(pcoll_examples) # Add the PCollection to the pipeline - _ = pipeline | split_name >> _build_pcollection() # pylint: disable=no-value-for-parameter + _ = pipeline | split_name >> _build_pcollection() # pylint: disable=no-value-for-parameter max_bytes_per_shard diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 1e626187186..37a0db769db 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -16,6 +16,7 @@ from .info import DatasetInfo from .splits import NamedSplit from .table import table_cast +from .utils.sharding import _number_of_shards_in_gen_kwargs, _shuffle_gen_kwargs, _split_gen_kwargs def _infer_features_from_batch(batch: Dict[str, list], try_features: Optional[Features] = None) -> Features: @@ -96,43 +97,6 @@ def n_shards(self) -> int: raise NotImplementedError(f"{type(self)} doesn't implement n_shards yet") -def _shuffle_kwargs(rng: np.random.Generator, kwargs: dict) -> dict: - """Return a shuffled copy of the input kwargs""" - # We must shuffle all the lists, and lists of the same size must have the same shuffling. - # This way entangled lists of (shard, shard_metadata) are still in the right order. - - # First, let's generate the shuffled indices per list size - list_sizes = set(len(value) for value in kwargs.values() if isinstance(value, list)) - indices_per_size = {} - for size in list_sizes: - indices_per_size[size] = list(range(size)) - rng.shuffle(indices_per_size[size]) - # Now let's copy the kwargs and shuffle the lists based on their sizes - shuffled_kwargs = dict(kwargs) - for key, value in shuffled_kwargs.items(): - if isinstance(value, list): - shuffled_kwargs[key] = [value[i] for i in indices_per_size[len(value)]] - return shuffled_kwargs - - -def _shard_kwargs(shard_idx: int, kwargs: dict) -> dict: - """Return a copy of the input kwargs but with only one shard""" - # Having lists of different sizes makes sharding ambigious, raise an error in this case - # until we decide how to define sharding without ambiguity for users - lists_lengths = {key: len(value) for key, value in kwargs.items() if isinstance(value, list)} - if len(set(lists_lengths.values())) > 1: - raise RuntimeError( - ( - "Sharding is ambiguous for this dataset: " - + "we found several data sources lists of different lengths, and we don't know over which list we should parallelize:\n" - + "\n".join(f"\t- key {key} has length {length}" for key, length in lists_lengths.items()) - + "\nTo fix this, check the dataset script 'gen_kwargs' and make sure to use lists only for data sources, " - + "and use tuples otherwise. In the end there should only be one single list, or several lists with the same length." - ) - ) - return {key: [value[shard_idx]] if isinstance(value, list) else value for key, value in kwargs.items()} - - class ExamplesIterable(_BaseExamplesIterable): def __init__(self, generate_examples_fn: Callable, kwargs: dict): self.generate_examples_fn = generate_examples_fn @@ -146,13 +110,12 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "ExamplesItera def shard_data_sources(self, shard_idx: int) -> "MappedExamplesIterable": """Keep only the requested shard.""" - kwargs_with_requested_data_source = _shard_kwargs(shard_idx, self.kwargs) + kwargs_with_requested_data_source = _split_gen_kwargs(self.kwargs, max_num_jobs=self.n_shards)[shard_idx] yield from self.generate_examples_fn(**kwargs_with_requested_data_source) @property def n_shards(self) -> int: - max_length = max((len(value) for value in self.kwargs.values() if isinstance(value, list)), default=0) - return max(1, max_length) + return _number_of_shards_in_gen_kwargs(self.kwargs) class ShardShuffledExamplesIterable(ExamplesIterable): @@ -163,14 +126,16 @@ def __init__(self, generate_examples_fn: Callable, kwargs: dict, generator: np.r def __iter__(self): """Shuffle the kwargs order to shuffle shards""" rng = deepcopy(self.generator) - kwargs_with_shuffled_shards = _shuffle_kwargs(rng, self.kwargs) + kwargs_with_shuffled_shards = _shuffle_gen_kwargs(rng, self.kwargs) yield from self.generate_examples_fn(**kwargs_with_shuffled_shards) def shard_data_sources(self, shard_idx: int) -> "MappedExamplesIterable": """Keep only the requested shard.""" rng = deepcopy(self.generator) - kwargs_with_shuffled_shards = _shuffle_kwargs(rng, self.kwargs) - kwargs_with_requested_data_source = _shard_kwargs(shard_idx, kwargs_with_shuffled_shards) + kwargs_with_shuffled_shards = _shuffle_gen_kwargs(rng, self.kwargs) + kwargs_with_requested_data_source = _split_gen_kwargs(kwargs_with_shuffled_shards, max_num_jobs=self.n_shards)[ + shard_idx + ] return ExamplesIterable(self.generate_examples_fn, kwargs_with_requested_data_source) diff --git a/src/datasets/load.py b/src/datasets/load.py index dc8754b1587..8c0ec9ef97c 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -1543,6 +1543,7 @@ def load_dataset( use_auth_token: Optional[Union[bool, str]] = None, task: Optional[Union[str, TaskTemplate]] = None, streaming: bool = False, + num_proc: int = None, **config_kwargs, ) -> Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset]: """Load a dataset from the Hugging Face Hub, or a local dataset. @@ -1632,6 +1633,10 @@ def load_dataset( Note that streaming works for datasets that use data formats that support being iterated over like txt, csv, jsonl for example. Json files may be downloaded completely. Also streaming from remote zip or gzip files is supported but other compressed formats like rar and xz are not yet supported. The tgz format doesn't allow streaming. + num_proc (:obj:`int`, optional, default `None`): Number of processes when downloading and generating the dataset locally. + Multiprocessing is disabled by default. + + **config_kwargs (additional keyword arguments): Keyword arguments to be passed to the :class:`BuilderConfig` and used in the :class:`DatasetBuilder`. @@ -1700,6 +1705,12 @@ def load_dataset( "Please use `load_from_disk` instead." ) + if streaming and num_proc is not None: + raise NotImplementedError( + "Loading a streaming dataset in parallel with `num_proc` is not implemented. " + "To parallelize streaming, you can wrap the dataset with a PyTorch DataLoader using `num_workers` > 1 instead." + ) + download_mode = DownloadMode(download_mode or DownloadMode.REUSE_DATASET_IF_EXISTS) ignore_verifications = ignore_verifications or save_infos @@ -1733,6 +1744,7 @@ def load_dataset( ignore_verifications=ignore_verifications, try_from_hf_gcs=try_from_hf_gcs, use_auth_token=use_auth_token, + num_proc=num_proc, ) # Build dataset for splits diff --git a/src/datasets/naming.py b/src/datasets/naming.py index 09f9977ae78..2bfd8d82694 100644 --- a/src/datasets/naming.py +++ b/src/datasets/naming.py @@ -14,7 +14,6 @@ # Lint as: python3 """Utilities for file names.""" - import itertools import os import re @@ -67,18 +66,19 @@ def filepattern_for_dataset_split(dataset_name, split, data_dir, filetype_suffix return f"{filepath}*" -def filename_for_dataset_split(dataset_name, split, filetype_suffix=None): - prefix = filename_prefix_for_split(dataset_name, split) - if filetype_suffix: - prefix += f".{filetype_suffix}" - return prefix - +def filenames_for_dataset_split(path, dataset_name, split, filetype_suffix=None, shard_lengths=None): -def filepath_for_dataset_split(dataset_name, split, data_dir, filetype_suffix=None): - filename = filename_for_dataset_split( - dataset_name=dataset_name, - split=split, - filetype_suffix=filetype_suffix, - ) - filepath = os.path.join(data_dir, filename) - return filepath + prefix = filename_prefix_for_split(dataset_name, split) + prefix = os.path.join(path, prefix) + + if shard_lengths: + num_shards = len(shard_lengths) + filenames = [f"{prefix}-{shard_id:05d}-of-{num_shards:05d}" for shard_id in range(num_shards)] + if filetype_suffix: + filenames = [filename + f".{filetype_suffix}" for filename in filenames] + return filenames + else: + filename = prefix + if filetype_suffix: + filename += f".{filetype_suffix}" + return [filename] diff --git a/src/datasets/splits.py b/src/datasets/splits.py index 3a934a097b8..f1ea57ee6a5 100644 --- a/src/datasets/splits.py +++ b/src/datasets/splits.py @@ -34,6 +34,7 @@ class SplitInfo: name: str = "" num_bytes: int = 0 num_examples: int = 0 + shard_lengths: Optional[List[int]] = None # Deprecated # For backward compatibility, this field needs to always be included in files like @@ -579,6 +580,9 @@ def copy(self): def _to_yaml_list(self) -> list: out = [asdict(s) for s in self.to_split_dict()] + # we don't need the shard lengths in YAML, since it depends on max_shard_size and num_proc + for split_info_dict in out: + split_info_dict.pop("shard_lengths", None) # we don't need the dataset_name attribute that is deprecated for split_info_dict in out: split_info_dict.pop("dataset_name", None) diff --git a/src/datasets/utils/py_utils.py b/src/datasets/utils/py_utils.py index 9e9d2226e39..ee33b358abf 100644 --- a/src/datasets/utils/py_utils.py +++ b/src/datasets/utils/py_utils.py @@ -20,19 +20,24 @@ import copy import functools import itertools +import multiprocessing.pool import os +import queue import re import types from contextlib import contextmanager from dataclasses import fields, is_dataclass from io import BytesIO as StringIO -from multiprocessing import Pool, RLock +from multiprocessing import Manager, Pool, RLock +from queue import Empty from shutil import disk_usage from types import CodeType, FunctionType -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union from urllib.parse import urlparse import dill +import multiprocess +import multiprocess.pool import numpy as np from packaging import version from tqdm.auto import tqdm @@ -1328,3 +1333,32 @@ def copyfunc(func): result = types.FunctionType(func.__code__, func.__globals__, func.__name__, func.__defaults__, func.__closure__) result.__kwdefaults__ = func.__kwdefaults__ return result + + +X = TypeVar("X") +Y = TypeVar("Y") + + +def _write_generator_to_queue(queue: queue.Queue, func: Callable[[X], Iterable[Y]], arg: X) -> int: + for i, result in enumerate(func(arg)): + queue.put(result) + return i + + +def iflatmap_unordered( + pool: Union[multiprocessing.pool.Pool, multiprocess.pool.Pool], + func: Callable[[X], Iterable[Y]], + iterable: Iterable[X], +) -> Iterable[Y]: + manager_cls = Manager if isinstance(pool, multiprocessing.pool.Pool) else multiprocess.Manager + with manager_cls() as manager: + queue = manager.Queue() + async_results = [pool.apply_async(_write_generator_to_queue, (queue, func, arg)) for arg in iterable] + while True: + try: + yield queue.get(timeout=0.05) + except Empty: + if all(async_result.ready() for async_result in async_results) and queue.empty(): + break + # we get the result in case there's an error to raise + [async_result.get() for async_result in async_results] diff --git a/src/datasets/utils/sharding.py b/src/datasets/utils/sharding.py new file mode 100644 index 00000000000..52cc0fe04e9 --- /dev/null +++ b/src/datasets/utils/sharding.py @@ -0,0 +1,87 @@ +from typing import List + +import numpy as np + + +def _number_of_shards_in_gen_kwargs(gen_kwargs: dict) -> int: + """Return the number of possible shards according to the input gen_kwargs""" + # Having lists of different sizes makes sharding ambigious, raise an error in this case + # until we decide how to define sharding without ambiguity for users + lists_lengths = {key: len(value) for key, value in gen_kwargs.items() if isinstance(value, list)} + if len(set(lists_lengths.values())) > 1: + raise RuntimeError( + ( + "Sharding is ambiguous for this dataset: " + + "we found several data sources lists of different lengths, and we don't know over which list we should parallelize:\n" + + "\n".join(f"\t- key {key} has length {length}" for key, length in lists_lengths.items()) + + "\nTo fix this, check the 'gen_kwargs' and make sure to use lists only for data sources, " + + "and use tuples otherwise. In the end there should only be one single list, or several lists with the same length." + ) + ) + max_length = max(lists_lengths.values(), default=0) + return max(1, max_length) + + +def _distribute_shards(num_shards: int, max_num_jobs: int) -> List[range]: + """ + Get the range of shard indices per job. + If num_shards>> _distribute_shards(2, max_num_jobs=4) + [range(0, 1), range(1, 2)] + >>> _distribute_shards(10, max_num_jobs=3) + [range(0, 4), range(4, 7), range(7, 10)] + ``` + """ + shards_indices_per_group = [] + for group_idx in range(max_num_jobs): + num_shards_to_add = num_shards // max_num_jobs + (group_idx < (num_shards % max_num_jobs)) + if num_shards_to_add == 0: + break + start = shards_indices_per_group[-1].stop if shards_indices_per_group else 0 + shard_indices = range(start, start + num_shards_to_add) + shards_indices_per_group.append(shard_indices) + return shards_indices_per_group + + +def _split_gen_kwargs(gen_kwargs: dict, max_num_jobs: int) -> List[dict]: + """Split the gen_kwargs into `max_num_job` gen_kwargs""" + # Having lists of different sizes makes sharding ambigious, raise an error in this case + num_shards = _number_of_shards_in_gen_kwargs(gen_kwargs) + if num_shards == 1: + return [dict(gen_kwargs)] + else: + shard_indices_per_group = _distribute_shards(num_shards=num_shards, max_num_jobs=max_num_jobs) + return [ + { + key: [value[shard_idx] for shard_idx in shard_indices_per_group[group_idx]] + if isinstance(value, list) + else value + for key, value in gen_kwargs.items() + } + for group_idx in range(len(shard_indices_per_group)) + ] + + +def _shuffle_gen_kwargs(rng: np.random.Generator, gen_kwargs: dict) -> dict: + """Return a shuffled copy of the input gen_kwargs""" + # We must shuffle all the lists, and lists of the same size must have the same shuffling. + # This way entangled lists of (shard, shard_metadata) are still in the right order. + + # First, let's generate the shuffled indices per list size + list_sizes = set(len(value) for value in gen_kwargs.values() if isinstance(value, list)) + indices_per_size = {} + for size in list_sizes: + indices_per_size[size] = list(range(size)) + rng.shuffle(indices_per_size[size]) + # Now let's copy the gen_kwargs and shuffle the lists based on their sizes + shuffled_kwargs = dict(gen_kwargs) + for key, value in shuffled_kwargs.items(): + if isinstance(value, list): + shuffled_kwargs[key] = [value[i] for i in indices_per_size[len(value)]] + return shuffled_kwargs diff --git a/tests/test_arrow_reader.py b/tests/test_arrow_reader.py index 9c935c43542..580d58ca5f5 100644 --- a/tests/test_arrow_reader.py +++ b/tests/test_arrow_reader.py @@ -7,7 +7,7 @@ import pytest from datasets.arrow_dataset import Dataset -from datasets.arrow_reader import ArrowReader, BaseReader, ReadInstruction +from datasets.arrow_reader import ArrowReader, BaseReader, FileInstructions, ReadInstruction, make_file_instructions from datasets.info import DatasetInfo from datasets.splits import NamedSplit, Split, SplitDict, SplitInfo @@ -29,6 +29,8 @@ def _get_table_from_filename(self, filename_skip_take, in_memory=False): ) open(os.path.join(filename), "wb").close() pa_table = pa.Table.from_pydict({"filename": [Path(filename).name] * 100}) + if take == -1: + take = len(pa_table) - skip if skip is not None and take is not None: pa_table = pa_table.slice(skip, take) return pa_table @@ -71,6 +73,24 @@ def test_read(self): self.assertEqual(str(test_dset.split), "test[:33%]") del train_dset, test_dset + def test_read_sharded(self): + name = "my_name" + train_info = SplitInfo(name="train", num_examples=1000, shard_lengths=[100] * 10) + split_infos = [train_info] + split_dict = SplitDict() + split_dict.add(train_info) + info = DatasetInfo(splits=split_dict) + + with tempfile.TemporaryDirectory() as tmp_dir: + reader = ReaderTest(tmp_dir, info) + + instructions = "train[:33%]" + dset = Dataset(**reader.read(name, instructions, split_infos)) + self.assertEqual(dset["filename"][0], f"{name}-train-00000-of-00010") + self.assertEqual(dset["filename"][-1], f"{name}-train-00003-of-00010") + self.assertEqual(dset.num_rows, 330) + self.assertEqual(dset.num_columns, 1) + def test_read_files(self): train_info = SplitInfo(name="train", num_examples=100) test_info = SplitInfo(name="test", num_examples=100) @@ -136,3 +156,29 @@ def test_read_instruction_spec(): spec_train_test_pct_rounding = "train[:10%](pct1_dropremainder)+test[-10%:](pct1_dropremainder)" assert ReadInstruction.from_spec(spec_train_test_pct_rounding).to_spec() == spec_train_test_pct_rounding + + +def test_make_file_instructions(): + name = "dummy" + split_infos = [SplitInfo(name="train", num_examples=100)] + instruction = "train[:33%]" + filetype_suffix = "arrow" + prefix_path = "prefix" + + file_instructions = make_file_instructions(name, split_infos, instruction, filetype_suffix, prefix_path) + assert isinstance(file_instructions, FileInstructions) + assert file_instructions.num_examples == 33 + assert file_instructions.file_instructions == [ + {"filename": os.path.join(prefix_path, f"{name}-train.arrow"), "skip": 0, "take": 33} + ] + + split_infos = [SplitInfo(name="train", num_examples=100, shard_lengths=[10] * 10)] + file_instructions = make_file_instructions(name, split_infos, instruction, filetype_suffix, prefix_path) + assert isinstance(file_instructions, FileInstructions) + assert file_instructions.num_examples == 33 + assert file_instructions.file_instructions == [ + {"filename": os.path.join(prefix_path, f"{name}-train-00000-of-00010.arrow"), "skip": 0, "take": -1}, + {"filename": os.path.join(prefix_path, f"{name}-train-00001-of-00010.arrow"), "skip": 0, "take": -1}, + {"filename": os.path.join(prefix_path, f"{name}-train-00002-of-00010.arrow"), "skip": 0, "take": -1}, + {"filename": os.path.join(prefix_path, f"{name}-train-00003-of-00010.arrow"), "skip": 0, "take": 3}, + ] diff --git a/tests/test_beam.py b/tests/test_beam.py index 63158243d31..999262b58d1 100644 --- a/tests/test_beam.py +++ b/tests/test_beam.py @@ -1,6 +1,8 @@ import os import tempfile +from functools import partial from unittest import TestCase +from unittest.mock import patch import datasets import datasets.config @@ -81,6 +83,43 @@ def test_download_and_prepare(self): ) del dset + @require_beam + def test_download_and_prepare_sharded(self): + import apache_beam as beam + + original_write_parquet = beam.io.parquetio.WriteToParquet + + expected_num_examples = len(get_test_dummy_examples()) + with tempfile.TemporaryDirectory() as tmp_cache_dir: + builder = DummyBeamDataset(cache_dir=tmp_cache_dir, beam_runner="DirectRunner") + with patch("apache_beam.io.parquetio.WriteToParquet") as write_parquet_mock: + write_parquet_mock.side_effect = partial(original_write_parquet, num_shards=2) + builder.download_and_prepare() + self.assertTrue( + os.path.exists( + os.path.join( + tmp_cache_dir, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-00002.arrow" + ) + ) + ) + self.assertTrue( + os.path.exists( + os.path.join( + tmp_cache_dir, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-00002.arrow" + ) + ) + ) + self.assertDictEqual(builder.info.features, datasets.Features({"content": datasets.Value("string")})) + dset = builder.as_dataset() + self.assertEqual(dset["train"].num_rows, expected_num_examples) + self.assertEqual(dset["train"].info.splits["train"].num_examples, expected_num_examples) + # Order is not preserved when sharding, so we just check that all the elements are there + self.assertListEqual(sorted(dset["train"]["content"]), sorted(["foo", "bar", "foobar"])) + self.assertTrue( + os.path.exists(os.path.join(tmp_cache_dir, builder.name, "default", "0.0.0", "dataset_info.json")) + ) + del dset + @require_beam def test_no_beam_options(self): with tempfile.TemporaryDirectory() as tmp_cache_dir: diff --git a/tests/test_builder.py b/tests/test_builder.py index 20237a2b81d..e55c0a0de58 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -164,6 +164,36 @@ def _split_generators(self, dl_manager): return [SplitGenerator(name=Split.TRAIN)] +class DummyArrowBasedBuilderWithShards(ArrowBasedBuilder): + def _info(self): + return DatasetInfo(features=Features({"id": Value("int8"), "filepath": Value("string")})) + + def _split_generators(self, dl_manager): + return [SplitGenerator(name=Split.TRAIN, gen_kwargs={"filepaths": [f"data{i}.txt" for i in range(4)]})] + + def _generate_tables(self, filepaths): + idx = 0 + for filepath in filepaths: + for i in range(10): + yield idx, pa.table({"id": range(10 * i, 10 * (i + 1)), "filepath": [filepath] * 10}) + idx += 1 + + +class DummyGeneratorBasedBuilderWithShards(GeneratorBasedBuilder): + def _info(self): + return DatasetInfo(features=Features({"id": Value("int8"), "filepath": Value("string")})) + + def _split_generators(self, dl_manager): + return [SplitGenerator(name=Split.TRAIN, gen_kwargs={"filepaths": [f"data{i}.txt" for i in range(4)]})] + + def _generate_examples(self, filepaths): + idx = 0 + for filepath in filepaths: + for i in range(100): + yield idx, {"id": i, "filepath": filepath} + idx += 1 + + def _run_concurrent_download_and_prepare(tmp_dir): builder = DummyBuilder(cache_dir=tmp_dir) builder.download_and_prepare(try_from_hf_gcs=False, download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS) @@ -741,7 +771,7 @@ def test_arrow_based_download_and_prepare(tmp_path): ) ) assert builder.info.features, Features({"text": Value("string")}) - assert builder.info.splits["train"].num_examples, 100 + assert builder.info.splits["train"].num_examples == 100 assert os.path.exists(os.path.join(tmp_path, builder.name, "default", "0.0.0", "dataset_info.json")) @@ -759,7 +789,7 @@ def test_beam_based_download_and_prepare(tmp_path): ) ) assert builder.info.features, Features({"text": Value("string")}) - assert builder.info.splits["train"].num_examples, 100 + assert builder.info.splits["train"].num_examples == 100 assert os.path.exists(os.path.join(tmp_path, builder.name, "default", "0.0.0", "dataset_info.json")) @@ -977,21 +1007,19 @@ def test_builder_with_filesystem_download_and_prepare_reload(tmp_path, mockfs, c def test_generator_based_builder_download_and_prepare_as_parquet(tmp_path): builder = DummyGeneratorBasedBuilder(cache_dir=tmp_path) builder.download_and_prepare(file_format="parquet") - assert builder.info.splits["train"].num_examples, 100 - parquet_path = os.path.join( - tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-00001.parquet" - ) + assert builder.info.splits["train"].num_examples == 100 + parquet_path = os.path.join(tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train.parquet") assert os.path.exists(parquet_path) assert pq.ParquetFile(parquet_path) is not None -def test_generator_based_builder_download_and_prepare_as_sharded_parquet(tmp_path): +def test_generator_based_builder_download_and_prepare_sharded(tmp_path): writer_batch_size = 25 builder = DummyGeneratorBasedBuilder(cache_dir=tmp_path, writer_batch_size=writer_batch_size) with patch("datasets.config.MAX_SHARD_SIZE", 1): # one batch per shard builder.download_and_prepare(file_format="parquet") expected_num_shards = 100 // writer_batch_size - assert builder.info.splits["train"].num_examples, 100 + assert builder.info.splits["train"].num_examples == 100 parquet_path = os.path.join( tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-{expected_num_shards:05d}.parquet" ) @@ -1004,12 +1032,12 @@ def test_generator_based_builder_download_and_prepare_as_sharded_parquet(tmp_pat assert sum(parquet_file.metadata.num_rows for parquet_file in parquet_files) == 100 -def test_generator_based_builder_download_and_prepare_as_sharded_parquet_with_max_shard_size(tmp_path): +def test_generator_based_builder_download_and_prepare_with_max_shard_size(tmp_path): writer_batch_size = 25 builder = DummyGeneratorBasedBuilder(cache_dir=tmp_path, writer_batch_size=writer_batch_size) builder.download_and_prepare(file_format="parquet", max_shard_size=1) # one batch per shard expected_num_shards = 100 // writer_batch_size - assert builder.info.splits["train"].num_examples, 100 + assert builder.info.splits["train"].num_examples == 100 parquet_path = os.path.join( tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-{expected_num_shards:05d}.parquet" ) @@ -1022,23 +1050,39 @@ def test_generator_based_builder_download_and_prepare_as_sharded_parquet_with_ma assert sum(parquet_file.metadata.num_rows for parquet_file in parquet_files) == 100 +def test_generator_based_builder_download_and_prepare_with_num_proc(tmp_path): + builder = DummyGeneratorBasedBuilderWithShards(cache_dir=tmp_path) + builder.download_and_prepare(num_proc=2) + expected_num_shards = 2 + assert builder.info.splits["train"].num_examples == 400 + assert builder.info.splits["train"].shard_lengths == [200, 200] + arrow_path = os.path.join( + tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-{expected_num_shards:05d}.arrow" + ) + assert os.path.exists(arrow_path) + ds = builder.as_dataset("train") + assert len(ds) == 400 + assert ds.to_dict() == { + "id": [i for _ in range(4) for i in range(100)], + "filepath": [f"data{i}.txt" for i in range(4) for _ in range(100)], + } + + def test_arrow_based_builder_download_and_prepare_as_parquet(tmp_path): builder = DummyArrowBasedBuilder(cache_dir=tmp_path) builder.download_and_prepare(file_format="parquet") - assert builder.info.splits["train"].num_examples, 100 - parquet_path = os.path.join( - tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-00001.parquet" - ) + assert builder.info.splits["train"].num_examples == 100 + parquet_path = os.path.join(tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train.parquet") assert os.path.exists(parquet_path) assert pq.ParquetFile(parquet_path) is not None -def test_arrow_based_builder_download_and_prepare_as_sharded_parquet(tmp_path): +def test_arrow_based_builder_download_and_prepare_sharded(tmp_path): builder = DummyArrowBasedBuilder(cache_dir=tmp_path) with patch("datasets.config.MAX_SHARD_SIZE", 1): # one batch per shard builder.download_and_prepare(file_format="parquet") expected_num_shards = 10 - assert builder.info.splits["train"].num_examples, 100 + assert builder.info.splits["train"].num_examples == 100 parquet_path = os.path.join( tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-{expected_num_shards:05d}.parquet" ) @@ -1051,11 +1095,11 @@ def test_arrow_based_builder_download_and_prepare_as_sharded_parquet(tmp_path): assert sum(parquet_file.metadata.num_rows for parquet_file in parquet_files) == 100 -def test_arrow_based_builder_download_and_prepare_as_sharded_parquet_with_max_shard_size(tmp_path): +def test_arrow_based_builder_download_and_prepare_with_max_shard_size(tmp_path): builder = DummyArrowBasedBuilder(cache_dir=tmp_path) builder.download_and_prepare(file_format="parquet", max_shard_size=1) # one table per shard expected_num_shards = 10 - assert builder.info.splits["train"].num_examples, 100 + assert builder.info.splits["train"].num_examples == 100 parquet_path = os.path.join( tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-{expected_num_shards:05d}.parquet" ) @@ -1068,13 +1112,29 @@ def test_arrow_based_builder_download_and_prepare_as_sharded_parquet_with_max_sh assert sum(parquet_file.metadata.num_rows for parquet_file in parquet_files) == 100 +def test_arrow_based_builder_download_and_prepare_with_num_proc(tmp_path): + builder = DummyArrowBasedBuilderWithShards(cache_dir=tmp_path) + builder.download_and_prepare(num_proc=2) + expected_num_shards = 2 + assert builder.info.splits["train"].num_examples == 400 + assert builder.info.splits["train"].shard_lengths == [200, 200] + arrow_path = os.path.join( + tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-{expected_num_shards:05d}.arrow" + ) + assert os.path.exists(arrow_path) + ds = builder.as_dataset("train") + assert len(ds) == 400 + assert ds.to_dict() == { + "id": [i for _ in range(4) for i in range(100)], + "filepath": [f"data{i}.txt" for i in range(4) for _ in range(100)], + } + + @require_beam def test_beam_based_builder_download_and_prepare_as_parquet(tmp_path): builder = DummyBeamBasedBuilder(cache_dir=tmp_path, beam_runner="DirectRunner") builder.download_and_prepare(file_format="parquet") - assert builder.info.splits["train"].num_examples, 100 - parquet_path = os.path.join( - tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-00001.parquet" - ) + assert builder.info.splits["train"].num_examples == 100 + parquet_path = os.path.join(tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train.parquet") assert os.path.exists(parquet_path) assert pq.ParquetFile(parquet_path) is not None diff --git a/tests/test_py_utils.py b/tests/test_py_utils.py index 7769fe46d63..f0b391fbf5c 100644 --- a/tests/test_py_utils.py +++ b/tests/test_py_utils.py @@ -1,11 +1,22 @@ +import time from dataclasses import dataclass +from multiprocessing import Pool from unittest import TestCase from unittest.mock import patch +import multiprocess import numpy as np import pytest -from datasets.utils.py_utils import NestedDataStructure, asdict, map_nested, temp_seed, temporary_assignment, zip_dict +from datasets.utils.py_utils import ( + NestedDataStructure, + asdict, + iflatmap_unordered, + map_nested, + temp_seed, + temporary_assignment, + zip_dict, +) from .utils import require_tf, require_torch @@ -227,3 +238,35 @@ def test_asdict(): with pytest.raises(TypeError): asdict([1, A(x=10, y="foo")]) + + +def _2seconds_generator_of_2items_with_timing(content): + yield (time.time(), content) + time.sleep(2) + yield (time.time(), content) + + +def test_iflatmap_unordered(): + + with Pool(2) as pool: + out = list(iflatmap_unordered(pool, str.split, ["hello there"] * 10)) + assert out.count("hello") == 10 + assert out.count("there") == 10 + assert len(out) == 20 + + # check multiprocess from pathos (uses dill for pickling) + with multiprocess.Pool(2) as pool: + out = list(iflatmap_unordered(pool, str.split, ["hello there"] * 10)) + assert out.count("hello") == 10 + assert out.count("there") == 10 + assert len(out) == 20 + + # check that we get items as fast as possible + with Pool(2) as pool: + out = [] + for yield_time, content in iflatmap_unordered(pool, _2seconds_generator_of_2items_with_timing, ["a", "b"]): + assert yield_time < time.time() + 0.1, "we should each item directly after it was yielded" + out.append(content) + assert out.count("a") == 2 + assert out.count("b") == 2 + assert len(out) == 4 diff --git a/tests/test_sharding_utils.py b/tests/test_sharding_utils.py new file mode 100644 index 00000000000..51c83cb478c --- /dev/null +++ b/tests/test_sharding_utils.py @@ -0,0 +1,54 @@ +import pytest + +from datasets.utils.sharding import _distribute_shards, _number_of_shards_in_gen_kwargs, _split_gen_kwargs + + +@pytest.mark.parametrize( + "kwargs, expected", + [ + ({"num_shards": 0, "max_num_jobs": 1}, []), + ({"num_shards": 10, "max_num_jobs": 1}, [range(10)]), + ({"num_shards": 10, "max_num_jobs": 10}, [range(i, i + 1) for i in range(10)]), + ({"num_shards": 1, "max_num_jobs": 10}, [range(1)]), + ({"num_shards": 10, "max_num_jobs": 3}, [range(0, 4), range(4, 7), range(7, 10)]), + ({"num_shards": 3, "max_num_jobs": 10}, [range(0, 1), range(1, 2), range(2, 3)]), + ], +) +def test_distribute_shards(kwargs, expected): + out = _distribute_shards(**kwargs) + assert out == expected + + +@pytest.mark.parametrize( + "gen_kwargs, max_num_jobs, expected", + [ + ({"foo": 0}, 10, [{"foo": 0}]), + ({"shards": [0, 1, 2, 3]}, 1, [{"shards": [0, 1, 2, 3]}]), + ({"shards": [0, 1, 2, 3]}, 4, [{"shards": [0]}, {"shards": [1]}, {"shards": [2]}, {"shards": [3]}]), + ({"shards": [0, 1]}, 4, [{"shards": [0]}, {"shards": [1]}]), + ({"shards": [0, 1, 2, 3]}, 2, [{"shards": [0, 1]}, {"shards": [2, 3]}]), + ], +) +def test_split_gen_kwargs(gen_kwargs, max_num_jobs, expected): + out = _split_gen_kwargs(gen_kwargs, max_num_jobs) + assert out == expected + + +@pytest.mark.parametrize( + "gen_kwargs, expected", + [ + ({"foo": 0}, 1), + ({"shards": [0]}, 1), + ({"shards": [0, 1, 2, 3]}, 4), + ({"shards": [0, 1, 2, 3], "foo": 0}, 4), + ({"shards": [0, 1, 2, 3], "other": (0, 1)}, 4), + ({"shards": [0, 1, 2, 3], "shards2": [0, 1]}, RuntimeError), + ], +) +def test_number_of_shards_in_gen_kwargs(gen_kwargs, expected): + if expected is RuntimeError: + with pytest.raises(expected): + _number_of_shards_in_gen_kwargs(gen_kwargs) + else: + out = _number_of_shards_in_gen_kwargs(gen_kwargs) + assert out == expected