Skip to content

Commit 4757930

Browse files
committed
shard parquet in download_and_prepare
1 parent 713f83c commit 4757930

File tree

2 files changed

+160
-36
lines changed

2 files changed

+160
-36
lines changed

src/datasets/builder.py

Lines changed: 125 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from typing import Dict, Mapping, Optional, Tuple, Union
3131

3232
import fsspec
33+
from tqdm.contrib.concurrent import thread_map
3334

3435
from . import config, utils
3536
from .arrow_dataset import Dataset
@@ -62,6 +63,7 @@
6263
from .utils.info_utils import get_size_checksum_dict, verify_checksums, verify_splits
6364
from .utils.py_utils import (
6465
classproperty,
66+
convert_file_size_to_int,
6567
has_sufficient_disk_space,
6668
map_nested,
6769
memoize,
@@ -575,6 +577,14 @@ def get_imported_module_dir(cls):
575577
"""Return the path of the module of this class or subclass."""
576578
return os.path.dirname(inspect.getfile(inspect.getmodule(cls)))
577579

580+
def _rename(self, src: str, dst: str):
581+
is_local = not is_remote_filesystem(self._fs)
582+
if is_local:
583+
# LocalFileSystem.mv does copy + rm, it is more efficient to simply rename a local directory
584+
os.rename(self._fs._strip_protocol(src), self._fs._strip_protocol(dst))
585+
else:
586+
self._fs.mv(src, dst, recursive=True)
587+
578588
def download_and_prepare(
579589
self,
580590
download_config: Optional[DownloadConfig] = None,
@@ -672,11 +682,7 @@ def incomplete_dir(dirname):
672682
yield tmp_dir
673683
if self._fs.isdir(dirname):
674684
self._fs.rm(dirname, recursive=True)
675-
if is_local:
676-
# LocalFileSystem.mv does copy + rm, it is more efficient to simply rename a local directory
677-
os.rename(self._fs._strip_protocol(tmp_dir), self._fs._strip_protocol(dirname))
678-
else:
679-
self._fs.mv(tmp_dir, dirname, recursive=True)
685+
self._rename(tmp_dir, dirname)
680686
finally:
681687
if self._fs.exists(tmp_dir):
682688
self._fs.rm(tmp_dir, recursive=True)
@@ -1224,51 +1230,90 @@ def _generate_examples(self, **kwargs):
12241230
"""
12251231
raise NotImplementedError()
12261232

1227-
def _prepare_split(self, split_generator, check_duplicate_keys, file_format=None):
1233+
def _prepare_split(self, split_generator, check_duplicate_keys, file_format=None, max_shard_size=None):
12281234
is_local = not is_remote_filesystem(self._fs)
12291235
path_join = os.path.join if is_local else posixpath.join
1236+
file_format = file_format or "arrow"
1237+
1238+
if max_shard_size is not None:
1239+
max_shard_size = convert_file_size_to_int(max_shard_size)
1240+
if file_format == "arrow":
1241+
raise NotImplementedError(
1242+
"Writing sharded arrow files is not supported. Please don't use max_shard_size or use parquet."
1243+
)
12301244

12311245
if self.info.splits is not None:
12321246
split_info = self.info.splits[split_generator.name]
12331247
else:
12341248
split_info = split_generator.split_info
12351249

1236-
file_format = file_format or "arrow"
1237-
suffix = "-00000-of-00001" if file_format == "parquet" else ""
1250+
suffix = "-SSSSS-of-NNNNN" if file_format == "parquet" else ""
12381251
fname = f"{self.name}-{split_generator.name}{suffix}.{file_format}"
12391252
fpath = path_join(self._cache_dir, fname)
12401253

12411254
generator = self._generate_examples(**split_generator.gen_kwargs)
12421255

12431256
writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter
1244-
with writer_class(
1257+
1258+
shard_id = 0
1259+
writer = writer_class(
12451260
features=self.info.features,
1246-
path=fpath,
1261+
path=fpath.replace("SSSSS", f"{shard_id:05d}"),
12471262
writer_batch_size=self._writer_batch_size,
12481263
hash_salt=split_info.name,
12491264
check_duplicates=check_duplicate_keys,
12501265
storage_options=self._fs.storage_options,
1251-
) as writer:
1252-
try:
1253-
for key, record in logging.tqdm(
1254-
generator,
1255-
unit=" examples",
1256-
total=split_info.num_examples,
1257-
leave=False,
1258-
disable=not logging.is_progress_bar_enabled(),
1259-
desc=f"Generating {split_info.name} split",
1260-
):
1261-
example = self.info.features.encode_example(record)
1262-
writer.write(example, key)
1263-
finally:
1264-
num_examples, num_bytes = writer.finalize()
1265-
1266-
split_generator.split_info.num_examples = num_examples
1267-
split_generator.split_info.num_bytes = num_bytes
1266+
)
1267+
total_num_examples, total_num_bytes = 0, 0
1268+
try:
1269+
for key, record in logging.tqdm(
1270+
generator,
1271+
unit=" examples",
1272+
total=split_info.num_examples,
1273+
leave=False,
1274+
disable=not logging.is_progress_bar_enabled(),
1275+
desc=f"Generating {split_info.name} split",
1276+
):
1277+
if max_shard_size is not None and writer._num_bytes > max_shard_size:
1278+
num_examples, num_bytes = writer.finalize()
1279+
total_num_examples += num_examples
1280+
total_num_bytes += num_bytes
1281+
shard_id += 1
1282+
writer = writer_class(
1283+
features=writer._features,
1284+
path=fpath.replace("SSSSS", f"{shard_id:05d}"),
1285+
writer_batch_size=self._writer_batch_size,
1286+
hash_salt=split_info.name,
1287+
check_duplicates=check_duplicate_keys,
1288+
storage_options=self._fs.storage_options,
1289+
)
1290+
example = self.info.features.encode_example(record)
1291+
writer.write(example, key)
1292+
finally:
1293+
num_shards = shard_id + 1
1294+
num_examples, num_bytes = writer.finalize()
1295+
total_num_examples += num_examples
1296+
total_num_bytes += num_bytes
12681297

1269-
def _download_and_prepare(self, dl_manager, verify_infos, file_format=None):
1298+
if file_format == "parquet":
1299+
1300+
def _rename_shard(shard_id: int):
1301+
self._rename(
1302+
fpath.replace("SSSSS", f"{shard_id:05d}"),
1303+
fpath.replace("SSSSS", f"{shard_id:05d}").replace("NNNNN", f"{num_shards:05d}"),
1304+
)
1305+
1306+
logger.debug(f"Renaming {num_shards} shards.")
1307+
thread_map(_rename_shard, range(num_shards), disable=True, max_workers=64)
1308+
1309+
split_generator.split_info.num_examples = total_num_examples
1310+
split_generator.split_info.num_bytes = total_num_bytes
1311+
if self.info.features is None:
1312+
self.info.features = writer._features
1313+
1314+
def _download_and_prepare(self, dl_manager, verify_infos, file_format=None, **kwargs):
12701315
super()._download_and_prepare(
1271-
dl_manager, verify_infos, file_format=file_format, check_duplicate_keys=verify_infos
1316+
dl_manager, verify_infos, file_format=file_format, check_duplicate_keys=verify_infos, **kwargs
12721317
)
12731318

12741319
def _get_examples_iterable_for_split(self, split_generator: SplitGenerator) -> ExamplesIterable:
@@ -1310,26 +1355,70 @@ def _generate_tables(self, **kwargs):
13101355
"""
13111356
raise NotImplementedError()
13121357

1313-
def _prepare_split(self, split_generator, file_format=None):
1358+
def _prepare_split(self, split_generator, file_format=None, max_shard_size=None):
13141359
is_local = not is_remote_filesystem(self._fs)
13151360
path_join = os.path.join if is_local else posixpath.join
1316-
13171361
file_format = file_format or "arrow"
1318-
suffix = "-00000-of-00001" if file_format == "parquet" else ""
1362+
1363+
if max_shard_size is not None:
1364+
if file_format == "arrow":
1365+
raise NotImplementedError(
1366+
"Writing sharded arrow files is not supported. Please don't use max_shard_size or use parquet."
1367+
)
1368+
max_shard_size = convert_file_size_to_int(max_shard_size or "500MB")
1369+
1370+
suffix = "-SSSSS-of-NNNNN" if file_format == "parquet" else ""
13191371
fname = f"{self.name}-{split_generator.name}{suffix}.{file_format}"
13201372
fpath = path_join(self._cache_dir, fname)
13211373

13221374
generator = self._generate_tables(**split_generator.gen_kwargs)
1375+
13231376
writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter
1324-
with writer_class(features=self.info.features, path=fpath, storage_options=self._fs.storage_options) as writer:
1377+
1378+
shard_id = 0
1379+
writer = writer_class(
1380+
features=self.info.features,
1381+
path=fpath.replace("SSSSS", f"{shard_id:05d}"),
1382+
storage_options=self._fs.storage_options,
1383+
)
1384+
total_num_examples, total_num_bytes = 0, 0
1385+
try:
13251386
for key, table in logging.tqdm(
1326-
generator, unit=" tables", leave=False, disable=(not logging.is_progress_bar_enabled())
1387+
generator,
1388+
unit=" tables",
1389+
leave=False,
1390+
disable=not logging.is_progress_bar_enabled(),
13271391
):
1392+
if max_shard_size is not None and writer._num_bytes > max_shard_size:
1393+
num_examples, num_bytes = writer.finalize()
1394+
total_num_examples += num_examples
1395+
total_num_bytes += num_bytes
1396+
shard_id += 1
1397+
writer = writer_class(
1398+
features=writer._features,
1399+
path=fpath.replace("SSSSS", f"{shard_id:05d}"),
1400+
storage_options=self._fs.storage_options,
1401+
)
13281402
writer.write_table(table)
1403+
finally:
1404+
num_shards = shard_id + 1
13291405
num_examples, num_bytes = writer.finalize()
1406+
total_num_examples += num_examples
1407+
total_num_bytes += num_bytes
1408+
1409+
if file_format == "parquet":
1410+
1411+
def _rename_shard(shard_id: int):
1412+
self._rename(
1413+
fpath.replace("SSSSS", f"{shard_id:05d}"),
1414+
fpath.replace("SSSSS", f"{shard_id:05d}").replace("NNNNN", f"{num_shards:05d}"),
1415+
)
1416+
1417+
logger.debug(f"Renaming {num_shards} shards.")
1418+
thread_map(_rename_shard, range(num_shards), disable=True, max_workers=64)
13301419

1331-
split_generator.split_info.num_examples = num_examples
1332-
split_generator.split_info.num_bytes = num_bytes
1420+
split_generator.split_info.num_examples = total_num_examples
1421+
split_generator.split_info.num_bytes = total_num_bytes
13331422
if self.info.features is None:
13341423
self.info.features = writer._features
13351424

tests/test_builder.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -952,6 +952,24 @@ def test_generator_based_builder_download_and_prepare_as_parquet(tmp_path):
952952
assert pq.ParquetFile(parquet_path) is not None
953953

954954

955+
def test_generator_based_builder_download_and_prepare_as_sharded_parquet(tmp_path):
956+
writer_batch_size = 25
957+
builder = DummyGeneratorBasedBuilder(cache_dir=tmp_path, writer_batch_size=writer_batch_size)
958+
builder.download_and_prepare(file_format="parquet", max_shard_size=1) # one batch per shard
959+
expected_num_shards = 100 // writer_batch_size
960+
assert builder.info.splits["train"].num_examples, 100
961+
parquet_path = os.path.join(
962+
tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-{expected_num_shards:05d}.parquet"
963+
)
964+
assert os.path.exists(parquet_path)
965+
parquet_files = [
966+
pq.ParquetFile(parquet_path)
967+
for parquet_path in Path(tmp_path).rglob(f"{builder.name}-train-*-of-{expected_num_shards:05d}.parquet")
968+
]
969+
assert len(parquet_files) == expected_num_shards
970+
assert sum(parquet_file.metadata.num_rows for parquet_file in parquet_files) == 100
971+
972+
955973
def test_arrow_based_builder_download_and_prepare_as_parquet(tmp_path):
956974
builder = DummyArrowBasedBuilder(cache_dir=tmp_path)
957975
builder.download_and_prepare(file_format="parquet")
@@ -963,6 +981,23 @@ def test_arrow_based_builder_download_and_prepare_as_parquet(tmp_path):
963981
assert pq.ParquetFile(parquet_path) is not None
964982

965983

984+
def test_arrow_based_builder_download_and_prepare_as_sharded_parquet(tmp_path):
985+
builder = DummyArrowBasedBuilder(cache_dir=tmp_path)
986+
builder.download_and_prepare(file_format="parquet", max_shard_size=1) # one table per shard
987+
expected_num_shards = 10
988+
assert builder.info.splits["train"].num_examples, 100
989+
parquet_path = os.path.join(
990+
tmp_path, builder.name, "default", "0.0.0", f"{builder.name}-train-00000-of-{expected_num_shards:05d}.parquet"
991+
)
992+
assert os.path.exists(parquet_path)
993+
parquet_files = [
994+
pq.ParquetFile(parquet_path)
995+
for parquet_path in Path(tmp_path).rglob(f"{builder.name}-train-*-of-{expected_num_shards:05d}.parquet")
996+
]
997+
assert len(parquet_files) == expected_num_shards
998+
assert sum(parquet_file.metadata.num_rows for parquet_file in parquet_files) == 100
999+
1000+
9661001
def test_beam_based_builder_download_and_prepare_as_parquet(tmp_path):
9671002
builder = DummyBeamBasedBuilder(cache_dir=tmp_path, beam_runner="DirectRunner")
9681003
builder.download_and_prepare(file_format="parquet")

0 commit comments

Comments
 (0)