Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 34 additions & 7 deletions src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import time
import urllib
import warnings
from dataclasses import dataclass
from dataclasses import dataclass, fields
from functools import partial
from pathlib import Path
from typing import Dict, Iterable, Mapping, Optional, Tuple, Union
Expand All @@ -46,12 +46,12 @@
ReadInstruction,
)
from .arrow_writer import ArrowWriter, BeamWriter, ParquetWriter, SchemaInferenceError
from .data_files import DataFilesDict, sanitize_patterns
from .data_files import DataFilesDict, DataFilesPatternsDict, sanitize_patterns
from .dataset_dict import DatasetDict, IterableDatasetDict
from .download.download_config import DownloadConfig
from .download.download_manager import DownloadManager, DownloadMode
from .download.mock_download_manager import MockDownloadManager
from .download.streaming_download_manager import StreamingDownloadManager, xopen
from .download.streaming_download_manager import StreamingDownloadManager, xjoin, xopen
from .features import Features
from .filesystems import (
is_remote_filesystem,
Expand Down Expand Up @@ -128,7 +128,7 @@ class BuilderConfig:
name: str = "default"
version: Optional[Union[utils.Version, str]] = utils.Version("0.0.0")
data_dir: Optional[str] = None
data_files: Optional[DataFilesDict] = None
data_files: Optional[Union[DataFilesDict, DataFilesPatternsDict]] = None
description: Optional[str] = None

def __post_init__(self):
Expand All @@ -139,7 +139,7 @@ def __post_init__(self):
f"Bad characters from black list '{INVALID_WINDOWS_CHARACTERS_IN_PATH}' found in '{self.name}'. "
f"They could create issues when creating a directory for this config on Windows filesystem."
)
if self.data_files is not None and not isinstance(self.data_files, DataFilesDict):
if self.data_files is not None and not isinstance(self.data_files, (DataFilesDict, DataFilesPatternsDict)):
raise ValueError(f"Expected a DataFilesDict in data_files but got {self.data_files}")

def __eq__(self, o):
Expand Down Expand Up @@ -213,6 +213,25 @@ def create_config_id(
else:
return self.name

def _resolve_data_files(self, base_path: str, download_config: DownloadConfig) -> None:
if isinstance(self.data_files, DataFilesPatternsDict):
base_path = xjoin(base_path, self.data_dir) if self.data_dir else base_path
self.data_files = self.data_files.resolve(base_path, download_config)

def _update_hash(self, hash: str) -> str:
"""
Used to update hash of packaged modules which is used for creating unique cache directories to reflect
different config parameters which are passed in metadata from readme.
"""
m = Hasher()
m.update(hash)
for field in sorted(field.name for field in fields(self)):
try:
m.update(getattr(self, field))
except TypeError:
continue
return m.hexdigest()


class DatasetBuilder:
"""Abstract base class for all datasets.
Expand Down Expand Up @@ -376,6 +395,8 @@ def __init__(
custom_features=features,
**config_kwargs,
)
if self.hash is not None:
self.hash = self.config._update_hash(hash)

# prepare info: DatasetInfo are a standardized dataclass across all datasets
# Prefill datasetinfo
Expand Down Expand Up @@ -555,7 +576,7 @@ def _create_builder_config(

# otherwise use the config_kwargs to overwrite the attributes
else:
builder_config = copy.deepcopy(builder_config)
builder_config = copy.deepcopy(builder_config) if config_kwargs else builder_config
for key, value in config_kwargs.items():
if value is not None:
if not hasattr(builder_config, key):
Expand All @@ -565,6 +586,12 @@ def _create_builder_config(
if not builder_config.name:
raise ValueError(f"BuilderConfig must have a name, got {builder_config.name}")

# resolve data files if needed
builder_config._resolve_data_files(
base_path=self.base_path,
download_config=DownloadConfig(token=self.token, storage_options=self.storage_options),
)

# compute the config id that is going to be used for caching
config_id = builder_config.create_config_id(
config_kwargs,
Expand All @@ -590,7 +617,7 @@ def _create_builder_config(
@classproperty
@classmethod
@memoize()
def builder_configs(cls):
def builder_configs(cls) -> Dict[str, BuilderConfig]:
"""Dictionary of pre-defined configurations for this builder class."""
configs = {config.name: config for config in cls.BUILDER_CONFIGS}
if len(configs) != len(cls.BUILDER_CONFIGS):
Expand Down
91 changes: 91 additions & 0 deletions src/datasets/data_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,3 +698,94 @@ def filter_extensions(self, extensions: List[str]) -> "DataFilesDict":
for key, data_files_list in self.items():
out[key] = data_files_list.filter_extensions(extensions)
return out


class DataFilesPatternsList(List[str]):
"""
List of data files patterns (absolute local paths or URLs).
For each pattern there should also be a list of allowed extensions
to keep, or a None ot keep all the files for the pattern.
"""

def __init__(
self,
patterns: List[str],
allowed_extensions: List[Optional[List[str]]],
):
super().__init__(patterns)
self.allowed_extensions = allowed_extensions

def __add__(self, other):
return DataFilesList([*self, *other], self.allowed_extensions + other.allowed_extensions)

@classmethod
def from_patterns(
cls, patterns: List[str], allowed_extensions: Optional[List[str]] = None
) -> "DataFilesPatternsDict":
return cls(patterns, [allowed_extensions] * len(patterns))

def resolve(
self,
base_path: str,
download_config: Optional[DownloadConfig] = None,
) -> "DataFilesList":
base_path = base_path if base_path is not None else Path().resolve().as_posix()
data_files = []
for pattern, allowed_extensions in zip(self, self.allowed_extensions):
try:
data_files.extend(
resolve_pattern(
pattern,
base_path=base_path,
allowed_extensions=allowed_extensions,
download_config=download_config,
)
)
except FileNotFoundError:
if not has_magic(pattern):
raise
origin_metadata = _get_origin_metadata(data_files, download_config=download_config)
return DataFilesList(data_files, origin_metadata)

def filter_extensions(self, extensions: List[str]) -> "DataFilesList":
return DataFilesPatternsList(
self, [allowed_extensions + extensions for allowed_extensions in self.allowed_extensions]
)


class DataFilesPatternsDict(Dict[str, DataFilesPatternsList]):
"""
Dict of split_name -> list of data files patterns (absolute local paths or URLs).
"""

@classmethod
def from_patterns(
cls, patterns: Dict[str, List[str]], allowed_extensions: Optional[List[str]] = None
) -> "DataFilesPatternsDict":
out = cls()
for key, patterns_for_key in patterns.items():
out[key] = (
DataFilesPatternsList.from_patterns(
patterns_for_key,
allowed_extensions=allowed_extensions,
)
if not isinstance(patterns_for_key, DataFilesPatternsList)
else patterns_for_key
)
return out

def resolve(
self,
base_path: str,
download_config: Optional[DownloadConfig] = None,
) -> "DataFilesDict":
out = DataFilesDict()
for key, data_files_patterns_list in self.items():
out[key] = data_files_patterns_list.resolve(base_path, download_config)
return out

def filter_extensions(self, extensions: List[str]) -> "DataFilesPatternsDict":
out = type(self)()
for key, data_files_patterns_list in self.items():
out[key] = data_files_patterns_list.filter_extensions(extensions)
return out
55 changes: 13 additions & 42 deletions src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
DEFAULT_PATTERNS_ALL,
DataFilesDict,
DataFilesList,
DataFilesPatternsDict,
DataFilesPatternsList,
EmptyDatasetError,
get_data_patterns,
get_metadata_patterns,
Expand All @@ -51,7 +53,6 @@
from .download.streaming_download_manager import StreamingDownloadManager, xbasename, xglob, xjoin
from .exceptions import DataFilesNotFoundError, DatasetNotFoundError
from .features import Features
from .fingerprint import Hasher
from .info import DatasetInfo, DatasetInfosDict
from .iterable_dataset import IterableDataset
from .metric import Metric
Expand Down Expand Up @@ -589,28 +590,12 @@ def infer_module_for_data_files(
return module_name, default_builder_kwargs


def update_hash_with_config_parameters(hash: str, config_parameters: dict) -> str:
"""
Used to update hash of packaged modules which is used for creating unique cache directories to reflect
different config parameters which are passed in metadata from readme.
"""
params_to_exclude = {"config_name", "version", "description"}
params_to_add_to_hash = {
param: value for param, value in sorted(config_parameters.items()) if param not in params_to_exclude
}
m = Hasher()
m.update(hash)
m.update(params_to_add_to_hash)
return m.hexdigest()


def create_builder_configs_from_metadata_configs(
module_path: str,
metadata_configs: MetadataConfigs,
supports_metadata: bool,
base_path: Optional[str] = None,
default_builder_kwargs: Dict[str, Any] = None,
download_config: Optional[DownloadConfig] = None,
) -> Tuple[List[BuilderConfig], str]:
builder_cls = import_main_class(module_path)
builder_config_cls = builder_cls.BUILDER_CONFIG_CLASS
Expand All @@ -622,18 +607,16 @@ def create_builder_configs_from_metadata_configs(
for config_name, config_params in metadata_configs.items():
config_data_files = config_params.get("data_files")
config_data_dir = config_params.get("data_dir")
config_base_path = base_path + "/" + config_data_dir if config_data_dir else base_path
config_base_path = xjoin(base_path, config_data_dir) if config_data_dir else base_path
try:
config_patterns = (
sanitize_patterns(config_data_files)
if config_data_files is not None
else get_data_patterns(config_base_path)
)
config_data_files_dict = DataFilesDict.from_patterns(
config_data_files_dict = DataFilesPatternsDict.from_patterns(
config_patterns,
base_path=config_base_path,
allowed_extensions=ALL_ALLOWED_EXTENSIONS,
download_config=download_config,
)
except EmptyDatasetError as e:
raise EmptyDatasetError(
Expand All @@ -646,16 +629,13 @@ def create_builder_configs_from_metadata_configs(
except FileNotFoundError:
config_metadata_patterns = None
if config_metadata_patterns is not None:
config_metadata_data_files_list = DataFilesList.from_patterns(
config_metadata_patterns, base_path=base_path
config_metadata_data_files_list = DataFilesPatternsList.from_patterns(config_metadata_patterns)
config_data_files_dict = DataFilesPatternsDict(
{
split: data_files_list + config_metadata_data_files_list
for split, data_files_list in config_data_files_dict.items()
}
)
if config_metadata_data_files_list:
config_data_files_dict = DataFilesDict(
{
split: data_files_list + config_metadata_data_files_list
for split, data_files_list in config_data_files_dict.items()
}
)
ignored_params = [
param for param in config_params if not hasattr(builder_config_cls, param) and param != "default"
]
Expand Down Expand Up @@ -1260,13 +1240,12 @@ def get_module(self) -> DatasetModule:
base_path=base_path,
supports_metadata=supports_metadata,
default_builder_kwargs=default_builder_kwargs,
download_config=self.download_config,
)
else:
builder_configs, default_config_name = None, None
builder_kwargs = {
"hash": hash,
"base_path": hf_hub_url(self.name, "", revision=self.revision),
"base_path": hf_hub_url(self.name, "", revision=revision).rstrip("/"),
"repo_id": self.name,
"dataset_name": camelcase_to_snakecase(Path(self.name).name),
}
Expand All @@ -1279,7 +1258,7 @@ def get_module(self) -> DatasetModule:
try:
# this file is deprecated and was created automatically in old versions of push_to_hub
dataset_infos_path = cached_path(
hf_hub_url(self.name, config.DATASETDICT_INFOS_FILENAME, revision=self.revision),
hf_hub_url(self.name, config.DATASETDICT_INFOS_FILENAME, revision=revision),
download_config=download_config,
)
with open(dataset_infos_path, encoding="utf-8") as f:
Expand Down Expand Up @@ -1349,7 +1328,6 @@ def get_module(self) -> DatasetModule:
module_path,
metadata_configs,
supports_metadata=False,
download_config=self.download_config,
)
builder_kwargs = {
"hash": hash,
Expand Down Expand Up @@ -1496,7 +1474,7 @@ def get_module(self) -> DatasetModule:
importlib.invalidate_caches()
builder_kwargs = {
"hash": hash,
"base_path": hf_hub_url(self.name, "", revision=self.revision),
"base_path": hf_hub_url(self.name, "", revision=self.revision).rstrip("/"),
"repo_id": self.name,
}
return DatasetModule(module_path, hash, builder_kwargs)
Expand Down Expand Up @@ -2184,13 +2162,6 @@ def load_dataset_builder(
dataset_name = builder_kwargs.pop("dataset_name", None)
hash = builder_kwargs.pop("hash")
info = dataset_module.dataset_infos.get(config_name) if dataset_module.dataset_infos else None
if (
dataset_module.builder_configs_parameters.metadata_configs
and config_name in dataset_module.builder_configs_parameters.metadata_configs
):
hash = update_hash_with_config_parameters(
hash, dataset_module.builder_configs_parameters.metadata_configs[config_name]
)

if path in _PACKAGED_DATASETS_MODULES and data_files is None:
error_msg = f"Please specify the data files or data directory to load for the {path} dataset builder."
Expand Down
27 changes: 27 additions & 0 deletions tests/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from datasets.features import Features, Value
from datasets.info import DatasetInfo, PostProcessedInfo
from datasets.iterable_dataset import IterableDataset
from datasets.load import configure_builder_class
from datasets.splits import Split, SplitDict, SplitGenerator, SplitInfo
from datasets.streaming import xjoin
from datasets.utils.file_utils import is_local_path
Expand Down Expand Up @@ -830,6 +831,32 @@ def test_cache_dir_for_data_dir(self):
other_builder = DummyBuilderWithManualDownload(cache_dir=tmp_dir, config_name="a", data_dir=tmp_dir)
self.assertNotEqual(builder.cache_dir, other_builder.cache_dir)

def test_cache_dir_for_configured_builder(self):
with tempfile.TemporaryDirectory() as tmp_dir, tempfile.TemporaryDirectory() as data_dir:
builder_cls = configure_builder_class(
DummyBuilderWithManualDownload,
builder_configs=[BuilderConfig(data_dir=data_dir)],
default_config_name=None,
dataset_name="dummy",
)
other_builder_cls = configure_builder_class(
DummyBuilderWithManualDownload,
builder_configs=[BuilderConfig(data_dir=data_dir)],
default_config_name=None,
dataset_name="dummy",
)
builder = builder_cls(cache_dir=tmp_dir, hash="abc")
other_builder = other_builder_cls(cache_dir=tmp_dir, hash="abc")
self.assertEqual(builder.cache_dir, other_builder.cache_dir)
other_builder_cls = configure_builder_class(
DummyBuilderWithManualDownload,
builder_configs=[BuilderConfig(data_dir=tmp_dir)],
default_config_name=None,
dataset_name="dummy",
)
other_builder = other_builder_cls(cache_dir=tmp_dir, hash="abc")
self.assertNotEqual(builder.cache_dir, other_builder.cache_dir)


def test_arrow_based_download_and_prepare(tmp_path):
builder = DummyArrowBasedBuilder(cache_dir=tmp_path)
Expand Down
Loading