Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ def _create_builder_config(
builder_config = None

# try default config
if config_name is None and self.BUILDER_CONFIGS and not config_kwargs:
if config_name is None and self.BUILDER_CONFIGS:
if self.DEFAULT_CONFIG_NAME is not None:
builder_config = self.builder_configs.get(self.DEFAULT_CONFIG_NAME)
logger.info(f"No config specified, defaulting to: {self.dataset_name}/{builder_config.name}")
Expand Down
149 changes: 109 additions & 40 deletions src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# Lint as: python3
"""Access datasets."""
import filecmp
import glob
import importlib
import inspect
import json
Expand Down Expand Up @@ -1070,15 +1071,18 @@ def get_module(self) -> DatasetModule:
default_builder_kwargs=default_builder_kwargs,
)
else:
builder_configs, default_config_name = None, None
builder_configs: List[BuilderConfig] = [
import_main_class(module_path).BUILDER_CONFIG_CLASS(
data_files=data_files,
**default_builder_kwargs,
)
]
default_config_name = None
builder_kwargs = {
"hash": hash,
"base_path": self.path,
"dataset_name": camelcase_to_snakecase(Path(self.path).name),
}
if self.data_files is not None or not metadata_configs:
builder_kwargs["data_files"] = data_files
builder_kwargs.update(default_builder_kwargs) # from _EXTENSION_TO_MODULE
# this file is deprecated and was created automatically in old versions of push_to_hub
if os.path.isfile(os.path.join(self.path, config.DATASETDICT_INFOS_FILENAME)):
with open(os.path.join(self.path, config.DATASETDICT_INFOS_FILENAME), encoding="utf-8") as f:
Expand Down Expand Up @@ -1263,16 +1267,19 @@ def get_module(self) -> DatasetModule:
download_config=self.download_config,
)
else:
builder_configs, default_config_name = None, None
builder_configs: List[BuilderConfig] = [
import_main_class(module_path).BUILDER_CONFIG_CLASS(
data_files=data_files,
**default_builder_kwargs,
)
]
default_config_name = None
builder_kwargs = {
"hash": hash,
"base_path": hf_hub_url(self.name, "", revision=self.revision),
"repo_id": self.name,
"dataset_name": camelcase_to_snakecase(Path(self.name).name),
}
if self.data_files is not None or not metadata_configs:
builder_kwargs["data_files"] = data_files
builder_kwargs.update(default_builder_kwargs) # from _EXTENSION_TO_MODULE
download_config = self.download_config.copy()
if download_config.download_desc is None:
download_config.download_desc = "Downloading metadata"
Expand Down Expand Up @@ -1515,9 +1522,11 @@ class CachedDatasetModuleFactory(_DatasetModuleFactory):
def __init__(
self,
name: str,
cache_dir: Optional[str] = None,
dynamic_modules_path: Optional[str] = None,
):
self.name = name
self.cache_dir = cache_dir
self.dynamic_modules_path = dynamic_modules_path
assert self.name.count("/") <= 1

Expand All @@ -1529,38 +1538,86 @@ def get_module(self) -> DatasetModule:
if os.path.isdir(importable_directory_path)
else None
)
if not hashes:
raise FileNotFoundError(f"Dataset {self.name} is not cached in {dynamic_modules_path}")
# get most recent

def _get_modification_time(module_hash):
return (Path(importable_directory_path) / module_hash / (self.name.split("/")[-1] + ".py")).stat().st_mtime
if hashes:
# get most recent
def _get_modification_time(module_hash):
return (
(Path(importable_directory_path) / module_hash / (self.name.split("/")[-1] + ".py"))
.stat()
.st_mtime
)

hash = sorted(hashes, key=_get_modification_time)[-1]
warning_msg = (
f"Using the latest cached version of the module from {os.path.join(importable_directory_path, hash)} "
f"(last modified on {time.ctime(_get_modification_time(hash))}) since it "
f"couldn't be found locally at {self.name}."
)
if not config.HF_DATASETS_OFFLINE:
warning_msg += ", or remotely on the Hugging Face Hub."
logger.warning(warning_msg)
# make the new module to be noticed by the import system
module_path = ".".join(
[
os.path.basename(dynamic_modules_path),
"datasets",
self.name.replace("/", "--"),
hash,
self.name.split("/")[-1],
hash = sorted(hashes, key=_get_modification_time)[-1]
warning_msg = (
f"Using the latest cached version of the module from {os.path.join(importable_directory_path, hash)} "
f"(last modified on {time.ctime(_get_modification_time(hash))}) since it "
f"couldn't be found locally at {self.name}."
)
if not config.HF_DATASETS_OFFLINE:
warning_msg += ", or remotely on the Hugging Face Hub."
logger.warning(warning_msg)
# make the new module to be noticed by the import system
module_path = ".".join(
[
os.path.basename(dynamic_modules_path),
"datasets",
self.name.replace("/", "--"),
hash,
self.name.split("/")[-1],
]
)
importlib.invalidate_caches()
builder_kwargs = {
"hash": hash,
"repo_id": self.name,
}
return DatasetModule(module_path, hash, builder_kwargs)
cache_dir = os.path.expanduser(str(self.cache_dir or config.HF_DATASETS_CACHE))
cached_datasets_directory_path_root = os.path.join(cache_dir, self.name.replace("/", "___"))
cached_directory_paths = [
cached_directory_path
for cached_directory_path in glob.glob(os.path.join(cached_datasets_directory_path_root, "*", "*", "*"))
if os.path.isdir(cached_directory_path)
]
if cached_directory_paths:
# get most recent
def _get_modification_time(cached_directory_path):
return (Path(cached_directory_path)).stat().st_mtime

hash = Path(sorted(cached_directory_paths, key=_get_modification_time)[-1]).name
cached_directory_paths = [
cached_directory_path
for cached_directory_path in cached_directory_paths
if Path(cached_directory_path).name == hash
]
)
importlib.invalidate_caches()
builder_kwargs = {
"hash": hash,
"repo_id": self.name,
}
return DatasetModule(module_path, hash, builder_kwargs)
warning_msg = (
f"Using the latest cached version of the dataset from {cached_datasets_directory_path_root}/*/*/{hash}"
f"(last modified on {time.ctime(_get_modification_time(cached_directory_paths[0]))}) since it "
f"couldn't be found locally at {self.name}."
)
if not config.HF_DATASETS_OFFLINE:
warning_msg += ", or remotely on the Hugging Face Hub."
logger.warning(warning_msg)
module_path = "datasets.packaged_modules.cache.cache"
builder_kwargs = {
"hash": hash,
"repo_id": self.name,
"dataset_name": self.name.split("/")[-1],
}
dataset_infos = DatasetInfosDict()
builder_configs = []
for cached_directory_path in cached_directory_paths:
config_name = Path(cached_directory_path).parts[-3]
dataset_infos[config_name] = DatasetInfo.from_directory(cached_directory_path)
builder_configs.append(import_main_class(module_path).BUILDER_CONFIG_CLASS(name=config_name))
return DatasetModule(
module_path,
hash,
builder_kwargs,
dataset_infos=dataset_infos,
builder_configs_parameters=BuilderConfigsParameters(builder_configs=builder_configs),
)
raise FileNotFoundError(f"Dataset {self.name} is not cached in {self.cache_dir}")


class CachedMetricModuleFactory(_MetricModuleFactory):
Expand Down Expand Up @@ -1620,6 +1677,7 @@ def dataset_module_factory(
dynamic_modules_path: Optional[str] = None,
data_dir: Optional[str] = None,
data_files: Optional[Union[Dict, List, str, DataFilesDict]] = None,
cache_dir: Optional[str] = None,
trust_remote_code: Optional[bool] = None,
**download_kwargs,
) -> DatasetModule:
Expand Down Expand Up @@ -1662,6 +1720,10 @@ def dataset_module_factory(
data_dir (:obj:`str`, optional): Directory with the data files. Used only if `data_files` is not specified,
in which case it's equal to pass `os.path.join(data_dir, "**")` as `data_files`.
data_files (:obj:`Union[Dict, List, str]`, optional): Defining the data_files of the dataset configuration.
cache_dir (`str`, *optional*):
Directory to read/write data. Defaults to `"~/.cache/huggingface/datasets"`.

<Added version="2.16.0"/>
trust_remote_code (`bool`, defaults to `True`):
Whether or not to allow for datasets defined on the Hub using a dataset script. This option
should only be set to `True` for repositories you trust and in which you have read the code, as it will
Expand Down Expand Up @@ -1809,7 +1871,9 @@ def dataset_module_factory(
except Exception as e1:
# All the attempts failed, before raising the error we should check if the module is already cached
try:
return CachedDatasetModuleFactory(path, dynamic_modules_path=dynamic_modules_path).get_module()
return CachedDatasetModuleFactory(
path, dynamic_modules_path=dynamic_modules_path, cache_dir=cache_dir
).get_module()
except Exception:
# If it's not in the cache, then it doesn't exist.
if isinstance(e1, OfflineModeIsEnabled):
Expand Down Expand Up @@ -2176,6 +2240,7 @@ def load_dataset_builder(
download_mode=download_mode,
data_dir=data_dir,
data_files=data_files,
cache_dir=cache_dir,
trust_remote_code=trust_remote_code,
)
# Get dataset builder class from the processing script
Expand All @@ -2196,7 +2261,11 @@ def load_dataset_builder(
hash, dataset_module.builder_configs_parameters.metadata_configs[config_name]
)

if path in _PACKAGED_DATASETS_MODULES and data_files is None:
if (
path in _PACKAGED_DATASETS_MODULES
and data_files is None
and dataset_module.builder_configs_parameters.builder_configs[0].data_files is None
):
error_msg = f"Please specify the data files or data directory to load for the {path} dataset builder."
example_extensions = [
extension for extension in _EXTENSION_TO_MODULE if _EXTENSION_TO_MODULE[extension] == path
Expand Down
1 change: 1 addition & 0 deletions src/datasets/packaged_modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from .arrow import arrow
from .audiofolder import audiofolder
from .cache import cache # noqa F401
from .csv import csv
from .imagefolder import imagefolder
from .json import json
Expand Down
Empty file.
59 changes: 59 additions & 0 deletions src/datasets/packaged_modules/cache/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os
import shutil
from typing import List, Optional

import pyarrow as pa

import datasets
from datasets.naming import filenames_for_dataset_split


logger = datasets.utils.logging.get_logger(__name__)


class Cache(datasets.ArrowBasedBuilder):
def _info(self) -> datasets.DatasetInfo:
return datasets.DatasetInfo()

def download_and_prepare(self, output_dir: Optional[str] = None, *args, **kwargs):
if not os.path.exists(self.cache_dir):
raise ValueError(f"Cache directory for {self.dataset_name} doesn't exist at {self.cache_dir}")
if output_dir is not None:
shutil.copytree(self.cache_dir, output_dir)

def _split_generators(self, dl_manager):
# used to stream from cache
if isinstance(self.info.splits, datasets.SplitDict):
split_infos: List[datasets.SplitInfo] = list(self.info.splits.values())
else:
raise ValueError(f"Missing splits info for {self.dataset_name} in cache directory {self.cache_dir}")
return [
datasets.SplitGenerator(
name=split_info.name,
gen_kwargs={
"files": filenames_for_dataset_split(
self.cache_dir,
dataset_name=self.dataset_name,
split=split_info.name,
filetype_suffix="arrow",
shard_lengths=split_info.shard_lengths,
)
},
)
for split_info in split_infos
]

def _generate_tables(self, files):
# used to stream from cache
for file_idx, file in enumerate(files):
with open(file, "rb") as f:
try:
for batch_idx, record_batch in enumerate(pa.ipc.open_stream(f)):
pa_table = pa.Table.from_batches([record_batch])
# Uncomment for debugging (will print the Arrow table size and elements)
# logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
# logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))
yield f"{file_idx}_{batch_idx}", pa_table
except ValueError as e:
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
raise
10 changes: 10 additions & 0 deletions tests/fixtures/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,16 @@ def text2_path(tmp_path_factory):
return path


@pytest.fixture(scope="session")
def text_dir(tmp_path_factory):
data = ["0", "1", "2", "3"]
path = tmp_path_factory.mktemp("data_text_dir") / "dataset.txt"
with open(path, "w") as f:
for item in data:
f.write(item + "\n")
return path.parent


@pytest.fixture(scope="session")
def text_dir_with_unsupported_extension(tmp_path_factory):
data = ["0", "1", "2", "3"]
Expand Down
62 changes: 62 additions & 0 deletions tests/packaged_modules/test_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from pathlib import Path

import pytest

from datasets import load_dataset
from datasets.load import configure_builder_class
from datasets.packaged_modules.cache.cache import Cache


SAMPLE_DATASET_TWO_CONFIG_IN_METADATA = "hf-internal-testing/audiofolder_two_configs_in_metadata"


def test_cache(text_dir: Path):
ds = load_dataset(str(text_dir))
hash = Path(ds["train"].cache_files[0]["filename"]).parts[-2]
cache = Cache(dataset_name=text_dir.name, hash=hash)
reloaded = cache.as_dataset()
assert list(ds) == list(reloaded)
assert list(ds["train"]) == list(reloaded["train"])


def test_cache_streaming(text_dir: Path):
ds = load_dataset(str(text_dir))
hash = Path(ds["train"].cache_files[0]["filename"]).parts[-2]
cache = Cache(dataset_name=text_dir.name, hash=hash)
reloaded = cache.as_streaming_dataset()
assert list(ds) == list(reloaded)
assert list(ds["train"]) == list(reloaded["train"])


def test_cache_missing(text_dir: Path):
ds = load_dataset(str(text_dir))
hash = Path(ds["train"].cache_files[0]["filename"]).parts[-2]
Cache(dataset_name=text_dir.name, hash=hash).download_and_prepare()
with pytest.raises(ValueError):
Cache(dataset_name="missing", hash=hash).download_and_prepare()
with pytest.raises(ValueError):
Cache(dataset_name=text_dir.name, hash="missing").download_and_prepare()
with pytest.raises(ValueError):
Cache(dataset_name=text_dir.name, config_name="missing", hash=hash).download_and_prepare()


@pytest.mark.integration
def test_cache_multi_configs():
repo_id = SAMPLE_DATASET_TWO_CONFIG_IN_METADATA
dataset_name = repo_id.split("/")[-1]
config_name = "v1"
ds = load_dataset(repo_id, config_name)
hash = Path(ds["train"].cache_files[0]["filename"]).parts[-2]
builder_cls = configure_builder_class(
Cache,
builder_configs=[Cache.BUILDER_CONFIG_CLASS(name="v1"), Cache.BUILDER_CONFIG_CLASS(name="v2")],
default_config_name=None,
dataset_name=dataset_name,
)
cache = builder_cls(dataset_name=dataset_name, repo_id=repo_id, config_name=config_name, hash=hash)
reloaded = cache.as_dataset()
assert list(ds) == list(reloaded)
assert len(ds["train"]) == len(reloaded["train"])
with pytest.raises(ValueError) as excinfo:
builder_cls(dataset_name=dataset_name, repo_id=repo_id, config_name="missing", hash=hash)
assert config_name in str(excinfo.value)
Loading