Skip to content

Commit f45bc6c

Browse files
lhoestqmariosasko
andauthored
Fix concurrent script loading with force_redownload (#6718)
* fix concurrent script loading with force_redownload * support various dynamic modules paths * fix test * fix tests * disable on windows * Update tests/test_load.py Co-authored-by: Mario Šaško <[email protected]> --------- Co-authored-by: Mario Šaško <[email protected]>
1 parent e52f4d0 commit f45bc6c

File tree

4 files changed

+73
-32
lines changed

4 files changed

+73
-32
lines changed

src/datasets/load.py

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import time
2828
import warnings
2929
from collections import Counter
30+
from contextlib import nullcontext
3031
from dataclasses import dataclass, field
3132
from pathlib import Path
3233
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union
@@ -70,7 +71,6 @@
7071
)
7172
from .splits import Split
7273
from .utils import _datasets_server
73-
from .utils._filelock import FileLock
7474
from .utils.deprecation_utils import deprecated
7575
from .utils.file_utils import (
7676
OfflineModeIsEnabled,
@@ -87,7 +87,7 @@
8787
from .utils.info_utils import VerificationMode, is_small_dataset
8888
from .utils.logging import get_logger
8989
from .utils.metadata import MetadataConfigs
90-
from .utils.py_utils import get_imports
90+
from .utils.py_utils import get_imports, lock_importable_file
9191
from .utils.version import Version
9292

9393

@@ -244,7 +244,10 @@ def __reduce__(self): # to make dynamically created class pickable, see _Initia
244244
def get_dataset_builder_class(
245245
dataset_module: "DatasetModule", dataset_name: Optional[str] = None
246246
) -> Type[DatasetBuilder]:
247-
builder_cls = import_main_class(dataset_module.module_path)
247+
with lock_importable_file(
248+
dataset_module.importable_file_path
249+
) if dataset_module.importable_file_path else nullcontext():
250+
builder_cls = import_main_class(dataset_module.module_path)
248251
if dataset_module.builder_configs_parameters.builder_configs:
249252
dataset_name = dataset_name or dataset_module.builder_kwargs.get("dataset_name")
250253
if dataset_name is None:
@@ -375,17 +378,15 @@ def _copy_script_and_other_resources_in_importable_dir(
375378
download_mode (Optional[Union[DownloadMode, str]]): download mode
376379
377380
Return:
378-
importable_local_file: path to an importable module with importlib.import_module
381+
importable_file: path to an importable module with importlib.import_module
379382
"""
380-
381383
# Define a directory with a unique name in our dataset or metric folder
382384
# path is: ./datasets|metrics/dataset|metric_name/hash_from_code/script.py
383385
# we use a hash as subdirectory_name to be able to have multiple versions of a dataset/metric processing file together
384386
importable_subdirectory = os.path.join(importable_directory_path, subdirectory_name)
385-
importable_local_file = os.path.join(importable_subdirectory, name + ".py")
387+
importable_file = os.path.join(importable_subdirectory, name + ".py")
386388
# Prevent parallel disk operations
387-
lock_path = importable_directory_path + ".lock"
388-
with FileLock(lock_path):
389+
with lock_importable_file(importable_file):
389390
# Create main dataset/metrics folder if needed
390391
if download_mode == DownloadMode.FORCE_REDOWNLOAD and os.path.exists(importable_directory_path):
391392
shutil.rmtree(importable_directory_path)
@@ -406,13 +407,13 @@ def _copy_script_and_other_resources_in_importable_dir(
406407
pass
407408

408409
# Copy dataset.py file in hash folder if needed
409-
if not os.path.exists(importable_local_file):
410-
shutil.copyfile(original_local_path, importable_local_file)
410+
if not os.path.exists(importable_file):
411+
shutil.copyfile(original_local_path, importable_file)
411412
# Record metadata associating original dataset path with local unique folder
412413
# Use os.path.splitext to split extension from importable_local_file
413-
meta_path = os.path.splitext(importable_local_file)[0] + ".json"
414+
meta_path = os.path.splitext(importable_file)[0] + ".json"
414415
if not os.path.exists(meta_path):
415-
meta = {"original file path": original_local_path, "local file path": importable_local_file}
416+
meta = {"original file path": original_local_path, "local file path": importable_file}
416417
# the filename is *.py in our case, so better rename to filename.json instead of filename.py.json
417418
with open(meta_path, "w", encoding="utf-8") as meta_file:
418419
json.dump(meta, meta_file)
@@ -437,7 +438,7 @@ def _copy_script_and_other_resources_in_importable_dir(
437438
original_path, destination_additional_path
438439
):
439440
shutil.copyfile(original_path, destination_additional_path)
440-
return importable_local_file
441+
return importable_file
441442

442443

443444
def _get_importable_file_path(
@@ -447,7 +448,7 @@ def _get_importable_file_path(
447448
name: str,
448449
) -> str:
449450
importable_directory_path = os.path.join(dynamic_modules_path, module_namespace, name.replace("/", "--"))
450-
return os.path.join(importable_directory_path, subdirectory_name, name + ".py")
451+
return os.path.join(importable_directory_path, subdirectory_name, name.split("/")[-1] + ".py")
451452

452453

453454
def _create_importable_file(
@@ -692,6 +693,7 @@ class DatasetModule:
692693
builder_kwargs: dict
693694
builder_configs_parameters: BuilderConfigsParameters = field(default_factory=BuilderConfigsParameters)
694695
dataset_infos: Optional[DatasetInfosDict] = None
696+
importable_file_path: Optional[str] = None
695697

696698

697699
@dataclass
@@ -983,7 +985,7 @@ def get_module(self) -> DatasetModule:
983985
# make the new module to be noticed by the import system
984986
importlib.invalidate_caches()
985987
builder_kwargs = {"base_path": str(Path(self.path).parent)}
986-
return DatasetModule(module_path, hash, builder_kwargs)
988+
return DatasetModule(module_path, hash, builder_kwargs, importable_file_path=importable_file_path)
987989

988990

989991
class LocalDatasetModuleFactoryWithoutScript(_DatasetModuleFactory):
@@ -1536,7 +1538,7 @@ def get_module(self) -> DatasetModule:
15361538
"base_path": hf_hub_url(self.name, "", revision=self.revision).rstrip("/"),
15371539
"repo_id": self.name,
15381540
}
1539-
return DatasetModule(module_path, hash, builder_kwargs)
1541+
return DatasetModule(module_path, hash, builder_kwargs, importable_file_path=importable_file_path)
15401542

15411543

15421544
class CachedDatasetModuleFactory(_DatasetModuleFactory):
@@ -1582,21 +1584,24 @@ def _get_modification_time(module_hash):
15821584
if not config.HF_DATASETS_OFFLINE:
15831585
warning_msg += ", or remotely on the Hugging Face Hub."
15841586
logger.warning(warning_msg)
1585-
# make the new module to be noticed by the import system
1586-
module_path = ".".join(
1587-
[
1588-
os.path.basename(dynamic_modules_path),
1589-
"datasets",
1590-
self.name.replace("/", "--"),
1591-
hash,
1592-
self.name.split("/")[-1],
1593-
]
1587+
importable_file_path = _get_importable_file_path(
1588+
dynamic_modules_path=dynamic_modules_path,
1589+
module_namespace="datasets",
1590+
subdirectory_name=hash,
1591+
name=self.name,
1592+
)
1593+
module_path, hash = _load_importable_file(
1594+
dynamic_modules_path=dynamic_modules_path,
1595+
module_namespace="datasets",
1596+
subdirectory_name=hash,
1597+
name=self.name,
15941598
)
1599+
# make the new module to be noticed by the import system
15951600
importlib.invalidate_caches()
15961601
builder_kwargs = {
15971602
"repo_id": self.name,
15981603
}
1599-
return DatasetModule(module_path, hash, builder_kwargs)
1604+
return DatasetModule(module_path, hash, builder_kwargs, importable_file_path=importable_file_path)
16001605
cache_dir = os.path.expanduser(str(self.cache_dir or config.HF_DATASETS_CACHE))
16011606
cached_datasets_directory_path_root = os.path.join(cache_dir, self.name.replace("/", "___"))
16021607
cached_directory_paths = [

src/datasets/streaming.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
)
3232
from .utils.logging import get_logger
3333
from .utils.patching import patch_submodule
34-
from .utils.py_utils import get_imports
34+
from .utils.py_utils import get_imports, lock_importable_file
3535

3636

3737
logger = get_logger(__name__)
@@ -120,11 +120,13 @@ def extend_dataset_builder_for_streaming(builder: "DatasetBuilder"):
120120
extend_module_for_streaming(builder.__module__, download_config=download_config)
121121
# if needed, we also have to extend additional internal imports (like wmt14 -> wmt_utils)
122122
if not builder.__module__.startswith("datasets."): # check that it's not a packaged builder like csv
123-
for imports in get_imports(inspect.getfile(builder.__class__)):
124-
if imports[0] == "internal":
125-
internal_import_name = imports[1]
126-
internal_module_name = ".".join(builder.__module__.split(".")[:-1] + [internal_import_name])
127-
extend_module_for_streaming(internal_module_name, download_config=download_config)
123+
importable_file = inspect.getfile(builder.__class__)
124+
with lock_importable_file(importable_file):
125+
for imports in get_imports(importable_file):
126+
if imports[0] == "internal":
127+
internal_import_name = imports[1]
128+
internal_module_name = ".".join(builder.__module__.split(".")[:-1] + [internal_import_name])
129+
extend_module_for_streaming(internal_module_name, download_config=download_config)
128130

129131
# builders can inherit from other builders that might use streaming functionality
130132
# (for example, ImageFolder and AudioFolder inherit from FolderBuilder which implements examples generation)

src/datasets/utils/py_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from contextlib import contextmanager
2828
from dataclasses import fields, is_dataclass
2929
from multiprocessing import Manager
30+
from pathlib import Path
3031
from queue import Empty
3132
from shutil import disk_usage
3233
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union
@@ -47,6 +48,7 @@
4748
dumps,
4849
pklregister,
4950
)
51+
from ._filelock import FileLock
5052

5153

5254
try: # pragma: no branch
@@ -537,6 +539,15 @@ def _convert_github_url(url_path: str) -> Tuple[str, Optional[str]]:
537539
return url_path, sub_directory
538540

539541

542+
def lock_importable_file(importable_local_file: str) -> FileLock:
543+
# Check the directory with a unique name in our dataset folder
544+
# path is: ./datasets/dataset_name/hash_from_code/script.py
545+
# we use a hash as subdirectory_name to be able to have multiple versions of a dataset/metric processing file together
546+
importable_directory_path = str(Path(importable_local_file).resolve().parent.parent)
547+
lock_path = importable_directory_path + ".lock"
548+
return FileLock(lock_path)
549+
550+
540551
def get_imports(file_path: str) -> Tuple[str, str, str, str]:
541552
"""Find whether we should import or clone additional files for a given processing script.
542553
And list the import.

tests/test_load.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
assert_arrow_memory_doesnt_increase,
5555
assert_arrow_memory_increases,
5656
offline,
57+
require_not_windows,
5758
require_pil,
5859
require_sndfile,
5960
set_current_working_directory_to_temp_dir,
@@ -1674,6 +1675,28 @@ def test_load_dataset_distributed(tmp_path, csv_path):
16741675
assert all(dataset.cache_files == datasets[0].cache_files for dataset in datasets)
16751676

16761677

1678+
def distributed_load_dataset_with_script(args):
1679+
data_name, tmp_dir, download_mode = args
1680+
dataset = load_dataset(data_name, cache_dir=tmp_dir, download_mode=download_mode)
1681+
return dataset
1682+
1683+
1684+
@require_not_windows # windows doesn't support overwriting Arrow files from other processes
1685+
@pytest.mark.parametrize("download_mode", [None, "force_redownload"])
1686+
def test_load_dataset_distributed_with_script(tmp_path, download_mode):
1687+
# we need to check in the "force_redownload" case
1688+
# since in `_copy_script_and_other_resources_in_importable_dir()` we might delete the directory
1689+
# containing the .py file while the other processes use it
1690+
num_workers = 5
1691+
args = (SAMPLE_DATASET_IDENTIFIER, str(tmp_path), download_mode)
1692+
with Pool(processes=num_workers) as pool: # start num_workers processes
1693+
datasets = pool.map(distributed_load_dataset_with_script, [args] * num_workers)
1694+
assert len(datasets) == num_workers
1695+
assert all(len(dataset) == len(datasets[0]) > 0 for dataset in datasets)
1696+
assert len(datasets[0].cache_files) > 0
1697+
assert all(dataset.cache_files == datasets[0].cache_files for dataset in datasets)
1698+
1699+
16771700
def test_load_dataset_with_storage_options(mockfs):
16781701
with mockfs.open("data.txt", "w") as f:
16791702
f.write("Hello there\n")

0 commit comments

Comments
 (0)