diff --git a/src/datasets/load.py b/src/datasets/load.py index 685e5d543c2..55c91f3e1df 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -27,6 +27,7 @@ import time import warnings from collections import Counter +from contextlib import nullcontext from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union @@ -70,7 +71,6 @@ ) from .splits import Split from .utils import _datasets_server -from .utils._filelock import FileLock from .utils.deprecation_utils import deprecated from .utils.file_utils import ( OfflineModeIsEnabled, @@ -87,7 +87,7 @@ from .utils.info_utils import VerificationMode, is_small_dataset from .utils.logging import get_logger from .utils.metadata import MetadataConfigs -from .utils.py_utils import get_imports +from .utils.py_utils import get_imports, lock_importable_file from .utils.version import Version @@ -244,7 +244,10 @@ def __reduce__(self): # to make dynamically created class pickable, see _Initia def get_dataset_builder_class( dataset_module: "DatasetModule", dataset_name: Optional[str] = None ) -> Type[DatasetBuilder]: - builder_cls = import_main_class(dataset_module.module_path) + with lock_importable_file( + dataset_module.importable_file_path + ) if dataset_module.importable_file_path else nullcontext(): + builder_cls = import_main_class(dataset_module.module_path) if dataset_module.builder_configs_parameters.builder_configs: dataset_name = dataset_name or dataset_module.builder_kwargs.get("dataset_name") if dataset_name is None: @@ -375,17 +378,15 @@ def _copy_script_and_other_resources_in_importable_dir( download_mode (Optional[Union[DownloadMode, str]]): download mode Return: - importable_local_file: path to an importable module with importlib.import_module + importable_file: path to an importable module with importlib.import_module """ - # Define a directory with a unique name in our dataset or metric folder # path is: ./datasets|metrics/dataset|metric_name/hash_from_code/script.py # we use a hash as subdirectory_name to be able to have multiple versions of a dataset/metric processing file together importable_subdirectory = os.path.join(importable_directory_path, subdirectory_name) - importable_local_file = os.path.join(importable_subdirectory, name + ".py") + importable_file = os.path.join(importable_subdirectory, name + ".py") # Prevent parallel disk operations - lock_path = importable_directory_path + ".lock" - with FileLock(lock_path): + with lock_importable_file(importable_file): # Create main dataset/metrics folder if needed if download_mode == DownloadMode.FORCE_REDOWNLOAD and os.path.exists(importable_directory_path): shutil.rmtree(importable_directory_path) @@ -406,13 +407,13 @@ def _copy_script_and_other_resources_in_importable_dir( pass # Copy dataset.py file in hash folder if needed - if not os.path.exists(importable_local_file): - shutil.copyfile(original_local_path, importable_local_file) + if not os.path.exists(importable_file): + shutil.copyfile(original_local_path, importable_file) # Record metadata associating original dataset path with local unique folder # Use os.path.splitext to split extension from importable_local_file - meta_path = os.path.splitext(importable_local_file)[0] + ".json" + meta_path = os.path.splitext(importable_file)[0] + ".json" if not os.path.exists(meta_path): - meta = {"original file path": original_local_path, "local file path": importable_local_file} + meta = {"original file path": original_local_path, "local file path": importable_file} # the filename is *.py in our case, so better rename to filename.json instead of filename.py.json with open(meta_path, "w", encoding="utf-8") as meta_file: json.dump(meta, meta_file) @@ -437,7 +438,7 @@ def _copy_script_and_other_resources_in_importable_dir( original_path, destination_additional_path ): shutil.copyfile(original_path, destination_additional_path) - return importable_local_file + return importable_file def _get_importable_file_path( @@ -447,7 +448,7 @@ def _get_importable_file_path( name: str, ) -> str: importable_directory_path = os.path.join(dynamic_modules_path, module_namespace, name.replace("/", "--")) - return os.path.join(importable_directory_path, subdirectory_name, name + ".py") + return os.path.join(importable_directory_path, subdirectory_name, name.split("/")[-1] + ".py") def _create_importable_file( @@ -692,6 +693,7 @@ class DatasetModule: builder_kwargs: dict builder_configs_parameters: BuilderConfigsParameters = field(default_factory=BuilderConfigsParameters) dataset_infos: Optional[DatasetInfosDict] = None + importable_file_path: Optional[str] = None @dataclass @@ -983,7 +985,7 @@ def get_module(self) -> DatasetModule: # make the new module to be noticed by the import system importlib.invalidate_caches() builder_kwargs = {"base_path": str(Path(self.path).parent)} - return DatasetModule(module_path, hash, builder_kwargs) + return DatasetModule(module_path, hash, builder_kwargs, importable_file_path=importable_file_path) class LocalDatasetModuleFactoryWithoutScript(_DatasetModuleFactory): @@ -1536,7 +1538,7 @@ def get_module(self) -> DatasetModule: "base_path": hf_hub_url(self.name, "", revision=self.revision).rstrip("/"), "repo_id": self.name, } - return DatasetModule(module_path, hash, builder_kwargs) + return DatasetModule(module_path, hash, builder_kwargs, importable_file_path=importable_file_path) class CachedDatasetModuleFactory(_DatasetModuleFactory): @@ -1582,21 +1584,24 @@ def _get_modification_time(module_hash): if not config.HF_DATASETS_OFFLINE: warning_msg += ", or remotely on the Hugging Face Hub." logger.warning(warning_msg) - # make the new module to be noticed by the import system - module_path = ".".join( - [ - os.path.basename(dynamic_modules_path), - "datasets", - self.name.replace("/", "--"), - hash, - self.name.split("/")[-1], - ] + importable_file_path = _get_importable_file_path( + dynamic_modules_path=dynamic_modules_path, + module_namespace="datasets", + subdirectory_name=hash, + name=self.name, + ) + module_path, hash = _load_importable_file( + dynamic_modules_path=dynamic_modules_path, + module_namespace="datasets", + subdirectory_name=hash, + name=self.name, ) + # make the new module to be noticed by the import system importlib.invalidate_caches() builder_kwargs = { "repo_id": self.name, } - return DatasetModule(module_path, hash, builder_kwargs) + return DatasetModule(module_path, hash, builder_kwargs, importable_file_path=importable_file_path) cache_dir = os.path.expanduser(str(self.cache_dir or config.HF_DATASETS_CACHE)) cached_datasets_directory_path_root = os.path.join(cache_dir, self.name.replace("/", "___")) cached_directory_paths = [ diff --git a/src/datasets/streaming.py b/src/datasets/streaming.py index d9e7e185a95..053375bfb8c 100644 --- a/src/datasets/streaming.py +++ b/src/datasets/streaming.py @@ -31,7 +31,7 @@ ) from .utils.logging import get_logger from .utils.patching import patch_submodule -from .utils.py_utils import get_imports +from .utils.py_utils import get_imports, lock_importable_file logger = get_logger(__name__) @@ -120,11 +120,13 @@ def extend_dataset_builder_for_streaming(builder: "DatasetBuilder"): extend_module_for_streaming(builder.__module__, download_config=download_config) # if needed, we also have to extend additional internal imports (like wmt14 -> wmt_utils) if not builder.__module__.startswith("datasets."): # check that it's not a packaged builder like csv - for imports in get_imports(inspect.getfile(builder.__class__)): - if imports[0] == "internal": - internal_import_name = imports[1] - internal_module_name = ".".join(builder.__module__.split(".")[:-1] + [internal_import_name]) - extend_module_for_streaming(internal_module_name, download_config=download_config) + importable_file = inspect.getfile(builder.__class__) + with lock_importable_file(importable_file): + for imports in get_imports(importable_file): + if imports[0] == "internal": + internal_import_name = imports[1] + internal_module_name = ".".join(builder.__module__.split(".")[:-1] + [internal_import_name]) + extend_module_for_streaming(internal_module_name, download_config=download_config) # builders can inherit from other builders that might use streaming functionality # (for example, ImageFolder and AudioFolder inherit from FolderBuilder which implements examples generation) diff --git a/src/datasets/utils/py_utils.py b/src/datasets/utils/py_utils.py index db63d5058c1..779d7dabc24 100644 --- a/src/datasets/utils/py_utils.py +++ b/src/datasets/utils/py_utils.py @@ -27,6 +27,7 @@ from contextlib import contextmanager from dataclasses import fields, is_dataclass from multiprocessing import Manager +from pathlib import Path from queue import Empty from shutil import disk_usage from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union @@ -47,6 +48,7 @@ dumps, pklregister, ) +from ._filelock import FileLock try: # pragma: no branch @@ -537,6 +539,15 @@ def _convert_github_url(url_path: str) -> Tuple[str, Optional[str]]: return url_path, sub_directory +def lock_importable_file(importable_local_file: str) -> FileLock: + # Check the directory with a unique name in our dataset folder + # path is: ./datasets/dataset_name/hash_from_code/script.py + # we use a hash as subdirectory_name to be able to have multiple versions of a dataset/metric processing file together + importable_directory_path = str(Path(importable_local_file).resolve().parent.parent) + lock_path = importable_directory_path + ".lock" + return FileLock(lock_path) + + def get_imports(file_path: str) -> Tuple[str, str, str, str]: """Find whether we should import or clone additional files for a given processing script. And list the import. diff --git a/tests/test_load.py b/tests/test_load.py index bd81f018403..9261ed7db2f 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -54,6 +54,7 @@ assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases, offline, + require_not_windows, require_pil, require_sndfile, set_current_working_directory_to_temp_dir, @@ -1674,6 +1675,28 @@ def test_load_dataset_distributed(tmp_path, csv_path): assert all(dataset.cache_files == datasets[0].cache_files for dataset in datasets) +def distributed_load_dataset_with_script(args): + data_name, tmp_dir, download_mode = args + dataset = load_dataset(data_name, cache_dir=tmp_dir, download_mode=download_mode) + return dataset + + +@require_not_windows # windows doesn't support overwriting Arrow files from other processes +@pytest.mark.parametrize("download_mode", [None, "force_redownload"]) +def test_load_dataset_distributed_with_script(tmp_path, download_mode): + # we need to check in the "force_redownload" case + # since in `_copy_script_and_other_resources_in_importable_dir()` we might delete the directory + # containing the .py file while the other processes use it + num_workers = 5 + args = (SAMPLE_DATASET_IDENTIFIER, str(tmp_path), download_mode) + with Pool(processes=num_workers) as pool: # start num_workers processes + datasets = pool.map(distributed_load_dataset_with_script, [args] * num_workers) + assert len(datasets) == num_workers + assert all(len(dataset) == len(datasets[0]) > 0 for dataset in datasets) + assert len(datasets[0].cache_files) > 0 + assert all(dataset.cache_files == datasets[0].cache_files for dataset in datasets) + + def test_load_dataset_with_storage_options(mockfs): with mockfs.open("data.txt", "w") as f: f.write("Hello there\n")