Skip to content
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
51002cb
lazy data files resolution
lhoestq Nov 29, 2023
5a5bb38
fix tests
lhoestq Nov 29, 2023
214a3e6
minor
lhoestq Nov 29, 2023
b7a9674
don't use expand_info=False yet
lhoestq Nov 29, 2023
32e0960
fix
lhoestq Nov 29, 2023
d3c0694
retrieve cached datasets that were pushed to hub
lhoestq Nov 29, 2023
8214ff2
minor
lhoestq Nov 29, 2023
68099ca
style
lhoestq Nov 29, 2023
5dd4698
Merge branch 'main' into lazy-data_files_resolution
lhoestq Nov 30, 2023
cf86d48
Merge branch 'main' into lazy-data_files_resolution
lhoestq Nov 30, 2023
b3fc428
tests
lhoestq Nov 30, 2023
796a47e
fix win test
lhoestq Nov 30, 2023
21209d2
Merge branch 'main' into lazy-data_files_resolution
lhoestq Dec 1, 2023
a924e94
fix tests
lhoestq Dec 1, 2023
adc07dd
Merge branch 'main' into lazy-data_files_resolution
lhoestq Dec 4, 2023
c9ecfca
fix tests again
lhoestq Dec 4, 2023
ddb488b
remove unused code
lhoestq Dec 4, 2023
50cec07
Merge branch 'main' into retrieve-cached-no-script-datasets
lhoestq Dec 6, 2023
a15ea9d
allow load from cache in streaming mode
lhoestq Dec 6, 2023
809b5cc
remove comment
lhoestq Dec 6, 2023
90f3d34
more tests
lhoestq Dec 7, 2023
1e3074a
fix tests
lhoestq Dec 7, 2023
5cc79dd
fix more tests
lhoestq Dec 7, 2023
ba892e1
fix tests
lhoestq Dec 7, 2023
b9d3b0a
fix tests
lhoestq Dec 8, 2023
565c294
fix cache on config change
lhoestq Dec 11, 2023
e01938d
simpler
lhoestq Dec 11, 2023
09516db
fix tests
lhoestq Dec 11, 2023
c77fb22
Merge branch 'retrieve-cached-no-script-datasets' into test-trm
lhoestq Dec 12, 2023
056b34b
Merge branch 'lazy-data_files_resolution' into test-trm
lhoestq Dec 12, 2023
46ae73d
make both PRs compatible
lhoestq Dec 12, 2023
64af455
style
lhoestq Dec 12, 2023
0050b20
fix tests
lhoestq Dec 12, 2023
c141bfa
fix tests
lhoestq Dec 12, 2023
70494f1
fix tests
lhoestq Dec 12, 2023
81b3ccf
fix test
lhoestq Dec 13, 2023
4608447
update hash when loading from parquet export too
lhoestq Dec 13, 2023
c3ddbd0
fix modify files
lhoestq Dec 13, 2023
5c9a6a2
fix base_path
lhoestq Dec 13, 2023
579c785
just use the commit sha as hash
lhoestq Dec 13, 2023
81d161f
use commit sha in parquet export dataset cache directories too
lhoestq Dec 13, 2023
e776c88
use version from parquet export dataset info
lhoestq Dec 14, 2023
d430de8
Merge branch 'main' into lazy-resolve-and-cache-reload
lhoestq Dec 15, 2023
6e4df3d
Merge branch 'main' into lazy-resolve-and-cache-reload
lhoestq Dec 19, 2023
2b88ad4
fix cache reload when config name and version are not the default ones
lhoestq Dec 19, 2023
0762ddb
fix tests
lhoestq Dec 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@
ReadInstruction,
)
from .arrow_writer import ArrowWriter, BeamWriter, ParquetWriter, SchemaInferenceError
from .data_files import DataFilesDict, sanitize_patterns
from .data_files import DataFilesDict, DataFilesPatternsDict, sanitize_patterns
from .dataset_dict import DatasetDict, IterableDatasetDict
from .download.download_config import DownloadConfig
from .download.download_manager import DownloadManager, DownloadMode
from .download.mock_download_manager import MockDownloadManager
from .download.streaming_download_manager import StreamingDownloadManager, xopen
from .download.streaming_download_manager import StreamingDownloadManager, xjoin, xopen
from .features import Features
from .filesystems import (
is_remote_filesystem,
Expand Down Expand Up @@ -128,7 +128,7 @@ class BuilderConfig:
name: str = "default"
version: Optional[Union[utils.Version, str]] = utils.Version("0.0.0")
data_dir: Optional[str] = None
data_files: Optional[DataFilesDict] = None
data_files: Optional[Union[DataFilesDict, DataFilesPatternsDict]] = None
description: Optional[str] = None

def __post_init__(self):
Expand All @@ -139,7 +139,7 @@ def __post_init__(self):
f"Bad characters from black list '{INVALID_WINDOWS_CHARACTERS_IN_PATH}' found in '{self.name}'. "
f"They could create issues when creating a directory for this config on Windows filesystem."
)
if self.data_files is not None and not isinstance(self.data_files, DataFilesDict):
if self.data_files is not None and not isinstance(self.data_files, (DataFilesDict, DataFilesPatternsDict)):
raise ValueError(f"Expected a DataFilesDict in data_files but got {self.data_files}")

def __eq__(self, o):
Expand Down Expand Up @@ -213,6 +213,11 @@ def create_config_id(
else:
return self.name

def _resolve_data_files(self, base_path: str, download_config: DownloadConfig) -> None:
if isinstance(self.data_files, DataFilesPatternsDict):
base_path = xjoin(base_path, self.data_dir) if self.data_dir else base_path
self.data_files = self.data_files.resolve(base_path, download_config)


class DatasetBuilder:
"""Abstract base class for all datasets.
Expand Down Expand Up @@ -517,7 +522,7 @@ def _create_builder_config(
builder_config = None

# try default config
if config_name is None and self.BUILDER_CONFIGS and not config_kwargs:
if config_name is None and self.BUILDER_CONFIGS:
if self.DEFAULT_CONFIG_NAME is not None:
builder_config = self.builder_configs.get(self.DEFAULT_CONFIG_NAME)
logger.info(f"No config specified, defaulting to: {self.dataset_name}/{builder_config.name}")
Expand Down Expand Up @@ -555,7 +560,7 @@ def _create_builder_config(

# otherwise use the config_kwargs to overwrite the attributes
else:
builder_config = copy.deepcopy(builder_config)
builder_config = copy.deepcopy(builder_config) if config_kwargs else builder_config
for key, value in config_kwargs.items():
if value is not None:
if not hasattr(builder_config, key):
Expand All @@ -565,6 +570,12 @@ def _create_builder_config(
if not builder_config.name:
raise ValueError(f"BuilderConfig must have a name, got {builder_config.name}")

# resolve data files if needed
builder_config._resolve_data_files(
base_path=self.base_path,
download_config=DownloadConfig(token=self.token, storage_options=self.storage_options),
)

# compute the config id that is going to be used for caching
config_id = builder_config.create_config_id(
config_kwargs,
Expand All @@ -590,7 +601,7 @@ def _create_builder_config(
@classproperty
@classmethod
@memoize()
def builder_configs(cls):
def builder_configs(cls) -> Dict[str, BuilderConfig]:
"""Dictionary of pre-defined configurations for this builder class."""
configs = {config.name: config for config in cls.BUILDER_CONFIGS}
if len(configs) != len(cls.BUILDER_CONFIGS):
Expand Down
91 changes: 91 additions & 0 deletions src/datasets/data_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,3 +698,94 @@ def filter_extensions(self, extensions: List[str]) -> "DataFilesDict":
for key, data_files_list in self.items():
out[key] = data_files_list.filter_extensions(extensions)
return out


class DataFilesPatternsList(List[str]):
"""
List of data files patterns (absolute local paths or URLs).
For each pattern there should also be a list of allowed extensions
to keep, or a None ot keep all the files for the pattern.
"""

def __init__(
self,
patterns: List[str],
allowed_extensions: List[Optional[List[str]]],
):
super().__init__(patterns)
self.allowed_extensions = allowed_extensions

def __add__(self, other):
return DataFilesList([*self, *other], self.allowed_extensions + other.allowed_extensions)

@classmethod
def from_patterns(
cls, patterns: List[str], allowed_extensions: Optional[List[str]] = None
) -> "DataFilesPatternsDict":
return cls(patterns, [allowed_extensions] * len(patterns))

def resolve(
self,
base_path: str,
download_config: Optional[DownloadConfig] = None,
) -> "DataFilesList":
base_path = base_path if base_path is not None else Path().resolve().as_posix()
data_files = []
for pattern, allowed_extensions in zip(self, self.allowed_extensions):
try:
data_files.extend(
resolve_pattern(
pattern,
base_path=base_path,
allowed_extensions=allowed_extensions,
download_config=download_config,
)
)
except FileNotFoundError:
if not has_magic(pattern):
raise
origin_metadata = _get_origin_metadata(data_files, download_config=download_config)
return DataFilesList(data_files, origin_metadata)

def filter_extensions(self, extensions: List[str]) -> "DataFilesList":
return DataFilesPatternsList(
self, [allowed_extensions + extensions for allowed_extensions in self.allowed_extensions]
)


class DataFilesPatternsDict(Dict[str, DataFilesPatternsList]):
"""
Dict of split_name -> list of data files patterns (absolute local paths or URLs).
"""

@classmethod
def from_patterns(
cls, patterns: Dict[str, List[str]], allowed_extensions: Optional[List[str]] = None
) -> "DataFilesPatternsDict":
out = cls()
for key, patterns_for_key in patterns.items():
out[key] = (
DataFilesPatternsList.from_patterns(
patterns_for_key,
allowed_extensions=allowed_extensions,
)
if not isinstance(patterns_for_key, DataFilesPatternsList)
else patterns_for_key
)
return out

def resolve(
self,
base_path: str,
download_config: Optional[DownloadConfig] = None,
) -> "DataFilesDict":
out = DataFilesDict()
for key, data_files_patterns_list in self.items():
out[key] = data_files_patterns_list.resolve(base_path, download_config)
return out

def filter_extensions(self, extensions: List[str]) -> "DataFilesPatternsDict":
out = type(self)()
for key, data_files_patterns_list in self.items():
out[key] = data_files_patterns_list.filter_extensions(extensions)
return out
Loading