77import warnings
88import zipfile
99from abc import ABC , abstractmethod
10+ from typing import TYPE_CHECKING , Optional , Union
1011
1112from .. import config
1213from .filelock import FileLock
1314
1415
16+ if TYPE_CHECKING :
17+ import pathlib
18+
19+
1520class ExtractManager :
16- def __init__ (self , cache_dir = None ):
21+ def __init__ (self , cache_dir : Optional [ str ] = None ):
1722 self .extract_dir = (
1823 os .path .join (cache_dir , config .EXTRACTED_DATASETS_DIR ) if cache_dir else config .EXTRACTED_DATASETS_PATH
1924 )
2025 self .extractor = Extractor
2126
22- def _get_output_path (self , path ) :
27+ def _get_output_path (self , path : str ) -> str :
2328 from .file_utils import hash_url_to_filename
2429
2530 # Path where we extract compressed archives
2631 # We extract in the cache dir, and get the extracted path name by hashing the original path"
2732 abs_path = os .path .abspath (path )
2833 return os .path .join (self .extract_dir , hash_url_to_filename (abs_path ))
2934
30- def _do_extract (self , output_path , force_extract ) :
35+ def _do_extract (self , output_path : str , force_extract : bool ) -> bool :
3136 return force_extract or (
3237 not os .path .isfile (output_path ) and not (os .path .isdir (output_path ) and os .listdir (output_path ))
3338 )
3439
35- def extract (self , input_path , force_extract = False ):
40+ def extract (self , input_path : str , force_extract : bool = False ) -> str :
3641 extractor_format = self .extractor .infer_extractor_format (input_path )
3742 if not extractor_format :
3843 return input_path
@@ -45,25 +50,25 @@ def extract(self, input_path, force_extract=False):
4550class BaseExtractor (ABC ):
4651 @classmethod
4752 @abstractmethod
48- def is_extractable (cls , path : str , ** kwargs ) -> bool :
53+ def is_extractable (cls , path : Union [ "pathlib.Path" , str ] , ** kwargs ) -> bool :
4954 ...
5055
5156 @staticmethod
5257 @abstractmethod
53- def extract (input_path : str , output_path : str ) -> None :
58+ def extract (input_path : Union [ "pathlib.Path" , str ] , output_path : Union [ "pathlib.Path" , str ] ) -> None :
5459 ...
5560
5661
5762class MagicNumberBaseExtractor (BaseExtractor , ABC ):
5863 magic_number = b""
5964
6065 @staticmethod
61- def read_magic_number (path : str , magic_number_length : int ):
66+ def read_magic_number (path : Union [ "pathlib.Path" , str ] , magic_number_length : int ):
6267 with open (path , "rb" ) as f :
6368 return f .read (magic_number_length )
6469
6570 @classmethod
66- def is_extractable (cls , path : str , magic_number : bytes = b"" ) -> bool :
71+ def is_extractable (cls , path : Union [ "pathlib.Path" , str ] , magic_number : bytes = b"" ) -> bool :
6772 if not magic_number :
6873 try :
6974 magic_number = cls .read_magic_number (path , len (cls .magic_number ))
@@ -74,11 +79,11 @@ def is_extractable(cls, path: str, magic_number: bytes = b"") -> bool:
7479
7580class TarExtractor (BaseExtractor ):
7681 @classmethod
77- def is_extractable (cls , path : str , ** kwargs ) -> bool :
82+ def is_extractable (cls , path : Union [ "pathlib.Path" , str ] , ** kwargs ) -> bool :
7883 return tarfile .is_tarfile (path )
7984
8085 @staticmethod
81- def extract (input_path : str , output_path : str ) -> None :
86+ def extract (input_path : Union [ "pathlib.Path" , str ] , output_path : Union [ "pathlib.Path" , str ] ) -> None :
8287 os .makedirs (output_path , exist_ok = True )
8388 tar_file = tarfile .open (input_path )
8489 tar_file .extractall (output_path )
@@ -89,19 +94,19 @@ class GzipExtractor(MagicNumberBaseExtractor):
8994 magic_number = b"\x1F \x8B "
9095
9196 @staticmethod
92- def extract (input_path : str , output_path : str ) -> None :
97+ def extract (input_path : Union [ "pathlib.Path" , str ] , output_path : Union [ "pathlib.Path" , str ] ) -> None :
9398 with gzip .open (input_path , "rb" ) as gzip_file :
9499 with open (output_path , "wb" ) as extracted_file :
95100 shutil .copyfileobj (gzip_file , extracted_file )
96101
97102
98103class ZipExtractor (BaseExtractor ):
99104 @classmethod
100- def is_extractable (cls , path : str , ** kwargs ) -> bool :
105+ def is_extractable (cls , path : Union [ "pathlib.Path" , str ] , ** kwargs ) -> bool :
101106 return zipfile .is_zipfile (path )
102107
103108 @staticmethod
104- def extract (input_path : str , output_path : str ) -> None :
109+ def extract (input_path : Union [ "pathlib.Path" , str ] , output_path : Union [ "pathlib.Path" , str ] ) -> None :
105110 os .makedirs (output_path , exist_ok = True )
106111 with zipfile .ZipFile (input_path , "r" ) as zip_file :
107112 zip_file .extractall (output_path )
@@ -112,7 +117,7 @@ class XzExtractor(MagicNumberBaseExtractor):
112117 magic_number = b"\xFD \x37 \x7A \x58 \x5A \x00 "
113118
114119 @staticmethod
115- def extract (input_path : str , output_path : str ) -> None :
120+ def extract (input_path : Union [ "pathlib.Path" , str ] , output_path : Union [ "pathlib.Path" , str ] ) -> None :
116121 with lzma .open (input_path ) as compressed_file :
117122 with open (output_path , "wb" ) as extracted_file :
118123 shutil .copyfileobj (compressed_file , extracted_file )
@@ -123,14 +128,14 @@ class RarExtractor(BaseExtractor):
123128 RAR5_ID = b"Rar!\x1a \x07 \x01 \x00 "
124129
125130 @classmethod
126- def is_extractable (cls , path : str , ** kwargs ) -> bool :
131+ def is_extractable (cls , path : Union [ "pathlib.Path" , str ] , ** kwargs ) -> bool :
127132 """https://github.com/markokr/rarfile/blob/master/rarfile.py"""
128133 with open (path , "rb" ) as f :
129134 magic_number = f .read (len (cls .RAR5_ID ))
130135 return magic_number == cls .RAR5_ID or magic_number .startswith (cls .RAR_ID )
131136
132137 @staticmethod
133- def extract (input_path : str , output_path : str ) -> None :
138+ def extract (input_path : Union [ "pathlib.Path" , str ] , output_path : Union [ "pathlib.Path" , str ] ) -> None :
134139 if not config .RARFILE_AVAILABLE :
135140 raise OSError ("Please pip install rarfile" )
136141 import rarfile
@@ -145,7 +150,7 @@ class ZstdExtractor(MagicNumberBaseExtractor):
145150 magic_number = b"\x28 \xb5 \x2F \xFD "
146151
147152 @staticmethod
148- def extract (input_path : str , output_path : str ) -> None :
153+ def extract (input_path : Union [ "pathlib.Path" , str ] , output_path : Union [ "pathlib.Path" , str ] ) -> None :
149154 if not config .ZSTANDARD_AVAILABLE :
150155 raise OSError ("Please pip install zstandard" )
151156 import zstandard as zstd
@@ -159,7 +164,7 @@ class Bzip2Extractor(MagicNumberBaseExtractor):
159164 magic_number = b"\x42 \x5A \x68 "
160165
161166 @staticmethod
162- def extract (input_path : str , output_path : str ) -> None :
167+ def extract (input_path : Union [ "pathlib.Path" , str ] , output_path : Union [ "pathlib.Path" , str ] ) -> None :
163168 with bz2 .open (input_path , "rb" ) as compressed_file :
164169 with open (output_path , "wb" ) as extracted_file :
165170 shutil .copyfileobj (compressed_file , extracted_file )
@@ -169,7 +174,7 @@ class SevenZipExtractor(MagicNumberBaseExtractor):
169174 magic_number = b"\x37 \x7A \xBC \xAF \x27 \x1C "
170175
171176 @staticmethod
172- def extract (input_path : str , output_path : str ) -> None :
177+ def extract (input_path : Union [ "pathlib.Path" , str ] , output_path : Union [ "pathlib.Path" , str ] ) -> None :
173178 if not config .PY7ZR_AVAILABLE :
174179 raise OSError ("Please pip install py7zr" )
175180 import py7zr
@@ -183,7 +188,7 @@ class Lz4Extractor(MagicNumberBaseExtractor):
183188 magic_number = b"\x04 \x22 \x4D \x18 "
184189
185190 @staticmethod
186- def extract (input_path : str , output_path : str ) -> None :
191+ def extract (input_path : Union [ "pathlib.Path" , str ] , output_path : Union [ "pathlib.Path" , str ] ) -> None :
187192 if not config .LZ4_AVAILABLE :
188193 raise OSError ("Please pip install lz4" )
189194 import lz4 .frame
@@ -219,14 +224,14 @@ def _get_magic_number_max_length(cls):
219224 return magic_number_max_length
220225
221226 @staticmethod
222- def _read_magic_number (path : str , magic_number_length : int ):
227+ def _read_magic_number (path : Union [ "pathlib.Path" , str ] , magic_number_length : int ):
223228 try :
224229 return MagicNumberBaseExtractor .read_magic_number (path , magic_number_length = magic_number_length )
225230 except OSError :
226231 return b""
227232
228233 @classmethod
229- def is_extractable (cls , path , return_extractor = False ):
234+ def is_extractable (cls , path : Union [ "pathlib.Path" , str ], return_extractor : bool = False ) -> bool :
230235 warnings .warn (
231236 "Method 'is_extractable' was deprecated in version 2.4.0 and will be removed in 3.0.0. "
232237 "Use 'infer_extractor_format' instead." ,
@@ -238,17 +243,23 @@ def is_extractable(cls, path, return_extractor=False):
238243 return False if not return_extractor else (False , None )
239244
240245 @classmethod
241- def infer_extractor_format (cls , path ) :
246+ def infer_extractor_format (cls , path : Union [ "pathlib.Path" , str ]) -> str :
242247 magic_number_max_length = cls ._get_magic_number_max_length ()
243248 magic_number = cls ._read_magic_number (path , magic_number_max_length )
244249 for extractor_format , extractor in cls .extractors .items ():
245250 if extractor .is_extractable (path , magic_number = magic_number ):
246251 return extractor_format
247252
248253 @classmethod
249- def extract (cls , input_path , output_path , extractor_format = None , extractor = "deprecated" ):
254+ def extract (
255+ cls ,
256+ input_path : Union ["pathlib.Path" , str ],
257+ output_path : Union ["pathlib.Path" , str ],
258+ extractor_format : Optional [str ] = None ,
259+ extractor : Optional [BaseExtractor ] = "deprecated" ,
260+ ) -> None :
250261 # Prevent parallel extractions
251- lock_path = input_path + ".lock"
262+ lock_path = str ( input_path ) + ".lock"
252263 with FileLock (lock_path ):
253264 shutil .rmtree (output_path , ignore_errors = True )
254265 os .makedirs (os .path .dirname (output_path ), exist_ok = True )
0 commit comments