From 51002cb0325772adaf46d6f3ce01d41c01b51079 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 29 Nov 2023 14:16:57 +0100 Subject: [PATCH 01/14] lazy data files resolution --- src/datasets/builder.py | 19 ++++++-- src/datasets/data_files.py | 98 ++++++++++++++++++++++++++++++++++++++ src/datasets/load.py | 23 ++++----- 3 files changed, 122 insertions(+), 18 deletions(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index a078cb4c2c8..e8de242829f 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -46,7 +46,7 @@ ReadInstruction, ) from .arrow_writer import ArrowWriter, BeamWriter, ParquetWriter, SchemaInferenceError -from .data_files import DataFilesDict, sanitize_patterns +from .data_files import DataFilesDict, DataFilesPatternsDict, sanitize_patterns from .dataset_dict import DatasetDict, IterableDatasetDict from .download.download_config import DownloadConfig from .download.download_manager import DownloadManager, DownloadMode @@ -128,7 +128,7 @@ class BuilderConfig: name: str = "default" version: Optional[Union[utils.Version, str]] = utils.Version("0.0.0") data_dir: Optional[str] = None - data_files: Optional[DataFilesDict] = None + data_files: Optional[Union[DataFilesDict, DataFilesPatternsDict]] = None description: Optional[str] = None def __post_init__(self): @@ -139,7 +139,7 @@ def __post_init__(self): f"Bad characters from black list '{INVALID_WINDOWS_CHARACTERS_IN_PATH}' found in '{self.name}'. " f"They could create issues when creating a directory for this config on Windows filesystem." ) - if self.data_files is not None and not isinstance(self.data_files, DataFilesDict): + if self.data_files is not None and not isinstance(self.data_files, (DataFilesDict, DataFilesPatternsDict)): raise ValueError(f"Expected a DataFilesDict in data_files but got {self.data_files}") def __eq__(self, o): @@ -213,6 +213,11 @@ def create_config_id( else: return self.name + def _resolve_data_files(self, base_path: str, download_config: DownloadConfig) -> None: + if self.data_files is not None and isinstance(self.data_files, DataFilesPatternsDict): + base_path = base_path + "/" + self.data_dir if self.data_dir else base_path + self.data_files = self.data_files.resolve(base_path, download_config) + class DatasetBuilder: """Abstract base class for all datasets. @@ -555,7 +560,7 @@ def _create_builder_config( # otherwise use the config_kwargs to overwrite the attributes else: - builder_config = copy.deepcopy(builder_config) + builder_config = copy.deepcopy(builder_config) if config_kwargs else builder_config for key, value in config_kwargs.items(): if value is not None: if not hasattr(builder_config, key): @@ -565,6 +570,12 @@ def _create_builder_config( if not builder_config.name: raise ValueError(f"BuilderConfig must have a name, got {builder_config.name}") + # resolve data files if needed + builder_config._resolve_data_files( + base_path=self.base_path, + download_config=DownloadConfig(token=self.token, storage_options=self.storage_options), + ) + # compute the config id that is going to be used for caching config_id = builder_config.create_config_id( config_kwargs, diff --git a/src/datasets/data_files.py b/src/datasets/data_files.py index f318e893738..02aa594ce42 100644 --- a/src/datasets/data_files.py +++ b/src/datasets/data_files.py @@ -340,6 +340,15 @@ def resolve_pattern( base_path = os.path.splitdrive(pattern)[0] + os.sep else: base_path = "" + if pattern.startswith(config.HF_ENDPOINT) and "/resolve/" in pattern: + pattern = "hf://" + pattern[len(config.HF_ENDPOINT) + 1 :].replace("/resolve/", "@", 1) + pattern = ( + pattern.split("@", 1)[0] + + "@" + + pattern.split("@", 1)[1].split("/", 1)[0].replace("%2F", "/") + + "/" + + pattern.split("@", 1)[1].split("/", 1)[1] + ) pattern, storage_options = _prepare_path_and_storage_options(pattern, download_config=download_config) fs, _, _ = get_fs_token_paths(pattern, storage_options=storage_options) fs_base_path = base_path.split("::")[0].split("://")[-1] or fs.root_marker @@ -698,3 +707,92 @@ def filter_extensions(self, extensions: List[str]) -> "DataFilesDict": for key, data_files_list in self.items(): out[key] = data_files_list.filter_extensions(extensions) return out + + +class DataFilesPatternsList(List[str]): + """ + List of data files patterns (absolute local paths or URLs). + """ + + def __init__( + self, + patterns: List[str], + allowed_extensions: List[Optional[List[str]]], + ): + super().__init__(patterns) + self.allowed_extensions = allowed_extensions + + def __add__(self, other): + return DataFilesList([*self, *other], self.allowed_extensions + other.allowed_extensions) + + @classmethod + def from_patterns( + cls, patterns: List[str], allowed_extensions: Optional[List[str]] = None + ) -> "DataFilesPatternsDict": + return cls(patterns, [allowed_extensions] * len(patterns)) + + def resolve( + self, + base_path: str, + download_config: Optional[DownloadConfig] = None, + ) -> "DataFilesList": + base_path = base_path if base_path is not None else Path().resolve().as_posix() + data_files = [] + for pattern, allowed_extensions in zip(self, self.allowed_extensions): + try: + data_files.extend( + resolve_pattern( + pattern, + base_path=base_path, + allowed_extensions=allowed_extensions, + download_config=download_config, + ) + ) + except FileNotFoundError: + if not has_magic(pattern): + raise + origin_metadata = _get_origin_metadata(data_files, download_config=download_config) + return DataFilesList(data_files, origin_metadata) + + def filter_extensions(self, extensions: List[str]) -> "DataFilesList": + return DataFilesPatternsList( + self, [allowed_extensions + extensions for allowed_extensions in self.allowed_extensions] + ) + + +class DataFilesPatternsDict(Dict[str, DataFilesPatternsList]): + """ + Dict of split_name -> list of data files patterns (absolute local paths or URLs). + """ + + @classmethod + def from_patterns( + cls, patterns: Dict[str, List[str]], allowed_extensions: Optional[List[str]] = None + ) -> "DataFilesPatternsDict": + out = cls() + for key, patterns_for_key in patterns.items(): + out[key] = ( + DataFilesPatternsList.from_patterns( + patterns_for_key, + allowed_extensions=allowed_extensions, + ) + if not isinstance(patterns_for_key, DataFilesPatternsList) + else patterns_for_key + ) + return out + + def resolve( + self, + base_path: str, + download_config: Optional[DownloadConfig] = None, + ) -> "DataFilesDict": + out = DataFilesDict() + for key, data_files_patterns_list in self.items(): + out[key] = data_files_patterns_list.resolve(base_path, download_config) + return out + + def filter_extensions(self, extensions: List[str]) -> "DataFilesPatternsDict": + out = type(self)() + for key, data_files_patterns_list in self.items(): + out[key] = data_files_patterns_list.filter_extensions(extensions) + return out diff --git a/src/datasets/load.py b/src/datasets/load.py index 6893e94242a..f0f3572cf5b 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -40,6 +40,8 @@ DEFAULT_PATTERNS_ALL, DataFilesDict, DataFilesList, + DataFilesPatternsDict, + DataFilesPatternsList, EmptyDatasetError, get_data_patterns, get_metadata_patterns, @@ -606,7 +608,6 @@ def create_builder_configs_from_metadata_configs( supports_metadata: bool, base_path: Optional[str] = None, default_builder_kwargs: Dict[str, Any] = None, - download_config: Optional[DownloadConfig] = None, ) -> Tuple[List[BuilderConfig], str]: builder_cls = import_main_class(module_path) builder_config_cls = builder_cls.BUILDER_CONFIG_CLASS @@ -624,11 +625,9 @@ def create_builder_configs_from_metadata_configs( if config_data_files is not None else get_data_patterns(config_base_path) ) - config_data_files_dict = DataFilesDict.from_patterns( + config_data_files_dict = DataFilesPatternsDict.from_patterns( config_patterns, - base_path=config_base_path, allowed_extensions=ALL_ALLOWED_EXTENSIONS, - download_config=download_config, ) except EmptyDatasetError as e: raise EmptyDatasetError( @@ -641,16 +640,13 @@ def create_builder_configs_from_metadata_configs( except FileNotFoundError: config_metadata_patterns = None if config_metadata_patterns is not None: - config_metadata_data_files_list = DataFilesList.from_patterns( - config_metadata_patterns, base_path=base_path + config_metadata_data_files_list = DataFilesPatternsList.from_patterns(config_metadata_patterns) + config_data_files_dict = DataFilesPatternsDict( + { + split: data_files_list + config_metadata_data_files_list + for split, data_files_list in config_data_files_dict.items() + } ) - if config_metadata_data_files_list: - config_data_files_dict = DataFilesDict( - { - split: data_files_list + config_metadata_data_files_list - for split, data_files_list in config_data_files_dict.items() - } - ) ignored_params = [ param for param in config_params if not hasattr(builder_config_cls, param) and param != "default" ] @@ -1255,7 +1251,6 @@ def get_module(self) -> DatasetModule: base_path=base_path, supports_metadata=supports_metadata, default_builder_kwargs=default_builder_kwargs, - download_config=self.download_config, ) else: builder_configs, default_config_name = None, None From 5a5bb38bcc71ea21f2d7304aab374fdb81ded463 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 29 Nov 2023 16:22:42 +0100 Subject: [PATCH 02/14] fix tests --- src/datasets/builder.py | 4 ++-- src/datasets/data_files.py | 2 +- src/datasets/load.py | 2 +- tests/test_load.py | 14 +++++++++++++- 4 files changed, 17 insertions(+), 5 deletions(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index e8de242829f..e24baf611a3 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -51,7 +51,7 @@ from .download.download_config import DownloadConfig from .download.download_manager import DownloadManager, DownloadMode from .download.mock_download_manager import MockDownloadManager -from .download.streaming_download_manager import StreamingDownloadManager, xopen +from .download.streaming_download_manager import StreamingDownloadManager, xjoin, xopen from .features import Features from .filesystems import ( is_remote_filesystem, @@ -215,7 +215,7 @@ def create_config_id( def _resolve_data_files(self, base_path: str, download_config: DownloadConfig) -> None: if self.data_files is not None and isinstance(self.data_files, DataFilesPatternsDict): - base_path = base_path + "/" + self.data_dir if self.data_dir else base_path + base_path = xjoin(base_path, self.data_dir) if self.data_dir else base_path self.data_files = self.data_files.resolve(base_path, download_config) diff --git a/src/datasets/data_files.py b/src/datasets/data_files.py index 02aa594ce42..b7143a2ade5 100644 --- a/src/datasets/data_files.py +++ b/src/datasets/data_files.py @@ -358,7 +358,7 @@ def resolve_pattern( protocol_prefix = protocol + "://" if protocol != "file" else "" matched_paths = [ filepath if filepath.startswith(protocol_prefix) else protocol_prefix + filepath - for filepath, info in fs.glob(pattern, detail=True).items() + for filepath, info in fs.glob(pattern, detail=True, expand_info=False).items() if info["type"] == "file" and (xbasename(filepath) not in files_to_ignore) and not _is_inside_unrequested_special_dir( diff --git a/src/datasets/load.py b/src/datasets/load.py index f0f3572cf5b..2e307d113c2 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -618,7 +618,7 @@ def create_builder_configs_from_metadata_configs( for config_name, config_params in metadata_configs.items(): config_data_files = config_params.get("data_files") config_data_dir = config_params.get("data_dir") - config_base_path = base_path + "/" + config_data_dir if config_data_dir else base_path + config_base_path = xjoin(base_path, config_data_dir) if config_data_dir else base_path try: config_patterns = ( sanitize_patterns(config_data_files) diff --git a/tests/test_load.py b/tests/test_load.py index 48ed31e0476..11fd99b09b1 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -21,7 +21,7 @@ from datasets.arrow_writer import ArrowWriter from datasets.builder import DatasetBuilder from datasets.config import METADATA_CONFIGS_FIELD -from datasets.data_files import DataFilesDict +from datasets.data_files import DataFilesDict, DataFilesPatternsDict from datasets.dataset_dict import DatasetDict, IterableDatasetDict from datasets.download.download_config import DownloadConfig from datasets.exceptions import DatasetNotFoundError @@ -515,6 +515,8 @@ def test_LocalDatasetModuleFactoryWithoutScript_with_single_config_in_metadata(s assert isinstance(module_builder_configs[0], ImageFolderConfig) assert module_builder_configs[0].name == "custom" assert module_builder_configs[0].data_files is not None + assert isinstance(module_builder_configs[0].data_files, DataFilesPatternsDict) + module_builder_configs[0]._resolve_data_files(self._data_dir_with_single_config_in_metadata, DownloadConfig()) assert isinstance(module_builder_configs[0].data_files, DataFilesDict) assert len(module_builder_configs[0].data_files) == 1 # one train split assert len(module_builder_configs[0].data_files["train"]) == 2 # two files @@ -550,6 +552,10 @@ def test_LocalDatasetModuleFactoryWithoutScript_with_two_configs_in_metadata(sel assert module_builder_config_v2.name == "v2" assert isinstance(module_builder_config_v1, ImageFolderConfig) assert isinstance(module_builder_config_v2, ImageFolderConfig) + assert isinstance(module_builder_config_v1.data_files, DataFilesPatternsDict) + assert isinstance(module_builder_config_v2.data_files, DataFilesPatternsDict) + module_builder_config_v1._resolve_data_files(self._data_dir_with_two_config_in_metadata, DownloadConfig()) + module_builder_config_v2._resolve_data_files(self._data_dir_with_two_config_in_metadata, DownloadConfig()) assert isinstance(module_builder_config_v1.data_files, DataFilesDict) assert isinstance(module_builder_config_v2.data_files, DataFilesDict) assert sorted(module_builder_config_v1.data_files) == ["train"] @@ -701,6 +707,8 @@ def test_HubDatasetModuleFactoryWithoutScript_with_one_default_config_in_metadat assert isinstance(module_builder_configs[0], AudioFolderConfig) assert module_builder_configs[0].name == "custom" assert module_builder_configs[0].data_files is not None + assert isinstance(module_builder_configs[0].data_files, DataFilesPatternsDict) + module_builder_configs[0]._resolve_data_files() assert isinstance(module_builder_configs[0].data_files, DataFilesDict) assert sorted(module_builder_configs[0].data_files) == ["test", "train"] assert len(module_builder_configs[0].data_files["train"]) == 3 @@ -738,6 +746,10 @@ def test_HubDatasetModuleFactoryWithoutScript_with_two_configs_in_metadata(self) assert module_builder_config_v2.name == "v2" assert isinstance(module_builder_config_v1, AudioFolderConfig) assert isinstance(module_builder_config_v2, AudioFolderConfig) + assert isinstance(module_builder_config_v1.data_files, DataFilesPatternsDict) + assert isinstance(module_builder_config_v2.data_files, DataFilesPatternsDict) + module_builder_config_v1._resolve_data_files() + module_builder_config_v2._resolve_data_files() assert isinstance(module_builder_config_v1.data_files, DataFilesDict) assert isinstance(module_builder_config_v2.data_files, DataFilesDict) assert sorted(module_builder_config_v1.data_files) == ["test", "train"] From 214a3e6dcb66e9c1a8ff586553e8eee0f1c70710 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 29 Nov 2023 16:24:51 +0100 Subject: [PATCH 03/14] minor --- src/datasets/load.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/datasets/load.py b/src/datasets/load.py index 2e307d113c2..ed7867f93ed 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -1256,7 +1256,7 @@ def get_module(self) -> DatasetModule: builder_configs, default_config_name = None, None builder_kwargs = { "hash": hash, - "base_path": hf_hub_url(self.name, "", revision=self.revision), + "base_path": hf_hub_url(self.name, "", revision=self.revision).rstrip("/"), "repo_id": self.name, "dataset_name": camelcase_to_snakecase(Path(self.name).name), } @@ -1424,7 +1424,7 @@ def get_module(self) -> DatasetModule: importlib.invalidate_caches() builder_kwargs = { "hash": hash, - "base_path": hf_hub_url(self.name, "", revision=self.revision), + "base_path": hf_hub_url(self.name, "", revision=self.revision).rstrip("/"), "repo_id": self.name, } return DatasetModule(module_path, hash, builder_kwargs) From b7a9674e17156ff10124632ba705125288de7442 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 29 Nov 2023 16:29:39 +0100 Subject: [PATCH 04/14] don't use expand_info=False yet --- src/datasets/data_files.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/data_files.py b/src/datasets/data_files.py index b7143a2ade5..02aa594ce42 100644 --- a/src/datasets/data_files.py +++ b/src/datasets/data_files.py @@ -358,7 +358,7 @@ def resolve_pattern( protocol_prefix = protocol + "://" if protocol != "file" else "" matched_paths = [ filepath if filepath.startswith(protocol_prefix) else protocol_prefix + filepath - for filepath, info in fs.glob(pattern, detail=True, expand_info=False).items() + for filepath, info in fs.glob(pattern, detail=True).items() if info["type"] == "file" and (xbasename(filepath) not in files_to_ignore) and not _is_inside_unrequested_special_dir( From 32e0960ea165a9481b1ff6eed31771475120cb38 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 29 Nov 2023 16:48:38 +0100 Subject: [PATCH 05/14] fix --- tests/test_load.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_load.py b/tests/test_load.py index 11fd99b09b1..c165762b261 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -708,7 +708,7 @@ def test_HubDatasetModuleFactoryWithoutScript_with_one_default_config_in_metadat assert module_builder_configs[0].name == "custom" assert module_builder_configs[0].data_files is not None assert isinstance(module_builder_configs[0].data_files, DataFilesPatternsDict) - module_builder_configs[0]._resolve_data_files() + module_builder_configs[0]._resolve_data_files(module_factory_result.builder_kwargs["base_path"], DownloadConfig()) assert isinstance(module_builder_configs[0].data_files, DataFilesDict) assert sorted(module_builder_configs[0].data_files) == ["test", "train"] assert len(module_builder_configs[0].data_files["train"]) == 3 @@ -748,8 +748,8 @@ def test_HubDatasetModuleFactoryWithoutScript_with_two_configs_in_metadata(self) assert isinstance(module_builder_config_v2, AudioFolderConfig) assert isinstance(module_builder_config_v1.data_files, DataFilesPatternsDict) assert isinstance(module_builder_config_v2.data_files, DataFilesPatternsDict) - module_builder_config_v1._resolve_data_files() - module_builder_config_v2._resolve_data_files() + module_builder_config_v1._resolve_data_files(module_factory_result.builder_kwargs["base_path"], DownloadConfig()) + module_builder_config_v2._resolve_data_files(module_factory_result.builder_kwargs["base_path"], DownloadConfig()) assert isinstance(module_builder_config_v1.data_files, DataFilesDict) assert isinstance(module_builder_config_v2.data_files, DataFilesDict) assert sorted(module_builder_config_v1.data_files) == ["test", "train"] From 68099ca55294bfc12a34781835dd73c533a764bd Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 29 Nov 2023 18:03:44 +0100 Subject: [PATCH 06/14] style --- tests/test_load.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/test_load.py b/tests/test_load.py index c165762b261..1c33998aeac 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -708,7 +708,9 @@ def test_HubDatasetModuleFactoryWithoutScript_with_one_default_config_in_metadat assert module_builder_configs[0].name == "custom" assert module_builder_configs[0].data_files is not None assert isinstance(module_builder_configs[0].data_files, DataFilesPatternsDict) - module_builder_configs[0]._resolve_data_files(module_factory_result.builder_kwargs["base_path"], DownloadConfig()) + module_builder_configs[0]._resolve_data_files( + module_factory_result.builder_kwargs["base_path"], DownloadConfig() + ) assert isinstance(module_builder_configs[0].data_files, DataFilesDict) assert sorted(module_builder_configs[0].data_files) == ["test", "train"] assert len(module_builder_configs[0].data_files["train"]) == 3 @@ -748,8 +750,12 @@ def test_HubDatasetModuleFactoryWithoutScript_with_two_configs_in_metadata(self) assert isinstance(module_builder_config_v2, AudioFolderConfig) assert isinstance(module_builder_config_v1.data_files, DataFilesPatternsDict) assert isinstance(module_builder_config_v2.data_files, DataFilesPatternsDict) - module_builder_config_v1._resolve_data_files(module_factory_result.builder_kwargs["base_path"], DownloadConfig()) - module_builder_config_v2._resolve_data_files(module_factory_result.builder_kwargs["base_path"], DownloadConfig()) + module_builder_config_v1._resolve_data_files( + module_factory_result.builder_kwargs["base_path"], DownloadConfig() + ) + module_builder_config_v2._resolve_data_files( + module_factory_result.builder_kwargs["base_path"], DownloadConfig() + ) assert isinstance(module_builder_config_v1.data_files, DataFilesDict) assert isinstance(module_builder_config_v2.data_files, DataFilesDict) assert sorted(module_builder_config_v1.data_files) == ["test", "train"] From b3fc42882a2d84d7482c27063f1e19539e99b9d3 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 30 Nov 2023 19:31:18 +0100 Subject: [PATCH 07/14] tests --- src/datasets/builder.py | 2 +- src/datasets/data_files.py | 2 ++ tests/test_data_files.py | 33 +++++++++++++++++++++++++++++++++ 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index e24baf611a3..15f4a0b7d63 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -214,7 +214,7 @@ def create_config_id( return self.name def _resolve_data_files(self, base_path: str, download_config: DownloadConfig) -> None: - if self.data_files is not None and isinstance(self.data_files, DataFilesPatternsDict): + if isinstance(self.data_files, DataFilesPatternsDict): base_path = xjoin(base_path, self.data_dir) if self.data_dir else base_path self.data_files = self.data_files.resolve(base_path, download_config) diff --git a/src/datasets/data_files.py b/src/datasets/data_files.py index 02aa594ce42..622890bcccf 100644 --- a/src/datasets/data_files.py +++ b/src/datasets/data_files.py @@ -712,6 +712,8 @@ def filter_extensions(self, extensions: List[str]) -> "DataFilesDict": class DataFilesPatternsList(List[str]): """ List of data files patterns (absolute local paths or URLs). + For each pattern there should also be a list of allowed extensions + to keep, or a None ot keep all the files for the pattern. """ def __init__( diff --git a/tests/test_data_files.py b/tests/test_data_files.py index 01b5e4dd15e..468f5909f2f 100644 --- a/tests/test_data_files.py +++ b/tests/test_data_files.py @@ -12,6 +12,8 @@ from datasets.data_files import ( DataFilesDict, DataFilesList, + DataFilesPatternsDict, + DataFilesPatternsList, _get_data_files_patterns, _get_metadata_files_patterns, _is_inside_unrequested_special_dir, @@ -487,6 +489,37 @@ def test_DataFilesDict_from_patterns_locally_or_remote_hashing(text_file): assert Hasher.hash(data_files1) != Hasher.hash(data_files2) +def test_DataFilesPatternsList(text_file): + data_files_patterns = DataFilesPatternsList([str(text_file)], allowed_extensions=[None]) + data_files = data_files_patterns.resolve(base_path="") + assert data_files == [str(text_file)] + assert isinstance(data_files, DataFilesList) + data_files_patterns = DataFilesPatternsList([str(text_file)], allowed_extensions=[[".txt"]]) + data_files = data_files_patterns.resolve(base_path="") + assert data_files == [str(text_file)] + assert isinstance(data_files, DataFilesList) + data_files_patterns = DataFilesPatternsList([str(text_file).replace(".txt", ".tx*")], allowed_extensions=[None]) + data_files = data_files_patterns.resolve(base_path="") + assert data_files == [str(text_file)] + assert isinstance(data_files, DataFilesList) + data_files_patterns = DataFilesPatternsList([Path(text_file).name], allowed_extensions=[None]) + data_files = data_files_patterns.resolve(base_path=str(Path(text_file).parent)) + assert data_files == [str(text_file)] + data_files_patterns = DataFilesPatternsList([str(text_file)], allowed_extensions=[[".zip"]]) + with pytest.raises(FileNotFoundError): + data_files_patterns.resolve(base_path="") + + +def test_DataFilesPatternsDict(text_file): + data_files_patterns_dict = DataFilesPatternsDict( + {"train": DataFilesPatternsList([str(text_file)], allowed_extensions=[None])} + ) + data_files_dict = data_files_patterns_dict.resolve(base_path="") + assert data_files_dict == {"train": [str(text_file)]} + assert isinstance(data_files_dict, DataFilesDict) + assert isinstance(data_files_dict["train"], DataFilesList) + + def mock_fs(file_paths: List[str]): """ Set up a mock filesystem for fsspec containing the provided files From 796a47e388a5c5711a95bd649648608c18219ac5 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Thu, 30 Nov 2023 19:47:58 +0100 Subject: [PATCH 08/14] fix win test --- tests/test_data_files.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_data_files.py b/tests/test_data_files.py index 468f5909f2f..9ac8945c931 100644 --- a/tests/test_data_files.py +++ b/tests/test_data_files.py @@ -492,19 +492,19 @@ def test_DataFilesDict_from_patterns_locally_or_remote_hashing(text_file): def test_DataFilesPatternsList(text_file): data_files_patterns = DataFilesPatternsList([str(text_file)], allowed_extensions=[None]) data_files = data_files_patterns.resolve(base_path="") - assert data_files == [str(text_file)] + assert data_files == [text_file.as_posix()] assert isinstance(data_files, DataFilesList) data_files_patterns = DataFilesPatternsList([str(text_file)], allowed_extensions=[[".txt"]]) data_files = data_files_patterns.resolve(base_path="") - assert data_files == [str(text_file)] + assert data_files == [text_file.as_posix()] assert isinstance(data_files, DataFilesList) data_files_patterns = DataFilesPatternsList([str(text_file).replace(".txt", ".tx*")], allowed_extensions=[None]) data_files = data_files_patterns.resolve(base_path="") - assert data_files == [str(text_file)] + assert data_files == [text_file.as_posix()] assert isinstance(data_files, DataFilesList) data_files_patterns = DataFilesPatternsList([Path(text_file).name], allowed_extensions=[None]) data_files = data_files_patterns.resolve(base_path=str(Path(text_file).parent)) - assert data_files == [str(text_file)] + assert data_files == [text_file.as_posix()] data_files_patterns = DataFilesPatternsList([str(text_file)], allowed_extensions=[[".zip"]]) with pytest.raises(FileNotFoundError): data_files_patterns.resolve(base_path="") @@ -515,7 +515,7 @@ def test_DataFilesPatternsDict(text_file): {"train": DataFilesPatternsList([str(text_file)], allowed_extensions=[None])} ) data_files_dict = data_files_patterns_dict.resolve(base_path="") - assert data_files_dict == {"train": [str(text_file)]} + assert data_files_dict == {"train": [text_file.as_posix()]} assert isinstance(data_files_dict, DataFilesDict) assert isinstance(data_files_dict["train"], DataFilesList) From a924e94d4fc5071afdec324f6dd9b81f707d262e Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Fri, 1 Dec 2023 18:39:59 +0100 Subject: [PATCH 09/14] fix tests --- src/datasets/load.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/datasets/load.py b/src/datasets/load.py index 08a04a24661..06e4e296d56 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -1256,7 +1256,7 @@ def get_module(self) -> DatasetModule: builder_configs, default_config_name = None, None builder_kwargs = { "hash": hash, - "base_path": hf_hub_url(self.name, "", revision=self.revision).rstrip("/"), + "base_path": hf_hub_url(self.name, "", revision=revision).rstrip("/"), "repo_id": self.name, "dataset_name": camelcase_to_snakecase(Path(self.name).name), } @@ -1269,7 +1269,7 @@ def get_module(self) -> DatasetModule: try: # this file is deprecated and was created automatically in old versions of push_to_hub dataset_infos_path = cached_path( - hf_hub_url(self.name, config.DATASETDICT_INFOS_FILENAME, revision=self.revision), + hf_hub_url(self.name, config.DATASETDICT_INFOS_FILENAME, revision=revision), download_config=download_config, ) with open(dataset_infos_path, encoding="utf-8") as f: From c9ecfcaa3bf6ace3de327b983503660e5569e2fd Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Mon, 4 Dec 2023 12:05:14 +0100 Subject: [PATCH 10/14] fix tests again --- src/datasets/load.py | 1 - tests/test_load.py | 3 +++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/datasets/load.py b/src/datasets/load.py index 87d63e24ce8..19d54834199 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -1344,7 +1344,6 @@ def get_module(self) -> DatasetModule: module_path, metadata_configs, supports_metadata=False, - download_config=self.download_config, ) builder_kwargs = { "hash": hash, diff --git a/tests/test_load.py b/tests/test_load.py index 8f0c193a527..72f0b3e5ec5 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -798,6 +798,9 @@ def test_HubDatasetModuleFactoryWithParquetExport(self): assert module_factory_result.module_path == "datasets.packaged_modules.parquet.parquet" assert module_factory_result.builder_configs_parameters.builder_configs assert isinstance(module_factory_result.builder_configs_parameters.builder_configs[0], ParquetConfig) + module_factory_result.builder_configs_parameters.builder_configs[0]._resolve_data_files( + base_path="", download_config=self.download_config + ) assert module_factory_result.builder_configs_parameters.builder_configs[0].data_files == { "train": [ "hf://datasets/hf-internal-testing/dataset_with_script@da4ed81df5a1bcd916043c827b75994de8ef7eda/default/train/0000.parquet" From ddb488bd7888a788c625bb6bbe41fdb16bf67a17 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Mon, 4 Dec 2023 12:40:22 +0100 Subject: [PATCH 11/14] remove unused code --- src/datasets/data_files.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/datasets/data_files.py b/src/datasets/data_files.py index 622890bcccf..6d1a0e6aa95 100644 --- a/src/datasets/data_files.py +++ b/src/datasets/data_files.py @@ -340,15 +340,6 @@ def resolve_pattern( base_path = os.path.splitdrive(pattern)[0] + os.sep else: base_path = "" - if pattern.startswith(config.HF_ENDPOINT) and "/resolve/" in pattern: - pattern = "hf://" + pattern[len(config.HF_ENDPOINT) + 1 :].replace("/resolve/", "@", 1) - pattern = ( - pattern.split("@", 1)[0] - + "@" - + pattern.split("@", 1)[1].split("/", 1)[0].replace("%2F", "/") - + "/" - + pattern.split("@", 1)[1].split("/", 1)[1] - ) pattern, storage_options = _prepare_path_and_storage_options(pattern, download_config=download_config) fs, _, _ = get_fs_token_paths(pattern, storage_options=storage_options) fs_base_path = base_path.split("::")[0].split("://")[-1] or fs.root_marker From 565c294fc12bc547730a023a610ed4f92313d8fb Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Mon, 11 Dec 2023 12:38:32 +0100 Subject: [PATCH 12/14] fix cache on config change --- src/datasets/builder.py | 14 +++++++++++++- src/datasets/load.py | 23 ----------------------- tests/test_builder.py | 27 +++++++++++++++++++++++++++ 3 files changed, 40 insertions(+), 24 deletions(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 15f4a0b7d63..369cd72d43b 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -218,6 +218,16 @@ def _resolve_data_files(self, base_path: str, download_config: DownloadConfig) - base_path = xjoin(base_path, self.data_dir) if self.data_dir else base_path self.data_files = self.data_files.resolve(base_path, download_config) + def _update_hash(self, hash: str) -> str: + """ + Used to update hash of packaged modules which is used for creating unique cache directories to reflect + different config parameters which are passed in metadata from readme. + """ + m = Hasher() + m.update(hash) + m.update(self) + return m.hexdigest() + class DatasetBuilder: """Abstract base class for all datasets. @@ -381,6 +391,8 @@ def __init__( custom_features=features, **config_kwargs, ) + if self.config.name in self.builder_configs: + self.hash = self.builder_configs[self.config.name]._update_hash(hash) # prepare info: DatasetInfo are a standardized dataclass across all datasets # Prefill datasetinfo @@ -601,7 +613,7 @@ def _create_builder_config( @classproperty @classmethod @memoize() - def builder_configs(cls): + def builder_configs(cls) -> Dict[str, BuilderConfig]: """Dictionary of pre-defined configurations for this builder class.""" configs = {config.name: config for config in cls.BUILDER_CONFIGS} if len(configs) != len(cls.BUILDER_CONFIGS): diff --git a/src/datasets/load.py b/src/datasets/load.py index 19d54834199..8270a4d23ff 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -53,7 +53,6 @@ from .download.streaming_download_manager import StreamingDownloadManager, xbasename, xglob, xjoin from .exceptions import DataFilesNotFoundError, DatasetNotFoundError from .features import Features -from .fingerprint import Hasher from .info import DatasetInfo, DatasetInfosDict from .iterable_dataset import IterableDataset from .metric import Metric @@ -591,21 +590,6 @@ def infer_module_for_data_files( return module_name, default_builder_kwargs -def update_hash_with_config_parameters(hash: str, config_parameters: dict) -> str: - """ - Used to update hash of packaged modules which is used for creating unique cache directories to reflect - different config parameters which are passed in metadata from readme. - """ - params_to_exclude = {"config_name", "version", "description"} - params_to_add_to_hash = { - param: value for param, value in sorted(config_parameters.items()) if param not in params_to_exclude - } - m = Hasher() - m.update(hash) - m.update(params_to_add_to_hash) - return m.hexdigest() - - def create_builder_configs_from_metadata_configs( module_path: str, metadata_configs: MetadataConfigs, @@ -2178,13 +2162,6 @@ def load_dataset_builder( dataset_name = builder_kwargs.pop("dataset_name", None) hash = builder_kwargs.pop("hash") info = dataset_module.dataset_infos.get(config_name) if dataset_module.dataset_infos else None - if ( - dataset_module.builder_configs_parameters.metadata_configs - and config_name in dataset_module.builder_configs_parameters.metadata_configs - ): - hash = update_hash_with_config_parameters( - hash, dataset_module.builder_configs_parameters.metadata_configs[config_name] - ) if path in _PACKAGED_DATASETS_MODULES and data_files is None: error_msg = f"Please specify the data files or data directory to load for the {path} dataset builder." diff --git a/tests/test_builder.py b/tests/test_builder.py index 54bae47ae08..feaa27ce8a6 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -23,6 +23,7 @@ from datasets.features import Features, Value from datasets.info import DatasetInfo, PostProcessedInfo from datasets.iterable_dataset import IterableDataset +from datasets.load import configure_builder_class from datasets.splits import Split, SplitDict, SplitGenerator, SplitInfo from datasets.streaming import xjoin from datasets.utils.file_utils import is_local_path @@ -830,6 +831,32 @@ def test_cache_dir_for_data_dir(self): other_builder = DummyBuilderWithManualDownload(cache_dir=tmp_dir, config_name="a", data_dir=tmp_dir) self.assertNotEqual(builder.cache_dir, other_builder.cache_dir) + def test_cache_dir_for_configured_builder(self): + with tempfile.TemporaryDirectory() as tmp_dir, tempfile.TemporaryDirectory() as data_dir: + builder_cls = configure_builder_class( + DummyBuilderWithManualDownload, + builder_configs=[BuilderConfig(data_dir=data_dir)], + default_config_name=None, + dataset_name="dummy", + ) + other_builder_cls = configure_builder_class( + DummyBuilderWithManualDownload, + builder_configs=[BuilderConfig(data_dir=data_dir)], + default_config_name=None, + dataset_name="dummy", + ) + builder = builder_cls(cache_dir=tmp_dir) + other_builder = other_builder_cls(cache_dir=tmp_dir) + self.assertEqual(builder.cache_dir, other_builder.cache_dir) + other_builder_cls = configure_builder_class( + DummyBuilderWithManualDownload, + builder_configs=[BuilderConfig(data_dir=tmp_dir)], + default_config_name=None, + dataset_name="dummy", + ) + other_builder = other_builder_cls(cache_dir=tmp_dir) + self.assertNotEqual(builder.cache_dir, other_builder.cache_dir) + def test_arrow_based_download_and_prepare(tmp_path): builder = DummyArrowBasedBuilder(cache_dir=tmp_path) From e01938d23b0ed7f0170d815dde4cf75eac34a784 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Mon, 11 Dec 2023 12:45:15 +0100 Subject: [PATCH 13/14] simpler --- src/datasets/builder.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 369cd72d43b..118e589eac3 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -391,8 +391,7 @@ def __init__( custom_features=features, **config_kwargs, ) - if self.config.name in self.builder_configs: - self.hash = self.builder_configs[self.config.name]._update_hash(hash) + self.hash = self.config._update_hash(hash) # prepare info: DatasetInfo are a standardized dataclass across all datasets # Prefill datasetinfo From 09516db55ec5f6fd33f1eeaac973a02f3f2bf58d Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Mon, 11 Dec 2023 13:05:36 +0100 Subject: [PATCH 14/14] fix tests --- src/datasets/builder.py | 11 ++++++++--- tests/test_builder.py | 6 +++--- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 118e589eac3..9412dadfb2e 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -26,7 +26,7 @@ import time import urllib import warnings -from dataclasses import dataclass +from dataclasses import dataclass, fields from functools import partial from pathlib import Path from typing import Dict, Iterable, Mapping, Optional, Tuple, Union @@ -225,7 +225,11 @@ def _update_hash(self, hash: str) -> str: """ m = Hasher() m.update(hash) - m.update(self) + for field in sorted(field.name for field in fields(self)): + try: + m.update(getattr(self, field)) + except TypeError: + continue return m.hexdigest() @@ -391,7 +395,8 @@ def __init__( custom_features=features, **config_kwargs, ) - self.hash = self.config._update_hash(hash) + if self.hash is not None: + self.hash = self.config._update_hash(hash) # prepare info: DatasetInfo are a standardized dataclass across all datasets # Prefill datasetinfo diff --git a/tests/test_builder.py b/tests/test_builder.py index feaa27ce8a6..5e89235fc3a 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -845,8 +845,8 @@ def test_cache_dir_for_configured_builder(self): default_config_name=None, dataset_name="dummy", ) - builder = builder_cls(cache_dir=tmp_dir) - other_builder = other_builder_cls(cache_dir=tmp_dir) + builder = builder_cls(cache_dir=tmp_dir, hash="abc") + other_builder = other_builder_cls(cache_dir=tmp_dir, hash="abc") self.assertEqual(builder.cache_dir, other_builder.cache_dir) other_builder_cls = configure_builder_class( DummyBuilderWithManualDownload, @@ -854,7 +854,7 @@ def test_cache_dir_for_configured_builder(self): default_config_name=None, dataset_name="dummy", ) - other_builder = other_builder_cls(cache_dir=tmp_dir) + other_builder = other_builder_cls(cache_dir=tmp_dir, hash="abc") self.assertNotEqual(builder.cache_dir, other_builder.cache_dir)