Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
a802ba5
multiprocessing-compatible naming scheme and refactor
TevenLeScao Oct 11, 2022
ea56329
multiprocessed shard writing for GeneratorBasedBuilder
TevenLeScao Oct 12, 2022
9536184
multiprocessed shard writing for ArrowBasedBuilder
TevenLeScao Oct 12, 2022
31d8395
style
TevenLeScao Oct 12, 2022
9c5843a
multiprocessed dataset loading
TevenLeScao Oct 15, 2022
328112e
compatibility with non-sharded datasets
TevenLeScao Oct 15, 2022
9dc8539
bugfix
TevenLeScao Oct 17, 2022
21a603a
bugfix
TevenLeScao Oct 17, 2022
55cb365
Merge remote-tracking branch 'origin/multiprocessed_dataset_prep' int…
TevenLeScao Oct 19, 2022
94efbdb
removed unused import
TevenLeScao Oct 19, 2022
bac2b2f
fixed bad ordering
TevenLeScao Oct 19, 2022
3e4f337
less misleading tqdm
TevenLeScao Oct 19, 2022
b2f634d
fix gen_kwargs distribution + read shards
lhoestq Oct 20, 2022
296302f
minor
lhoestq Oct 20, 2022
9b312d4
minor2
lhoestq Oct 20, 2022
d2e70f2
support beam datasets
lhoestq Oct 21, 2022
e3a30fa
docstrings + minor
lhoestq Oct 25, 2022
cf6fd25
add iflatmap_unordered for parallel write & progress updates
lhoestq Oct 26, 2022
3e5d0cc
use 1 tqdm bar receiving updates from subprocesses
lhoestq Oct 26, 2022
09c13a7
docs
lhoestq Oct 26, 2022
a2e83d5
add test_iflatmap_unordered
lhoestq Oct 27, 2022
e3bc7a7
style
lhoestq Oct 27, 2022
e8923e2
test arrow_reader.py
lhoestq Oct 27, 2022
ef9c7f1
fix test_iflatmap_unordered
lhoestq Oct 28, 2022
088dbb1
add Beam test_download_and_prepare_sharded
lhoestq Oct 28, 2022
eb1fc58
test gen_kwargs distribution
lhoestq Oct 28, 2022
e035339
test download_and_prepare with num_proc
lhoestq Oct 28, 2022
06c5d33
Merge branch 'main' into multiprocessed_dataset_prep
lhoestq Oct 28, 2022
e50ec74
style
lhoestq Oct 28, 2022
525c829
improve test
lhoestq Nov 2, 2022
eae6491
don't close the pool
lhoestq Nov 2, 2022
93f355d
Merge branch 'main' into multiprocessed_dataset_prep
lhoestq Nov 2, 2022
b321c61
fix multiprocessing on windows
lhoestq Nov 2, 2022
b05e551
keep multiprocessing disabled by default
lhoestq Nov 2, 2022
020eb89
again + docs
lhoestq Nov 2, 2022
142f822
more docs
lhoestq Nov 2, 2022
f22c162
more docs
lhoestq Nov 2, 2022
08b8626
Merge remote-tracking branch 'upstream/main' into multiprocessed_data…
lhoestq Nov 3, 2022
4ce2d12
some var renaming
lhoestq Nov 3, 2022
e05ad83
style
lhoestq Nov 3, 2022
c621cb6
Apply suggestions from code review
lhoestq Nov 8, 2022
22d965e
Apply suggestions from code review
lhoestq Nov 8, 2022
dc0ef15
added utils/sharding.py
lhoestq Nov 8, 2022
95cdd0b
Merge remote-tracking branch 'upstream/main' into multiprocessed_data…
lhoestq Nov 8, 2022
12d69f3
style
lhoestq Nov 8, 2022
db45b3b
style
lhoestq Nov 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 199 additions & 60 deletions src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from typing import Dict, Mapping, Optional, Tuple, Union

import fsspec
from multiprocess import Pool
from tqdm.contrib.concurrent import thread_map

from . import config, utils
Expand All @@ -53,7 +54,13 @@
from .filesystems import is_remote_filesystem
from .fingerprint import Hasher
from .info import DatasetInfo, DatasetInfosDict, PostProcessedInfo
from .iterable_dataset import ExamplesIterable, IterableDataset, _generate_examples_from_tables_wrapper
from .iterable_dataset import (
ExamplesIterable,
IterableDataset,
_all_shard_kwargs,
_generate_examples_from_tables_wrapper,
_shard_number,
)
from .keyhash import DuplicatedKeysError
from .naming import INVALID_WINDOWS_CHARACTERS_IN_PATH, camelcase_to_snakecase
from .splits import Split, SplitDict, SplitGenerator
Expand Down Expand Up @@ -605,6 +612,7 @@ def download_and_prepare(
use_auth_token: Optional[Union[bool, str]] = None,
file_format: str = "arrow",
max_shard_size: Optional[Union[int, str]] = None,
num_proc=None,
storage_options: Optional[dict] = None,
**download_and_prepare_kwargs,
):
Expand Down Expand Up @@ -680,11 +688,6 @@ def download_and_prepare(
if file_format is not None and file_format not in ["arrow", "parquet"]:
raise ValueError(f"Unsupported file_format: {file_format}. Expected 'arrow' or 'parquet'")

if file_format == "arrow" and max_shard_size is not None:
raise NotImplementedError(
"Writing sharded arrow files is not supported. Please don't use max_shard_size or use file_format='parquet'."
)

if self._fs._strip_protocol(self._output_dir) == "":
# We don't support the root directory, because it has no dirname,
# and we need a dirname to use a <dirname>.incomplete directory
Expand Down Expand Up @@ -809,6 +812,7 @@ def incomplete_dir(dirname):
prepare_split_kwargs = {
"file_format": file_format,
"max_shard_size": max_shard_size,
"num_proc": num_proc,
**download_and_prepare_kwargs,
}
self._download_and_prepare(
Expand Down Expand Up @@ -865,7 +869,7 @@ def _download_prepared_from_hf_gcs(self, download_config: DownloadConfig):
logger.info(f"Couldn't download resourse file {resource_file_name} from Hf google storage.")
logger.info("Dataset downloaded from Hf google storage.")

def _download_and_prepare(self, dl_manager, verify_infos, **prepare_split_kwargs):
def _download_and_prepare(self, dl_manager, verify_infos, num_proc=None, **prepare_split_kwargs):
"""Downloads and prepares dataset for reading.

This is the internal implementation to overwrite called when user calls
Expand Down Expand Up @@ -902,7 +906,7 @@ def _download_and_prepare(self, dl_manager, verify_infos, **prepare_split_kwargs

try:
# Prepare split will record examples associated to the split
self._prepare_split(split_generator, **prepare_split_kwargs)
self._prepare_split(split_generator, num_proc=num_proc, **prepare_split_kwargs)
except OSError as e:
raise OSError(
"Cannot find data file. "
Expand Down Expand Up @@ -1245,6 +1249,7 @@ def _prepare_split(
split_generator: SplitGenerator,
file_format: str = "arrow",
max_shard_size: Optional[Union[str, int]] = None,
num_proc=None,
**kwargs,
):
"""Generate the examples and record them on disk.
Expand Down Expand Up @@ -1331,45 +1336,125 @@ def _prepare_split(
split_generator: SplitGenerator,
check_duplicate_keys: bool,
file_format="arrow",
num_proc=None,
max_shard_size: Optional[Union[int, str]] = None,
):

max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE)
is_local = not is_remote_filesystem(self._fs)
path_join = os.path.join if is_local else posixpath.join

if file_format == "arrow":
if max_shard_size is not None:
raise NotImplementedError(
"Writing sharded arrow files is not supported. Please don't use max_shard_size or use file_format='parquet'."
)
else:
max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE)

if self.info.splits is not None:
split_info = self.info.splits[split_generator.name]
else:
split_info = split_generator.split_info

suffix = "-SSSSS-of-NNNNN" if file_format == "parquet" else ""
suffix = "-RRRRR-SSSSS-of-NNNNN"
fname = f"{self.name}-{split_generator.name}{suffix}.{file_format}"
fpath = path_join(self._output_dir, fname)

generator = self._generate_examples(**split_generator.gen_kwargs)
# Default to using 16-way parallelism for preparation if the number of files is higher than 16.
if num_proc is None and len(_all_shard_kwargs(split_generator.gen_kwargs)) >= 16:
num_proc = 16

if _shard_number(split_generator.gen_kwargs) <= 1 and num_proc is not None:
logger.warning(
f"Setting num_proc from {num_proc} back to 1 to disable multiprocessing as dataset is not shardable"
)
num_proc = None

if num_proc is None or num_proc == 1:
result = self._prepare_split_single(
split_generator.gen_kwargs,
fpath=fpath,
file_format=file_format,
max_shard_size=max_shard_size,
split_info=split_info,
check_duplicate_keys=check_duplicate_keys,
)
# wrapping everything into lists for consistency with the multiprocessed code path
examples_per_rank, bytes_per_rank, features_per_rank, shards_per_rank = [[item] for item in result]
else:
args_per_shard = [
(
shard_kwargs,
fpath,
file_format,
max_shard_size,
split_info,
check_duplicate_keys,
)
for shard_kwargs in _all_shard_kwargs(split_generator.gen_kwargs)
]

examples_per_rank = [None] * len(args_per_shard)
bytes_per_rank = [None] * len(args_per_shard)
features_per_rank = [None] * len(args_per_shard)
shards_per_rank = [None] * len(args_per_shard)

with Pool(num_proc) as pool:
results = {
rank: pool.apply_async(self._prepare_split_single, args=args, kwds={"rank": rank})
for rank, args in enumerate(args_per_shard)
}
for index, async_result in results.items():
result = async_result.get()
(
examples_per_rank[index],
bytes_per_rank[index],
features_per_rank[index],
shards_per_rank[index],
) = result

assert None not in examples_per_rank, f"result list {examples_per_rank} still contains None"
# wrapping everything into lists for consistency with the multiprocessed code path

total_shards = sum(shards_per_rank)
total_num_examples = sum(examples_per_rank)
total_num_bytes = sum(bytes_per_rank)
features = features_per_rank[0]

# should rename everything at the end, scheme still TBD
def _rename_shard(shard_id_and_rank: Tuple[int]):
shard_id, rank = shard_id_and_rank
global_shard_id = sum(shards_per_rank[:rank]) + shard_id
self._rename(
fpath.replace("SSSSS", f"{shard_id:05d}").replace("RRRRR", f"{rank:05d}"),
fpath.replace("RRRRR-SSSSS", f"{global_shard_id:05d}").replace("NNNNN", f"{total_shards:05d}"),
)

logger.debug(f"Renaming {total_shards} shards.")
shard_ids_and_ranks = [
(shard_id, rank) for rank, num_shards in enumerate(shards_per_rank) for shard_id in range(num_shards)
]
thread_map(_rename_shard, shard_ids_and_ranks, disable=True, max_workers=64)

split_generator.split_info.num_examples = total_num_examples
split_generator.split_info.num_bytes = total_num_bytes

if self.info.features is None:
self.info.features = features

def _prepare_split_single(
self, shard_kwargs, fpath, file_format, max_shard_size, split_info, check_duplicate_keys, rank=0
):

generator = self._generate_examples(**shard_kwargs)
writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter
embed_local_files = file_format == "parquet"
total_num_examples, total_num_bytes = 0, 0

shard_id = 0
# TODO: embed the images/audio files inside parquet files.
writer = writer_class(
features=self.info.features,
path=fpath.replace("SSSSS", f"{shard_id:05d}"),
path=fpath.replace("SSSSS", f"{shard_id:05d}").replace("RRRRR", f"{rank:05d}"),
writer_batch_size=self._writer_batch_size,
hash_salt=split_info.name,
check_duplicates=check_duplicate_keys,
storage_options=self._fs.storage_options,
embed_local_files=embed_local_files,
)
total_num_examples, total_num_bytes = 0, 0
try:
for key, record in logging.tqdm(
generator,
Expand Down Expand Up @@ -1403,21 +1488,7 @@ def _prepare_split(
total_num_examples += num_examples
total_num_bytes += num_bytes

if file_format == "parquet":

def _rename_shard(shard_id: int):
self._rename(
fpath.replace("SSSSS", f"{shard_id:05d}"),
fpath.replace("SSSSS", f"{shard_id:05d}").replace("NNNNN", f"{num_shards:05d}"),
)

logger.debug(f"Renaming {num_shards} shards.")
thread_map(_rename_shard, range(num_shards), disable=True, max_workers=64)

split_generator.split_info.num_examples = total_num_examples
split_generator.split_info.num_bytes = total_num_bytes
if self.info.features is None:
self.info.features = writer._features
return total_num_examples, total_num_bytes, writer._features, num_shards

def _download_and_prepare(self, dl_manager, verify_infos, **prepare_splits_kwargs):
super()._download_and_prepare(
Expand Down Expand Up @@ -1467,43 +1538,125 @@ def _prepare_split(
self,
split_generator: SplitGenerator,
file_format: str = "arrow",
num_proc=None,
max_shard_size: Optional[Union[str, int]] = None,
):

max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE)
is_local = not is_remote_filesystem(self._fs)
path_join = os.path.join if is_local else posixpath.join

if file_format == "arrow":
if max_shard_size is not None:
raise NotImplementedError(
"Writing sharded arrow files is not supported. Please don't use max_shard_size or use file_format='parquet'."
)
if self.info.splits is not None:
split_info = self.info.splits[split_generator.name]
else:
max_shard_size = convert_file_size_to_int(max_shard_size or config.MAX_SHARD_SIZE)
split_info = split_generator.split_info

suffix = "-SSSSS-of-NNNNN" if file_format == "parquet" else ""
suffix = "-RRRRR-SSSSS-of-NNNNN"
fname = f"{self.name}-{split_generator.name}{suffix}.{file_format}"
fpath = path_join(self._output_dir, fname)

generator = self._generate_tables(**split_generator.gen_kwargs)
# Default to using 16-way parallelism for preparation if the number of files is higher than 16.
if num_proc is None and len(_all_shard_kwargs(split_generator.gen_kwargs)) >= 16:
num_proc = 16

if _shard_number(split_generator.gen_kwargs) <= 1 and num_proc is not None:
logger.warning(
f"Setting num_proc from {num_proc} back to 1 to disable multiprocessing as dataset is not shardable"
)
num_proc = None

if num_proc is None or num_proc == 1:
result = self._prepare_split_single(
split_generator.gen_kwargs,
fpath=fpath,
file_format=file_format,
max_shard_size=max_shard_size,
split_info=split_info,
)
# wrapping everything into lists for consistency with the multiprocessed code path
examples_per_rank, bytes_per_rank, features_per_rank, shards_per_rank = [[item] for item in result]
else:
args_per_shard = [
(
shard_kwargs,
fpath,
file_format,
max_shard_size,
split_info,
)
for shard_kwargs in _all_shard_kwargs(split_generator.gen_kwargs)
]

examples_per_rank = [None] * len(args_per_shard)
bytes_per_rank = [None] * len(args_per_shard)
features_per_rank = [None] * len(args_per_shard)
shards_per_rank = [None] * len(args_per_shard)

with Pool(num_proc) as pool:
results = {
rank: pool.apply_async(self._prepare_split_single, args=args, kwds={"rank": rank})
for rank, args in enumerate(args_per_shard)
}
for index, async_result in results.items():
result = async_result.get()
(
examples_per_rank[index],
bytes_per_rank[index],
features_per_rank[index],
shards_per_rank[index],
) = result

assert None not in examples_per_rank, f"result list {examples_per_rank} still contains None"
# wrapping everything into lists for consistency with the multiprocessed code path

total_shards = sum(shards_per_rank)
total_num_examples = sum(examples_per_rank)
total_num_bytes = sum(bytes_per_rank)
features = features_per_rank[0]

# should rename everything at the end, scheme still TBD
def _rename_shard(shard_id_and_rank: Tuple[int]):
shard_id, rank = shard_id_and_rank
global_shard_id = sum(shards_per_rank[:rank]) + shard_id
self._rename(
fpath.replace("SSSSS", f"{shard_id:05d}").replace("RRRRR", f"{rank:05d}"),
fpath.replace("RRRRR-SSSSS", f"{global_shard_id:05d}").replace("NNNNN", f"{total_shards:05d}"),
)
Copy link
Member

@lhoestq lhoestq Oct 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this preserve the order of the original dataset ? If so that's amazing :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does! Or at least, this preserves the order of the shards in split_generator.gen_kwargs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually it doesn't after testing, but I can't quite figure out why :/


logger.debug(f"Renaming {total_shards} shards.")
shard_ids_and_ranks = [
(shard_id, rank) for rank, num_shards in enumerate(shards_per_rank) for shard_id in range(num_shards)
]
thread_map(_rename_shard, shard_ids_and_ranks, disable=True, max_workers=64)

split_generator.split_info.num_examples = total_num_examples
split_generator.split_info.num_bytes = total_num_bytes

if self.info.features is None:
self.info.features = features

def _prepare_split_single(self, shard_kwargs, fpath, file_format, max_shard_size, split_info, rank=0):

generator = self._generate_tables(**shard_kwargs)
writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter
embed_local_files = file_format == "parquet"
total_num_examples, total_num_bytes = 0, 0

shard_id = 0
# TODO: embed the images/audio files inside parquet files.
writer = writer_class(
features=self.info.features,
path=fpath.replace("SSSSS", f"{shard_id:05d}"),
path=fpath.replace("SSSSS", f"{shard_id:05d}").replace("RRRRR", f"{rank:05d}"),
storage_options=self._fs.storage_options,
embed_local_files=embed_local_files,
)
total_num_examples, total_num_bytes = 0, 0
try:
for key, table in logging.tqdm(
generator,
unit=" tables",
leave=False,
disable=not logging.is_progress_bar_enabled(),
desc=f"Generating {split_info.name} split",
):
if max_shard_size is not None and writer._num_bytes > max_shard_size:
num_examples, num_bytes = writer.finalize()
Expand All @@ -1525,21 +1678,7 @@ def _prepare_split(
total_num_examples += num_examples
total_num_bytes += num_bytes

if file_format == "parquet":

def _rename_shard(shard_id: int):
self._rename(
fpath.replace("SSSSS", f"{shard_id:05d}"),
fpath.replace("SSSSS", f"{shard_id:05d}").replace("NNNNN", f"{num_shards:05d}"),
)

logger.debug(f"Renaming {num_shards} shards.")
thread_map(_rename_shard, range(num_shards), disable=True, max_workers=64)

split_generator.split_info.num_examples = total_num_examples
split_generator.split_info.num_bytes = total_num_bytes
if self.info.features is None:
self.info.features = writer._features
return total_num_examples, total_num_bytes, writer._features, num_shards

def _get_examples_iterable_for_split(self, split_generator: SplitGenerator) -> ExamplesIterable:
return ExamplesIterable(
Expand Down
Loading