Skip to content

Commit 4fb3ed0

Browse files
Refactor base extractors (#4690)
* Implement BaseExtractor * Refactor base extractors * Refactor zipfile import * Improve performance of test_extractor * Allow passing magic number to is_extractable * Read magic number only once * Refactor Extractor to use extractor_format * Update test_extractor * Make ExtractManager use extractor_format * Refactor class hierarchy
1 parent c2a06b5 commit 4fb3ed0

File tree

2 files changed

+136
-111
lines changed

2 files changed

+136
-111
lines changed

src/datasets/utils/extract.py

Lines changed: 133 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import lzma
44
import os
55
import shutil
6-
import struct
76
import tarfile
8-
from zipfile import ZipFile
9-
from zipfile import is_zipfile as _is_zipfile
7+
import warnings
8+
import zipfile
9+
from abc import ABC, abstractmethod
1010

1111
from .. import config
1212
from .filelock import FileLock
@@ -33,96 +33,104 @@ def _do_extract(self, output_path, force_extract):
3333
)
3434

3535
def extract(self, input_path, force_extract=False):
36-
is_extractable, extractor = self.extractor.is_extractable(input_path, return_extractor=True)
37-
if not is_extractable:
36+
extractor_format = self.extractor.infer_extractor_format(input_path)
37+
if not extractor_format:
3838
return input_path
3939
output_path = self._get_output_path(input_path)
4040
if self._do_extract(output_path, force_extract):
41-
self.extractor.extract(input_path, output_path, extractor=extractor)
41+
self.extractor.extract(input_path, output_path, extractor_format)
4242
return output_path
4343

4444

45-
class TarExtractor:
45+
class BaseExtractor(ABC):
46+
@classmethod
47+
@abstractmethod
48+
def is_extractable(cls, path: str, **kwargs) -> bool:
49+
...
50+
51+
@staticmethod
52+
@abstractmethod
53+
def extract(input_path: str, output_path: str) -> None:
54+
...
55+
56+
57+
class MagicNumberBaseExtractor(BaseExtractor, ABC):
58+
magic_number = b""
59+
4660
@staticmethod
47-
def is_extractable(path):
61+
def read_magic_number(path: str, magic_number_length: int):
62+
with open(path, "rb") as f:
63+
return f.read(magic_number_length)
64+
65+
@classmethod
66+
def is_extractable(cls, path: str, magic_number: bytes = b"") -> bool:
67+
if not magic_number:
68+
try:
69+
magic_number = cls.read_magic_number(path, len(cls.magic_number))
70+
except OSError:
71+
return False
72+
return magic_number.startswith(cls.magic_number)
73+
74+
75+
class TarExtractor(BaseExtractor):
76+
@classmethod
77+
def is_extractable(cls, path: str, **kwargs) -> bool:
4878
return tarfile.is_tarfile(path)
4979

5080
@staticmethod
51-
def extract(input_path, output_path):
81+
def extract(input_path: str, output_path: str) -> None:
5282
os.makedirs(output_path, exist_ok=True)
5383
tar_file = tarfile.open(input_path)
5484
tar_file.extractall(output_path)
5585
tar_file.close()
5686

5787

58-
class GzipExtractor:
59-
@staticmethod
60-
def is_extractable(path: str) -> bool:
61-
"""from https://stackoverflow.com/a/60634210"""
62-
with gzip.open(path, "r") as fh:
63-
try:
64-
fh.read(1)
65-
return True
66-
except OSError:
67-
return False
88+
class GzipExtractor(MagicNumberBaseExtractor):
89+
magic_number = b"\x1F\x8B"
6890

6991
@staticmethod
70-
def extract(input_path, output_path):
92+
def extract(input_path: str, output_path: str) -> None:
7193
with gzip.open(input_path, "rb") as gzip_file:
7294
with open(output_path, "wb") as extracted_file:
7395
shutil.copyfileobj(gzip_file, extracted_file)
7496

7597

76-
class ZipExtractor:
77-
@staticmethod
78-
def is_extractable(path):
79-
return _is_zipfile(path)
98+
class ZipExtractor(BaseExtractor):
99+
@classmethod
100+
def is_extractable(cls, path: str, **kwargs) -> bool:
101+
return zipfile.is_zipfile(path)
80102

81103
@staticmethod
82-
def extract(input_path, output_path):
104+
def extract(input_path: str, output_path: str) -> None:
83105
os.makedirs(output_path, exist_ok=True)
84-
with ZipFile(input_path, "r") as zip_file:
106+
with zipfile.ZipFile(input_path, "r") as zip_file:
85107
zip_file.extractall(output_path)
86108
zip_file.close()
87109

88110

89-
class XzExtractor:
90-
@staticmethod
91-
def is_extractable(path: str) -> bool:
92-
"""https://tukaani.org/xz/xz-file-format-1.0.4.txt"""
93-
with open(path, "rb") as f:
94-
try:
95-
header_magic_bytes = f.read(6)
96-
except OSError:
97-
return False
98-
if header_magic_bytes == b"\xfd7zXZ\x00":
99-
return True
100-
else:
101-
return False
111+
class XzExtractor(MagicNumberBaseExtractor):
112+
magic_number = b"\xFD\x37\x7A\x58\x5A\x00"
102113

103114
@staticmethod
104-
def extract(input_path, output_path):
115+
def extract(input_path: str, output_path: str) -> None:
105116
with lzma.open(input_path) as compressed_file:
106117
with open(output_path, "wb") as extracted_file:
107118
shutil.copyfileobj(compressed_file, extracted_file)
108119

109120

110-
class RarExtractor:
111-
@staticmethod
112-
def is_extractable(path: str) -> bool:
113-
"""https://github.com/markokr/rarfile/blob/master/rarfile.py"""
114-
RAR_ID = b"Rar!\x1a\x07\x00"
115-
RAR5_ID = b"Rar!\x1a\x07\x01\x00"
121+
class RarExtractor(BaseExtractor):
122+
RAR_ID = b"Rar!\x1a\x07\x00"
123+
RAR5_ID = b"Rar!\x1a\x07\x01\x00"
116124

117-
with open(path, "rb", 1024) as fd:
118-
buf = fd.read(len(RAR5_ID))
119-
if buf.startswith(RAR_ID) or buf.startswith(RAR5_ID):
120-
return True
121-
else:
122-
return False
125+
@classmethod
126+
def is_extractable(cls, path: str, **kwargs) -> bool:
127+
"""https://github.com/markokr/rarfile/blob/master/rarfile.py"""
128+
with open(path, "rb") as f:
129+
magic_number = f.read(len(cls.RAR5_ID))
130+
return magic_number == cls.RAR5_ID or magic_number.startswith(cls.RAR_ID)
123131

124132
@staticmethod
125-
def extract(input_path, output_path):
133+
def extract(input_path: str, output_path: str) -> None:
126134
if not config.RARFILE_AVAILABLE:
127135
raise OSError("Please pip install rarfile")
128136
import rarfile
@@ -133,22 +141,11 @@ def extract(input_path, output_path):
133141
rf.close()
134142

135143

136-
class ZstdExtractor:
137-
@staticmethod
138-
def is_extractable(path: str) -> bool:
139-
"""https://datatracker.ietf.org/doc/html/rfc8878
140-
141-
Magic_Number: 4 bytes, little-endian format. Value: 0xFD2FB528.
142-
"""
143-
with open(path, "rb") as f:
144-
try:
145-
magic_number = f.read(4)
146-
except OSError:
147-
return False
148-
return True if magic_number == struct.pack("<I", 0xFD2FB528) else False
144+
class ZstdExtractor(MagicNumberBaseExtractor):
145+
magic_number = b"\x28\xb5\x2F\xFD"
149146

150147
@staticmethod
151-
def extract(input_path: str, output_path: str):
148+
def extract(input_path: str, output_path: str) -> None:
152149
if not config.ZSTANDARD_AVAILABLE:
153150
raise OSError("Please pip install zstandard")
154151
import zstandard as zstd
@@ -158,40 +155,21 @@ def extract(input_path: str, output_path: str):
158155
dctx.copy_stream(ifh, ofh)
159156

160157

161-
class Bzip2Extractor:
162-
@staticmethod
163-
def is_extractable(path: str) -> bool:
164-
with open(path, "rb") as f:
165-
try:
166-
header_magic_bytes = f.read(3)
167-
except OSError:
168-
return False
169-
if header_magic_bytes == b"BZh":
170-
return True
171-
else:
172-
return False
158+
class Bzip2Extractor(MagicNumberBaseExtractor):
159+
magic_number = b"\x42\x5A\x68"
173160

174161
@staticmethod
175-
def extract(input_path, output_path):
162+
def extract(input_path: str, output_path: str) -> None:
176163
with bz2.open(input_path, "rb") as compressed_file:
177164
with open(output_path, "wb") as extracted_file:
178165
shutil.copyfileobj(compressed_file, extracted_file)
179166

180167

181-
class SevenZipExtractor:
168+
class SevenZipExtractor(MagicNumberBaseExtractor):
182169
magic_number = b"\x37\x7A\xBC\xAF\x27\x1C"
183170

184-
@classmethod
185-
def is_extractable(cls, path):
186-
with open(path, "rb") as f:
187-
try:
188-
magic_number = f.read(len(cls.magic_number))
189-
except OSError:
190-
return False
191-
return True if magic_number == cls.magic_number else False
192-
193171
@staticmethod
194-
def extract(input_path: str, output_path: str):
172+
def extract(input_path: str, output_path: str) -> None:
195173
if not config.PY7ZR_AVAILABLE:
196174
raise OSError("Please pip install py7zr")
197175
import py7zr
@@ -203,33 +181,79 @@ def extract(input_path: str, output_path: str):
203181

204182
class Extractor:
205183
# Put zip file to the last, b/c it is possible wrongly detected as zip (I guess it means: as tar or gzip)
206-
extractors = [
207-
TarExtractor,
208-
GzipExtractor,
209-
ZipExtractor,
210-
XzExtractor,
211-
RarExtractor,
212-
ZstdExtractor,
213-
Bzip2Extractor,
214-
SevenZipExtractor,
215-
]
184+
extractors = {
185+
"tar": TarExtractor,
186+
"gzip": GzipExtractor,
187+
"zip": ZipExtractor,
188+
"xz": XzExtractor,
189+
"rar": RarExtractor,
190+
"zstd": ZstdExtractor,
191+
"bz2": Bzip2Extractor,
192+
"7z": SevenZipExtractor,
193+
}
194+
195+
@classmethod
196+
def _get_magic_number_max_length(cls):
197+
magic_number_max_length = 0
198+
for extractor in cls.extractors.values():
199+
if hasattr(extractor, "magic_number"):
200+
magic_number_length = len(extractor.magic_number)
201+
magic_number_max_length = (
202+
magic_number_length if magic_number_length > magic_number_max_length else magic_number_max_length
203+
)
204+
return magic_number_max_length
205+
206+
@staticmethod
207+
def _read_magic_number(path: str, magic_number_length: int):
208+
try:
209+
return MagicNumberBaseExtractor.read_magic_number(path, magic_number_length=magic_number_length)
210+
except OSError:
211+
return b""
216212

217213
@classmethod
218214
def is_extractable(cls, path, return_extractor=False):
219-
for extractor in cls.extractors:
220-
if extractor.is_extractable(path):
221-
return True if not return_extractor else (True, extractor)
215+
warnings.warn(
216+
"Method 'is_extractable' was deprecated in version 2.4.0 and will be removed in 3.0.0. "
217+
"Use 'infer_extractor_format' instead.",
218+
category=FutureWarning,
219+
)
220+
extractor_format = cls.infer_extractor_format(path)
221+
if extractor_format:
222+
return True if not return_extractor else (True, cls.extractors[extractor_format])
222223
return False if not return_extractor else (False, None)
223224

224225
@classmethod
225-
def extract(cls, input_path, output_path, extractor=None):
226+
def infer_extractor_format(cls, path):
227+
magic_number_max_length = cls._get_magic_number_max_length()
228+
magic_number = cls._read_magic_number(path, magic_number_max_length)
229+
for extractor_format, extractor in cls.extractors.items():
230+
if extractor.is_extractable(path, magic_number=magic_number):
231+
return extractor_format
232+
233+
@classmethod
234+
def extract(cls, input_path, output_path, extractor_format=None, extractor="deprecated"):
226235
# Prevent parallel extractions
227236
lock_path = input_path + ".lock"
228237
with FileLock(lock_path):
229238
shutil.rmtree(output_path, ignore_errors=True)
230239
os.makedirs(os.path.dirname(output_path), exist_ok=True)
231-
if extractor:
240+
if extractor_format or extractor != "deprecated":
241+
if extractor != "deprecated" or not isinstance(extractor_format, str): # passed as positional arg
242+
warnings.warn(
243+
"Parameter 'extractor' was deprecated in version 2.4.0 and will be removed in 3.0.0. "
244+
"Use 'extractor_format' instead.",
245+
category=FutureWarning,
246+
)
247+
extractor = extractor if extractor != "deprecated" else extractor_format
248+
else:
249+
extractor = cls.extractors[extractor_format]
232250
return extractor.extract(input_path, output_path)
233-
for extractor in cls.extractors:
234-
if extractor.is_extractable(input_path):
235-
return extractor.extract(input_path, output_path)
251+
else:
252+
warnings.warn(
253+
"Parameter 'extractor_format' was made required in version 2.4.0 and not passing it will raise an "
254+
"exception in 3.0.0.",
255+
category=FutureWarning,
256+
)
257+
for extractor in cls.extractors.values():
258+
if extractor.is_extractable(input_path):
259+
return extractor.extract(input_path, output_path)

tests/test_extract.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,10 @@ def test_extractor(
9797
reason += require_zstandard.kwargs["reason"]
9898
pytest.skip(reason)
9999
input_path = str(input_path)
100-
assert Extractor.is_extractable(input_path)
100+
extractor_format = Extractor.infer_extractor_format(input_path)
101+
assert extractor_format is not None
101102
output_path = tmp_path / ("extracted" if is_archive else "extracted.txt")
102-
Extractor.extract(input_path, output_path)
103+
Extractor.extract(input_path, output_path, extractor_format)
103104
if is_archive:
104105
assert output_path.is_dir()
105106
for file_path in output_path.iterdir():

0 commit comments

Comments
 (0)