Skip to content

Commit c1dce72

Browse files
Support remote cache_dir (#4347)
* Support cache_dir when remote * Cast cache_dir to str * Force CI re-run * Fix cast cache_dir to str * Refactor data dirs preparation * Make explicit when only local cache dir is considered * Fix casting of cache_dir_root * Fix casting of cache_downloaded_dir * Use xjoin * Revert "Use xjoin" This reverts commit 43714db. * Add suggested comments
1 parent 7c8106d commit c1dce72

File tree

1 file changed

+66
-41
lines changed

1 file changed

+66
-41
lines changed

src/datasets/builder.py

Lines changed: 66 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import copy
2121
import inspect
2222
import os
23+
import posixpath
2324
import shutil
2425
import textwrap
2526
import urllib
@@ -286,25 +287,37 @@ def __init__(
286287
if features is not None:
287288
self.info.features = features
288289

289-
# prepare data dirs
290-
self._cache_dir_root = os.path.expanduser(cache_dir or config.HF_DATASETS_CACHE)
290+
# Prepare data dirs:
291+
# cache_dir can be a remote bucket on GCS or S3 (when using BeamBasedBuilder for distributed data processing)
292+
self._cache_dir_root = str(cache_dir or config.HF_DATASETS_CACHE)
293+
self._cache_dir_root = (
294+
self._cache_dir_root if is_remote_url(self._cache_dir_root) else os.path.expanduser(self._cache_dir_root)
295+
)
296+
path_join = posixpath.join if is_remote_url(self._cache_dir_root) else os.path.join
297+
self._cache_downloaded_dir = (
298+
path_join(self._cache_dir_root, config.DOWNLOADED_DATASETS_DIR)
299+
if cache_dir
300+
else str(config.DOWNLOADED_DATASETS_PATH)
301+
)
291302
self._cache_downloaded_dir = (
292-
os.path.join(cache_dir, config.DOWNLOADED_DATASETS_DIR) if cache_dir else config.DOWNLOADED_DATASETS_PATH
303+
self._cache_downloaded_dir
304+
if is_remote_url(self._cache_downloaded_dir)
305+
else os.path.expanduser(self._cache_downloaded_dir)
293306
)
294307
self._cache_dir = self._build_cache_dir()
295308
if not is_remote_url(self._cache_dir_root):
296309
os.makedirs(self._cache_dir_root, exist_ok=True)
297-
lock_path = os.path.join(self._cache_dir_root, self._cache_dir.replace(os.sep, "_") + ".lock")
298-
with FileLock(lock_path):
299-
if os.path.exists(self._cache_dir): # check if data exist
300-
if len(os.listdir(self._cache_dir)) > 0:
301-
logger.info("Overwrite dataset info from restored data version.")
302-
self.info = DatasetInfo.from_directory(self._cache_dir)
303-
else: # dir exists but no data, remove the empty dir as data aren't available anymore
304-
logger.warning(
305-
f"Old caching folder {self._cache_dir} for dataset {self.name} exists but not data were found. Removing it. "
306-
)
307-
os.rmdir(self._cache_dir)
310+
lock_path = os.path.join(self._cache_dir_root, self._cache_dir.replace(os.sep, "_") + ".lock")
311+
with FileLock(lock_path):
312+
if os.path.exists(self._cache_dir): # check if data exist
313+
if len(os.listdir(self._cache_dir)) > 0:
314+
logger.info("Overwrite dataset info from restored data version.")
315+
self.info = DatasetInfo.from_directory(self._cache_dir)
316+
else: # dir exists but no data, remove the empty dir as data aren't available anymore
317+
logger.warning(
318+
f"Old caching folder {self._cache_dir} for dataset {self.name} exists but not data were found. Removing it. "
319+
)
320+
os.rmdir(self._cache_dir)
308321

309322
# Set download manager
310323
self.dl_manager = None
@@ -439,7 +452,7 @@ def builder_configs(cls):
439452
def cache_dir(self):
440453
return self._cache_dir
441454

442-
def _relative_data_dir(self, with_version=True, with_hash=True) -> str:
455+
def _relative_data_dir(self, with_version=True, with_hash=True, is_local=True) -> str:
443456
"""Relative path of this dataset in cache_dir:
444457
Will be:
445458
self.name/self.config.version/self.hash/
@@ -451,19 +464,26 @@ def _relative_data_dir(self, with_version=True, with_hash=True) -> str:
451464
builder_data_dir = self.name if namespace is None else f"{namespace}___{self.name}"
452465
builder_config = self.config
453466
hash = self.hash
467+
path_join = os.path.join if is_local else posixpath.join
454468
if builder_config:
455469
# use the enriched name instead of the name to make it unique
456-
builder_data_dir = os.path.join(builder_data_dir, self.config_id)
470+
builder_data_dir = path_join(builder_data_dir, self.config_id)
457471
if with_version:
458-
builder_data_dir = os.path.join(builder_data_dir, str(self.config.version))
472+
builder_data_dir = path_join(builder_data_dir, str(self.config.version))
459473
if with_hash and hash and isinstance(hash, str):
460-
builder_data_dir = os.path.join(builder_data_dir, hash)
474+
builder_data_dir = path_join(builder_data_dir, hash)
461475
return builder_data_dir
462476

463477
def _build_cache_dir(self):
464478
"""Return the data directory for the current version."""
465-
builder_data_dir = os.path.join(self._cache_dir_root, self._relative_data_dir(with_version=False))
466-
version_data_dir = os.path.join(self._cache_dir_root, self._relative_data_dir(with_version=True))
479+
is_local = not is_remote_url(self._cache_dir_root)
480+
path_join = os.path.join if is_local else posixpath.join
481+
builder_data_dir = path_join(
482+
self._cache_dir_root, self._relative_data_dir(with_version=False, is_local=is_local)
483+
)
484+
version_data_dir = path_join(
485+
self._cache_dir_root, self._relative_data_dir(with_version=True, is_local=is_local)
486+
)
467487

468488
def _other_versions_on_disk():
469489
"""Returns previous versions on disk."""
@@ -480,16 +500,17 @@ def _other_versions_on_disk():
480500
return version_dirnames
481501

482502
# Check and warn if other versions exist on disk
483-
version_dirs = _other_versions_on_disk()
484-
if version_dirs:
485-
other_version = version_dirs[0][0]
486-
if other_version != self.config.version:
487-
warn_msg = (
488-
f"Found a different version {str(other_version)} of dataset {self.name} in "
489-
f"cache_dir {self._cache_dir_root}. Using currently defined version "
490-
f"{str(self.config.version)}."
491-
)
492-
logger.warning(warn_msg)
503+
if not is_remote_url(builder_data_dir):
504+
version_dirs = _other_versions_on_disk()
505+
if version_dirs:
506+
other_version = version_dirs[0][0]
507+
if other_version != self.config.version:
508+
warn_msg = (
509+
f"Found a different version {str(other_version)} of dataset {self.name} in "
510+
f"cache_dir {self._cache_dir_root}. Using currently defined version "
511+
f"{str(self.config.version)}."
512+
)
513+
logger.warning(warn_msg)
493514

494515
return version_data_dir
495516

@@ -571,18 +592,22 @@ def download_and_prepare(
571592
self.dl_manager = dl_manager
572593

573594
# Prevent parallel disk operations
574-
lock_path = os.path.join(self._cache_dir_root, self._cache_dir.replace(os.sep, "_") + ".lock")
575-
with FileLock(lock_path):
576-
data_exists = os.path.exists(self._cache_dir)
577-
if data_exists and download_mode == DownloadMode.REUSE_DATASET_IF_EXISTS:
578-
logger.warning(f"Reusing dataset {self.name} ({self._cache_dir})")
579-
# We need to update the info in case some splits were added in the meantime
580-
# for example when calling load_dataset from multiple workers.
581-
self.info = self._load_info()
582-
self.download_post_processing_resources(dl_manager)
583-
return
595+
is_local = not is_remote_url(self._cache_dir_root)
596+
if is_local:
597+
lock_path = os.path.join(self._cache_dir_root, self._cache_dir.replace(os.sep, "_") + ".lock")
598+
# File locking only with local paths; no file locking on GCS or S3
599+
with FileLock(lock_path) if is_local else contextlib.nullcontext():
600+
if is_local:
601+
data_exists = os.path.exists(self._cache_dir)
602+
if data_exists and download_mode == DownloadMode.REUSE_DATASET_IF_EXISTS:
603+
logger.warning(f"Reusing dataset {self.name} ({self._cache_dir})")
604+
# We need to update the info in case some splits were added in the meantime
605+
# for example when calling load_dataset from multiple workers.
606+
self.info = self._load_info()
607+
self.download_post_processing_resources(dl_manager)
608+
return
584609
logger.info(f"Generating dataset {self.name} ({self._cache_dir})")
585-
if not is_remote_url(self._cache_dir_root): # if cache dir is local, check for available space
610+
if is_local: # if cache dir is local, check for available space
586611
if not has_sufficient_disk_space(self.info.size_in_bytes or 0, directory=self._cache_dir_root):
587612
raise OSError(
588613
f"Not enough disk space. Needed: {size_str(self.info.size_in_bytes or 0)} (download: {size_str(self.info.download_size or 0)}, generated: {size_str(self.info.dataset_size or 0)}, post-processed: {size_str(self.info.post_processing_size or 0)})"

0 commit comments

Comments
 (0)