Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions src/datasets/packaged_modules/cache/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import os
import shutil
import time
import warnings
from pathlib import Path
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union

import pyarrow as pa

import datasets
import datasets.config
import datasets.data_files
from datasets.naming import filenames_for_dataset_split


Expand Down Expand Up @@ -79,15 +81,47 @@ def __init__(
config_name: Optional[str] = None,
version: Optional[str] = "0.0.0",
hash: Optional[str] = None,
base_path: Optional[str] = None,
info: Optional[datasets.DatasetInfo] = None,
features: Optional[datasets.Features] = None,
token: Optional[Union[bool, str]] = None,
use_auth_token="deprecated",
repo_id: Optional[str] = None,
**kwargs,
data_files: Optional[Union[str, list, dict, datasets.data_files.DataFilesDict]] = None,
data_dir: Optional[str] = None,
storage_options: Optional[dict] = None,
writer_batch_size: Optional[int] = None,
name="deprecated",
**config_kwargs,
):
if use_auth_token != "deprecated":
warnings.warn(
"'use_auth_token' was deprecated in favor of 'token' in version 2.14.0 and will be removed in 3.0.0.\n"
f"You can remove this warning by passing 'token={use_auth_token}' instead.",
FutureWarning,
)
token = use_auth_token
if name != "deprecated":
warnings.warn(
"Parameter 'name' was renamed to 'config_name' in version 2.3.0 and will be removed in 3.0.0.",
category=FutureWarning,
)
config_name = name
if repo_id is None and dataset_name is None:
raise ValueError("repo_id or dataset_name is required for the Cache dataset builder")
if data_files is not None:
config_kwargs["data_files"] = data_files
if data_dir is not None:
config_kwargs["data_dir"] = data_dir
if hash == "auto" and version == "auto":
# First we try to find a folder that takes the config_kwargs into account
# e.g. with "default-data_dir=data%2Ffortran" as config_id
config_id = self.BUILDER_CONFIG_CLASS(config_name or "default").create_config_id(
config_kwargs=config_kwargs, custom_features=features
)
config_name, version, hash = _find_hash_in_cache(
dataset_name=repo_id or dataset_name,
config_name=config_name,
config_name=config_id,
cache_dir=cache_dir,
)
elif hash == "auto" or version == "auto":
Expand All @@ -98,8 +132,12 @@ def __init__(
config_name=config_name,
version=version,
hash=hash,
base_path=base_path,
info=info,
token=token,
repo_id=repo_id,
**kwargs,
storage_options=storage_options,
writer_batch_size=writer_batch_size,
)

def _info(self) -> datasets.DatasetInfo:
Expand Down
15 changes: 15 additions & 0 deletions tests/packaged_modules/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,21 @@ def test_cache_auto_hash(text_dir: Path):
assert list(ds["train"]) == list(reloaded["train"])


def test_cache_auto_hash_with_custom_config(text_dir: Path):
ds = load_dataset(str(text_dir), sample_by="paragraph")
another_ds = load_dataset(str(text_dir))
cache = Cache(dataset_name=text_dir.name, version="auto", hash="auto", sample_by="paragraph")
another_cache = Cache(dataset_name=text_dir.name, version="auto", hash="auto")
assert cache.config_id.endswith("paragraph")
assert not another_cache.config_id.endswith("paragraph")
reloaded = cache.as_dataset()
another_reloaded = another_cache.as_dataset()
assert list(ds) == list(reloaded)
assert list(ds["train"]) == list(reloaded["train"])
assert list(another_ds) == list(another_reloaded)
assert list(another_ds["train"]) == list(another_reloaded["train"])


def test_cache_missing(text_dir: Path):
load_dataset(str(text_dir))
Cache(dataset_name=text_dir.name, version="auto", hash="auto").download_and_prepare()
Expand Down