Skip to content

Commit dcd0104

Browse files
authored
Fix offline mode with single config (#6741)
* fix offline mode with single config * fix tests * style * mario's suggestion
1 parent d546883 commit dcd0104

File tree

2 files changed

+121
-45
lines changed

2 files changed

+121
-45
lines changed

src/datasets/packaged_modules/cache/cache.py

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import glob
2+
import json
23
import os
34
import shutil
45
import time
@@ -22,43 +23,62 @@ def _get_modification_time(cached_directory_path):
2223

2324

2425
def _find_hash_in_cache(
25-
dataset_name: str, config_name: Optional[str], cache_dir: Optional[str]
26+
dataset_name: str,
27+
config_name: Optional[str],
28+
cache_dir: Optional[str],
29+
config_kwargs: dict,
30+
custom_features: Optional[datasets.Features],
2631
) -> Tuple[str, str, str]:
32+
if config_name or config_kwargs or custom_features:
33+
config_id = datasets.BuilderConfig(config_name or "default").create_config_id(
34+
config_kwargs=config_kwargs, custom_features=custom_features
35+
)
36+
else:
37+
config_id = None
2738
cache_dir = os.path.expanduser(str(cache_dir or datasets.config.HF_DATASETS_CACHE))
2839
cached_datasets_directory_path_root = os.path.join(cache_dir, dataset_name.replace("/", "___"))
2940
cached_directory_paths = [
3041
cached_directory_path
3142
for cached_directory_path in glob.glob(
32-
os.path.join(cached_datasets_directory_path_root, config_name or "*", "*", "*")
43+
os.path.join(cached_datasets_directory_path_root, config_id or "*", "*", "*")
3344
)
3445
if os.path.isdir(cached_directory_path)
46+
and (
47+
config_kwargs
48+
or custom_features
49+
or json.loads(Path(cached_directory_path, "dataset_info.json").read_text(encoding="utf-8"))["config_name"]
50+
== Path(cached_directory_path).parts[-3] # no extra params => config_id == config_name
51+
)
3552
]
3653
if not cached_directory_paths:
37-
if config_name is not None:
38-
cached_directory_paths = [
39-
cached_directory_path
40-
for cached_directory_path in glob.glob(
41-
os.path.join(cached_datasets_directory_path_root, "*", "*", "*")
42-
)
43-
if os.path.isdir(cached_directory_path)
44-
]
54+
cached_directory_paths = [
55+
cached_directory_path
56+
for cached_directory_path in glob.glob(os.path.join(cached_datasets_directory_path_root, "*", "*", "*"))
57+
if os.path.isdir(cached_directory_path)
58+
]
4559
available_configs = sorted(
4660
{Path(cached_directory_path).parts[-3] for cached_directory_path in cached_directory_paths}
4761
)
4862
raise ValueError(
4963
f"Couldn't find cache for {dataset_name}"
50-
+ (f" for config '{config_name}'" if config_name else "")
64+
+ (f" for config '{config_id}'" if config_id else "")
5165
+ (f"\nAvailable configs in the cache: {available_configs}" if available_configs else "")
5266
)
5367
# get most recent
5468
cached_directory_path = Path(sorted(cached_directory_paths, key=_get_modification_time)[-1])
5569
version, hash = cached_directory_path.parts[-2:]
5670
other_configs = [
57-
Path(cached_directory_path).parts[-3]
58-
for cached_directory_path in glob.glob(os.path.join(cached_datasets_directory_path_root, "*", version, hash))
59-
if os.path.isdir(cached_directory_path)
71+
Path(_cached_directory_path).parts[-3]
72+
for _cached_directory_path in glob.glob(os.path.join(cached_datasets_directory_path_root, "*", version, hash))
73+
if os.path.isdir(_cached_directory_path)
74+
and (
75+
config_kwargs
76+
or custom_features
77+
or json.loads(Path(_cached_directory_path, "dataset_info.json").read_text(encoding="utf-8"))["config_name"]
78+
== Path(_cached_directory_path).parts[-3] # no extra params => config_id == config_name
79+
)
6080
]
61-
if not config_name and len(other_configs) > 1:
81+
if not config_id and len(other_configs) > 1:
6282
raise ValueError(
6383
f"There are multiple '{dataset_name}' configurations in the cache: {', '.join(other_configs)}"
6484
f"\nPlease specify which configuration to reload from the cache, e.g."
@@ -114,15 +134,12 @@ def __init__(
114134
if data_dir is not None:
115135
config_kwargs["data_dir"] = data_dir
116136
if hash == "auto" and version == "auto":
117-
# First we try to find a folder that takes the config_kwargs into account
118-
# e.g. with "default-data_dir=data%2Ffortran" as config_id
119-
config_id = self.BUILDER_CONFIG_CLASS(config_name or "default").create_config_id(
120-
config_kwargs=config_kwargs, custom_features=features
121-
)
122137
config_name, version, hash = _find_hash_in_cache(
123138
dataset_name=repo_id or dataset_name,
124-
config_name=config_id,
139+
config_name=config_name,
125140
cache_dir=cache_dir,
141+
config_kwargs=config_kwargs,
142+
custom_features=features,
126143
)
127144
elif hash == "auto" or version == "auto":
128145
raise NotImplementedError("Pass both hash='auto' and version='auto' instead")

tests/packaged_modules/test_cache.py

Lines changed: 83 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,40 +6,47 @@
66
from datasets.packaged_modules.cache.cache import Cache
77

88

9+
SAMPLE_DATASET_SINGLE_CONFIG_IN_METADATA = "hf-internal-testing/audiofolder_single_config_in_metadata"
910
SAMPLE_DATASET_TWO_CONFIG_IN_METADATA = "hf-internal-testing/audiofolder_two_configs_in_metadata"
1011

1112

12-
def test_cache(text_dir: Path):
13-
ds = load_dataset(str(text_dir))
13+
def test_cache(text_dir: Path, tmp_path: Path):
14+
cache_dir = tmp_path / "test_cache"
15+
ds = load_dataset(str(text_dir), cache_dir=str(cache_dir))
1416
hash = Path(ds["train"].cache_files[0]["filename"]).parts[-2]
15-
cache = Cache(dataset_name=text_dir.name, hash=hash)
17+
cache = Cache(cache_dir=str(cache_dir), dataset_name=text_dir.name, hash=hash)
1618
reloaded = cache.as_dataset()
1719
assert list(ds) == list(reloaded)
1820
assert list(ds["train"]) == list(reloaded["train"])
1921

2022

21-
def test_cache_streaming(text_dir: Path):
22-
ds = load_dataset(str(text_dir))
23+
def test_cache_streaming(text_dir: Path, tmp_path: Path):
24+
cache_dir = tmp_path / "test_cache_streaming"
25+
ds = load_dataset(str(text_dir), cache_dir=str(cache_dir))
2326
hash = Path(ds["train"].cache_files[0]["filename"]).parts[-2]
24-
cache = Cache(dataset_name=text_dir.name, hash=hash)
27+
cache = Cache(cache_dir=str(cache_dir), dataset_name=text_dir.name, hash=hash)
2528
reloaded = cache.as_streaming_dataset()
2629
assert list(ds) == list(reloaded)
2730
assert list(ds["train"]) == list(reloaded["train"])
2831

2932

30-
def test_cache_auto_hash(text_dir: Path):
31-
ds = load_dataset(str(text_dir))
32-
cache = Cache(dataset_name=text_dir.name, version="auto", hash="auto")
33+
def test_cache_auto_hash(text_dir: Path, tmp_path: Path):
34+
cache_dir = tmp_path / "test_cache_auto_hash"
35+
ds = load_dataset(str(text_dir), cache_dir=str(cache_dir))
36+
cache = Cache(cache_dir=str(cache_dir), dataset_name=text_dir.name, version="auto", hash="auto")
3337
reloaded = cache.as_dataset()
3438
assert list(ds) == list(reloaded)
3539
assert list(ds["train"]) == list(reloaded["train"])
3640

3741

38-
def test_cache_auto_hash_with_custom_config(text_dir: Path):
39-
ds = load_dataset(str(text_dir), sample_by="paragraph")
40-
another_ds = load_dataset(str(text_dir))
41-
cache = Cache(dataset_name=text_dir.name, version="auto", hash="auto", sample_by="paragraph")
42-
another_cache = Cache(dataset_name=text_dir.name, version="auto", hash="auto")
42+
def test_cache_auto_hash_with_custom_config(text_dir: Path, tmp_path: Path):
43+
cache_dir = tmp_path / "test_cache_auto_hash_with_custom_config"
44+
ds = load_dataset(str(text_dir), sample_by="paragraph", cache_dir=str(cache_dir))
45+
another_ds = load_dataset(str(text_dir), cache_dir=str(cache_dir))
46+
cache = Cache(
47+
cache_dir=str(cache_dir), dataset_name=text_dir.name, version="auto", hash="auto", sample_by="paragraph"
48+
)
49+
another_cache = Cache(cache_dir=str(cache_dir), dataset_name=text_dir.name, version="auto", hash="auto")
4350
assert cache.config_id.endswith("paragraph")
4451
assert not another_cache.config_id.endswith("paragraph")
4552
reloaded = cache.as_dataset()
@@ -50,27 +57,79 @@ def test_cache_auto_hash_with_custom_config(text_dir: Path):
5057
assert list(another_ds["train"]) == list(another_reloaded["train"])
5158

5259

53-
def test_cache_missing(text_dir: Path):
54-
load_dataset(str(text_dir))
55-
Cache(dataset_name=text_dir.name, version="auto", hash="auto").download_and_prepare()
60+
def test_cache_missing(text_dir: Path, tmp_path: Path):
61+
cache_dir = tmp_path / "test_cache_missing"
62+
load_dataset(str(text_dir), cache_dir=str(cache_dir))
63+
Cache(cache_dir=str(cache_dir), dataset_name=text_dir.name, version="auto", hash="auto").download_and_prepare()
5664
with pytest.raises(ValueError):
57-
Cache(dataset_name="missing", version="auto", hash="auto").download_and_prepare()
65+
Cache(cache_dir=str(cache_dir), dataset_name="missing", version="auto", hash="auto").download_and_prepare()
5866
with pytest.raises(ValueError):
59-
Cache(dataset_name=text_dir.name, hash="missing").download_and_prepare()
67+
Cache(cache_dir=str(cache_dir), dataset_name=text_dir.name, hash="missing").download_and_prepare()
6068
with pytest.raises(ValueError):
61-
Cache(dataset_name=text_dir.name, config_name="missing", version="auto", hash="auto").download_and_prepare()
69+
Cache(
70+
cache_dir=str(cache_dir), dataset_name=text_dir.name, config_name="missing", version="auto", hash="auto"
71+
).download_and_prepare()
6272

6373

6474
@pytest.mark.integration
65-
def test_cache_multi_configs():
75+
def test_cache_multi_configs(tmp_path: Path):
76+
cache_dir = tmp_path / "test_cache_multi_configs"
6677
repo_id = SAMPLE_DATASET_TWO_CONFIG_IN_METADATA
6778
dataset_name = repo_id.split("/")[-1]
6879
config_name = "v1"
69-
ds = load_dataset(repo_id, config_name)
70-
cache = Cache(dataset_name=dataset_name, repo_id=repo_id, config_name=config_name, version="auto", hash="auto")
80+
ds = load_dataset(repo_id, config_name, cache_dir=str(cache_dir))
81+
cache = Cache(
82+
cache_dir=str(cache_dir),
83+
dataset_name=dataset_name,
84+
repo_id=repo_id,
85+
config_name=config_name,
86+
version="auto",
87+
hash="auto",
88+
)
7189
reloaded = cache.as_dataset()
7290
assert list(ds) == list(reloaded)
7391
assert len(ds["train"]) == len(reloaded["train"])
7492
with pytest.raises(ValueError) as excinfo:
75-
Cache(dataset_name=dataset_name, repo_id=repo_id, config_name="missing", version="auto", hash="auto")
93+
Cache(
94+
cache_dir=str(cache_dir),
95+
dataset_name=dataset_name,
96+
repo_id=repo_id,
97+
config_name="missing",
98+
version="auto",
99+
hash="auto",
100+
)
101+
assert config_name in str(excinfo.value)
102+
103+
104+
@pytest.mark.integration
105+
def test_cache_single_config(tmp_path: Path):
106+
cache_dir = tmp_path / "test_cache_single_config"
107+
repo_id = SAMPLE_DATASET_SINGLE_CONFIG_IN_METADATA
108+
dataset_name = repo_id.split("/")[-1]
109+
config_name = "custom"
110+
ds = load_dataset(repo_id, cache_dir=str(cache_dir))
111+
cache = Cache(cache_dir=str(cache_dir), dataset_name=dataset_name, repo_id=repo_id, version="auto", hash="auto")
112+
reloaded = cache.as_dataset()
113+
assert list(ds) == list(reloaded)
114+
assert len(ds["train"]) == len(reloaded["train"])
115+
cache = Cache(
116+
cache_dir=str(cache_dir),
117+
dataset_name=dataset_name,
118+
config_name=config_name,
119+
repo_id=repo_id,
120+
version="auto",
121+
hash="auto",
122+
)
123+
reloaded = cache.as_dataset()
124+
assert list(ds) == list(reloaded)
125+
assert len(ds["train"]) == len(reloaded["train"])
126+
with pytest.raises(ValueError) as excinfo:
127+
Cache(
128+
cache_dir=str(cache_dir),
129+
dataset_name=dataset_name,
130+
repo_id=repo_id,
131+
config_name="missing",
132+
version="auto",
133+
hash="auto",
134+
)
76135
assert config_name in str(excinfo.value)

0 commit comments

Comments
 (0)