diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 0f758258e82..3def9a98755 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -29,7 +29,8 @@ from dataclasses import dataclass from functools import partial from pathlib import Path -from typing import Dict, Iterable, Mapping, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, Iterable, Mapping, Optional, Tuple, Union +from unittest.mock import patch import fsspec import pyarrow as pa @@ -85,6 +86,10 @@ from .utils.track import tracked_list +if TYPE_CHECKING: + from .load import DatasetModule + + logger = logging.get_logger(__name__) @@ -400,6 +405,10 @@ def __init__( if is_remote_url(self._cache_downloaded_dir) else os.path.expanduser(self._cache_downloaded_dir) ) + + # In case there exists a legacy cache directory + self._legacy_relative_data_dir = None + self._cache_dir = self._build_cache_dir() if not is_remote_url(self._cache_dir_root): os.makedirs(self._cache_dir_root, exist_ok=True) @@ -452,23 +461,71 @@ def __setstate__(self, d): def manual_download_instructions(self) -> Optional[str]: return None - def _has_legacy_cache(self) -> bool: - """Check for the old cache directory template {cache_dir}/{namespace}___{builder_name}""" + def _check_legacy_cache(self) -> Optional[str]: + """Check for the old cache directory template {cache_dir}/{namespace}___{builder_name} from 2.13""" if ( self.__module__.startswith("datasets.") and not is_remote_url(self._cache_dir_root) and self.config.name == "default" ): + from .packaged_modules import _PACKAGED_DATASETS_MODULES + + namespace = self.repo_id.split("/")[0] if self.repo_id and self.repo_id.count("/") > 0 else None + config_name = self.repo_id.replace("/", "--") if self.repo_id is not None else self.dataset_name + config_id = config_name + self.config_id[len(self.config.name) :] + hash = _PACKAGED_DATASETS_MODULES.get(self.name, "missing")[1] + legacy_relative_data_dir = posixpath.join( + self.dataset_name if namespace is None else f"{namespace}___{self.dataset_name}", + config_id, + "0.0.0", + hash, + ) + legacy_cache_dir = posixpath.join(self._cache_dir_root, legacy_relative_data_dir) + if os.path.isdir(legacy_cache_dir): + return legacy_relative_data_dir + + def _check_legacy_cache2(self, dataset_module: "DatasetModule") -> Optional[str]: + """Check for the old cache directory template {cache_dir}/{namespace}___{dataset_name}/{config_name}-xxx from 2.14 and 2.15""" + if self.__module__.startswith("datasets.") and not is_remote_url(self._cache_dir_root): + from .packaged_modules import _PACKAGED_DATASETS_MODULES + from .utils._dill import Pickler + + def update_hash_with_config_parameters(hash: str, config_parameters: dict) -> str: + """ + Used to update hash of packaged modules which is used for creating unique cache directories to reflect + different config parameters which are passed in metadata from readme. + """ + params_to_exclude = {"config_name", "version", "description"} + params_to_add_to_hash = { + param: value + for param, value in sorted(config_parameters.items()) + if param not in params_to_exclude + } + m = Hasher() + m.update(hash) + m.update(params_to_add_to_hash) + return m.hexdigest() + namespace = self.repo_id.split("/")[0] if self.repo_id and self.repo_id.count("/") > 0 else None - legacy_config_name = self.repo_id.replace("/", "--") if self.repo_id is not None else self.dataset_name - legacy_config_id = legacy_config_name + self.config_id[len(self.config.name) :] - legacy_cache_dir = os.path.join( - self._cache_dir_root, - self.name if namespace is None else f"{namespace}___{self.name}", - legacy_config_id, + with patch.object(Pickler, "_legacy_no_dict_keys_sorting", True): + config_id = self.config.name + "-" + Hasher.hash({"data_files": self.config.data_files}) + hash = _PACKAGED_DATASETS_MODULES.get(self.name, "missing")[1] + if ( + dataset_module.builder_configs_parameters.metadata_configs + and self.config.name in dataset_module.builder_configs_parameters.metadata_configs + ): + hash = update_hash_with_config_parameters( + hash, dataset_module.builder_configs_parameters.metadata_configs[self.config.name] + ) + legacy_relative_data_dir = posixpath.join( + self.dataset_name if namespace is None else f"{namespace}___{self.dataset_name}", + config_id, + "0.0.0", + hash, ) - return os.path.isdir(legacy_cache_dir) - return False + legacy_cache_dir = posixpath.join(self._cache_dir_root, legacy_relative_data_dir) + if os.path.isdir(legacy_cache_dir): + return legacy_relative_data_dir @classmethod def get_all_exported_dataset_infos(cls) -> DatasetInfosDict: @@ -600,6 +657,14 @@ def builder_configs(cls) -> Dict[str, BuilderConfig]: def cache_dir(self): return self._cache_dir + def _use_legacy_cache_dir_if_possible(self, dataset_module: "DatasetModule"): + # Check for the legacy cache directory template (datasets<3.0.0) + self._legacy_relative_data_dir = ( + self._check_legacy_cache2(dataset_module) or self._check_legacy_cache() or None + ) + self._cache_dir = self._build_cache_dir() + self._output_dir = self._cache_dir + def _relative_data_dir(self, with_version=True, with_hash=True) -> str: """Relative path of this dataset in cache_dir: Will be: @@ -608,21 +673,12 @@ def _relative_data_dir(self, with_version=True, with_hash=True) -> str: self.namespace___self.dataset_name/self.config.version/self.hash/ If any of these element is missing or if ``with_version=False`` the corresponding subfolders are dropped. """ - - # Check for the legacy cache directory template (datasets<3.0.0) - if self._has_legacy_cache(): - # use legacy names - dataset_name = self.name - config_name = self.repo_id.replace("/", "--") if self.repo_id is not None else self.dataset_name - config_id = config_name + self.config_id[len(self.config.name) :] - else: - dataset_name = self.dataset_name - config_name = self.config.name - config_id = self.config_id + if self._legacy_relative_data_dir is not None and with_version and with_hash: + return self._legacy_relative_data_dir namespace = self.repo_id.split("/")[0] if self.repo_id and self.repo_id.count("/") > 0 else None - builder_data_dir = dataset_name if namespace is None else f"{namespace}___{dataset_name}" - builder_data_dir = posixpath.join(builder_data_dir, config_id) + builder_data_dir = self.dataset_name if namespace is None else f"{namespace}___{self.dataset_name}" + builder_data_dir = posixpath.join(builder_data_dir, self.config_id) if with_version: builder_data_dir = posixpath.join(builder_data_dir, str(self.config.version)) if with_hash and self.hash and isinstance(self.hash, str): @@ -1285,7 +1341,7 @@ def _as_dataset(self, split: Union[ReadInstruction, Split] = Split.TRAIN, in_mem """ cache_dir = self._fs._strip_protocol(self._output_dir) dataset_name = self.dataset_name - if self._has_legacy_cache(): + if self._check_legacy_cache(): dataset_name = self.name dataset_kwargs = ArrowReader(cache_dir, self.info).read( name=dataset_name, diff --git a/src/datasets/load.py b/src/datasets/load.py index f953534447a..0f76b37ab4d 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -2239,6 +2239,7 @@ def load_dataset_builder( **builder_kwargs, **config_kwargs, ) + builder_instance._use_legacy_cache_dir_if_possible(dataset_module) return builder_instance diff --git a/src/datasets/utils/_dill.py b/src/datasets/utils/_dill.py index 39d0bb67ea5..cdbb1f41341 100644 --- a/src/datasets/utils/_dill.py +++ b/src/datasets/utils/_dill.py @@ -25,6 +25,7 @@ class Pickler(dill.Pickler): dispatch = dill._dill.MetaCatchingDict(dill.Pickler.dispatch.copy()) + _legacy_no_dict_keys_sorting = False def save(self, obj, save_persistent_id=True): obj_type = type(obj) @@ -68,6 +69,8 @@ def save(self, obj, save_persistent_id=True): dill.Pickler.save(self, obj, save_persistent_id=save_persistent_id) def _batch_setitems(self, items): + if self._legacy_no_dict_keys_sorting: + return super()._batch_setitems(items) # Ignore the order of keys in a dict try: # Faster, but fails for unorderable elements diff --git a/tests/test_load.py b/tests/test_load.py index acc9a2680a5..cf2bddd367b 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -1657,3 +1657,34 @@ def test_resolve_trust_remote_code_future(trust_remote_code, expected): else: with pytest.raises(expected): resolve_trust_remote_code(trust_remote_code, repo_id="dummy") + + +@pytest.mark.integration +def test_reload_old_cache_from_2_15(tmp_path: Path): + cache_dir = tmp_path / "test_reload_old_cache_from_2_15" + builder_cache_dir = ( + cache_dir / "polinaeterna___audiofolder_two_configs_in_metadata/v2-374bfde4f55442bc/0.0.0/7896925d64deea5d" + ) + builder_cache_dir.mkdir(parents=True) + arrow_path = builder_cache_dir / "audiofolder_two_configs_in_metadata-train.arrow" + dataset_info_path = builder_cache_dir / "dataset_info.json" + with dataset_info_path.open("w") as f: + f.write("{}") + arrow_path.touch() + builder = load_dataset_builder( + "polinaeterna/audiofolder_two_configs_in_metadata", + "v2", + data_files="v2/train/*", + cache_dir=cache_dir.as_posix(), + ) + assert builder.cache_dir == builder_cache_dir.as_posix() # old cache from 2.15 + + builder = load_dataset_builder( + "polinaeterna/audiofolder_two_configs_in_metadata", "v2", cache_dir=cache_dir.as_posix() + ) + assert ( + builder.cache_dir + == ( + cache_dir / "polinaeterna___audiofolder_two_configs_in_metadata" / "v2" / "0.0.0" / str(builder.hash) + ).as_posix() + ) # new cache