From 02d2d6fe91cab46470bc2ef3f881ff01d7d5057a Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Tue, 30 Jan 2024 19:40:28 +0100 Subject: [PATCH 1/3] fix reload cache with data-_dir --- src/datasets/packaged_modules/cache/cache.py | 60 +++++++++++++++++--- 1 file changed, 53 insertions(+), 7 deletions(-) diff --git a/src/datasets/packaged_modules/cache/cache.py b/src/datasets/packaged_modules/cache/cache.py index a378df23165..ad34df4fceb 100644 --- a/src/datasets/packaged_modules/cache/cache.py +++ b/src/datasets/packaged_modules/cache/cache.py @@ -2,13 +2,15 @@ import os import shutil import time +import warnings from pathlib import Path -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import pyarrow as pa import datasets import datasets.config +import datasets.data_files from datasets.naming import filenames_for_dataset_split @@ -79,17 +81,57 @@ def __init__( config_name: Optional[str] = None, version: Optional[str] = "0.0.0", hash: Optional[str] = None, + base_path: Optional[str] = None, + info: Optional[datasets.DatasetInfo] = None, + features: Optional[datasets.Features] = None, + token: Optional[Union[bool, str]] = None, + use_auth_token="deprecated", repo_id: Optional[str] = None, - **kwargs, + data_files: Optional[Union[str, list, dict, datasets.data_files.DataFilesDict]] = None, + data_dir: Optional[str] = None, + storage_options: Optional[dict] = None, + writer_batch_size: Optional[int] = None, + name="deprecated", + **config_kwargs, ): + if use_auth_token != "deprecated": + warnings.warn( + "'use_auth_token' was deprecated in favor of 'token' in version 2.14.0 and will be removed in 3.0.0.\n" + f"You can remove this warning by passing 'token={use_auth_token}' instead.", + FutureWarning, + ) + token = use_auth_token + if name != "deprecated": + warnings.warn( + "Parameter 'name' was renamed to 'config_name' in version 2.3.0 and will be removed in 3.0.0.", + category=FutureWarning, + ) + config_name = name if repo_id is None and dataset_name is None: raise ValueError("repo_id or dataset_name is required for the Cache dataset builder") + if data_files is not None: + config_kwargs["data_files"] = data_files + if data_dir is not None: + config_kwargs["data_dir"] = data_dir if hash == "auto" and version == "auto": - config_name, version, hash = _find_hash_in_cache( - dataset_name=repo_id or dataset_name, - config_name=config_name, - cache_dir=cache_dir, + # First we try to find a folder that takes the config_kwargs into account + # e.g. with "default-data_dir=data%2Ffortran" as config_id + config_id = self.BUILDER_CONFIG_CLASS(config_name or "default").create_config_id( + config_kwargs=config_kwargs, custom_features=features ) + try: + config_name, version, hash = _find_hash_in_cache( + dataset_name=repo_id or dataset_name, + config_name=config_id, + cache_dir=cache_dir, + ) + except ValueError: + # Otherwise we fall back on the requested config name + config_name, version, hash = _find_hash_in_cache( + dataset_name=repo_id or dataset_name, + config_name=config_name, + cache_dir=cache_dir, + ) elif hash == "auto" or version == "auto": raise NotImplementedError("Pass both hash='auto' and version='auto' instead") super().__init__( @@ -98,8 +140,12 @@ def __init__( config_name=config_name, version=version, hash=hash, + base_path=base_path, + info=info, + token=token, repo_id=repo_id, - **kwargs, + storage_options=storage_options, + writer_batch_size=writer_batch_size, ) def _info(self) -> datasets.DatasetInfo: From 8f46544dd01670794f7ddae65ca0c87a8a107ac6 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Tue, 30 Jan 2024 19:52:02 +0100 Subject: [PATCH 2/3] add test --- tests/packaged_modules/test_cache.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/packaged_modules/test_cache.py b/tests/packaged_modules/test_cache.py index 8d37fb0a2d9..aed1496cee8 100644 --- a/tests/packaged_modules/test_cache.py +++ b/tests/packaged_modules/test_cache.py @@ -35,6 +35,21 @@ def test_cache_auto_hash(text_dir: Path): assert list(ds["train"]) == list(reloaded["train"]) +def test_cache_auto_hash_with_custom_config(text_dir: Path): + ds = load_dataset(str(text_dir), sample_by="paragraph") + another_ds = load_dataset(str(text_dir)) + cache = Cache(dataset_name=text_dir.name, version="auto", hash="auto", sample_by="paragraph") + another_cache = Cache(dataset_name=text_dir.name, version="auto", hash="auto") + assert cache.config_id.endswith("paragraph") + assert not another_cache.config_id.endswith("paragraph") + reloaded = cache.as_dataset() + another_reloaded = another_cache.as_dataset() + assert list(ds) == list(reloaded) + assert list(ds["train"]) == list(reloaded["train"]) + assert list(another_ds) == list(another_reloaded) + assert list(another_ds["train"]) == list(another_reloaded["train"]) + + def test_cache_missing(text_dir: Path): load_dataset(str(text_dir)) Cache(dataset_name=text_dir.name, version="auto", hash="auto").download_and_prepare() From 89f7eb7a49f97716ae1ea802094cd1947e824857 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 31 Jan 2024 09:24:25 +0100 Subject: [PATCH 3/3] remove unused code --- src/datasets/packaged_modules/cache/cache.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/src/datasets/packaged_modules/cache/cache.py b/src/datasets/packaged_modules/cache/cache.py index ad34df4fceb..54c4737be11 100644 --- a/src/datasets/packaged_modules/cache/cache.py +++ b/src/datasets/packaged_modules/cache/cache.py @@ -119,19 +119,11 @@ def __init__( config_id = self.BUILDER_CONFIG_CLASS(config_name or "default").create_config_id( config_kwargs=config_kwargs, custom_features=features ) - try: - config_name, version, hash = _find_hash_in_cache( - dataset_name=repo_id or dataset_name, - config_name=config_id, - cache_dir=cache_dir, - ) - except ValueError: - # Otherwise we fall back on the requested config name - config_name, version, hash = _find_hash_in_cache( - dataset_name=repo_id or dataset_name, - config_name=config_name, - cache_dir=cache_dir, - ) + config_name, version, hash = _find_hash_in_cache( + dataset_name=repo_id or dataset_name, + config_name=config_id, + cache_dir=cache_dir, + ) elif hash == "auto" or version == "auto": raise NotImplementedError("Pass both hash='auto' and version='auto' instead") super().__init__(