diff --git a/src/datasets/packaged_modules/cache/cache.py b/src/datasets/packaged_modules/cache/cache.py index a378df23165..54c4737be11 100644 --- a/src/datasets/packaged_modules/cache/cache.py +++ b/src/datasets/packaged_modules/cache/cache.py @@ -2,13 +2,15 @@ import os import shutil import time +import warnings from pathlib import Path -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import pyarrow as pa import datasets import datasets.config +import datasets.data_files from datasets.naming import filenames_for_dataset_split @@ -79,15 +81,47 @@ def __init__( config_name: Optional[str] = None, version: Optional[str] = "0.0.0", hash: Optional[str] = None, + base_path: Optional[str] = None, + info: Optional[datasets.DatasetInfo] = None, + features: Optional[datasets.Features] = None, + token: Optional[Union[bool, str]] = None, + use_auth_token="deprecated", repo_id: Optional[str] = None, - **kwargs, + data_files: Optional[Union[str, list, dict, datasets.data_files.DataFilesDict]] = None, + data_dir: Optional[str] = None, + storage_options: Optional[dict] = None, + writer_batch_size: Optional[int] = None, + name="deprecated", + **config_kwargs, ): + if use_auth_token != "deprecated": + warnings.warn( + "'use_auth_token' was deprecated in favor of 'token' in version 2.14.0 and will be removed in 3.0.0.\n" + f"You can remove this warning by passing 'token={use_auth_token}' instead.", + FutureWarning, + ) + token = use_auth_token + if name != "deprecated": + warnings.warn( + "Parameter 'name' was renamed to 'config_name' in version 2.3.0 and will be removed in 3.0.0.", + category=FutureWarning, + ) + config_name = name if repo_id is None and dataset_name is None: raise ValueError("repo_id or dataset_name is required for the Cache dataset builder") + if data_files is not None: + config_kwargs["data_files"] = data_files + if data_dir is not None: + config_kwargs["data_dir"] = data_dir if hash == "auto" and version == "auto": + # First we try to find a folder that takes the config_kwargs into account + # e.g. with "default-data_dir=data%2Ffortran" as config_id + config_id = self.BUILDER_CONFIG_CLASS(config_name or "default").create_config_id( + config_kwargs=config_kwargs, custom_features=features + ) config_name, version, hash = _find_hash_in_cache( dataset_name=repo_id or dataset_name, - config_name=config_name, + config_name=config_id, cache_dir=cache_dir, ) elif hash == "auto" or version == "auto": @@ -98,8 +132,12 @@ def __init__( config_name=config_name, version=version, hash=hash, + base_path=base_path, + info=info, + token=token, repo_id=repo_id, - **kwargs, + storage_options=storage_options, + writer_batch_size=writer_batch_size, ) def _info(self) -> datasets.DatasetInfo: diff --git a/tests/packaged_modules/test_cache.py b/tests/packaged_modules/test_cache.py index 8d37fb0a2d9..aed1496cee8 100644 --- a/tests/packaged_modules/test_cache.py +++ b/tests/packaged_modules/test_cache.py @@ -35,6 +35,21 @@ def test_cache_auto_hash(text_dir: Path): assert list(ds["train"]) == list(reloaded["train"]) +def test_cache_auto_hash_with_custom_config(text_dir: Path): + ds = load_dataset(str(text_dir), sample_by="paragraph") + another_ds = load_dataset(str(text_dir)) + cache = Cache(dataset_name=text_dir.name, version="auto", hash="auto", sample_by="paragraph") + another_cache = Cache(dataset_name=text_dir.name, version="auto", hash="auto") + assert cache.config_id.endswith("paragraph") + assert not another_cache.config_id.endswith("paragraph") + reloaded = cache.as_dataset() + another_reloaded = another_cache.as_dataset() + assert list(ds) == list(reloaded) + assert list(ds["train"]) == list(reloaded["train"]) + assert list(another_ds) == list(another_reloaded) + assert list(another_ds["train"]) == list(another_reloaded["train"]) + + def test_cache_missing(text_dir: Path): load_dataset(str(text_dir)) Cache(dataset_name=text_dir.name, version="auto", hash="auto").download_and_prepare()