Skip to content

Commit 2afbf78

Browse files
authored
Cache backward compatibility with 2.15.0 (#6514)
* cache backward compatibility * fix win * debug run action for win * fix
1 parent e1b82ea commit 2afbf78

File tree

4 files changed

+116
-25
lines changed

4 files changed

+116
-25
lines changed

src/datasets/builder.py

Lines changed: 81 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
from dataclasses import dataclass
3030
from functools import partial
3131
from pathlib import Path
32-
from typing import Dict, Iterable, Mapping, Optional, Tuple, Union
32+
from typing import TYPE_CHECKING, Dict, Iterable, Mapping, Optional, Tuple, Union
33+
from unittest.mock import patch
3334

3435
import fsspec
3536
import pyarrow as pa
@@ -85,6 +86,10 @@
8586
from .utils.track import tracked_list
8687

8788

89+
if TYPE_CHECKING:
90+
from .load import DatasetModule
91+
92+
8893
logger = logging.get_logger(__name__)
8994

9095

@@ -400,6 +405,10 @@ def __init__(
400405
if is_remote_url(self._cache_downloaded_dir)
401406
else os.path.expanduser(self._cache_downloaded_dir)
402407
)
408+
409+
# In case there exists a legacy cache directory
410+
self._legacy_relative_data_dir = None
411+
403412
self._cache_dir = self._build_cache_dir()
404413
if not is_remote_url(self._cache_dir_root):
405414
os.makedirs(self._cache_dir_root, exist_ok=True)
@@ -452,23 +461,71 @@ def __setstate__(self, d):
452461
def manual_download_instructions(self) -> Optional[str]:
453462
return None
454463

455-
def _has_legacy_cache(self) -> bool:
456-
"""Check for the old cache directory template {cache_dir}/{namespace}___{builder_name}"""
464+
def _check_legacy_cache(self) -> Optional[str]:
465+
"""Check for the old cache directory template {cache_dir}/{namespace}___{builder_name} from 2.13"""
457466
if (
458467
self.__module__.startswith("datasets.")
459468
and not is_remote_url(self._cache_dir_root)
460469
and self.config.name == "default"
461470
):
471+
from .packaged_modules import _PACKAGED_DATASETS_MODULES
472+
473+
namespace = self.repo_id.split("/")[0] if self.repo_id and self.repo_id.count("/") > 0 else None
474+
config_name = self.repo_id.replace("/", "--") if self.repo_id is not None else self.dataset_name
475+
config_id = config_name + self.config_id[len(self.config.name) :]
476+
hash = _PACKAGED_DATASETS_MODULES.get(self.name, "missing")[1]
477+
legacy_relative_data_dir = posixpath.join(
478+
self.dataset_name if namespace is None else f"{namespace}___{self.dataset_name}",
479+
config_id,
480+
"0.0.0",
481+
hash,
482+
)
483+
legacy_cache_dir = posixpath.join(self._cache_dir_root, legacy_relative_data_dir)
484+
if os.path.isdir(legacy_cache_dir):
485+
return legacy_relative_data_dir
486+
487+
def _check_legacy_cache2(self, dataset_module: "DatasetModule") -> Optional[str]:
488+
"""Check for the old cache directory template {cache_dir}/{namespace}___{dataset_name}/{config_name}-xxx from 2.14 and 2.15"""
489+
if self.__module__.startswith("datasets.") and not is_remote_url(self._cache_dir_root):
490+
from .packaged_modules import _PACKAGED_DATASETS_MODULES
491+
from .utils._dill import Pickler
492+
493+
def update_hash_with_config_parameters(hash: str, config_parameters: dict) -> str:
494+
"""
495+
Used to update hash of packaged modules which is used for creating unique cache directories to reflect
496+
different config parameters which are passed in metadata from readme.
497+
"""
498+
params_to_exclude = {"config_name", "version", "description"}
499+
params_to_add_to_hash = {
500+
param: value
501+
for param, value in sorted(config_parameters.items())
502+
if param not in params_to_exclude
503+
}
504+
m = Hasher()
505+
m.update(hash)
506+
m.update(params_to_add_to_hash)
507+
return m.hexdigest()
508+
462509
namespace = self.repo_id.split("/")[0] if self.repo_id and self.repo_id.count("/") > 0 else None
463-
legacy_config_name = self.repo_id.replace("/", "--") if self.repo_id is not None else self.dataset_name
464-
legacy_config_id = legacy_config_name + self.config_id[len(self.config.name) :]
465-
legacy_cache_dir = os.path.join(
466-
self._cache_dir_root,
467-
self.name if namespace is None else f"{namespace}___{self.name}",
468-
legacy_config_id,
510+
with patch.object(Pickler, "_legacy_no_dict_keys_sorting", True):
511+
config_id = self.config.name + "-" + Hasher.hash({"data_files": self.config.data_files})
512+
hash = _PACKAGED_DATASETS_MODULES.get(self.name, "missing")[1]
513+
if (
514+
dataset_module.builder_configs_parameters.metadata_configs
515+
and self.config.name in dataset_module.builder_configs_parameters.metadata_configs
516+
):
517+
hash = update_hash_with_config_parameters(
518+
hash, dataset_module.builder_configs_parameters.metadata_configs[self.config.name]
519+
)
520+
legacy_relative_data_dir = posixpath.join(
521+
self.dataset_name if namespace is None else f"{namespace}___{self.dataset_name}",
522+
config_id,
523+
"0.0.0",
524+
hash,
469525
)
470-
return os.path.isdir(legacy_cache_dir)
471-
return False
526+
legacy_cache_dir = posixpath.join(self._cache_dir_root, legacy_relative_data_dir)
527+
if os.path.isdir(legacy_cache_dir):
528+
return legacy_relative_data_dir
472529

473530
@classmethod
474531
def get_all_exported_dataset_infos(cls) -> DatasetInfosDict:
@@ -600,6 +657,14 @@ def builder_configs(cls) -> Dict[str, BuilderConfig]:
600657
def cache_dir(self):
601658
return self._cache_dir
602659

660+
def _use_legacy_cache_dir_if_possible(self, dataset_module: "DatasetModule"):
661+
# Check for the legacy cache directory template (datasets<3.0.0)
662+
self._legacy_relative_data_dir = (
663+
self._check_legacy_cache2(dataset_module) or self._check_legacy_cache() or None
664+
)
665+
self._cache_dir = self._build_cache_dir()
666+
self._output_dir = self._cache_dir
667+
603668
def _relative_data_dir(self, with_version=True, with_hash=True) -> str:
604669
"""Relative path of this dataset in cache_dir:
605670
Will be:
@@ -608,21 +673,12 @@ def _relative_data_dir(self, with_version=True, with_hash=True) -> str:
608673
self.namespace___self.dataset_name/self.config.version/self.hash/
609674
If any of these element is missing or if ``with_version=False`` the corresponding subfolders are dropped.
610675
"""
611-
612-
# Check for the legacy cache directory template (datasets<3.0.0)
613-
if self._has_legacy_cache():
614-
# use legacy names
615-
dataset_name = self.name
616-
config_name = self.repo_id.replace("/", "--") if self.repo_id is not None else self.dataset_name
617-
config_id = config_name + self.config_id[len(self.config.name) :]
618-
else:
619-
dataset_name = self.dataset_name
620-
config_name = self.config.name
621-
config_id = self.config_id
676+
if self._legacy_relative_data_dir is not None and with_version and with_hash:
677+
return self._legacy_relative_data_dir
622678

623679
namespace = self.repo_id.split("/")[0] if self.repo_id and self.repo_id.count("/") > 0 else None
624-
builder_data_dir = dataset_name if namespace is None else f"{namespace}___{dataset_name}"
625-
builder_data_dir = posixpath.join(builder_data_dir, config_id)
680+
builder_data_dir = self.dataset_name if namespace is None else f"{namespace}___{self.dataset_name}"
681+
builder_data_dir = posixpath.join(builder_data_dir, self.config_id)
626682
if with_version:
627683
builder_data_dir = posixpath.join(builder_data_dir, str(self.config.version))
628684
if with_hash and self.hash and isinstance(self.hash, str):
@@ -1285,7 +1341,7 @@ def _as_dataset(self, split: Union[ReadInstruction, Split] = Split.TRAIN, in_mem
12851341
"""
12861342
cache_dir = self._fs._strip_protocol(self._output_dir)
12871343
dataset_name = self.dataset_name
1288-
if self._has_legacy_cache():
1344+
if self._check_legacy_cache():
12891345
dataset_name = self.name
12901346
dataset_kwargs = ArrowReader(cache_dir, self.info).read(
12911347
name=dataset_name,

src/datasets/load.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2239,6 +2239,7 @@ def load_dataset_builder(
22392239
**builder_kwargs,
22402240
**config_kwargs,
22412241
)
2242+
builder_instance._use_legacy_cache_dir_if_possible(dataset_module)
22422243

22432244
return builder_instance
22442245

src/datasets/utils/_dill.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
class Pickler(dill.Pickler):
2727
dispatch = dill._dill.MetaCatchingDict(dill.Pickler.dispatch.copy())
28+
_legacy_no_dict_keys_sorting = False
2829

2930
def save(self, obj, save_persistent_id=True):
3031
obj_type = type(obj)
@@ -68,6 +69,8 @@ def save(self, obj, save_persistent_id=True):
6869
dill.Pickler.save(self, obj, save_persistent_id=save_persistent_id)
6970

7071
def _batch_setitems(self, items):
72+
if self._legacy_no_dict_keys_sorting:
73+
return super()._batch_setitems(items)
7174
# Ignore the order of keys in a dict
7275
try:
7376
# Faster, but fails for unorderable elements

tests/test_load.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1657,3 +1657,34 @@ def test_resolve_trust_remote_code_future(trust_remote_code, expected):
16571657
else:
16581658
with pytest.raises(expected):
16591659
resolve_trust_remote_code(trust_remote_code, repo_id="dummy")
1660+
1661+
1662+
@pytest.mark.integration
1663+
def test_reload_old_cache_from_2_15(tmp_path: Path):
1664+
cache_dir = tmp_path / "test_reload_old_cache_from_2_15"
1665+
builder_cache_dir = (
1666+
cache_dir / "polinaeterna___audiofolder_two_configs_in_metadata/v2-374bfde4f55442bc/0.0.0/7896925d64deea5d"
1667+
)
1668+
builder_cache_dir.mkdir(parents=True)
1669+
arrow_path = builder_cache_dir / "audiofolder_two_configs_in_metadata-train.arrow"
1670+
dataset_info_path = builder_cache_dir / "dataset_info.json"
1671+
with dataset_info_path.open("w") as f:
1672+
f.write("{}")
1673+
arrow_path.touch()
1674+
builder = load_dataset_builder(
1675+
"polinaeterna/audiofolder_two_configs_in_metadata",
1676+
"v2",
1677+
data_files="v2/train/*",
1678+
cache_dir=cache_dir.as_posix(),
1679+
)
1680+
assert builder.cache_dir == builder_cache_dir.as_posix() # old cache from 2.15
1681+
1682+
builder = load_dataset_builder(
1683+
"polinaeterna/audiofolder_two_configs_in_metadata", "v2", cache_dir=cache_dir.as_posix()
1684+
)
1685+
assert (
1686+
builder.cache_dir
1687+
== (
1688+
cache_dir / "polinaeterna___audiofolder_two_configs_in_metadata" / "v2" / "0.0.0" / str(builder.hash)
1689+
).as_posix()
1690+
) # new cache

0 commit comments

Comments
 (0)