-
Notifications
You must be signed in to change notification settings - Fork 3k
Multiprocessed dataset builder #5107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
a802ba5
ea56329
9536184
31d8395
9c5843a
328112e
9dc8539
21a603a
55cb365
94efbdb
bac2b2f
3e4f337
b2f634d
296302f
9b312d4
d2e70f2
e3a30fa
cf6fd25
3e5d0cc
09c13a7
a2e83d5
e3bc7a7
e8923e2
ef9c7f1
088dbb1
eb1fc58
e035339
06c5d33
e50ec74
525c829
eae6491
93f355d
b321c61
b05e551
020eb89
142f822
f22c162
08b8626
4ce2d12
e05ad83
c621cb6
22d965e
dc0ef15
95cdd0b
12d69f3
db45b3b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,6 +31,7 @@ | |
| from typing import Dict, Mapping, Optional, Tuple, Union | ||
|
|
||
| import fsspec | ||
| from multiprocess import Pool | ||
| from tqdm.contrib.concurrent import thread_map | ||
|
|
||
| from . import config, utils | ||
|
|
@@ -53,7 +54,13 @@ | |
| from .filesystems import is_remote_filesystem | ||
| from .fingerprint import Hasher | ||
| from .info import DatasetInfo, DatasetInfosDict, PostProcessedInfo | ||
| from .iterable_dataset import ExamplesIterable, IterableDataset, _generate_examples_from_tables_wrapper | ||
| from .iterable_dataset import ( | ||
| ExamplesIterable, | ||
| IterableDataset, | ||
| _all_shard_kwargs, | ||
| _generate_examples_from_tables_wrapper, | ||
| _shard_number, | ||
| ) | ||
| from .keyhash import DuplicatedKeysError | ||
| from .naming import INVALID_WINDOWS_CHARACTERS_IN_PATH, camelcase_to_snakecase | ||
| from .splits import Split, SplitDict, SplitGenerator | ||
|
|
@@ -605,6 +612,7 @@ def download_and_prepare( | |
| use_auth_token: Optional[Union[bool, str]] = None, | ||
| file_format: str = "arrow", | ||
| max_shard_size: Optional[Union[int, str]] = None, | ||
| num_proc=None, | ||
| storage_options: Optional[dict] = None, | ||
| **download_and_prepare_kwargs, | ||
| ): | ||
|
|
@@ -680,11 +688,6 @@ def download_and_prepare( | |
| if file_format is not None and file_format not in ["arrow", "parquet"]: | ||
| raise ValueError(f"Unsupported file_format: {file_format}. Expected 'arrow' or 'parquet'") | ||
|
|
||
| if file_format == "arrow" and max_shard_size is not None: | ||
| raise NotImplementedError( | ||
| "Writing sharded arrow files is not supported. Please don't use max_shard_size or use file_format='parquet'." | ||
| ) | ||
|
|
||
| if self._fs._strip_protocol(self._output_dir) == "": | ||
| # We don't support the root directory, because it has no dirname, | ||
| # and we need a dirname to use a <dirname>.incomplete directory | ||
|
|
@@ -809,6 +812,7 @@ def incomplete_dir(dirname): | |
| prepare_split_kwargs = { | ||
| "file_format": file_format, | ||
| "max_shard_size": max_shard_size, | ||
| "num_proc": num_proc, | ||
| **download_and_prepare_kwargs, | ||
| } | ||
| self._download_and_prepare( | ||
|
|
@@ -865,7 +869,7 @@ def _download_prepared_from_hf_gcs(self, download_config: DownloadConfig): | |
| logger.info(f"Couldn't download resourse file {resource_file_name} from Hf google storage.") | ||
| logger.info("Dataset downloaded from Hf google storage.") | ||
|
|
||
| def _download_and_prepare(self, dl_manager, verify_infos, **prepare_split_kwargs): | ||
| def _download_and_prepare(self, dl_manager, verify_infos, num_proc=None, **prepare_split_kwargs): | ||
| """Downloads and prepares dataset for reading. | ||
|
|
||
| This is the internal implementation to overwrite called when user calls | ||
|
|
@@ -902,7 +906,7 @@ def _download_and_prepare(self, dl_manager, verify_infos, **prepare_split_kwargs | |
|
|
||
| try: | ||
| # Prepare split will record examples associated to the split | ||
| self._prepare_split(split_generator, **prepare_split_kwargs) | ||
| self._prepare_split(split_generator, num_proc=num_proc, **prepare_split_kwargs) | ||
| except OSError as e: | ||
| raise OSError( | ||
| "Cannot find data file. " | ||
|
|
@@ -1245,6 +1249,7 @@ def _prepare_split( | |
| split_generator: SplitGenerator, | ||
| file_format: str = "arrow", | ||
| max_shard_size: Optional[Union[str, int]] = None, | ||
| num_proc=None, | ||
| **kwargs, | ||
| ): | ||
| """Generate the examples and record them on disk. | ||
|
|
@@ -1331,45 +1336,125 @@ def _prepare_split( | |
| split_generator: SplitGenerator, | ||
| check_duplicate_keys: bool, | ||
| file_format="arrow", | ||
| num_proc=None, | ||
| max_shard_size: Optional[Union[int, str]] = None, | ||
| ): | ||
|
|
||
| max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE) | ||
| is_local = not is_remote_filesystem(self._fs) | ||
| path_join = os.path.join if is_local else posixpath.join | ||
|
|
||
| if file_format == "arrow": | ||
| if max_shard_size is not None: | ||
| raise NotImplementedError( | ||
| "Writing sharded arrow files is not supported. Please don't use max_shard_size or use file_format='parquet'." | ||
| ) | ||
| else: | ||
| max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE) | ||
|
|
||
| if self.info.splits is not None: | ||
| split_info = self.info.splits[split_generator.name] | ||
| else: | ||
| split_info = split_generator.split_info | ||
|
|
||
| suffix = "-SSSSS-of-NNNNN" if file_format == "parquet" else "" | ||
| suffix = "-RRRRR-SSSSS-of-NNNNN" | ||
| fname = f"{self.name}-{split_generator.name}{suffix}.{file_format}" | ||
| fpath = path_join(self._output_dir, fname) | ||
|
|
||
| generator = self._generate_examples(**split_generator.gen_kwargs) | ||
| # Default to using 16-way parallelism for preparation if the number of files is higher than 16. | ||
| if num_proc is None and len(_all_shard_kwargs(split_generator.gen_kwargs)) >= 16: | ||
| num_proc = 16 | ||
|
|
||
| if _shard_number(split_generator.gen_kwargs) <= 1 and num_proc is not None: | ||
| logger.warning( | ||
| f"Setting num_proc from {num_proc} back to 1 to disable multiprocessing as dataset is not shardable" | ||
| ) | ||
| num_proc = None | ||
|
|
||
| if num_proc is None or num_proc == 1: | ||
| result = self._prepare_split_single( | ||
| split_generator.gen_kwargs, | ||
| fpath=fpath, | ||
| file_format=file_format, | ||
| max_shard_size=max_shard_size, | ||
| split_info=split_info, | ||
| check_duplicate_keys=check_duplicate_keys, | ||
| ) | ||
| # wrapping everything into lists for consistency with the multiprocessed code path | ||
| examples_per_rank, bytes_per_rank, features_per_rank, shards_per_rank = [[item] for item in result] | ||
| else: | ||
| args_per_shard = [ | ||
| ( | ||
| shard_kwargs, | ||
| fpath, | ||
| file_format, | ||
| max_shard_size, | ||
| split_info, | ||
| check_duplicate_keys, | ||
| ) | ||
| for shard_kwargs in _all_shard_kwargs(split_generator.gen_kwargs) | ||
| ] | ||
|
|
||
| examples_per_rank = [None] * len(args_per_shard) | ||
| bytes_per_rank = [None] * len(args_per_shard) | ||
| features_per_rank = [None] * len(args_per_shard) | ||
| shards_per_rank = [None] * len(args_per_shard) | ||
|
|
||
| with Pool(num_proc) as pool: | ||
| results = { | ||
| rank: pool.apply_async(self._prepare_split_single, args=args, kwds={"rank": rank}) | ||
| for rank, args in enumerate(args_per_shard) | ||
| } | ||
| for index, async_result in results.items(): | ||
| result = async_result.get() | ||
| ( | ||
| examples_per_rank[index], | ||
| bytes_per_rank[index], | ||
| features_per_rank[index], | ||
| shards_per_rank[index], | ||
| ) = result | ||
|
|
||
| assert None not in examples_per_rank, f"result list {examples_per_rank} still contains None" | ||
| # wrapping everything into lists for consistency with the multiprocessed code path | ||
|
|
||
| total_shards = sum(shards_per_rank) | ||
| total_num_examples = sum(examples_per_rank) | ||
| total_num_bytes = sum(bytes_per_rank) | ||
| features = features_per_rank[0] | ||
|
|
||
| # should rename everything at the end, scheme still TBD | ||
| def _rename_shard(shard_id_and_rank: Tuple[int]): | ||
| shard_id, rank = shard_id_and_rank | ||
| global_shard_id = sum(shards_per_rank[:rank]) + shard_id | ||
| self._rename( | ||
| fpath.replace("SSSSS", f"{shard_id:05d}").replace("RRRRR", f"{rank:05d}"), | ||
| fpath.replace("RRRRR-SSSSS", f"{global_shard_id:05d}").replace("NNNNN", f"{total_shards:05d}"), | ||
| ) | ||
|
|
||
| logger.debug(f"Renaming {total_shards} shards.") | ||
| shard_ids_and_ranks = [ | ||
| (shard_id, rank) for rank, num_shards in enumerate(shards_per_rank) for shard_id in range(num_shards) | ||
| ] | ||
| thread_map(_rename_shard, shard_ids_and_ranks, disable=True, max_workers=64) | ||
|
|
||
| split_generator.split_info.num_examples = total_num_examples | ||
| split_generator.split_info.num_bytes = total_num_bytes | ||
|
|
||
| if self.info.features is None: | ||
| self.info.features = features | ||
|
|
||
| def _prepare_split_single( | ||
| self, shard_kwargs, fpath, file_format, max_shard_size, split_info, check_duplicate_keys, rank=0 | ||
| ): | ||
|
|
||
| generator = self._generate_examples(**shard_kwargs) | ||
| writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter | ||
| embed_local_files = file_format == "parquet" | ||
| total_num_examples, total_num_bytes = 0, 0 | ||
|
|
||
| shard_id = 0 | ||
| # TODO: embed the images/audio files inside parquet files. | ||
| writer = writer_class( | ||
| features=self.info.features, | ||
| path=fpath.replace("SSSSS", f"{shard_id:05d}"), | ||
| path=fpath.replace("SSSSS", f"{shard_id:05d}").replace("RRRRR", f"{rank:05d}"), | ||
| writer_batch_size=self._writer_batch_size, | ||
| hash_salt=split_info.name, | ||
| check_duplicates=check_duplicate_keys, | ||
| storage_options=self._fs.storage_options, | ||
| embed_local_files=embed_local_files, | ||
| ) | ||
| total_num_examples, total_num_bytes = 0, 0 | ||
| try: | ||
| for key, record in logging.tqdm( | ||
| generator, | ||
|
|
@@ -1403,21 +1488,7 @@ def _prepare_split( | |
| total_num_examples += num_examples | ||
| total_num_bytes += num_bytes | ||
|
|
||
| if file_format == "parquet": | ||
|
|
||
| def _rename_shard(shard_id: int): | ||
| self._rename( | ||
| fpath.replace("SSSSS", f"{shard_id:05d}"), | ||
| fpath.replace("SSSSS", f"{shard_id:05d}").replace("NNNNN", f"{num_shards:05d}"), | ||
| ) | ||
|
|
||
| logger.debug(f"Renaming {num_shards} shards.") | ||
| thread_map(_rename_shard, range(num_shards), disable=True, max_workers=64) | ||
|
|
||
| split_generator.split_info.num_examples = total_num_examples | ||
| split_generator.split_info.num_bytes = total_num_bytes | ||
| if self.info.features is None: | ||
| self.info.features = writer._features | ||
| return total_num_examples, total_num_bytes, writer._features, num_shards | ||
|
|
||
| def _download_and_prepare(self, dl_manager, verify_infos, **prepare_splits_kwargs): | ||
| super()._download_and_prepare( | ||
|
|
@@ -1467,43 +1538,125 @@ def _prepare_split( | |
| self, | ||
| split_generator: SplitGenerator, | ||
| file_format: str = "arrow", | ||
| num_proc=None, | ||
| max_shard_size: Optional[Union[str, int]] = None, | ||
| ): | ||
|
|
||
| max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE) | ||
| is_local = not is_remote_filesystem(self._fs) | ||
| path_join = os.path.join if is_local else posixpath.join | ||
|
|
||
| if file_format == "arrow": | ||
| if max_shard_size is not None: | ||
| raise NotImplementedError( | ||
| "Writing sharded arrow files is not supported. Please don't use max_shard_size or use file_format='parquet'." | ||
| ) | ||
| if self.info.splits is not None: | ||
| split_info = self.info.splits[split_generator.name] | ||
| else: | ||
| max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE) | ||
| split_info = split_generator.split_info | ||
|
|
||
| suffix = "-SSSSS-of-NNNNN" if file_format == "parquet" else "" | ||
| suffix = "-RRRRR-SSSSS-of-NNNNN" | ||
| fname = f"{self.name}-{split_generator.name}{suffix}.{file_format}" | ||
| fpath = path_join(self._output_dir, fname) | ||
|
|
||
| generator = self._generate_tables(**split_generator.gen_kwargs) | ||
| # Default to using 16-way parallelism for preparation if the number of files is higher than 16. | ||
| if num_proc is None and len(_all_shard_kwargs(split_generator.gen_kwargs)) >= 16: | ||
| num_proc = 16 | ||
|
|
||
| if _shard_number(split_generator.gen_kwargs) <= 1 and num_proc is not None: | ||
| logger.warning( | ||
| f"Setting num_proc from {num_proc} back to 1 to disable multiprocessing as dataset is not shardable" | ||
| ) | ||
| num_proc = None | ||
|
|
||
| if num_proc is None or num_proc == 1: | ||
| result = self._prepare_split_single( | ||
| split_generator.gen_kwargs, | ||
| fpath=fpath, | ||
| file_format=file_format, | ||
| max_shard_size=max_shard_size, | ||
| split_info=split_info, | ||
| ) | ||
| # wrapping everything into lists for consistency with the multiprocessed code path | ||
| examples_per_rank, bytes_per_rank, features_per_rank, shards_per_rank = [[item] for item in result] | ||
| else: | ||
| args_per_shard = [ | ||
| ( | ||
| shard_kwargs, | ||
| fpath, | ||
| file_format, | ||
| max_shard_size, | ||
| split_info, | ||
| ) | ||
| for shard_kwargs in _all_shard_kwargs(split_generator.gen_kwargs) | ||
| ] | ||
|
|
||
| examples_per_rank = [None] * len(args_per_shard) | ||
| bytes_per_rank = [None] * len(args_per_shard) | ||
| features_per_rank = [None] * len(args_per_shard) | ||
| shards_per_rank = [None] * len(args_per_shard) | ||
|
|
||
| with Pool(num_proc) as pool: | ||
| results = { | ||
| rank: pool.apply_async(self._prepare_split_single, args=args, kwds={"rank": rank}) | ||
| for rank, args in enumerate(args_per_shard) | ||
| } | ||
| for index, async_result in results.items(): | ||
| result = async_result.get() | ||
| ( | ||
| examples_per_rank[index], | ||
| bytes_per_rank[index], | ||
| features_per_rank[index], | ||
| shards_per_rank[index], | ||
| ) = result | ||
|
|
||
| assert None not in examples_per_rank, f"result list {examples_per_rank} still contains None" | ||
| # wrapping everything into lists for consistency with the multiprocessed code path | ||
|
|
||
| total_shards = sum(shards_per_rank) | ||
| total_num_examples = sum(examples_per_rank) | ||
| total_num_bytes = sum(bytes_per_rank) | ||
| features = features_per_rank[0] | ||
|
|
||
| # should rename everything at the end, scheme still TBD | ||
| def _rename_shard(shard_id_and_rank: Tuple[int]): | ||
| shard_id, rank = shard_id_and_rank | ||
| global_shard_id = sum(shards_per_rank[:rank]) + shard_id | ||
| self._rename( | ||
| fpath.replace("SSSSS", f"{shard_id:05d}").replace("RRRRR", f"{rank:05d}"), | ||
| fpath.replace("RRRRR-SSSSS", f"{global_shard_id:05d}").replace("NNNNN", f"{total_shards:05d}"), | ||
| ) | ||
|
||
|
|
||
| logger.debug(f"Renaming {total_shards} shards.") | ||
| shard_ids_and_ranks = [ | ||
| (shard_id, rank) for rank, num_shards in enumerate(shards_per_rank) for shard_id in range(num_shards) | ||
| ] | ||
| thread_map(_rename_shard, shard_ids_and_ranks, disable=True, max_workers=64) | ||
|
|
||
| split_generator.split_info.num_examples = total_num_examples | ||
| split_generator.split_info.num_bytes = total_num_bytes | ||
|
|
||
| if self.info.features is None: | ||
| self.info.features = features | ||
|
|
||
| def _prepare_split_single(self, shard_kwargs, fpath, file_format, max_shard_size, split_info, rank=0): | ||
|
|
||
| generator = self._generate_tables(**shard_kwargs) | ||
| writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter | ||
| embed_local_files = file_format == "parquet" | ||
| total_num_examples, total_num_bytes = 0, 0 | ||
|
|
||
| shard_id = 0 | ||
| # TODO: embed the images/audio files inside parquet files. | ||
| writer = writer_class( | ||
| features=self.info.features, | ||
| path=fpath.replace("SSSSS", f"{shard_id:05d}"), | ||
| path=fpath.replace("SSSSS", f"{shard_id:05d}").replace("RRRRR", f"{rank:05d}"), | ||
| storage_options=self._fs.storage_options, | ||
| embed_local_files=embed_local_files, | ||
| ) | ||
| total_num_examples, total_num_bytes = 0, 0 | ||
| try: | ||
| for key, table in logging.tqdm( | ||
| generator, | ||
| unit=" tables", | ||
| leave=False, | ||
| disable=not logging.is_progress_bar_enabled(), | ||
| desc=f"Generating {split_info.name} split", | ||
| ): | ||
| if max_shard_size is not None and writer._num_bytes > max_shard_size: | ||
| num_examples, num_bytes = writer.finalize() | ||
|
|
@@ -1525,21 +1678,7 @@ def _prepare_split( | |
| total_num_examples += num_examples | ||
| total_num_bytes += num_bytes | ||
|
|
||
| if file_format == "parquet": | ||
|
|
||
| def _rename_shard(shard_id: int): | ||
| self._rename( | ||
| fpath.replace("SSSSS", f"{shard_id:05d}"), | ||
| fpath.replace("SSSSS", f"{shard_id:05d}").replace("NNNNN", f"{num_shards:05d}"), | ||
| ) | ||
|
|
||
| logger.debug(f"Renaming {num_shards} shards.") | ||
| thread_map(_rename_shard, range(num_shards), disable=True, max_workers=64) | ||
|
|
||
| split_generator.split_info.num_examples = total_num_examples | ||
| split_generator.split_info.num_bytes = total_num_bytes | ||
| if self.info.features is None: | ||
| self.info.features = writer._features | ||
| return total_num_examples, total_num_bytes, writer._features, num_shards | ||
|
|
||
| def _get_examples_iterable_for_split(self, split_generator: SplitGenerator) -> ExamplesIterable: | ||
| return ExamplesIterable( | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.