diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 93ce534a928..0f758258e82 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -46,12 +46,12 @@ ReadInstruction, ) from .arrow_writer import ArrowWriter, BeamWriter, ParquetWriter, SchemaInferenceError -from .data_files import DataFilesDict, sanitize_patterns +from .data_files import DataFilesDict, DataFilesPatternsDict, sanitize_patterns from .dataset_dict import DatasetDict, IterableDatasetDict from .download.download_config import DownloadConfig from .download.download_manager import DownloadManager, DownloadMode from .download.mock_download_manager import MockDownloadManager -from .download.streaming_download_manager import StreamingDownloadManager, xopen +from .download.streaming_download_manager import StreamingDownloadManager, xjoin, xopen from .exceptions import DatasetGenerationCastError, DatasetGenerationError, FileFormatError, ManualDownloadError from .features import Features from .filesystems import ( @@ -115,7 +115,7 @@ class BuilderConfig: name: str = "default" version: Optional[Union[utils.Version, str]] = utils.Version("0.0.0") data_dir: Optional[str] = None - data_files: Optional[DataFilesDict] = None + data_files: Optional[Union[DataFilesDict, DataFilesPatternsDict]] = None description: Optional[str] = None def __post_init__(self): @@ -126,7 +126,7 @@ def __post_init__(self): f"Bad characters from black list '{INVALID_WINDOWS_CHARACTERS_IN_PATH}' found in '{self.name}'. " f"They could create issues when creating a directory for this config on Windows filesystem." ) - if self.data_files is not None and not isinstance(self.data_files, DataFilesDict): + if self.data_files is not None and not isinstance(self.data_files, (DataFilesDict, DataFilesPatternsDict)): raise ValueError(f"Expected a DataFilesDict in data_files but got {self.data_files}") def __eq__(self, o): @@ -200,6 +200,11 @@ def create_config_id( else: return self.name + def _resolve_data_files(self, base_path: str, download_config: DownloadConfig) -> None: + if isinstance(self.data_files, DataFilesPatternsDict): + base_path = xjoin(base_path, self.data_dir) if self.data_dir else base_path + self.data_files = self.data_files.resolve(base_path, download_config) + class DatasetBuilder: """Abstract base class for all datasets. @@ -504,7 +509,7 @@ def _create_builder_config( builder_config = None # try default config - if config_name is None and self.BUILDER_CONFIGS and not config_kwargs: + if config_name is None and self.BUILDER_CONFIGS: if self.DEFAULT_CONFIG_NAME is not None: builder_config = self.builder_configs.get(self.DEFAULT_CONFIG_NAME) logger.info(f"No config specified, defaulting to: {self.dataset_name}/{builder_config.name}") @@ -542,7 +547,7 @@ def _create_builder_config( # otherwise use the config_kwargs to overwrite the attributes else: - builder_config = copy.deepcopy(builder_config) + builder_config = copy.deepcopy(builder_config) if config_kwargs else builder_config for key, value in config_kwargs.items(): if value is not None: if not hasattr(builder_config, key): @@ -552,6 +557,12 @@ def _create_builder_config( if not builder_config.name: raise ValueError(f"BuilderConfig must have a name, got {builder_config.name}") + # resolve data files if needed + builder_config._resolve_data_files( + base_path=self.base_path, + download_config=DownloadConfig(token=self.token, storage_options=self.storage_options), + ) + # compute the config id that is going to be used for caching config_id = builder_config.create_config_id( config_kwargs, @@ -577,7 +588,7 @@ def _create_builder_config( @classproperty @classmethod @memoize() - def builder_configs(cls): + def builder_configs(cls) -> Dict[str, BuilderConfig]: """Dictionary of pre-defined configurations for this builder class.""" configs = {config.name: config for config in cls.BUILDER_CONFIGS} if len(configs) != len(cls.BUILDER_CONFIGS): diff --git a/src/datasets/data_files.py b/src/datasets/data_files.py index 7999aa0a4dc..fa8c814de0d 100644 --- a/src/datasets/data_files.py +++ b/src/datasets/data_files.py @@ -702,3 +702,94 @@ def filter_extensions(self, extensions: List[str]) -> "DataFilesDict": for key, data_files_list in self.items(): out[key] = data_files_list.filter_extensions(extensions) return out + + +class DataFilesPatternsList(List[str]): + """ + List of data files patterns (absolute local paths or URLs). + For each pattern there should also be a list of allowed extensions + to keep, or a None ot keep all the files for the pattern. + """ + + def __init__( + self, + patterns: List[str], + allowed_extensions: List[Optional[List[str]]], + ): + super().__init__(patterns) + self.allowed_extensions = allowed_extensions + + def __add__(self, other): + return DataFilesList([*self, *other], self.allowed_extensions + other.allowed_extensions) + + @classmethod + def from_patterns( + cls, patterns: List[str], allowed_extensions: Optional[List[str]] = None + ) -> "DataFilesPatternsDict": + return cls(patterns, [allowed_extensions] * len(patterns)) + + def resolve( + self, + base_path: str, + download_config: Optional[DownloadConfig] = None, + ) -> "DataFilesList": + base_path = base_path if base_path is not None else Path().resolve().as_posix() + data_files = [] + for pattern, allowed_extensions in zip(self, self.allowed_extensions): + try: + data_files.extend( + resolve_pattern( + pattern, + base_path=base_path, + allowed_extensions=allowed_extensions, + download_config=download_config, + ) + ) + except FileNotFoundError: + if not has_magic(pattern): + raise + origin_metadata = _get_origin_metadata(data_files, download_config=download_config) + return DataFilesList(data_files, origin_metadata) + + def filter_extensions(self, extensions: List[str]) -> "DataFilesList": + return DataFilesPatternsList( + self, [allowed_extensions + extensions for allowed_extensions in self.allowed_extensions] + ) + + +class DataFilesPatternsDict(Dict[str, DataFilesPatternsList]): + """ + Dict of split_name -> list of data files patterns (absolute local paths or URLs). + """ + + @classmethod + def from_patterns( + cls, patterns: Dict[str, List[str]], allowed_extensions: Optional[List[str]] = None + ) -> "DataFilesPatternsDict": + out = cls() + for key, patterns_for_key in patterns.items(): + out[key] = ( + DataFilesPatternsList.from_patterns( + patterns_for_key, + allowed_extensions=allowed_extensions, + ) + if not isinstance(patterns_for_key, DataFilesPatternsList) + else patterns_for_key + ) + return out + + def resolve( + self, + base_path: str, + download_config: Optional[DownloadConfig] = None, + ) -> "DataFilesDict": + out = DataFilesDict() + for key, data_files_patterns_list in self.items(): + out[key] = data_files_patterns_list.resolve(base_path, download_config) + return out + + def filter_extensions(self, extensions: List[str]) -> "DataFilesPatternsDict": + out = type(self)() + for key, data_files_patterns_list in self.items(): + out[key] = data_files_patterns_list.filter_extensions(extensions) + return out diff --git a/src/datasets/load.py b/src/datasets/load.py index a0c5848176f..75112b8bc82 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -15,6 +15,7 @@ # Lint as: python3 """Access datasets.""" import filecmp +import glob import importlib import inspect import json @@ -40,6 +41,8 @@ DEFAULT_PATTERNS_ALL, DataFilesDict, DataFilesList, + DataFilesPatternsDict, + DataFilesPatternsList, EmptyDatasetError, get_data_patterns, get_metadata_patterns, @@ -589,28 +592,12 @@ def infer_module_for_data_files( return module_name, default_builder_kwargs -def update_hash_with_config_parameters(hash: str, config_parameters: dict) -> str: - """ - Used to update hash of packaged modules which is used for creating unique cache directories to reflect - different config parameters which are passed in metadata from readme. - """ - params_to_exclude = {"config_name", "version", "description"} - params_to_add_to_hash = { - param: value for param, value in sorted(config_parameters.items()) if param not in params_to_exclude - } - m = Hasher() - m.update(hash) - m.update(params_to_add_to_hash) - return m.hexdigest() - - def create_builder_configs_from_metadata_configs( module_path: str, metadata_configs: MetadataConfigs, supports_metadata: bool, base_path: Optional[str] = None, default_builder_kwargs: Dict[str, Any] = None, - download_config: Optional[DownloadConfig] = None, ) -> Tuple[List[BuilderConfig], str]: builder_cls = import_main_class(module_path) builder_config_cls = builder_cls.BUILDER_CONFIG_CLASS @@ -622,18 +609,16 @@ def create_builder_configs_from_metadata_configs( for config_name, config_params in metadata_configs.items(): config_data_files = config_params.get("data_files") config_data_dir = config_params.get("data_dir") - config_base_path = base_path + "/" + config_data_dir if config_data_dir else base_path + config_base_path = xjoin(base_path, config_data_dir) if config_data_dir else base_path try: config_patterns = ( sanitize_patterns(config_data_files) if config_data_files is not None else get_data_patterns(config_base_path) ) - config_data_files_dict = DataFilesDict.from_patterns( + config_data_files_dict = DataFilesPatternsDict.from_patterns( config_patterns, - base_path=config_base_path, allowed_extensions=ALL_ALLOWED_EXTENSIONS, - download_config=download_config, ) except EmptyDatasetError as e: raise EmptyDatasetError( @@ -646,16 +631,13 @@ def create_builder_configs_from_metadata_configs( except FileNotFoundError: config_metadata_patterns = None if config_metadata_patterns is not None: - config_metadata_data_files_list = DataFilesList.from_patterns( - config_metadata_patterns, base_path=base_path + config_metadata_data_files_list = DataFilesPatternsList.from_patterns(config_metadata_patterns) + config_data_files_dict = DataFilesPatternsDict( + { + split: data_files_list + config_metadata_data_files_list + for split, data_files_list in config_data_files_dict.items() + } ) - if config_metadata_data_files_list: - config_data_files_dict = DataFilesDict( - { - split: data_files_list + config_metadata_data_files_list - for split, data_files_list in config_data_files_dict.items() - } - ) ignored_params = [ param for param in config_params if not hasattr(builder_config_cls, param) and param != "default" ] @@ -995,7 +977,7 @@ def get_module(self) -> DatasetModule: # make the new module to be noticed by the import system importlib.invalidate_caches() - builder_kwargs = {"hash": hash, "base_path": str(Path(self.path).parent)} + builder_kwargs = {"base_path": str(Path(self.path).parent)} return DatasetModule(module_path, hash, builder_kwargs) @@ -1060,7 +1042,7 @@ def get_module(self) -> DatasetModule: } ) - module_path, hash = _PACKAGED_DATASETS_MODULES[module_name] + module_path, _ = _PACKAGED_DATASETS_MODULES[module_name] if metadata_configs: builder_configs, default_config_name = create_builder_configs_from_metadata_configs( module_path, @@ -1070,15 +1052,17 @@ def get_module(self) -> DatasetModule: default_builder_kwargs=default_builder_kwargs, ) else: - builder_configs, default_config_name = None, None + builder_configs: List[BuilderConfig] = [ + import_main_class(module_path).BUILDER_CONFIG_CLASS( + data_files=data_files, + **default_builder_kwargs, + ) + ] + default_config_name = None builder_kwargs = { - "hash": hash, "base_path": self.path, "dataset_name": camelcase_to_snakecase(Path(self.path).name), } - if self.data_files is not None or not metadata_configs: - builder_kwargs["data_files"] = data_files - builder_kwargs.update(default_builder_kwargs) # from _EXTENSION_TO_MODULE # this file is deprecated and was created automatically in old versions of push_to_hub if os.path.isfile(os.path.join(self.path, config.DATASETDICT_INFOS_FILENAME)): with open(os.path.join(self.path, config.DATASETDICT_INFOS_FILENAME), encoding="utf-8") as f: @@ -1097,6 +1081,7 @@ def get_module(self) -> DatasetModule: if default_config_name is None and len(dataset_infos) == 1: default_config_name = next(iter(dataset_infos)) + hash = Hasher.hash({"dataset_infos": dataset_infos, "builder_configs": builder_configs}) return DatasetModule( module_path, hash, @@ -1157,7 +1142,6 @@ def get_module(self) -> DatasetModule: module_path, hash = _PACKAGED_DATASETS_MODULES[self.name] builder_kwargs = { - "hash": hash, "data_files": data_files, "dataset_name": self.name, } @@ -1252,7 +1236,7 @@ def get_module(self) -> DatasetModule: } ) - module_path, hash = _PACKAGED_DATASETS_MODULES[module_name] + module_path, _ = _PACKAGED_DATASETS_MODULES[module_name] if metadata_configs: builder_configs, default_config_name = create_builder_configs_from_metadata_configs( module_path, @@ -1260,26 +1244,27 @@ def get_module(self) -> DatasetModule: base_path=base_path, supports_metadata=supports_metadata, default_builder_kwargs=default_builder_kwargs, - download_config=self.download_config, ) else: - builder_configs, default_config_name = None, None + builder_configs: List[BuilderConfig] = [ + import_main_class(module_path).BUILDER_CONFIG_CLASS( + data_files=data_files, + **default_builder_kwargs, + ) + ] + default_config_name = None builder_kwargs = { - "hash": hash, - "base_path": hf_hub_url(self.name, "", revision=self.revision), + "base_path": hf_hub_url(self.name, "", revision=revision).rstrip("/"), "repo_id": self.name, "dataset_name": camelcase_to_snakecase(Path(self.name).name), } - if self.data_files is not None or not metadata_configs: - builder_kwargs["data_files"] = data_files - builder_kwargs.update(default_builder_kwargs) # from _EXTENSION_TO_MODULE download_config = self.download_config.copy() if download_config.download_desc is None: download_config.download_desc = "Downloading metadata" try: # this file is deprecated and was created automatically in old versions of push_to_hub dataset_infos_path = cached_path( - hf_hub_url(self.name, config.DATASETDICT_INFOS_FILENAME, revision=self.revision), + hf_hub_url(self.name, config.DATASETDICT_INFOS_FILENAME, revision=revision), download_config=download_config, ) with open(dataset_infos_path, encoding="utf-8") as f: @@ -1300,6 +1285,7 @@ def get_module(self) -> DatasetModule: if default_config_name is None and len(dataset_infos) == 1: default_config_name = next(iter(dataset_infos)) + hash = revision return DatasetModule( module_path, hash, @@ -1336,27 +1322,30 @@ def get_module(self) -> DatasetModule: exported_dataset_infos = _datasets_server.get_exported_dataset_infos( dataset=self.name, revision=self.revision, token=self.download_config.token ) + dataset_infos = DatasetInfosDict( + { + config_name: DatasetInfo.from_dict(exported_dataset_infos[config_name]) + for config_name in exported_dataset_infos + } + ) hfh_dataset_info = HfApi(config.HF_ENDPOINT).dataset_info( self.name, revision="refs/convert/parquet", token=self.download_config.token, timeout=100.0, ) - # even if metadata_configs is not None (which means that we will resolve files for each config later) - # we cannot skip resolving all files because we need to infer module name by files extensions revision = hfh_dataset_info.sha # fix the revision in case there are new commits in the meantime - metadata_configs = MetadataConfigs._from_exported_parquet_files( - revision=revision, exported_parquet_files=exported_parquet_files + metadata_configs = MetadataConfigs._from_exported_parquet_files_and_dataset_infos( + revision=revision, exported_parquet_files=exported_parquet_files, dataset_infos=dataset_infos ) - module_path, hash = _PACKAGED_DATASETS_MODULES["parquet"] + module_path, _ = _PACKAGED_DATASETS_MODULES["parquet"] builder_configs, default_config_name = create_builder_configs_from_metadata_configs( module_path, metadata_configs, supports_metadata=False, - download_config=self.download_config, ) + hash = self.revision builder_kwargs = { - "hash": hash, "repo_id": self.name, "dataset_name": camelcase_to_snakecase(Path(self.name).name), } @@ -1365,12 +1354,7 @@ def get_module(self) -> DatasetModule: module_path, hash, builder_kwargs, - dataset_infos=DatasetInfosDict( - { - config_name: DatasetInfo.from_dict(exported_dataset_infos[config_name]) - for config_name in exported_dataset_infos - } - ), + dataset_infos=dataset_infos, builder_configs_parameters=BuilderConfigsParameters( metadata_configs=metadata_configs, builder_configs=builder_configs, @@ -1499,8 +1483,7 @@ def get_module(self) -> DatasetModule: # make the new module to be noticed by the import system importlib.invalidate_caches() builder_kwargs = { - "hash": hash, - "base_path": hf_hub_url(self.name, "", revision=self.revision), + "base_path": hf_hub_url(self.name, "", revision=self.revision).rstrip("/"), "repo_id": self.name, } return DatasetModule(module_path, hash, builder_kwargs) @@ -1515,9 +1498,11 @@ class CachedDatasetModuleFactory(_DatasetModuleFactory): def __init__( self, name: str, + cache_dir: Optional[str] = None, dynamic_modules_path: Optional[str] = None, ): self.name = name + self.cache_dir = cache_dir self.dynamic_modules_path = dynamic_modules_path assert self.name.count("/") <= 1 @@ -1529,38 +1514,61 @@ def get_module(self) -> DatasetModule: if os.path.isdir(importable_directory_path) else None ) - if not hashes: - raise FileNotFoundError(f"Dataset {self.name} is not cached in {dynamic_modules_path}") - # get most recent - - def _get_modification_time(module_hash): - return (Path(importable_directory_path) / module_hash / (self.name.split("/")[-1] + ".py")).stat().st_mtime + if hashes: + # get most recent + def _get_modification_time(module_hash): + return ( + (Path(importable_directory_path) / module_hash / (self.name.split("/")[-1] + ".py")) + .stat() + .st_mtime + ) - hash = sorted(hashes, key=_get_modification_time)[-1] - warning_msg = ( - f"Using the latest cached version of the module from {os.path.join(importable_directory_path, hash)} " - f"(last modified on {time.ctime(_get_modification_time(hash))}) since it " - f"couldn't be found locally at {self.name}." - ) - if not config.HF_DATASETS_OFFLINE: - warning_msg += ", or remotely on the Hugging Face Hub." - logger.warning(warning_msg) - # make the new module to be noticed by the import system - module_path = ".".join( - [ - os.path.basename(dynamic_modules_path), - "datasets", - self.name.replace("/", "--"), - hash, - self.name.split("/")[-1], - ] - ) - importlib.invalidate_caches() - builder_kwargs = { - "hash": hash, - "repo_id": self.name, - } - return DatasetModule(module_path, hash, builder_kwargs) + hash = sorted(hashes, key=_get_modification_time)[-1] + warning_msg = ( + f"Using the latest cached version of the module from {os.path.join(importable_directory_path, hash)} " + f"(last modified on {time.ctime(_get_modification_time(hash))}) since it " + f"couldn't be found locally at {self.name}" + ) + if not config.HF_DATASETS_OFFLINE: + warning_msg += ", or remotely on the Hugging Face Hub." + logger.warning(warning_msg) + # make the new module to be noticed by the import system + module_path = ".".join( + [ + os.path.basename(dynamic_modules_path), + "datasets", + self.name.replace("/", "--"), + hash, + self.name.split("/")[-1], + ] + ) + importlib.invalidate_caches() + builder_kwargs = { + "repo_id": self.name, + } + return DatasetModule(module_path, hash, builder_kwargs) + cache_dir = os.path.expanduser(str(self.cache_dir or config.HF_DATASETS_CACHE)) + cached_datasets_directory_path_root = os.path.join(cache_dir, self.name.replace("/", "___")) + cached_directory_paths = [ + cached_directory_path + for cached_directory_path in glob.glob(os.path.join(cached_datasets_directory_path_root, "*", "*", "*")) + if os.path.isdir(cached_directory_path) + ] + if cached_directory_paths: + builder_kwargs = { + "repo_id": self.name, + "dataset_name": self.name.split("/")[-1], + } + warning_msg = f"Using the latest cached version of the dataset since {self.name} couldn't be found on the Hugging Face Hub" + if config.HF_DATASETS_OFFLINE: + warning_msg += " (offline mode is enabled)." + logger.warning(warning_msg) + return DatasetModule( + "datasets.packaged_modules.cache.cache", + "auto", + {**builder_kwargs, "version": "auto"}, + ) + raise FileNotFoundError(f"Dataset {self.name} is not cached in {self.cache_dir}") class CachedMetricModuleFactory(_MetricModuleFactory): @@ -1620,6 +1628,7 @@ def dataset_module_factory( dynamic_modules_path: Optional[str] = None, data_dir: Optional[str] = None, data_files: Optional[Union[Dict, List, str, DataFilesDict]] = None, + cache_dir: Optional[str] = None, trust_remote_code: Optional[bool] = None, _require_default_config_name=True, **download_kwargs, @@ -1663,6 +1672,10 @@ def dataset_module_factory( data_dir (:obj:`str`, optional): Directory with the data files. Used only if `data_files` is not specified, in which case it's equal to pass `os.path.join(data_dir, "**")` as `data_files`. data_files (:obj:`Union[Dict, List, str]`, optional): Defining the data_files of the dataset configuration. + cache_dir (`str`, *optional*): + Directory to read/write data. Defaults to `"~/.cache/huggingface/datasets"`. + + trust_remote_code (`bool`, defaults to `True`): Whether or not to allow for datasets defined on the Hub using a dataset script. This option should only be set to `True` for repositories you trust and in which you have read the code, as it will @@ -1810,7 +1823,9 @@ def dataset_module_factory( except Exception as e1: # All the attempts failed, before raising the error we should check if the module is already cached try: - return CachedDatasetModuleFactory(path, dynamic_modules_path=dynamic_modules_path).get_module() + return CachedDatasetModuleFactory( + path, dynamic_modules_path=dynamic_modules_path, cache_dir=cache_dir + ).get_module() except Exception: # If it's not in the cache, then it doesn't exist. if isinstance(e1, OfflineModeIsEnabled): @@ -2178,6 +2193,7 @@ def load_dataset_builder( download_mode=download_mode, data_dir=data_dir, data_files=data_files, + cache_dir=cache_dir, trust_remote_code=trust_remote_code, _require_default_config_name=_require_default_config_name, ) @@ -2189,17 +2205,13 @@ def load_dataset_builder( "config_name", name or dataset_module.builder_configs_parameters.default_config_name ) dataset_name = builder_kwargs.pop("dataset_name", None) - hash = builder_kwargs.pop("hash") info = dataset_module.dataset_infos.get(config_name) if dataset_module.dataset_infos else None + if ( - dataset_module.builder_configs_parameters.metadata_configs - and config_name in dataset_module.builder_configs_parameters.metadata_configs + path in _PACKAGED_DATASETS_MODULES + and data_files is None + and dataset_module.builder_configs_parameters.builder_configs[0].data_files is None ): - hash = update_hash_with_config_parameters( - hash, dataset_module.builder_configs_parameters.metadata_configs[config_name] - ) - - if path in _PACKAGED_DATASETS_MODULES and data_files is None: error_msg = f"Please specify the data files or data directory to load for the {path} dataset builder." example_extensions = [ extension for extension in _EXTENSION_TO_MODULE if _EXTENSION_TO_MODULE[extension] == path @@ -2216,7 +2228,7 @@ def load_dataset_builder( config_name=config_name, data_dir=data_dir, data_files=data_files, - hash=hash, + hash=dataset_module.hash, info=info, features=features, token=token, diff --git a/src/datasets/packaged_modules/__init__.py b/src/datasets/packaged_modules/__init__.py index a9fe30d7f5d..91181c4f550 100644 --- a/src/datasets/packaged_modules/__init__.py +++ b/src/datasets/packaged_modules/__init__.py @@ -6,6 +6,7 @@ from .arrow import arrow from .audiofolder import audiofolder +from .cache import cache # noqa F401 from .csv import csv from .imagefolder import imagefolder from .json import json diff --git a/src/datasets/packaged_modules/cache/__init__.py b/src/datasets/packaged_modules/cache/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/datasets/packaged_modules/cache/cache.py b/src/datasets/packaged_modules/cache/cache.py new file mode 100644 index 00000000000..a378df23165 --- /dev/null +++ b/src/datasets/packaged_modules/cache/cache.py @@ -0,0 +1,149 @@ +import glob +import os +import shutil +import time +from pathlib import Path +from typing import List, Optional, Tuple + +import pyarrow as pa + +import datasets +import datasets.config +from datasets.naming import filenames_for_dataset_split + + +logger = datasets.utils.logging.get_logger(__name__) + + +def _get_modification_time(cached_directory_path): + return (Path(cached_directory_path)).stat().st_mtime + + +def _find_hash_in_cache( + dataset_name: str, config_name: Optional[str], cache_dir: Optional[str] +) -> Tuple[str, str, str]: + cache_dir = os.path.expanduser(str(cache_dir or datasets.config.HF_DATASETS_CACHE)) + cached_datasets_directory_path_root = os.path.join(cache_dir, dataset_name.replace("/", "___")) + cached_directory_paths = [ + cached_directory_path + for cached_directory_path in glob.glob( + os.path.join(cached_datasets_directory_path_root, config_name or "*", "*", "*") + ) + if os.path.isdir(cached_directory_path) + ] + if not cached_directory_paths: + if config_name is not None: + cached_directory_paths = [ + cached_directory_path + for cached_directory_path in glob.glob( + os.path.join(cached_datasets_directory_path_root, "*", "*", "*") + ) + if os.path.isdir(cached_directory_path) + ] + available_configs = sorted( + {Path(cached_directory_path).parts[-3] for cached_directory_path in cached_directory_paths} + ) + raise ValueError( + f"Couldn't find cache for {dataset_name}" + + (f" for config '{config_name}'" if config_name else "") + + (f"\nAvailable configs in the cache: {available_configs}" if available_configs else "") + ) + # get most recent + cached_directory_path = Path(sorted(cached_directory_paths, key=_get_modification_time)[-1]) + version, hash = cached_directory_path.parts[-2:] + other_configs = [ + Path(cached_directory_path).parts[-3] + for cached_directory_path in glob.glob(os.path.join(cached_datasets_directory_path_root, "*", version, hash)) + if os.path.isdir(cached_directory_path) + ] + if not config_name and len(other_configs) > 1: + raise ValueError( + f"There are multiple '{dataset_name}' configurations in the cache: {', '.join(other_configs)}" + f"\nPlease specify which configuration to reload from the cache, e.g." + f"\n\tload_dataset('{dataset_name}', '{other_configs[0]}')" + ) + config_name = cached_directory_path.parts[-3] + warning_msg = ( + f"Found the latest cached dataset configuration '{config_name}' at {cached_directory_path} " + f"(last modified on {time.ctime(_get_modification_time(cached_directory_path))})." + ) + logger.warning(warning_msg) + return config_name, version, hash + + +class Cache(datasets.ArrowBasedBuilder): + def __init__( + self, + cache_dir: Optional[str] = None, + dataset_name: Optional[str] = None, + config_name: Optional[str] = None, + version: Optional[str] = "0.0.0", + hash: Optional[str] = None, + repo_id: Optional[str] = None, + **kwargs, + ): + if repo_id is None and dataset_name is None: + raise ValueError("repo_id or dataset_name is required for the Cache dataset builder") + if hash == "auto" and version == "auto": + config_name, version, hash = _find_hash_in_cache( + dataset_name=repo_id or dataset_name, + config_name=config_name, + cache_dir=cache_dir, + ) + elif hash == "auto" or version == "auto": + raise NotImplementedError("Pass both hash='auto' and version='auto' instead") + super().__init__( + cache_dir=cache_dir, + dataset_name=dataset_name, + config_name=config_name, + version=version, + hash=hash, + repo_id=repo_id, + **kwargs, + ) + + def _info(self) -> datasets.DatasetInfo: + return datasets.DatasetInfo() + + def download_and_prepare(self, output_dir: Optional[str] = None, *args, **kwargs): + if not os.path.exists(self.cache_dir): + raise ValueError(f"Cache directory for {self.dataset_name} doesn't exist at {self.cache_dir}") + if output_dir is not None and output_dir != self.cache_dir: + shutil.copytree(self.cache_dir, output_dir) + + def _split_generators(self, dl_manager): + # used to stream from cache + if isinstance(self.info.splits, datasets.SplitDict): + split_infos: List[datasets.SplitInfo] = list(self.info.splits.values()) + else: + raise ValueError(f"Missing splits info for {self.dataset_name} in cache directory {self.cache_dir}") + return [ + datasets.SplitGenerator( + name=split_info.name, + gen_kwargs={ + "files": filenames_for_dataset_split( + self.cache_dir, + dataset_name=self.dataset_name, + split=split_info.name, + filetype_suffix="arrow", + shard_lengths=split_info.shard_lengths, + ) + }, + ) + for split_info in split_infos + ] + + def _generate_tables(self, files): + # used to stream from cache + for file_idx, file in enumerate(files): + with open(file, "rb") as f: + try: + for batch_idx, record_batch in enumerate(pa.ipc.open_stream(f)): + pa_table = pa.Table.from_batches([record_batch]) + # Uncomment for debugging (will print the Arrow table size and elements) + # logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}") + # logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows))) + yield f"{file_idx}_{batch_idx}", pa_table + except ValueError as e: + logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}") + raise diff --git a/src/datasets/utils/metadata.py b/src/datasets/utils/metadata.py index 44b18d97919..4d2d6b059c7 100644 --- a/src/datasets/utils/metadata.py +++ b/src/datasets/utils/metadata.py @@ -9,6 +9,7 @@ from huggingface_hub import DatasetCardData from ..config import METADATA_CONFIGS_FIELD +from ..info import DatasetInfo, DatasetInfosDict from ..utils.logging import get_logger from .deprecation_utils import deprecated @@ -172,8 +173,11 @@ def _raise_if_data_files_field_not_valid(metadata_config: dict): raise ValueError(yaml_error_message) @classmethod - def _from_exported_parquet_files( - cls, revision: str, exported_parquet_files: List[Dict[str, Any]] + def _from_exported_parquet_files_and_dataset_infos( + cls, + revision: str, + exported_parquet_files: List[Dict[str, Any]], + dataset_infos: DatasetInfosDict, ) -> "MetadataConfigs": return cls( { @@ -189,7 +193,8 @@ def _from_exported_parquet_files( for split_name, parquet_files_for_split in groupby( parquet_files_for_config, itemgetter("split") ) - ] + ], + "version": str(dataset_infos.get(config_name, DatasetInfo()).version or "0.0.0"), } for config_name, parquet_files_for_config in groupby(exported_parquet_files, itemgetter("config")) } diff --git a/tests/fixtures/files.py b/tests/fixtures/files.py index da8d3efee48..76bf830676d 100644 --- a/tests/fixtures/files.py +++ b/tests/fixtures/files.py @@ -471,6 +471,16 @@ def text2_path(tmp_path_factory): return path +@pytest.fixture(scope="session") +def text_dir(tmp_path_factory): + data = ["0", "1", "2", "3"] + path = tmp_path_factory.mktemp("data_text_dir") / "dataset.txt" + with open(path, "w") as f: + for item in data: + f.write(item + "\n") + return path.parent + + @pytest.fixture(scope="session") def text_dir_with_unsupported_extension(tmp_path_factory): data = ["0", "1", "2", "3"] diff --git a/tests/packaged_modules/test_cache.py b/tests/packaged_modules/test_cache.py new file mode 100644 index 00000000000..8d37fb0a2d9 --- /dev/null +++ b/tests/packaged_modules/test_cache.py @@ -0,0 +1,61 @@ +from pathlib import Path + +import pytest + +from datasets import load_dataset +from datasets.packaged_modules.cache.cache import Cache + + +SAMPLE_DATASET_TWO_CONFIG_IN_METADATA = "hf-internal-testing/audiofolder_two_configs_in_metadata" + + +def test_cache(text_dir: Path): + ds = load_dataset(str(text_dir)) + hash = Path(ds["train"].cache_files[0]["filename"]).parts[-2] + cache = Cache(dataset_name=text_dir.name, hash=hash) + reloaded = cache.as_dataset() + assert list(ds) == list(reloaded) + assert list(ds["train"]) == list(reloaded["train"]) + + +def test_cache_streaming(text_dir: Path): + ds = load_dataset(str(text_dir)) + hash = Path(ds["train"].cache_files[0]["filename"]).parts[-2] + cache = Cache(dataset_name=text_dir.name, hash=hash) + reloaded = cache.as_streaming_dataset() + assert list(ds) == list(reloaded) + assert list(ds["train"]) == list(reloaded["train"]) + + +def test_cache_auto_hash(text_dir: Path): + ds = load_dataset(str(text_dir)) + cache = Cache(dataset_name=text_dir.name, version="auto", hash="auto") + reloaded = cache.as_dataset() + assert list(ds) == list(reloaded) + assert list(ds["train"]) == list(reloaded["train"]) + + +def test_cache_missing(text_dir: Path): + load_dataset(str(text_dir)) + Cache(dataset_name=text_dir.name, version="auto", hash="auto").download_and_prepare() + with pytest.raises(ValueError): + Cache(dataset_name="missing", version="auto", hash="auto").download_and_prepare() + with pytest.raises(ValueError): + Cache(dataset_name=text_dir.name, hash="missing").download_and_prepare() + with pytest.raises(ValueError): + Cache(dataset_name=text_dir.name, config_name="missing", version="auto", hash="auto").download_and_prepare() + + +@pytest.mark.integration +def test_cache_multi_configs(): + repo_id = SAMPLE_DATASET_TWO_CONFIG_IN_METADATA + dataset_name = repo_id.split("/")[-1] + config_name = "v1" + ds = load_dataset(repo_id, config_name) + cache = Cache(dataset_name=dataset_name, repo_id=repo_id, config_name=config_name, version="auto", hash="auto") + reloaded = cache.as_dataset() + assert list(ds) == list(reloaded) + assert len(ds["train"]) == len(reloaded["train"]) + with pytest.raises(ValueError) as excinfo: + Cache(dataset_name=dataset_name, repo_id=repo_id, config_name="missing", version="auto", hash="auto") + assert config_name in str(excinfo.value) diff --git a/tests/test_builder.py b/tests/test_builder.py index 54bae47ae08..3722a728147 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -23,6 +23,7 @@ from datasets.features import Features, Value from datasets.info import DatasetInfo, PostProcessedInfo from datasets.iterable_dataset import IterableDataset +from datasets.load import configure_builder_class from datasets.splits import Split, SplitDict, SplitGenerator, SplitInfo from datasets.streaming import xjoin from datasets.utils.file_utils import is_local_path @@ -830,6 +831,20 @@ def test_cache_dir_for_data_dir(self): other_builder = DummyBuilderWithManualDownload(cache_dir=tmp_dir, config_name="a", data_dir=tmp_dir) self.assertNotEqual(builder.cache_dir, other_builder.cache_dir) + def test_cache_dir_for_configured_builder(self): + with tempfile.TemporaryDirectory() as tmp_dir, tempfile.TemporaryDirectory() as data_dir: + builder_cls = configure_builder_class( + DummyBuilderWithManualDownload, + builder_configs=[BuilderConfig(data_dir=data_dir)], + default_config_name=None, + dataset_name="dummy", + ) + builder = builder_cls(cache_dir=tmp_dir, hash="abc") + other_builder = builder_cls(cache_dir=tmp_dir, hash="abc") + self.assertEqual(builder.cache_dir, other_builder.cache_dir) + other_builder = builder_cls(cache_dir=tmp_dir, hash="def") + self.assertNotEqual(builder.cache_dir, other_builder.cache_dir) + def test_arrow_based_download_and_prepare(tmp_path): builder = DummyArrowBasedBuilder(cache_dir=tmp_path) diff --git a/tests/test_data_files.py b/tests/test_data_files.py index 01b5e4dd15e..9ac8945c931 100644 --- a/tests/test_data_files.py +++ b/tests/test_data_files.py @@ -12,6 +12,8 @@ from datasets.data_files import ( DataFilesDict, DataFilesList, + DataFilesPatternsDict, + DataFilesPatternsList, _get_data_files_patterns, _get_metadata_files_patterns, _is_inside_unrequested_special_dir, @@ -487,6 +489,37 @@ def test_DataFilesDict_from_patterns_locally_or_remote_hashing(text_file): assert Hasher.hash(data_files1) != Hasher.hash(data_files2) +def test_DataFilesPatternsList(text_file): + data_files_patterns = DataFilesPatternsList([str(text_file)], allowed_extensions=[None]) + data_files = data_files_patterns.resolve(base_path="") + assert data_files == [text_file.as_posix()] + assert isinstance(data_files, DataFilesList) + data_files_patterns = DataFilesPatternsList([str(text_file)], allowed_extensions=[[".txt"]]) + data_files = data_files_patterns.resolve(base_path="") + assert data_files == [text_file.as_posix()] + assert isinstance(data_files, DataFilesList) + data_files_patterns = DataFilesPatternsList([str(text_file).replace(".txt", ".tx*")], allowed_extensions=[None]) + data_files = data_files_patterns.resolve(base_path="") + assert data_files == [text_file.as_posix()] + assert isinstance(data_files, DataFilesList) + data_files_patterns = DataFilesPatternsList([Path(text_file).name], allowed_extensions=[None]) + data_files = data_files_patterns.resolve(base_path=str(Path(text_file).parent)) + assert data_files == [text_file.as_posix()] + data_files_patterns = DataFilesPatternsList([str(text_file)], allowed_extensions=[[".zip"]]) + with pytest.raises(FileNotFoundError): + data_files_patterns.resolve(base_path="") + + +def test_DataFilesPatternsDict(text_file): + data_files_patterns_dict = DataFilesPatternsDict( + {"train": DataFilesPatternsList([str(text_file)], allowed_extensions=[None])} + ) + data_files_dict = data_files_patterns_dict.resolve(base_path="") + assert data_files_dict == {"train": [text_file.as_posix()]} + assert isinstance(data_files_dict, DataFilesDict) + assert isinstance(data_files_dict["train"], DataFilesList) + + def mock_fs(file_paths: List[str]): """ Set up a mock filesystem for fsspec containing the provided files diff --git a/tests/test_load.py b/tests/test_load.py index 66321ef8c47..acc9a2680a5 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -21,7 +21,7 @@ from datasets.arrow_writer import ArrowWriter from datasets.builder import DatasetBuilder from datasets.config import METADATA_CONFIGS_FIELD -from datasets.data_files import DataFilesDict +from datasets.data_files import DataFilesDict, DataFilesPatternsDict from datasets.dataset_dict import DatasetDict, IterableDatasetDict from datasets.download.download_config import DownloadConfig from datasets.exceptions import DatasetNotFoundError @@ -469,34 +469,29 @@ def test_LocalDatasetModuleFactoryWithoutScript_with_data_dir(self): factory = LocalDatasetModuleFactoryWithoutScript(self._data_dir2, data_dir=self._sub_data_dir) module_factory_result = factory.get_module() assert importlib.import_module(module_factory_result.module_path) is not None + builder_config = module_factory_result.builder_configs_parameters.builder_configs[0] assert ( - module_factory_result.builder_kwargs["data_files"] is not None - and len(module_factory_result.builder_kwargs["data_files"]["train"]) == 1 - and len(module_factory_result.builder_kwargs["data_files"]["test"]) == 1 + builder_config.data_files is not None + and len(builder_config.data_files["train"]) == 1 + and len(builder_config.data_files["test"]) == 1 ) assert all( self._sub_data_dir in Path(data_file).parts - for data_file in module_factory_result.builder_kwargs["data_files"]["train"] - + module_factory_result.builder_kwargs["data_files"]["test"] + for data_file in builder_config.data_files["train"] + builder_config.data_files["test"] ) def test_LocalDatasetModuleFactoryWithoutScript_with_metadata(self): factory = LocalDatasetModuleFactoryWithoutScript(self._data_dir_with_metadata) module_factory_result = factory.get_module() assert importlib.import_module(module_factory_result.module_path) is not None + builder_config = module_factory_result.builder_configs_parameters.builder_configs[0] assert ( - module_factory_result.builder_kwargs["data_files"] is not None - and len(module_factory_result.builder_kwargs["data_files"]["train"]) > 0 - and len(module_factory_result.builder_kwargs["data_files"]["test"]) > 0 - ) - assert any( - Path(data_file).name == "metadata.jsonl" - for data_file in module_factory_result.builder_kwargs["data_files"]["train"] - ) - assert any( - Path(data_file).name == "metadata.jsonl" - for data_file in module_factory_result.builder_kwargs["data_files"]["test"] + builder_config.data_files is not None + and len(builder_config.data_files["train"]) > 0 + and len(builder_config.data_files["test"]) > 0 ) + assert any(Path(data_file).name == "metadata.jsonl" for data_file in builder_config.data_files["train"]) + assert any(Path(data_file).name == "metadata.jsonl" for data_file in builder_config.data_files["test"]) def test_LocalDatasetModuleFactoryWithoutScript_with_single_config_in_metadata(self): factory = LocalDatasetModuleFactoryWithoutScript( @@ -518,6 +513,8 @@ def test_LocalDatasetModuleFactoryWithoutScript_with_single_config_in_metadata(s assert isinstance(module_builder_configs[0], ImageFolderConfig) assert module_builder_configs[0].name == "custom" assert module_builder_configs[0].data_files is not None + assert isinstance(module_builder_configs[0].data_files, DataFilesPatternsDict) + module_builder_configs[0]._resolve_data_files(self._data_dir_with_single_config_in_metadata, DownloadConfig()) assert isinstance(module_builder_configs[0].data_files, DataFilesDict) assert len(module_builder_configs[0].data_files) == 1 # one train split assert len(module_builder_configs[0].data_files["train"]) == 2 # two files @@ -553,6 +550,10 @@ def test_LocalDatasetModuleFactoryWithoutScript_with_two_configs_in_metadata(sel assert module_builder_config_v2.name == "v2" assert isinstance(module_builder_config_v1, ImageFolderConfig) assert isinstance(module_builder_config_v2, ImageFolderConfig) + assert isinstance(module_builder_config_v1.data_files, DataFilesPatternsDict) + assert isinstance(module_builder_config_v2.data_files, DataFilesPatternsDict) + module_builder_config_v1._resolve_data_files(self._data_dir_with_two_config_in_metadata, DownloadConfig()) + module_builder_config_v2._resolve_data_files(self._data_dir_with_two_config_in_metadata, DownloadConfig()) assert isinstance(module_builder_config_v1.data_files, DataFilesDict) assert isinstance(module_builder_config_v2.data_files, DataFilesDict) assert sorted(module_builder_config_v1.data_files) == ["train"] @@ -580,13 +581,10 @@ def test_PackagedDatasetModuleFactory_with_data_dir(self): factory = PackagedDatasetModuleFactory("json", data_dir=self._data_dir, download_config=self.download_config) module_factory_result = factory.get_module() assert importlib.import_module(module_factory_result.module_path) is not None - assert ( - module_factory_result.builder_kwargs["data_files"] is not None - and len(module_factory_result.builder_kwargs["data_files"]["train"]) > 0 - and len(module_factory_result.builder_kwargs["data_files"]["test"]) > 0 - ) - assert Path(module_factory_result.builder_kwargs["data_files"]["train"][0]).parent.samefile(self._data_dir) - assert Path(module_factory_result.builder_kwargs["data_files"]["test"][0]).parent.samefile(self._data_dir) + data_files = module_factory_result.builder_kwargs.get("data_files") + assert data_files is not None and len(data_files["train"]) > 0 and len(data_files["test"]) > 0 + assert Path(data_files["train"][0]).parent.samefile(self._data_dir) + assert Path(data_files["test"][0]).parent.samefile(self._data_dir) def test_PackagedDatasetModuleFactory_with_data_dir_and_metadata(self): factory = PackagedDatasetModuleFactory( @@ -594,25 +592,12 @@ def test_PackagedDatasetModuleFactory_with_data_dir_and_metadata(self): ) module_factory_result = factory.get_module() assert importlib.import_module(module_factory_result.module_path) is not None - assert ( - module_factory_result.builder_kwargs["data_files"] is not None - and len(module_factory_result.builder_kwargs["data_files"]["train"]) > 0 - and len(module_factory_result.builder_kwargs["data_files"]["test"]) > 0 - ) - assert Path(module_factory_result.builder_kwargs["data_files"]["train"][0]).parent.samefile( - self._data_dir_with_metadata - ) - assert Path(module_factory_result.builder_kwargs["data_files"]["test"][0]).parent.samefile( - self._data_dir_with_metadata - ) - assert any( - Path(data_file).name == "metadata.jsonl" - for data_file in module_factory_result.builder_kwargs["data_files"]["train"] - ) - assert any( - Path(data_file).name == "metadata.jsonl" - for data_file in module_factory_result.builder_kwargs["data_files"]["test"] - ) + data_files = module_factory_result.builder_kwargs.get("data_files") + assert data_files is not None and len(data_files["train"]) > 0 and len(data_files["test"]) > 0 + assert Path(data_files["train"][0]).parent.samefile(self._data_dir_with_metadata) + assert Path(data_files["test"][0]).parent.samefile(self._data_dir_with_metadata) + assert any(Path(data_file).name == "metadata.jsonl" for data_file in data_files["train"]) + assert any(Path(data_file).name == "metadata.jsonl" for data_file in data_files["test"]) @pytest.mark.integration def test_HubDatasetModuleFactoryWithoutScript(self): @@ -631,16 +616,16 @@ def test_HubDatasetModuleFactoryWithoutScript_with_data_dir(self): ) module_factory_result = factory.get_module() assert importlib.import_module(module_factory_result.module_path) is not None + builder_config = module_factory_result.builder_configs_parameters.builder_configs[0] assert module_factory_result.builder_kwargs["base_path"].startswith(config.HF_ENDPOINT) assert ( - module_factory_result.builder_kwargs["data_files"] is not None - and len(module_factory_result.builder_kwargs["data_files"]["train"]) == 1 - and len(module_factory_result.builder_kwargs["data_files"]["test"]) == 1 + builder_config.data_files is not None + and len(builder_config.data_files["train"]) == 1 + and len(builder_config.data_files["test"]) == 1 ) assert all( data_dir in Path(data_file).parts - for data_file in module_factory_result.builder_kwargs["data_files"]["train"] - + module_factory_result.builder_kwargs["data_files"]["test"] + for data_file in builder_config.data_files["train"] + builder_config.data_files["test"] ) @pytest.mark.integration @@ -650,36 +635,29 @@ def test_HubDatasetModuleFactoryWithoutScript_with_metadata(self): ) module_factory_result = factory.get_module() assert importlib.import_module(module_factory_result.module_path) is not None + builder_config = module_factory_result.builder_configs_parameters.builder_configs[0] assert module_factory_result.builder_kwargs["base_path"].startswith(config.HF_ENDPOINT) assert ( - module_factory_result.builder_kwargs["data_files"] is not None - and len(module_factory_result.builder_kwargs["data_files"]["train"]) > 0 - and len(module_factory_result.builder_kwargs["data_files"]["test"]) > 0 - ) - assert any( - Path(data_file).name == "metadata.jsonl" - for data_file in module_factory_result.builder_kwargs["data_files"]["train"] - ) - assert any( - Path(data_file).name == "metadata.jsonl" - for data_file in module_factory_result.builder_kwargs["data_files"]["test"] + builder_config.data_files is not None + and len(builder_config.data_files["train"]) > 0 + and len(builder_config.data_files["test"]) > 0 ) + assert any(Path(data_file).name == "metadata.jsonl" for data_file in builder_config.data_files["train"]) + assert any(Path(data_file).name == "metadata.jsonl" for data_file in builder_config.data_files["test"]) factory = HubDatasetModuleFactoryWithoutScript( SAMPLE_DATASET_IDENTIFIER5, download_config=self.download_config ) module_factory_result = factory.get_module() assert importlib.import_module(module_factory_result.module_path) is not None + builder_config = module_factory_result.builder_configs_parameters.builder_configs[0] assert module_factory_result.builder_kwargs["base_path"].startswith(config.HF_ENDPOINT) assert ( - module_factory_result.builder_kwargs["data_files"] is not None - and len(module_factory_result.builder_kwargs["data_files"]) == 1 - and len(module_factory_result.builder_kwargs["data_files"]["train"]) > 0 - ) - assert any( - Path(data_file).name == "metadata.jsonl" - for data_file in module_factory_result.builder_kwargs["data_files"]["train"] + builder_config.data_files is not None + and len(builder_config.data_files) == 1 + and len(builder_config.data_files["train"]) > 0 ) + assert any(Path(data_file).name == "metadata.jsonl" for data_file in builder_config.data_files["train"]) @pytest.mark.integration def test_HubDatasetModuleFactoryWithoutScript_with_one_default_config_in_metadata(self): @@ -704,6 +682,10 @@ def test_HubDatasetModuleFactoryWithoutScript_with_one_default_config_in_metadat assert isinstance(module_builder_configs[0], AudioFolderConfig) assert module_builder_configs[0].name == "custom" assert module_builder_configs[0].data_files is not None + assert isinstance(module_builder_configs[0].data_files, DataFilesPatternsDict) + module_builder_configs[0]._resolve_data_files( + module_factory_result.builder_kwargs["base_path"], DownloadConfig() + ) assert isinstance(module_builder_configs[0].data_files, DataFilesDict) assert sorted(module_builder_configs[0].data_files) == ["test", "train"] assert len(module_builder_configs[0].data_files["train"]) == 3 @@ -741,6 +723,14 @@ def test_HubDatasetModuleFactoryWithoutScript_with_two_configs_in_metadata(self) assert module_builder_config_v2.name == "v2" assert isinstance(module_builder_config_v1, AudioFolderConfig) assert isinstance(module_builder_config_v2, AudioFolderConfig) + assert isinstance(module_builder_config_v1.data_files, DataFilesPatternsDict) + assert isinstance(module_builder_config_v2.data_files, DataFilesPatternsDict) + module_builder_config_v1._resolve_data_files( + module_factory_result.builder_kwargs["base_path"], DownloadConfig() + ) + module_builder_config_v2._resolve_data_files( + module_factory_result.builder_kwargs["base_path"], DownloadConfig() + ) assert isinstance(module_builder_config_v1.data_files, DataFilesDict) assert isinstance(module_builder_config_v2.data_files, DataFilesDict) assert sorted(module_builder_config_v1.data_files) == ["test", "train"] @@ -780,6 +770,9 @@ def test_HubDatasetModuleFactoryWithParquetExport(self): assert module_factory_result.module_path == "datasets.packaged_modules.parquet.parquet" assert module_factory_result.builder_configs_parameters.builder_configs assert isinstance(module_factory_result.builder_configs_parameters.builder_configs[0], ParquetConfig) + module_factory_result.builder_configs_parameters.builder_configs[0]._resolve_data_files( + base_path="", download_config=self.download_config + ) assert module_factory_result.builder_configs_parameters.builder_configs[0].data_files == { "train": [ "hf://datasets/hf-internal-testing/dataset_with_script@da4ed81df5a1bcd916043c827b75994de8ef7eda/default/train/0000.parquet" @@ -805,7 +798,20 @@ def test_HubDatasetModuleFactoryWithParquetExport_errors_on_wrong_sha(self): with self.assertRaises(_datasets_server.DatasetsServerError): factory.get_module() + @pytest.mark.integration def test_CachedDatasetModuleFactory(self): + name = SAMPLE_DATASET_IDENTIFIER2 + load_dataset_builder(name, cache_dir=self.cache_dir).download_and_prepare() + for offline_mode in OfflineSimulationMode: + with offline(offline_mode): + factory = CachedDatasetModuleFactory( + name, + cache_dir=self.cache_dir, + ) + module_factory_result = factory.get_module() + assert importlib.import_module(module_factory_result.module_path) is not None + + def test_CachedDatasetModuleFactory_with_script(self): path = os.path.join(self._dataset_loading_script_dir, f"{DATASET_LOADING_SCRIPT_NAME}.py") factory = LocalDatasetModuleFactoryWithScript( path, download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path @@ -866,12 +872,14 @@ def inject_fixtures(self, caplog): def setUp(self): self.hf_modules_cache = tempfile.mkdtemp() + self.cache_dir = tempfile.mkdtemp() self.dynamic_modules_path = datasets.load.init_dynamic_modules( name="test_datasets_modules2", hf_modules_cache=self.hf_modules_cache ) def tearDown(self): shutil.rmtree(self.hf_modules_cache) + shutil.rmtree(self.cache_dir) def _dummy_module_dir(self, modules_dir, dummy_module_name, dummy_code): assert dummy_module_name.startswith("__") @@ -913,7 +921,20 @@ def test_dataset_module_factory(self): "__missing_dummy_module_name__", dynamic_modules_path=self.dynamic_modules_path ) + @pytest.mark.integration def test_offline_dataset_module_factory(self): + repo_id = SAMPLE_DATASET_IDENTIFIER2 + builder = load_dataset_builder(repo_id, cache_dir=self.cache_dir) + builder.download_and_prepare() + for offline_simulation_mode in list(OfflineSimulationMode): + with offline(offline_simulation_mode): + self._caplog.clear() + # allow provide the repo id without an explicit path to remote or local actual file + dataset_module = datasets.load.dataset_module_factory(repo_id, cache_dir=self.cache_dir) + self.assertEqual(dataset_module.module_path, "datasets.packaged_modules.cache.cache") + self.assertIn("Using the latest cached version of the dataset", self._caplog.text) + + def test_offline_dataset_module_factory_with_script(self): with tempfile.TemporaryDirectory() as tmp_dir: dummy_code = "MY_DUMMY_VARIABLE = 'hello there'" module_dir = self._dummy_module_dir(tmp_dir, "__dummy_module_name2__", dummy_code) @@ -986,9 +1007,8 @@ def test_load_dataset_builder_with_metadata(): assert builder.config.name == "default" assert builder.config.data_files is not None assert builder.config.drop_metadata is None - builder = datasets.load_dataset_builder(SAMPLE_DATASET_IDENTIFIER4, "non-existing-config") - assert isinstance(builder, ImageFolder) - assert builder.config.name == "non-existing-config" + with pytest.raises(ValueError): + builder = datasets.load_dataset_builder(SAMPLE_DATASET_IDENTIFIER4, "non-existing-config") @pytest.mark.integration @@ -1138,13 +1158,21 @@ def test_load_dataset_builder_fail(): @pytest.mark.parametrize("keep_in_memory", [False, True]) -def test_load_dataset_local(dataset_loading_script_dir, data_dir, keep_in_memory, caplog): +def test_load_dataset_local_script(dataset_loading_script_dir, data_dir, keep_in_memory, caplog): with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase(): dataset = load_dataset(dataset_loading_script_dir, data_dir=data_dir, keep_in_memory=keep_in_memory) assert isinstance(dataset, DatasetDict) assert all(isinstance(d, Dataset) for d in dataset.values()) assert len(dataset) == 2 assert isinstance(next(iter(dataset["train"])), dict) + + +def test_load_dataset_cached_local_script(dataset_loading_script_dir, data_dir, caplog): + dataset = load_dataset(dataset_loading_script_dir, data_dir=data_dir) + assert isinstance(dataset, DatasetDict) + assert all(isinstance(d, Dataset) for d in dataset.values()) + assert len(dataset) == 2 + assert isinstance(next(iter(dataset["train"])), dict) for offline_simulation_mode in list(OfflineSimulationMode): with offline(offline_simulation_mode): caplog.clear() @@ -1152,6 +1180,28 @@ def test_load_dataset_local(dataset_loading_script_dir, data_dir, keep_in_memory dataset = datasets.load_dataset(DATASET_LOADING_SCRIPT_NAME, data_dir=data_dir) assert len(dataset) == 2 assert "Using the latest cached version of the module" in caplog.text + assert isinstance(next(iter(dataset["train"])), dict) + with pytest.raises(DatasetNotFoundError) as exc_info: + datasets.load_dataset(SAMPLE_DATASET_NAME_THAT_DOESNT_EXIST) + assert f"Dataset '{SAMPLE_DATASET_NAME_THAT_DOESNT_EXIST}' doesn't exist on the Hub" in str(exc_info.value) + + +@pytest.mark.integration +@pytest.mark.parametrize("stream_from_cache, ", [False, True]) +def test_load_dataset_cached_from_hub(stream_from_cache, caplog): + dataset = load_dataset(SAMPLE_DATASET_IDENTIFIER3) + assert isinstance(dataset, DatasetDict) + assert all(isinstance(d, Dataset) for d in dataset.values()) + assert len(dataset) == 2 + assert isinstance(next(iter(dataset["train"])), dict) + for offline_simulation_mode in list(OfflineSimulationMode): + with offline(offline_simulation_mode): + caplog.clear() + # Load dataset from cache + dataset = datasets.load_dataset(SAMPLE_DATASET_IDENTIFIER3, streaming=stream_from_cache) + assert len(dataset) == 2 + assert "Using the latest cached version of the dataset" in caplog.text + assert isinstance(next(iter(dataset["train"])), dict) with pytest.raises(DatasetNotFoundError) as exc_info: datasets.load_dataset(SAMPLE_DATASET_NAME_THAT_DOESNT_EXIST) assert f"Dataset '{SAMPLE_DATASET_NAME_THAT_DOESNT_EXIST}' doesn't exist on the Hub" in str(exc_info.value) @@ -1453,6 +1503,18 @@ def test_load_dataset_then_move_then_reload(dataset_loading_script_dir, data_dir assert dataset._fingerprint != fingerprint1 +def test_load_dataset_builder_then_edit_then_load_again(tmp_path: Path): + dataset_dir = tmp_path / "test_load_dataset_then_edit_then_load_again" + dataset_dir.mkdir() + with open(dataset_dir / "train.txt", "w") as f: + f.write("Hello there") + dataset_builder = load_dataset_builder(str(dataset_dir)) + with open(dataset_dir / "train.txt", "w") as f: + f.write("General Kenobi !") + edited_dataset_builder = load_dataset_builder(str(dataset_dir)) + assert dataset_builder.cache_dir != edited_dataset_builder.cache_dir + + def test_load_dataset_readonly(dataset_loading_script_dir, dataset_loading_script_dir_readonly, data_dir, tmp_path): cache_dir1 = tmp_path / "cache1" cache_dir2 = tmp_path / "cache2"