Skip to content

Commit 3b66daa

Browse files
Fix wrong type hints in data_files (#6910)
* Fix wrong type hints in data_files * Add missing type hints in DataFilesList * Inverse condition with negation * Fix type hint of origin metadata * Add missing arg to resolve_pattern docstring * Minor fix docstrings * Remove TypeAlias
1 parent 60d21ef commit 3b66daa

File tree

1 file changed

+28
-24
lines changed

1 file changed

+28
-24
lines changed

src/datasets/data_files.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
from .utils.py_utils import glob_pattern_to_regex, string_to_dict
2323

2424

25+
SingleOriginMetadata = Union[Tuple[str, str], Tuple[str], Tuple[()]]
26+
27+
2528
SANITIZED_DEFAULT_SPLIT = str(Split.TRAIN)
2629

2730

@@ -361,6 +364,7 @@ def resolve_pattern(
361364
base_path (str): Base path to use when resolving relative paths.
362365
allowed_extensions (Optional[list], optional): White-list of file extensions to use. Defaults to None (all extensions).
363366
For example: allowed_extensions=[".csv", ".json", ".txt", ".parquet"]
367+
download_config ([`DownloadConfig`], *optional*): Specific download configuration parameters.
364368
Returns:
365369
List[str]: List of paths or URLs to the local or remote files that match the patterns.
366370
"""
@@ -516,17 +520,17 @@ def get_metadata_patterns(
516520
def _get_single_origin_metadata(
517521
data_file: str,
518522
download_config: Optional[DownloadConfig] = None,
519-
) -> Tuple[str]:
523+
) -> SingleOriginMetadata:
520524
data_file, storage_options = _prepare_path_and_storage_options(data_file, download_config=download_config)
521525
fs, *_ = url_to_fs(data_file, **storage_options)
522526
if isinstance(fs, HfFileSystem):
523527
resolved_path = fs.resolve_path(data_file)
524-
return (resolved_path.repo_id, resolved_path.revision)
528+
return resolved_path.repo_id, resolved_path.revision
525529
elif isinstance(fs, HTTPFileSystem) and data_file.startswith(config.HF_ENDPOINT):
526530
hffs = HfFileSystem(endpoint=config.HF_ENDPOINT, token=download_config.token)
527531
data_file = "hf://" + data_file[len(config.HF_ENDPOINT) + 1 :].replace("/resolve/", "@", 1)
528532
resolved_path = hffs.resolve_path(data_file)
529-
return (resolved_path.repo_id, resolved_path.revision)
533+
return resolved_path.repo_id, resolved_path.revision
530534
info = fs.info(data_file)
531535
# s3fs uses "ETag", gcsfs uses "etag", and for local we simply check mtime
532536
for key in ["ETag", "etag", "mtime"]:
@@ -539,7 +543,7 @@ def _get_origin_metadata(
539543
data_files: List[str],
540544
download_config: Optional[DownloadConfig] = None,
541545
max_workers: Optional[int] = None,
542-
) -> Tuple[str]:
546+
) -> List[SingleOriginMetadata]:
543547
max_workers = max_workers if max_workers is not None else config.HF_DATASETS_MULTITHREADING_MAX_WORKERS
544548
return thread_map(
545549
partial(_get_single_origin_metadata, download_config=download_config),
@@ -555,11 +559,11 @@ def _get_origin_metadata(
555559
class DataFilesList(List[str]):
556560
"""
557561
List of data files (absolute local paths or URLs).
558-
It has two construction methods given the user's data files patterns :
562+
It has two construction methods given the user's data files patterns:
559563
- ``from_hf_repo``: resolve patterns inside a dataset repository
560564
- ``from_local_or_remote``: resolve patterns from a local path
561565
562-
Moreover DataFilesList has an additional attribute ``origin_metadata``.
566+
Moreover, DataFilesList has an additional attribute ``origin_metadata``.
563567
It can store:
564568
- the last modified time of local files
565569
- ETag of remote files
@@ -570,11 +574,11 @@ class DataFilesList(List[str]):
570574
This is useful for caching Dataset objects that are obtained from a list of data files.
571575
"""
572576

573-
def __init__(self, data_files: List[str], origin_metadata: List[Tuple[str]]):
577+
def __init__(self, data_files: List[str], origin_metadata: List[SingleOriginMetadata]) -> None:
574578
super().__init__(data_files)
575579
self.origin_metadata = origin_metadata
576580

577-
def __add__(self, other):
581+
def __add__(self, other: "DataFilesList") -> "DataFilesList":
578582
return DataFilesList([*self, *other], self.origin_metadata + other.origin_metadata)
579583

580584
@classmethod
@@ -646,9 +650,9 @@ class DataFilesDict(Dict[str, DataFilesList]):
646650
- ``from_hf_repo``: resolve patterns inside a dataset repository
647651
- ``from_local_or_remote``: resolve patterns from a local path
648652
649-
Moreover each list is a DataFilesList. It is possible to hash the dictionary
653+
Moreover, each list is a DataFilesList. It is possible to hash the dictionary
650654
and get a different hash if and only if at least one file changed.
651-
For more info, see ``DataFilesList``.
655+
For more info, see [`DataFilesList`].
652656
653657
This is useful for caching Dataset objects that are obtained from a list of data files.
654658
@@ -666,14 +670,14 @@ def from_local_or_remote(
666670
out = cls()
667671
for key, patterns_for_key in patterns.items():
668672
out[key] = (
669-
DataFilesList.from_local_or_remote(
673+
patterns_for_key
674+
if isinstance(patterns_for_key, DataFilesList)
675+
else DataFilesList.from_local_or_remote(
670676
patterns_for_key,
671677
base_path=base_path,
672678
allowed_extensions=allowed_extensions,
673679
download_config=download_config,
674680
)
675-
if not isinstance(patterns_for_key, DataFilesList)
676-
else patterns_for_key
677681
)
678682
return out
679683

@@ -689,15 +693,15 @@ def from_hf_repo(
689693
out = cls()
690694
for key, patterns_for_key in patterns.items():
691695
out[key] = (
692-
DataFilesList.from_hf_repo(
696+
patterns_for_key
697+
if isinstance(patterns_for_key, DataFilesList)
698+
else DataFilesList.from_hf_repo(
693699
patterns_for_key,
694700
dataset_info=dataset_info,
695701
base_path=base_path,
696702
allowed_extensions=allowed_extensions,
697703
download_config=download_config,
698704
)
699-
if not isinstance(patterns_for_key, DataFilesList)
700-
else patterns_for_key
701705
)
702706
return out
703707

@@ -712,14 +716,14 @@ def from_patterns(
712716
out = cls()
713717
for key, patterns_for_key in patterns.items():
714718
out[key] = (
715-
DataFilesList.from_patterns(
719+
patterns_for_key
720+
if isinstance(patterns_for_key, DataFilesList)
721+
else DataFilesList.from_patterns(
716722
patterns_for_key,
717723
base_path=base_path,
718724
allowed_extensions=allowed_extensions,
719725
download_config=download_config,
720726
)
721-
if not isinstance(patterns_for_key, DataFilesList)
722-
else patterns_for_key
723727
)
724728
return out
725729

@@ -751,7 +755,7 @@ def __add__(self, other):
751755
@classmethod
752756
def from_patterns(
753757
cls, patterns: List[str], allowed_extensions: Optional[List[str]] = None
754-
) -> "DataFilesPatternsDict":
758+
) -> "DataFilesPatternsList":
755759
return cls(patterns, [allowed_extensions] * len(patterns))
756760

757761
def resolve(
@@ -777,7 +781,7 @@ def resolve(
777781
origin_metadata = _get_origin_metadata(data_files, download_config=download_config)
778782
return DataFilesList(data_files, origin_metadata)
779783

780-
def filter_extensions(self, extensions: List[str]) -> "DataFilesList":
784+
def filter_extensions(self, extensions: List[str]) -> "DataFilesPatternsList":
781785
return DataFilesPatternsList(
782786
self, [allowed_extensions + extensions for allowed_extensions in self.allowed_extensions]
783787
)
@@ -795,12 +799,12 @@ def from_patterns(
795799
out = cls()
796800
for key, patterns_for_key in patterns.items():
797801
out[key] = (
798-
DataFilesPatternsList.from_patterns(
802+
patterns_for_key
803+
if isinstance(patterns_for_key, DataFilesPatternsList)
804+
else DataFilesPatternsList.from_patterns(
799805
patterns_for_key,
800806
allowed_extensions=allowed_extensions,
801807
)
802-
if not isinstance(patterns_for_key, DataFilesPatternsList)
803-
else patterns_for_key
804808
)
805809
return out
806810

0 commit comments

Comments
 (0)