Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 81 additions & 25 deletions src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Dict, Iterable, Mapping, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, Iterable, Mapping, Optional, Tuple, Union
from unittest.mock import patch

import fsspec
import pyarrow as pa
Expand Down Expand Up @@ -85,6 +86,10 @@
from .utils.track import tracked_list


if TYPE_CHECKING:
from .load import DatasetModule


logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -400,6 +405,10 @@ def __init__(
if is_remote_url(self._cache_downloaded_dir)
else os.path.expanduser(self._cache_downloaded_dir)
)

# In case there exists a legacy cache directory
self._legacy_relative_data_dir = None

self._cache_dir = self._build_cache_dir()
if not is_remote_url(self._cache_dir_root):
os.makedirs(self._cache_dir_root, exist_ok=True)
Expand Down Expand Up @@ -452,23 +461,71 @@ def __setstate__(self, d):
def manual_download_instructions(self) -> Optional[str]:
return None

def _has_legacy_cache(self) -> bool:
"""Check for the old cache directory template {cache_dir}/{namespace}___{builder_name}"""
def _check_legacy_cache(self) -> Optional[str]:
"""Check for the old cache directory template {cache_dir}/{namespace}___{builder_name} from 2.13"""
if (
self.__module__.startswith("datasets.")
and not is_remote_url(self._cache_dir_root)
and self.config.name == "default"
):
from .packaged_modules import _PACKAGED_DATASETS_MODULES

namespace = self.repo_id.split("/")[0] if self.repo_id and self.repo_id.count("/") > 0 else None
config_name = self.repo_id.replace("/", "--") if self.repo_id is not None else self.dataset_name
config_id = config_name + self.config_id[len(self.config.name) :]
hash = _PACKAGED_DATASETS_MODULES.get(self.name, "missing")[1]
legacy_relative_data_dir = posixpath.join(
self.dataset_name if namespace is None else f"{namespace}___{self.dataset_name}",
config_id,
"0.0.0",
hash,
)
legacy_cache_dir = posixpath.join(self._cache_dir_root, legacy_relative_data_dir)
if os.path.isdir(legacy_cache_dir):
return legacy_relative_data_dir

def _check_legacy_cache2(self, dataset_module: "DatasetModule") -> Optional[str]:
"""Check for the old cache directory template {cache_dir}/{namespace}___{dataset_name}/{config_name}-xxx from 2.14 and 2.15"""
if self.__module__.startswith("datasets.") and not is_remote_url(self._cache_dir_root):
from .packaged_modules import _PACKAGED_DATASETS_MODULES
from .utils._dill import Pickler

def update_hash_with_config_parameters(hash: str, config_parameters: dict) -> str:
"""
Used to update hash of packaged modules which is used for creating unique cache directories to reflect
different config parameters which are passed in metadata from readme.
"""
params_to_exclude = {"config_name", "version", "description"}
params_to_add_to_hash = {
param: value
for param, value in sorted(config_parameters.items())
if param not in params_to_exclude
}
m = Hasher()
m.update(hash)
m.update(params_to_add_to_hash)
return m.hexdigest()

namespace = self.repo_id.split("/")[0] if self.repo_id and self.repo_id.count("/") > 0 else None
legacy_config_name = self.repo_id.replace("/", "--") if self.repo_id is not None else self.dataset_name
legacy_config_id = legacy_config_name + self.config_id[len(self.config.name) :]
legacy_cache_dir = os.path.join(
self._cache_dir_root,
self.name if namespace is None else f"{namespace}___{self.name}",
legacy_config_id,
with patch.object(Pickler, "_legacy_no_dict_keys_sorting", True):
config_id = self.config.name + "-" + Hasher.hash({"data_files": self.config.data_files})
hash = _PACKAGED_DATASETS_MODULES.get(self.name, "missing")[1]
if (
dataset_module.builder_configs_parameters.metadata_configs
and self.config.name in dataset_module.builder_configs_parameters.metadata_configs
):
hash = update_hash_with_config_parameters(
hash, dataset_module.builder_configs_parameters.metadata_configs[self.config.name]
)
legacy_relative_data_dir = posixpath.join(
self.dataset_name if namespace is None else f"{namespace}___{self.dataset_name}",
config_id,
"0.0.0",
hash,
)
return os.path.isdir(legacy_cache_dir)
return False
legacy_cache_dir = posixpath.join(self._cache_dir_root, legacy_relative_data_dir)
if os.path.isdir(legacy_cache_dir):
return legacy_relative_data_dir

@classmethod
def get_all_exported_dataset_infos(cls) -> DatasetInfosDict:
Expand Down Expand Up @@ -600,6 +657,14 @@ def builder_configs(cls) -> Dict[str, BuilderConfig]:
def cache_dir(self):
return self._cache_dir

def _use_legacy_cache_dir_if_possible(self, dataset_module: "DatasetModule"):
# Check for the legacy cache directory template (datasets<3.0.0)
self._legacy_relative_data_dir = (
self._check_legacy_cache2(dataset_module) or self._check_legacy_cache() or None
)
self._cache_dir = self._build_cache_dir()
self._output_dir = self._cache_dir

def _relative_data_dir(self, with_version=True, with_hash=True) -> str:
"""Relative path of this dataset in cache_dir:
Will be:
Expand All @@ -608,21 +673,12 @@ def _relative_data_dir(self, with_version=True, with_hash=True) -> str:
self.namespace___self.dataset_name/self.config.version/self.hash/
If any of these element is missing or if ``with_version=False`` the corresponding subfolders are dropped.
"""

# Check for the legacy cache directory template (datasets<3.0.0)
if self._has_legacy_cache():
# use legacy names
dataset_name = self.name
config_name = self.repo_id.replace("/", "--") if self.repo_id is not None else self.dataset_name
config_id = config_name + self.config_id[len(self.config.name) :]
else:
dataset_name = self.dataset_name
config_name = self.config.name
config_id = self.config_id
if self._legacy_relative_data_dir is not None and with_version and with_hash:
return self._legacy_relative_data_dir

namespace = self.repo_id.split("/")[0] if self.repo_id and self.repo_id.count("/") > 0 else None
builder_data_dir = dataset_name if namespace is None else f"{namespace}___{dataset_name}"
builder_data_dir = posixpath.join(builder_data_dir, config_id)
builder_data_dir = self.dataset_name if namespace is None else f"{namespace}___{self.dataset_name}"
builder_data_dir = posixpath.join(builder_data_dir, self.config_id)
if with_version:
builder_data_dir = posixpath.join(builder_data_dir, str(self.config.version))
if with_hash and self.hash and isinstance(self.hash, str):
Expand Down Expand Up @@ -1285,7 +1341,7 @@ def _as_dataset(self, split: Union[ReadInstruction, Split] = Split.TRAIN, in_mem
"""
cache_dir = self._fs._strip_protocol(self._output_dir)
dataset_name = self.dataset_name
if self._has_legacy_cache():
if self._check_legacy_cache():
dataset_name = self.name
dataset_kwargs = ArrowReader(cache_dir, self.info).read(
name=dataset_name,
Expand Down
1 change: 1 addition & 0 deletions src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -2239,6 +2239,7 @@ def load_dataset_builder(
**builder_kwargs,
**config_kwargs,
)
builder_instance._use_legacy_cache_dir_if_possible(dataset_module)

return builder_instance

Expand Down
3 changes: 3 additions & 0 deletions src/datasets/utils/_dill.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

class Pickler(dill.Pickler):
dispatch = dill._dill.MetaCatchingDict(dill.Pickler.dispatch.copy())
_legacy_no_dict_keys_sorting = False

def save(self, obj, save_persistent_id=True):
obj_type = type(obj)
Expand Down Expand Up @@ -68,6 +69,8 @@ def save(self, obj, save_persistent_id=True):
dill.Pickler.save(self, obj, save_persistent_id=save_persistent_id)

def _batch_setitems(self, items):
if self._legacy_no_dict_keys_sorting:
return super()._batch_setitems(items)
# Ignore the order of keys in a dict
try:
# Faster, but fails for unorderable elements
Expand Down
31 changes: 31 additions & 0 deletions tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -1657,3 +1657,34 @@ def test_resolve_trust_remote_code_future(trust_remote_code, expected):
else:
with pytest.raises(expected):
resolve_trust_remote_code(trust_remote_code, repo_id="dummy")


@pytest.mark.integration
def test_reload_old_cache_from_2_15(tmp_path: Path):
cache_dir = tmp_path / "test_reload_old_cache_from_2_15"
builder_cache_dir = (
cache_dir / "polinaeterna___audiofolder_two_configs_in_metadata/v2-374bfde4f55442bc/0.0.0/7896925d64deea5d"
)
builder_cache_dir.mkdir(parents=True)
arrow_path = builder_cache_dir / "audiofolder_two_configs_in_metadata-train.arrow"
dataset_info_path = builder_cache_dir / "dataset_info.json"
with dataset_info_path.open("w") as f:
f.write("{}")
arrow_path.touch()
builder = load_dataset_builder(
"polinaeterna/audiofolder_two_configs_in_metadata",
"v2",
data_files="v2/train/*",
cache_dir=cache_dir.as_posix(),
)
assert builder.cache_dir == builder_cache_dir.as_posix() # old cache from 2.15

builder = load_dataset_builder(
"polinaeterna/audiofolder_two_configs_in_metadata", "v2", cache_dir=cache_dir.as_posix()
)
assert (
builder.cache_dir
== (
cache_dir / "polinaeterna___audiofolder_two_configs_in_metadata" / "v2" / "0.0.0" / str(builder.hash)
).as_posix()
) # new cache