Skip to content

Commit 7353a41

Browse files
Create ExtractManager (#2295)
* Reorder execution flow in cached_path * Extract function _extract from cached_path * Extract method for each extract type * Create an Extractor class for each file type * Rename extract method and input/output path params * Rename is_extractable method * Create generic Extractor.is_extractable * Create class attribute extractors * Move extract functionality to extract module * Fix Extractor.is_extractable * Create ExtractManager with all extract logic * Fix circular import * Fix typo * Remove unused class * Fix style * Fix issues after merge upstream master * Remove default os.makedirs and os.rmdir when not applicable * Create parent dirs of output_path * Minor refactoring of ExtractManager * Optimize Extractor.extract by returning specific extractor * Fix extract * Test extract and add gzip to test_cached_path_extract * Address requested changes
1 parent d7a7223 commit 7353a41

5 files changed

Lines changed: 254 additions & 151 deletions

File tree

src/datasets/utils/extract.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
import gzip
2+
import lzma
3+
import os
4+
import shutil
5+
import struct
6+
import tarfile
7+
from zipfile import ZipFile
8+
from zipfile import is_zipfile as _is_zipfile
9+
10+
from datasets import config
11+
from datasets.utils.filelock import FileLock
12+
13+
14+
class ExtractManager:
15+
def __init__(self, cache_dir=None):
16+
self.extract_dir = (
17+
os.path.join(cache_dir, config.EXTRACTED_DATASETS_DIR) if cache_dir else config.EXTRACTED_DATASETS_PATH
18+
)
19+
self.extractor = Extractor
20+
21+
def _get_output_path(self, path):
22+
from datasets.utils.file_utils import hash_url_to_filename
23+
24+
# Path where we extract compressed archives
25+
# We extract in the cache dir, and get the extracted path name by hashing the original path"
26+
abs_path = os.path.abspath(path)
27+
return os.path.join(self.extract_dir, hash_url_to_filename(abs_path))
28+
29+
def _do_extract(self, output_path, force_extract):
30+
return force_extract or (
31+
not os.path.isfile(output_path) and not (os.path.isdir(output_path) and os.listdir(output_path))
32+
)
33+
34+
def extract(self, input_path, force_extract=False):
35+
is_extractable, extractor = self.extractor.is_extractable(input_path, return_extractor=True)
36+
if not is_extractable:
37+
return input_path
38+
output_path = self._get_output_path(input_path)
39+
if self._do_extract(output_path, force_extract):
40+
self.extractor.extract(input_path, output_path, extractor=extractor)
41+
return output_path
42+
43+
44+
class TarExtractor:
45+
@staticmethod
46+
def is_extractable(path):
47+
return tarfile.is_tarfile(path)
48+
49+
@staticmethod
50+
def extract(input_path, output_path):
51+
os.makedirs(output_path, exist_ok=True)
52+
tar_file = tarfile.open(input_path)
53+
tar_file.extractall(output_path)
54+
tar_file.close()
55+
56+
57+
class GzipExtractor:
58+
@staticmethod
59+
def is_extractable(path: str) -> bool:
60+
"""from https://stackoverflow.com/a/60634210"""
61+
with gzip.open(path, "r") as fh:
62+
try:
63+
fh.read(1)
64+
return True
65+
except OSError:
66+
return False
67+
68+
@staticmethod
69+
def extract(input_path, output_path):
70+
with gzip.open(input_path, "rb") as gzip_file:
71+
with open(output_path, "wb") as extracted_file:
72+
shutil.copyfileobj(gzip_file, extracted_file)
73+
74+
75+
class ZipExtractor:
76+
@staticmethod
77+
def is_extractable(path):
78+
return _is_zipfile(path)
79+
80+
@staticmethod
81+
def extract(input_path, output_path):
82+
os.makedirs(output_path, exist_ok=True)
83+
with ZipFile(input_path, "r") as zip_file:
84+
zip_file.extractall(output_path)
85+
zip_file.close()
86+
87+
88+
class XzExtractor:
89+
@staticmethod
90+
def is_extractable(path: str) -> bool:
91+
"""https://tukaani.org/xz/xz-file-format-1.0.4.txt"""
92+
with open(path, "rb") as f:
93+
try:
94+
header_magic_bytes = f.read(6)
95+
except OSError:
96+
return False
97+
if header_magic_bytes == b"\xfd7zXZ\x00":
98+
return True
99+
else:
100+
return False
101+
102+
@staticmethod
103+
def extract(input_path, output_path):
104+
with lzma.open(input_path) as compressed_file:
105+
with open(output_path, "wb") as extracted_file:
106+
shutil.copyfileobj(compressed_file, extracted_file)
107+
108+
109+
class RarExtractor:
110+
@staticmethod
111+
def is_extractable(path: str) -> bool:
112+
"""https://github.com/markokr/rarfile/blob/master/rarfile.py"""
113+
RAR_ID = b"Rar!\x1a\x07\x00"
114+
RAR5_ID = b"Rar!\x1a\x07\x01\x00"
115+
116+
with open(path, "rb", 1024) as fd:
117+
buf = fd.read(len(RAR5_ID))
118+
if buf.startswith(RAR_ID) or buf.startswith(RAR5_ID):
119+
return True
120+
else:
121+
return False
122+
123+
@staticmethod
124+
def extract(input_path, output_path):
125+
if not config.RARFILE_AVAILABLE:
126+
raise EnvironmentError("Please pip install rarfile")
127+
import rarfile
128+
129+
os.makedirs(output_path, exist_ok=True)
130+
rf = rarfile.RarFile(input_path)
131+
rf.extractall(output_path)
132+
rf.close()
133+
134+
135+
class ZstdExtractor:
136+
@staticmethod
137+
def is_extractable(path: str) -> bool:
138+
"""https://datatracker.ietf.org/doc/html/rfc8878
139+
140+
Magic_Number: 4 bytes, little-endian format. Value: 0xFD2FB528.
141+
"""
142+
with open(path, "rb") as f:
143+
try:
144+
magic_number = f.read(4)
145+
except OSError:
146+
return False
147+
return True if magic_number == struct.pack("<I", 0xFD2FB528) else False
148+
149+
@staticmethod
150+
def extract(input_path: str, output_path: str):
151+
if not config.ZSTANDARD_AVAILABLE:
152+
raise EnvironmentError("Please pip install zstandard")
153+
import zstandard as zstd
154+
155+
dctx = zstd.ZstdDecompressor()
156+
with open(input_path, "rb") as ifh, open(output_path, "wb") as ofh:
157+
dctx.copy_stream(ifh, ofh)
158+
159+
160+
class Extractor:
161+
# Put zip file to the last, b/c it is possible wrongly detected as zip (I guess it means: as tar or gzip)
162+
extractors = [TarExtractor, GzipExtractor, ZipExtractor, XzExtractor, RarExtractor, ZstdExtractor]
163+
164+
@classmethod
165+
def is_extractable(cls, path, return_extractor=False):
166+
for extractor in cls.extractors:
167+
if extractor.is_extractable(path):
168+
return True if not return_extractor else (True, extractor)
169+
return False if not return_extractor else (False, None)
170+
171+
@classmethod
172+
def extract(cls, input_path, output_path, extractor=None):
173+
# Prevent parallel extractions
174+
lock_path = input_path + ".lock"
175+
with FileLock(lock_path):
176+
shutil.rmtree(output_path, ignore_errors=True)
177+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
178+
if extractor:
179+
return extractor.extract(input_path, output_path)
180+
for extractor in cls.extractors:
181+
if extractor.is_extractable(input_path):
182+
return extractor.extract(input_path, output_path)

src/datasets/utils/file_utils.py

Lines changed: 7 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,11 @@
55
"""
66

77
import copy
8-
import gzip
98
import json
10-
import lzma
119
import os
1210
import re
1311
import shutil
14-
import struct
1512
import sys
16-
import tarfile
1713
import tempfile
1814
import time
1915
import urllib
@@ -24,7 +20,6 @@
2420
from pathlib import Path
2521
from typing import Dict, Optional, Union
2622
from urllib.parse import urlparse
27-
from zipfile import ZipFile, is_zipfile
2823

2924
import numpy as np
3025
import posixpath
@@ -33,6 +28,7 @@
3328

3429
from .. import __version__, config
3530
from . import logging
31+
from .extract import ExtractManager
3632
from .filelock import FileLock
3733

3834

@@ -300,75 +296,13 @@ def cached_path(
300296
# Something unknown
301297
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
302298

303-
if download_config.extract_compressed_file and output_path is not None:
304-
305-
if (
306-
not is_zipfile(output_path)
307-
and not tarfile.is_tarfile(output_path)
308-
and not is_gzip(output_path)
309-
and not is_xz(output_path)
310-
and not is_rarfile(output_path)
311-
and not ZstdExtractor.is_extractable(output_path)
312-
):
313-
return output_path
314-
315-
# Path where we extract compressed archives
316-
# We extract in the cache dir, and get the extracted path name by hashing the original path
317-
abs_output_path = os.path.abspath(output_path)
318-
output_path_extracted = (
319-
os.path.join(
320-
download_config.cache_dir, config.EXTRACTED_DATASETS_DIR, hash_url_to_filename(abs_output_path)
321-
)
322-
if download_config.cache_dir
323-
else os.path.join(config.EXTRACTED_DATASETS_PATH, hash_url_to_filename(abs_output_path))
324-
)
299+
if output_path is None:
300+
return output_path
325301

326-
if (
327-
os.path.isdir(output_path_extracted)
328-
and os.listdir(output_path_extracted)
329-
and not download_config.force_extract
330-
) or (os.path.isfile(output_path_extracted) and not download_config.force_extract):
331-
return output_path_extracted
332-
333-
# Prevent parallel extractions
334-
lock_path = output_path + ".lock"
335-
with FileLock(lock_path):
336-
shutil.rmtree(output_path_extracted, ignore_errors=True)
337-
os.makedirs(output_path_extracted, exist_ok=True)
338-
if tarfile.is_tarfile(output_path):
339-
tar_file = tarfile.open(output_path)
340-
tar_file.extractall(output_path_extracted)
341-
tar_file.close()
342-
elif is_gzip(output_path):
343-
os.rmdir(output_path_extracted)
344-
with gzip.open(output_path, "rb") as gzip_file:
345-
with open(output_path_extracted, "wb") as extracted_file:
346-
shutil.copyfileobj(gzip_file, extracted_file)
347-
elif is_zipfile(output_path): # put zip file to the last, b/c it is possible wrongly detected as zip
348-
with ZipFile(output_path, "r") as zip_file:
349-
zip_file.extractall(output_path_extracted)
350-
zip_file.close()
351-
elif is_xz(output_path):
352-
os.rmdir(output_path_extracted)
353-
with lzma.open(output_path) as compressed_file:
354-
with open(output_path_extracted, "wb") as extracted_file:
355-
shutil.copyfileobj(compressed_file, extracted_file)
356-
elif is_rarfile(output_path):
357-
if config.RARFILE_AVAILABLE:
358-
import rarfile
359-
360-
rf = rarfile.RarFile(output_path)
361-
rf.extractall(output_path_extracted)
362-
rf.close()
363-
else:
364-
raise EnvironmentError("Please pip install rarfile")
365-
elif ZstdExtractor.is_extractable(output_path):
366-
os.rmdir(output_path_extracted)
367-
ZstdExtractor.extract(output_path, output_path_extracted)
368-
else:
369-
raise EnvironmentError("Archive format of {} could not be identified".format(output_path))
370-
371-
return output_path_extracted
302+
if download_config.extract_compressed_file:
303+
output_path = ExtractManager(cache_dir=download_config.cache_dir).extract(
304+
output_path, force_extract=download_config.force_extract
305+
)
372306

373307
return output_path
374308

@@ -693,67 +627,6 @@ def _resumable_file_manager():
693627
return cache_path
694628

695629

696-
def is_gzip(path: str) -> bool:
697-
"""from https://stackoverflow.com/a/60634210"""
698-
with gzip.open(path, "r") as fh:
699-
try:
700-
fh.read(1)
701-
return True
702-
except OSError:
703-
return False
704-
705-
706-
def is_xz(path: str) -> bool:
707-
"""https://tukaani.org/xz/xz-file-format-1.0.4.txt"""
708-
with open(path, "rb") as f:
709-
try:
710-
header_magic_bytes = f.read(6)
711-
except OSError:
712-
return False
713-
if header_magic_bytes == b"\xfd7zXZ\x00":
714-
return True
715-
else:
716-
return False
717-
718-
719-
def is_rarfile(path: str) -> bool:
720-
"""https://github.com/markokr/rarfile/blob/master/rarfile.py"""
721-
RAR_ID = b"Rar!\x1a\x07\x00"
722-
RAR5_ID = b"Rar!\x1a\x07\x01\x00"
723-
724-
with open(path, "rb", 1024) as fd:
725-
buf = fd.read(len(RAR5_ID))
726-
if buf.startswith(RAR_ID) or buf.startswith(RAR5_ID):
727-
return True
728-
else:
729-
return False
730-
731-
732-
class ZstdExtractor:
733-
@staticmethod
734-
def is_extractable(path: str) -> bool:
735-
"""https://datatracker.ietf.org/doc/html/rfc8878
736-
737-
Magic_Number: 4 bytes, little-endian format. Value: 0xFD2FB528.
738-
"""
739-
with open(path, "rb") as f:
740-
try:
741-
magic_number = f.read(4)
742-
except OSError:
743-
return False
744-
return True if magic_number == struct.pack("<I", 0xFD2FB528) else False
745-
746-
@staticmethod
747-
def extract(input_path: str, output_path: str):
748-
if not config.ZSTANDARD_AVAILABLE:
749-
raise EnvironmentError("Please pip install zstandard")
750-
import zstandard as zstd
751-
752-
dctx = zstd.ZstdDecompressor()
753-
with open(input_path, "rb") as ifh, open(output_path, "wb") as ofh:
754-
dctx.copy_stream(ifh, ofh)
755-
756-
757630
def add_start_docstrings(*docstr):
758631
def docstring_decorator(fn):
759632
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")

tests/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,17 @@ def xz_file(tmp_path_factory):
8787
return filename
8888

8989

90+
@pytest.fixture(scope="session")
91+
def gz_path(tmp_path_factory, text_path):
92+
import gzip
93+
94+
path = str(tmp_path_factory.mktemp("data") / "file.gz")
95+
data = bytes(FILE_CONTENT, "utf-8")
96+
with gzip.open(path, "wb") as f:
97+
f.write(data)
98+
return path
99+
100+
90101
@pytest.fixture(scope="session")
91102
def xml_file(tmp_path_factory):
92103
filename = tmp_path_factory.mktemp("data") / "file.xml"

0 commit comments

Comments
 (0)