|
30 | 30 | from typing import Dict, Mapping, Optional, Tuple, Union |
31 | 31 |
|
32 | 32 | import fsspec |
| 33 | +from tqdm.contrib.concurrent import thread_map |
33 | 34 |
|
34 | 35 | from . import config, utils |
35 | 36 | from .arrow_dataset import Dataset |
|
62 | 63 | from .utils.info_utils import get_size_checksum_dict, verify_checksums, verify_splits |
63 | 64 | from .utils.py_utils import ( |
64 | 65 | classproperty, |
| 66 | + convert_file_size_to_int, |
65 | 67 | has_sufficient_disk_space, |
66 | 68 | map_nested, |
67 | 69 | memoize, |
@@ -575,6 +577,14 @@ def get_imported_module_dir(cls): |
575 | 577 | """Return the path of the module of this class or subclass.""" |
576 | 578 | return os.path.dirname(inspect.getfile(inspect.getmodule(cls))) |
577 | 579 |
|
| 580 | + def _rename(self, src: str, dst: str): |
| 581 | + is_local = not is_remote_filesystem(self._fs) |
| 582 | + if is_local: |
| 583 | + # LocalFileSystem.mv does copy + rm, it is more efficient to simply rename a local directory |
| 584 | + os.rename(self._fs._strip_protocol(src), self._fs._strip_protocol(dst)) |
| 585 | + else: |
| 586 | + self._fs.mv(src, dst, recursive=True) |
| 587 | + |
578 | 588 | def download_and_prepare( |
579 | 589 | self, |
580 | 590 | download_config: Optional[DownloadConfig] = None, |
@@ -672,11 +682,7 @@ def incomplete_dir(dirname): |
672 | 682 | yield tmp_dir |
673 | 683 | if self._fs.isdir(dirname): |
674 | 684 | self._fs.rm(dirname, recursive=True) |
675 | | - if is_local: |
676 | | - # LocalFileSystem.mv does copy + rm, it is more efficient to simply rename a local directory |
677 | | - os.rename(self._fs._strip_protocol(tmp_dir), self._fs._strip_protocol(dirname)) |
678 | | - else: |
679 | | - self._fs.mv(tmp_dir, dirname, recursive=True) |
| 685 | + self._rename(tmp_dir, dirname) |
680 | 686 | finally: |
681 | 687 | if self._fs.exists(tmp_dir): |
682 | 688 | self._fs.rm(tmp_dir, recursive=True) |
@@ -1224,51 +1230,90 @@ def _generate_examples(self, **kwargs): |
1224 | 1230 | """ |
1225 | 1231 | raise NotImplementedError() |
1226 | 1232 |
|
1227 | | - def _prepare_split(self, split_generator, check_duplicate_keys, file_format=None): |
| 1233 | + def _prepare_split(self, split_generator, check_duplicate_keys, file_format=None, max_shard_size=None): |
1228 | 1234 | is_local = not is_remote_filesystem(self._fs) |
1229 | 1235 | path_join = os.path.join if is_local else posixpath.join |
| 1236 | + file_format = file_format or "arrow" |
| 1237 | + |
| 1238 | + if max_shard_size is not None: |
| 1239 | + max_shard_size = convert_file_size_to_int(max_shard_size) |
| 1240 | + if file_format == "arrow": |
| 1241 | + raise NotImplementedError( |
| 1242 | + "Writing sharded arrow files is not supported. Please don't use max_shard_size or use parquet." |
| 1243 | + ) |
1230 | 1244 |
|
1231 | 1245 | if self.info.splits is not None: |
1232 | 1246 | split_info = self.info.splits[split_generator.name] |
1233 | 1247 | else: |
1234 | 1248 | split_info = split_generator.split_info |
1235 | 1249 |
|
1236 | | - file_format = file_format or "arrow" |
1237 | | - suffix = "-00000-of-00001" if file_format == "parquet" else "" |
| 1250 | + suffix = "-SSSSS-of-NNNNN" if file_format == "parquet" else "" |
1238 | 1251 | fname = f"{self.name}-{split_generator.name}{suffix}.{file_format}" |
1239 | 1252 | fpath = path_join(self._cache_dir, fname) |
1240 | 1253 |
|
1241 | 1254 | generator = self._generate_examples(**split_generator.gen_kwargs) |
1242 | 1255 |
|
1243 | 1256 | writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter |
1244 | | - with writer_class( |
| 1257 | + |
| 1258 | + shard_id = 0 |
| 1259 | + writer = writer_class( |
1245 | 1260 | features=self.info.features, |
1246 | | - path=fpath, |
| 1261 | + path=fpath.replace("SSSSS", f"{shard_id:05d}"), |
1247 | 1262 | writer_batch_size=self._writer_batch_size, |
1248 | 1263 | hash_salt=split_info.name, |
1249 | 1264 | check_duplicates=check_duplicate_keys, |
1250 | 1265 | storage_options=self._fs.storage_options, |
1251 | | - ) as writer: |
1252 | | - try: |
1253 | | - for key, record in logging.tqdm( |
1254 | | - generator, |
1255 | | - unit=" examples", |
1256 | | - total=split_info.num_examples, |
1257 | | - leave=False, |
1258 | | - disable=not logging.is_progress_bar_enabled(), |
1259 | | - desc=f"Generating {split_info.name} split", |
1260 | | - ): |
1261 | | - example = self.info.features.encode_example(record) |
1262 | | - writer.write(example, key) |
1263 | | - finally: |
1264 | | - num_examples, num_bytes = writer.finalize() |
1265 | | - |
1266 | | - split_generator.split_info.num_examples = num_examples |
1267 | | - split_generator.split_info.num_bytes = num_bytes |
| 1266 | + ) |
| 1267 | + total_num_examples, total_num_bytes = 0, 0 |
| 1268 | + try: |
| 1269 | + for key, record in logging.tqdm( |
| 1270 | + generator, |
| 1271 | + unit=" examples", |
| 1272 | + total=split_info.num_examples, |
| 1273 | + leave=False, |
| 1274 | + disable=not logging.is_progress_bar_enabled(), |
| 1275 | + desc=f"Generating {split_info.name} split", |
| 1276 | + ): |
| 1277 | + if max_shard_size is not None and writer._num_bytes > max_shard_size: |
| 1278 | + num_examples, num_bytes = writer.finalize() |
| 1279 | + total_num_examples += num_examples |
| 1280 | + total_num_bytes += num_bytes |
| 1281 | + shard_id += 1 |
| 1282 | + writer = writer_class( |
| 1283 | + features=writer._features, |
| 1284 | + path=fpath.replace("SSSSS", f"{shard_id:05d}"), |
| 1285 | + writer_batch_size=self._writer_batch_size, |
| 1286 | + hash_salt=split_info.name, |
| 1287 | + check_duplicates=check_duplicate_keys, |
| 1288 | + storage_options=self._fs.storage_options, |
| 1289 | + ) |
| 1290 | + example = self.info.features.encode_example(record) |
| 1291 | + writer.write(example, key) |
| 1292 | + finally: |
| 1293 | + num_shards = shard_id + 1 |
| 1294 | + num_examples, num_bytes = writer.finalize() |
| 1295 | + total_num_examples += num_examples |
| 1296 | + total_num_bytes += num_bytes |
1268 | 1297 |
|
1269 | | - def _download_and_prepare(self, dl_manager, verify_infos, file_format=None): |
| 1298 | + if file_format == "parquet": |
| 1299 | + |
| 1300 | + def _rename_shard(shard_id: int): |
| 1301 | + self._rename( |
| 1302 | + fpath.replace("SSSSS", f"{shard_id:05d}"), |
| 1303 | + fpath.replace("SSSSS", f"{shard_id:05d}").replace("NNNNN", f"{num_shards:05d}"), |
| 1304 | + ) |
| 1305 | + |
| 1306 | + logger.debug(f"Renaming {num_shards} shards.") |
| 1307 | + thread_map(_rename_shard, range(num_shards), disable=True, max_workers=64) |
| 1308 | + |
| 1309 | + split_generator.split_info.num_examples = total_num_examples |
| 1310 | + split_generator.split_info.num_bytes = total_num_bytes |
| 1311 | + if self.info.features is None: |
| 1312 | + self.info.features = writer._features |
| 1313 | + |
| 1314 | + def _download_and_prepare(self, dl_manager, verify_infos, file_format=None, **kwargs): |
1270 | 1315 | super()._download_and_prepare( |
1271 | | - dl_manager, verify_infos, file_format=file_format, check_duplicate_keys=verify_infos |
| 1316 | + dl_manager, verify_infos, file_format=file_format, check_duplicate_keys=verify_infos, **kwargs |
1272 | 1317 | ) |
1273 | 1318 |
|
1274 | 1319 | def _get_examples_iterable_for_split(self, split_generator: SplitGenerator) -> ExamplesIterable: |
@@ -1310,26 +1355,70 @@ def _generate_tables(self, **kwargs): |
1310 | 1355 | """ |
1311 | 1356 | raise NotImplementedError() |
1312 | 1357 |
|
1313 | | - def _prepare_split(self, split_generator, file_format=None): |
| 1358 | + def _prepare_split(self, split_generator, file_format=None, max_shard_size=None): |
1314 | 1359 | is_local = not is_remote_filesystem(self._fs) |
1315 | 1360 | path_join = os.path.join if is_local else posixpath.join |
1316 | | - |
1317 | 1361 | file_format = file_format or "arrow" |
1318 | | - suffix = "-00000-of-00001" if file_format == "parquet" else "" |
| 1362 | + |
| 1363 | + if max_shard_size is not None: |
| 1364 | + if file_format == "arrow": |
| 1365 | + raise NotImplementedError( |
| 1366 | + "Writing sharded arrow files is not supported. Please don't use max_shard_size or use parquet." |
| 1367 | + ) |
| 1368 | + max_shard_size = convert_file_size_to_int(max_shard_size or "500MB") |
| 1369 | + |
| 1370 | + suffix = "-SSSSS-of-NNNNN" if file_format == "parquet" else "" |
1319 | 1371 | fname = f"{self.name}-{split_generator.name}{suffix}.{file_format}" |
1320 | 1372 | fpath = path_join(self._cache_dir, fname) |
1321 | 1373 |
|
1322 | 1374 | generator = self._generate_tables(**split_generator.gen_kwargs) |
| 1375 | + |
1323 | 1376 | writer_class = ParquetWriter if file_format == "parquet" else ArrowWriter |
1324 | | - with writer_class(features=self.info.features, path=fpath, storage_options=self._fs.storage_options) as writer: |
| 1377 | + |
| 1378 | + shard_id = 0 |
| 1379 | + writer = writer_class( |
| 1380 | + features=self.info.features, |
| 1381 | + path=fpath.replace("SSSSS", f"{shard_id:05d}"), |
| 1382 | + storage_options=self._fs.storage_options, |
| 1383 | + ) |
| 1384 | + total_num_examples, total_num_bytes = 0, 0 |
| 1385 | + try: |
1325 | 1386 | for key, table in logging.tqdm( |
1326 | | - generator, unit=" tables", leave=False, disable=(not logging.is_progress_bar_enabled()) |
| 1387 | + generator, |
| 1388 | + unit=" tables", |
| 1389 | + leave=False, |
| 1390 | + disable=not logging.is_progress_bar_enabled(), |
1327 | 1391 | ): |
| 1392 | + if max_shard_size is not None and writer._num_bytes > max_shard_size: |
| 1393 | + num_examples, num_bytes = writer.finalize() |
| 1394 | + total_num_examples += num_examples |
| 1395 | + total_num_bytes += num_bytes |
| 1396 | + shard_id += 1 |
| 1397 | + writer = writer_class( |
| 1398 | + features=writer._features, |
| 1399 | + path=fpath.replace("SSSSS", f"{shard_id:05d}"), |
| 1400 | + storage_options=self._fs.storage_options, |
| 1401 | + ) |
1328 | 1402 | writer.write_table(table) |
| 1403 | + finally: |
| 1404 | + num_shards = shard_id + 1 |
1329 | 1405 | num_examples, num_bytes = writer.finalize() |
| 1406 | + total_num_examples += num_examples |
| 1407 | + total_num_bytes += num_bytes |
| 1408 | + |
| 1409 | + if file_format == "parquet": |
| 1410 | + |
| 1411 | + def _rename_shard(shard_id: int): |
| 1412 | + self._rename( |
| 1413 | + fpath.replace("SSSSS", f"{shard_id:05d}"), |
| 1414 | + fpath.replace("SSSSS", f"{shard_id:05d}").replace("NNNNN", f"{num_shards:05d}"), |
| 1415 | + ) |
| 1416 | + |
| 1417 | + logger.debug(f"Renaming {num_shards} shards.") |
| 1418 | + thread_map(_rename_shard, range(num_shards), disable=True, max_workers=64) |
1330 | 1419 |
|
1331 | | - split_generator.split_info.num_examples = num_examples |
1332 | | - split_generator.split_info.num_bytes = num_bytes |
| 1420 | + split_generator.split_info.num_examples = total_num_examples |
| 1421 | + split_generator.split_info.num_bytes = total_num_bytes |
1333 | 1422 | if self.info.features is None: |
1334 | 1423 | self.info.features = writer._features |
1335 | 1424 |
|
|
0 commit comments