2222from .utils .py_utils import glob_pattern_to_regex , string_to_dict
2323
2424
25+ SingleOriginMetadata = Union [Tuple [str , str ], Tuple [str ], Tuple [()]]
26+
27+
2528SANITIZED_DEFAULT_SPLIT = str (Split .TRAIN )
2629
2730
@@ -361,6 +364,7 @@ def resolve_pattern(
361364 base_path (str): Base path to use when resolving relative paths.
362365 allowed_extensions (Optional[list], optional): White-list of file extensions to use. Defaults to None (all extensions).
363366 For example: allowed_extensions=[".csv", ".json", ".txt", ".parquet"]
367+ download_config ([`DownloadConfig`], *optional*): Specific download configuration parameters.
364368 Returns:
365369 List[str]: List of paths or URLs to the local or remote files that match the patterns.
366370 """
@@ -516,17 +520,17 @@ def get_metadata_patterns(
516520def _get_single_origin_metadata (
517521 data_file : str ,
518522 download_config : Optional [DownloadConfig ] = None ,
519- ) -> Tuple [ str ] :
523+ ) -> SingleOriginMetadata :
520524 data_file , storage_options = _prepare_path_and_storage_options (data_file , download_config = download_config )
521525 fs , * _ = url_to_fs (data_file , ** storage_options )
522526 if isinstance (fs , HfFileSystem ):
523527 resolved_path = fs .resolve_path (data_file )
524- return ( resolved_path .repo_id , resolved_path .revision )
528+ return resolved_path .repo_id , resolved_path .revision
525529 elif isinstance (fs , HTTPFileSystem ) and data_file .startswith (config .HF_ENDPOINT ):
526530 hffs = HfFileSystem (endpoint = config .HF_ENDPOINT , token = download_config .token )
527531 data_file = "hf://" + data_file [len (config .HF_ENDPOINT ) + 1 :].replace ("/resolve/" , "@" , 1 )
528532 resolved_path = hffs .resolve_path (data_file )
529- return ( resolved_path .repo_id , resolved_path .revision )
533+ return resolved_path .repo_id , resolved_path .revision
530534 info = fs .info (data_file )
531535 # s3fs uses "ETag", gcsfs uses "etag", and for local we simply check mtime
532536 for key in ["ETag" , "etag" , "mtime" ]:
@@ -539,7 +543,7 @@ def _get_origin_metadata(
539543 data_files : List [str ],
540544 download_config : Optional [DownloadConfig ] = None ,
541545 max_workers : Optional [int ] = None ,
542- ) -> Tuple [ str ]:
546+ ) -> List [ SingleOriginMetadata ]:
543547 max_workers = max_workers if max_workers is not None else config .HF_DATASETS_MULTITHREADING_MAX_WORKERS
544548 return thread_map (
545549 partial (_get_single_origin_metadata , download_config = download_config ),
@@ -555,11 +559,11 @@ def _get_origin_metadata(
555559class DataFilesList (List [str ]):
556560 """
557561 List of data files (absolute local paths or URLs).
558- It has two construction methods given the user's data files patterns :
562+ It has two construction methods given the user's data files patterns:
559563 - ``from_hf_repo``: resolve patterns inside a dataset repository
560564 - ``from_local_or_remote``: resolve patterns from a local path
561565
562- Moreover DataFilesList has an additional attribute ``origin_metadata``.
566+ Moreover, DataFilesList has an additional attribute ``origin_metadata``.
563567 It can store:
564568 - the last modified time of local files
565569 - ETag of remote files
@@ -570,11 +574,11 @@ class DataFilesList(List[str]):
570574 This is useful for caching Dataset objects that are obtained from a list of data files.
571575 """
572576
573- def __init__ (self , data_files : List [str ], origin_metadata : List [Tuple [ str ]]) :
577+ def __init__ (self , data_files : List [str ], origin_metadata : List [SingleOriginMetadata ]) -> None :
574578 super ().__init__ (data_files )
575579 self .origin_metadata = origin_metadata
576580
577- def __add__ (self , other ) :
581+ def __add__ (self , other : "DataFilesList" ) -> "DataFilesList" :
578582 return DataFilesList ([* self , * other ], self .origin_metadata + other .origin_metadata )
579583
580584 @classmethod
@@ -646,9 +650,9 @@ class DataFilesDict(Dict[str, DataFilesList]):
646650 - ``from_hf_repo``: resolve patterns inside a dataset repository
647651 - ``from_local_or_remote``: resolve patterns from a local path
648652
649- Moreover each list is a DataFilesList. It is possible to hash the dictionary
653+ Moreover, each list is a DataFilesList. It is possible to hash the dictionary
650654 and get a different hash if and only if at least one file changed.
651- For more info, see `` DataFilesList`` .
655+ For more info, see [` DataFilesList`] .
652656
653657 This is useful for caching Dataset objects that are obtained from a list of data files.
654658
@@ -666,14 +670,14 @@ def from_local_or_remote(
666670 out = cls ()
667671 for key , patterns_for_key in patterns .items ():
668672 out [key ] = (
669- DataFilesList .from_local_or_remote (
673+ patterns_for_key
674+ if isinstance (patterns_for_key , DataFilesList )
675+ else DataFilesList .from_local_or_remote (
670676 patterns_for_key ,
671677 base_path = base_path ,
672678 allowed_extensions = allowed_extensions ,
673679 download_config = download_config ,
674680 )
675- if not isinstance (patterns_for_key , DataFilesList )
676- else patterns_for_key
677681 )
678682 return out
679683
@@ -689,15 +693,15 @@ def from_hf_repo(
689693 out = cls ()
690694 for key , patterns_for_key in patterns .items ():
691695 out [key ] = (
692- DataFilesList .from_hf_repo (
696+ patterns_for_key
697+ if isinstance (patterns_for_key , DataFilesList )
698+ else DataFilesList .from_hf_repo (
693699 patterns_for_key ,
694700 dataset_info = dataset_info ,
695701 base_path = base_path ,
696702 allowed_extensions = allowed_extensions ,
697703 download_config = download_config ,
698704 )
699- if not isinstance (patterns_for_key , DataFilesList )
700- else patterns_for_key
701705 )
702706 return out
703707
@@ -712,14 +716,14 @@ def from_patterns(
712716 out = cls ()
713717 for key , patterns_for_key in patterns .items ():
714718 out [key ] = (
715- DataFilesList .from_patterns (
719+ patterns_for_key
720+ if isinstance (patterns_for_key , DataFilesList )
721+ else DataFilesList .from_patterns (
716722 patterns_for_key ,
717723 base_path = base_path ,
718724 allowed_extensions = allowed_extensions ,
719725 download_config = download_config ,
720726 )
721- if not isinstance (patterns_for_key , DataFilesList )
722- else patterns_for_key
723727 )
724728 return out
725729
@@ -751,7 +755,7 @@ def __add__(self, other):
751755 @classmethod
752756 def from_patterns (
753757 cls , patterns : List [str ], allowed_extensions : Optional [List [str ]] = None
754- ) -> "DataFilesPatternsDict " :
758+ ) -> "DataFilesPatternsList " :
755759 return cls (patterns , [allowed_extensions ] * len (patterns ))
756760
757761 def resolve (
@@ -777,7 +781,7 @@ def resolve(
777781 origin_metadata = _get_origin_metadata (data_files , download_config = download_config )
778782 return DataFilesList (data_files , origin_metadata )
779783
780- def filter_extensions (self , extensions : List [str ]) -> "DataFilesList " :
784+ def filter_extensions (self , extensions : List [str ]) -> "DataFilesPatternsList " :
781785 return DataFilesPatternsList (
782786 self , [allowed_extensions + extensions for allowed_extensions in self .allowed_extensions ]
783787 )
@@ -795,12 +799,12 @@ def from_patterns(
795799 out = cls ()
796800 for key , patterns_for_key in patterns .items ():
797801 out [key ] = (
798- DataFilesPatternsList .from_patterns (
802+ patterns_for_key
803+ if isinstance (patterns_for_key , DataFilesPatternsList )
804+ else DataFilesPatternsList .from_patterns (
799805 patterns_for_key ,
800806 allowed_extensions = allowed_extensions ,
801807 )
802- if not isinstance (patterns_for_key , DataFilesPatternsList )
803- else patterns_for_key
804808 )
805809 return out
806810
0 commit comments