diff --git a/src/datasets/data_files.py b/src/datasets/data_files.py index 75fee776e5a..793c6ed8115 100644 --- a/src/datasets/data_files.py +++ b/src/datasets/data_files.py @@ -22,6 +22,9 @@ from .utils.py_utils import glob_pattern_to_regex, string_to_dict +SingleOriginMetadata = Union[Tuple[str, str], Tuple[str], Tuple[()]] + + SANITIZED_DEFAULT_SPLIT = str(Split.TRAIN) @@ -361,6 +364,7 @@ def resolve_pattern( base_path (str): Base path to use when resolving relative paths. allowed_extensions (Optional[list], optional): White-list of file extensions to use. Defaults to None (all extensions). For example: allowed_extensions=[".csv", ".json", ".txt", ".parquet"] + download_config ([`DownloadConfig`], *optional*): Specific download configuration parameters. Returns: List[str]: List of paths or URLs to the local or remote files that match the patterns. """ @@ -516,17 +520,17 @@ def get_metadata_patterns( def _get_single_origin_metadata( data_file: str, download_config: Optional[DownloadConfig] = None, -) -> Tuple[str]: +) -> SingleOriginMetadata: data_file, storage_options = _prepare_path_and_storage_options(data_file, download_config=download_config) fs, *_ = url_to_fs(data_file, **storage_options) if isinstance(fs, HfFileSystem): resolved_path = fs.resolve_path(data_file) - return (resolved_path.repo_id, resolved_path.revision) + return resolved_path.repo_id, resolved_path.revision elif isinstance(fs, HTTPFileSystem) and data_file.startswith(config.HF_ENDPOINT): hffs = HfFileSystem(endpoint=config.HF_ENDPOINT, token=download_config.token) data_file = "hf://" + data_file[len(config.HF_ENDPOINT) + 1 :].replace("/resolve/", "@", 1) resolved_path = hffs.resolve_path(data_file) - return (resolved_path.repo_id, resolved_path.revision) + return resolved_path.repo_id, resolved_path.revision info = fs.info(data_file) # s3fs uses "ETag", gcsfs uses "etag", and for local we simply check mtime for key in ["ETag", "etag", "mtime"]: @@ -539,7 +543,7 @@ def _get_origin_metadata( data_files: List[str], download_config: Optional[DownloadConfig] = None, max_workers: Optional[int] = None, -) -> Tuple[str]: +) -> List[SingleOriginMetadata]: max_workers = max_workers if max_workers is not None else config.HF_DATASETS_MULTITHREADING_MAX_WORKERS return thread_map( partial(_get_single_origin_metadata, download_config=download_config), @@ -555,11 +559,11 @@ def _get_origin_metadata( class DataFilesList(List[str]): """ List of data files (absolute local paths or URLs). - It has two construction methods given the user's data files patterns : + It has two construction methods given the user's data files patterns: - ``from_hf_repo``: resolve patterns inside a dataset repository - ``from_local_or_remote``: resolve patterns from a local path - Moreover DataFilesList has an additional attribute ``origin_metadata``. + Moreover, DataFilesList has an additional attribute ``origin_metadata``. It can store: - the last modified time of local files - ETag of remote files @@ -570,11 +574,11 @@ class DataFilesList(List[str]): This is useful for caching Dataset objects that are obtained from a list of data files. """ - def __init__(self, data_files: List[str], origin_metadata: List[Tuple[str]]): + def __init__(self, data_files: List[str], origin_metadata: List[SingleOriginMetadata]) -> None: super().__init__(data_files) self.origin_metadata = origin_metadata - def __add__(self, other): + def __add__(self, other: "DataFilesList") -> "DataFilesList": return DataFilesList([*self, *other], self.origin_metadata + other.origin_metadata) @classmethod @@ -646,9 +650,9 @@ class DataFilesDict(Dict[str, DataFilesList]): - ``from_hf_repo``: resolve patterns inside a dataset repository - ``from_local_or_remote``: resolve patterns from a local path - Moreover each list is a DataFilesList. It is possible to hash the dictionary + Moreover, each list is a DataFilesList. It is possible to hash the dictionary and get a different hash if and only if at least one file changed. - For more info, see ``DataFilesList``. + For more info, see [`DataFilesList`]. This is useful for caching Dataset objects that are obtained from a list of data files. @@ -666,14 +670,14 @@ def from_local_or_remote( out = cls() for key, patterns_for_key in patterns.items(): out[key] = ( - DataFilesList.from_local_or_remote( + patterns_for_key + if isinstance(patterns_for_key, DataFilesList) + else DataFilesList.from_local_or_remote( patterns_for_key, base_path=base_path, allowed_extensions=allowed_extensions, download_config=download_config, ) - if not isinstance(patterns_for_key, DataFilesList) - else patterns_for_key ) return out @@ -689,15 +693,15 @@ def from_hf_repo( out = cls() for key, patterns_for_key in patterns.items(): out[key] = ( - DataFilesList.from_hf_repo( + patterns_for_key + if isinstance(patterns_for_key, DataFilesList) + else DataFilesList.from_hf_repo( patterns_for_key, dataset_info=dataset_info, base_path=base_path, allowed_extensions=allowed_extensions, download_config=download_config, ) - if not isinstance(patterns_for_key, DataFilesList) - else patterns_for_key ) return out @@ -712,14 +716,14 @@ def from_patterns( out = cls() for key, patterns_for_key in patterns.items(): out[key] = ( - DataFilesList.from_patterns( + patterns_for_key + if isinstance(patterns_for_key, DataFilesList) + else DataFilesList.from_patterns( patterns_for_key, base_path=base_path, allowed_extensions=allowed_extensions, download_config=download_config, ) - if not isinstance(patterns_for_key, DataFilesList) - else patterns_for_key ) return out @@ -751,7 +755,7 @@ def __add__(self, other): @classmethod def from_patterns( cls, patterns: List[str], allowed_extensions: Optional[List[str]] = None - ) -> "DataFilesPatternsDict": + ) -> "DataFilesPatternsList": return cls(patterns, [allowed_extensions] * len(patterns)) def resolve( @@ -777,7 +781,7 @@ def resolve( origin_metadata = _get_origin_metadata(data_files, download_config=download_config) return DataFilesList(data_files, origin_metadata) - def filter_extensions(self, extensions: List[str]) -> "DataFilesList": + def filter_extensions(self, extensions: List[str]) -> "DataFilesPatternsList": return DataFilesPatternsList( self, [allowed_extensions + extensions for allowed_extensions in self.allowed_extensions] ) @@ -795,12 +799,12 @@ def from_patterns( out = cls() for key, patterns_for_key in patterns.items(): out[key] = ( - DataFilesPatternsList.from_patterns( + patterns_for_key + if isinstance(patterns_for_key, DataFilesPatternsList) + else DataFilesPatternsList.from_patterns( patterns_for_key, allowed_extensions=allowed_extensions, ) - if not isinstance(patterns_for_key, DataFilesPatternsList) - else patterns_for_key ) return out