Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
9edc2e7
Reorder execution flow in cached_path
albertvillanova Apr 30, 2021
77b0252
Extract function _extract from cached_path
albertvillanova Apr 30, 2021
66f6975
Extract method for each extract type
albertvillanova Apr 30, 2021
c1efe88
Create an Extractor class for each file type
albertvillanova Apr 30, 2021
652f47b
Rename extract method and input/output path params
albertvillanova Apr 30, 2021
1e3e6e2
Rename is_extractable method
albertvillanova Apr 30, 2021
56dcea7
Create generic Extractor.is_extractable
albertvillanova Apr 30, 2021
75e27ee
Create class attribute extractors
albertvillanova May 3, 2021
7b9359f
Move extract functionality to extract module
albertvillanova May 3, 2021
4f1bf02
Fix Extractor.is_extractable
albertvillanova May 3, 2021
a36663b
Create ExtractManager with all extract logic
albertvillanova May 3, 2021
104f0f7
Fix circular import
albertvillanova May 3, 2021
9f57f52
Fix typo
albertvillanova May 3, 2021
ad59f5d
Remove unused class
albertvillanova May 3, 2021
7ef66c9
Fix style
albertvillanova May 3, 2021
75be2e8
Merge remote-tracking branch 'upstream/master' into refactoring-3
albertvillanova Jul 5, 2021
e62ac10
Fix issues after merge upstream master
albertvillanova Jul 5, 2021
8bd53c4
Remove default os.makedirs and os.rmdir when not applicable
albertvillanova Jul 5, 2021
70c3344
Create parent dirs of output_path
albertvillanova Jul 5, 2021
f673970
Minor refactoring of ExtractManager
albertvillanova Jul 5, 2021
5792130
Optimize Extractor.extract by returning specific extractor
albertvillanova Jul 5, 2021
7126a1d
Fix extract
albertvillanova Jul 5, 2021
834ecc3
Test extract and add gzip to test_cached_path_extract
albertvillanova Jul 5, 2021
006dfe2
Address requested changes
albertvillanova Jul 7, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 182 additions & 0 deletions src/datasets/utils/extract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import gzip
import lzma
import os
import shutil
import struct
import tarfile
from zipfile import ZipFile
from zipfile import is_zipfile as _is_zipfile

from datasets import config
from datasets.utils.filelock import FileLock


class ExtractManager:
def __init__(self, cache_dir=None):
self.extract_dir = (
os.path.join(cache_dir, config.EXTRACTED_DATASETS_DIR) if cache_dir else config.EXTRACTED_DATASETS_PATH
)
self.extractor = Extractor

def _get_output_path(self, path):
from datasets.utils.file_utils import hash_url_to_filename

# Path where we extract compressed archives
# We extract in the cache dir, and get the extracted path name by hashing the original path"
abs_path = os.path.abspath(path)
return os.path.join(self.extract_dir, hash_url_to_filename(abs_path))

def _do_extract(self, output_path, force_extract):
return force_extract or (
not os.path.isfile(output_path) and not (os.path.isdir(output_path) and os.listdir(output_path))
)

def extract(self, input_path, force_extract=False):
is_extractable, extractor = self.extractor.is_extractable(input_path, return_extractor=True)
if not is_extractable:
return input_path
output_path = self._get_output_path(input_path)
if self._do_extract(output_path, force_extract):
self.extractor.extract(input_path, output_path, extractor=extractor)
return output_path


class TarExtractor:
@staticmethod
def is_extractable(path):
return tarfile.is_tarfile(path)

@staticmethod
def extract(input_path, output_path):
os.makedirs(output_path, exist_ok=True)
tar_file = tarfile.open(input_path)
tar_file.extractall(output_path)
tar_file.close()


class GzipExtractor:
@staticmethod
def is_extractable(path: str) -> bool:
"""from https://stackoverflow.com/a/60634210"""
with gzip.open(path, "r") as fh:
try:
fh.read(1)
return True
except OSError:
return False

@staticmethod
def extract(input_path, output_path):
with gzip.open(input_path, "rb") as gzip_file:
with open(output_path, "wb") as extracted_file:
shutil.copyfileobj(gzip_file, extracted_file)


class ZipExtractor:
@staticmethod
def is_extractable(path):
return _is_zipfile(path)

@staticmethod
def extract(input_path, output_path):
os.makedirs(output_path, exist_ok=True)
with ZipFile(input_path, "r") as zip_file:
zip_file.extractall(output_path)
zip_file.close()


class XzExtractor:
@staticmethod
def is_extractable(path: str) -> bool:
"""https://tukaani.org/xz/xz-file-format-1.0.4.txt"""
with open(path, "rb") as f:
try:
header_magic_bytes = f.read(6)
except OSError:
return False
if header_magic_bytes == b"\xfd7zXZ\x00":
return True
else:
return False

@staticmethod
def extract(input_path, output_path):
with lzma.open(input_path) as compressed_file:
with open(output_path, "wb") as extracted_file:
shutil.copyfileobj(compressed_file, extracted_file)


class RarExtractor:
@staticmethod
def is_extractable(path: str) -> bool:
"""https://github.com/markokr/rarfile/blob/master/rarfile.py"""
RAR_ID = b"Rar!\x1a\x07\x00"
RAR5_ID = b"Rar!\x1a\x07\x01\x00"

with open(path, "rb", 1024) as fd:
buf = fd.read(len(RAR5_ID))
if buf.startswith(RAR_ID) or buf.startswith(RAR5_ID):
return True
else:
return False

@staticmethod
def extract(input_path, output_path):
if not config.RARFILE_AVAILABLE:
raise EnvironmentError("Please pip install rarfile")
import rarfile

os.makedirs(output_path, exist_ok=True)
rf = rarfile.RarFile(input_path)
rf.extractall(output_path)
rf.close()


class ZstdExtractor:
@staticmethod
def is_extractable(path: str) -> bool:
"""https://datatracker.ietf.org/doc/html/rfc8878

Magic_Number: 4 bytes, little-endian format. Value: 0xFD2FB528.
"""
with open(path, "rb") as f:
try:
magic_number = f.read(4)
except OSError:
return False
return True if magic_number == struct.pack("<I", 0xFD2FB528) else False

@staticmethod
def extract(input_path: str, output_path: str):
if not config.ZSTANDARD_AVAILABLE:
raise EnvironmentError("Please pip install zstandard")
import zstandard as zstd

dctx = zstd.ZstdDecompressor()
with open(input_path, "rb") as ifh, open(output_path, "wb") as ofh:
dctx.copy_stream(ifh, ofh)


class Extractor:
# Put zip file to the last, b/c it is possible wrongly detected as zip (I guess it means: as tar or gzip)
extractors = [TarExtractor, GzipExtractor, ZipExtractor, XzExtractor, RarExtractor, ZstdExtractor]

@classmethod
def is_extractable(cls, path, return_extractor=False):
for extractor in cls.extractors:
if extractor.is_extractable(path):
return True if not return_extractor else (True, extractor)
return False if not return_extractor else (False, None)

@classmethod
def extract(cls, input_path, output_path, extractor=None):
# Prevent parallel extractions
lock_path = input_path + ".lock"
with FileLock(lock_path):
shutil.rmtree(output_path, ignore_errors=True)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
if extractor:
return extractor.extract(input_path, output_path)
for extractor in cls.extractors:
if extractor.is_extractable(input_path):
return extractor.extract(input_path, output_path)
141 changes: 7 additions & 134 deletions src/datasets/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,11 @@
"""

import copy
import gzip
import json
import lzma
import os
import re
import shutil
import struct
import sys
import tarfile
import tempfile
import time
import urllib
Expand All @@ -24,7 +20,6 @@
from pathlib import Path
from typing import Dict, Optional, Union
from urllib.parse import urlparse
from zipfile import ZipFile, is_zipfile

import numpy as np
import posixpath
Expand All @@ -33,6 +28,7 @@

from .. import __version__, config
from . import logging
from .extract import ExtractManager
from .filelock import FileLock


Expand Down Expand Up @@ -300,75 +296,13 @@ def cached_path(
# Something unknown
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))

if download_config.extract_compressed_file and output_path is not None:

if (
not is_zipfile(output_path)
and not tarfile.is_tarfile(output_path)
and not is_gzip(output_path)
and not is_xz(output_path)
and not is_rarfile(output_path)
and not ZstdExtractor.is_extractable(output_path)
):
return output_path

# Path where we extract compressed archives
# We extract in the cache dir, and get the extracted path name by hashing the original path
abs_output_path = os.path.abspath(output_path)
output_path_extracted = (
os.path.join(
download_config.cache_dir, config.EXTRACTED_DATASETS_DIR, hash_url_to_filename(abs_output_path)
)
if download_config.cache_dir
else os.path.join(config.EXTRACTED_DATASETS_PATH, hash_url_to_filename(abs_output_path))
)
if output_path is None:
return output_path

if (
os.path.isdir(output_path_extracted)
and os.listdir(output_path_extracted)
and not download_config.force_extract
) or (os.path.isfile(output_path_extracted) and not download_config.force_extract):
return output_path_extracted

# Prevent parallel extractions
lock_path = output_path + ".lock"
with FileLock(lock_path):
shutil.rmtree(output_path_extracted, ignore_errors=True)
os.makedirs(output_path_extracted, exist_ok=True)
if tarfile.is_tarfile(output_path):
tar_file = tarfile.open(output_path)
tar_file.extractall(output_path_extracted)
tar_file.close()
elif is_gzip(output_path):
os.rmdir(output_path_extracted)
with gzip.open(output_path, "rb") as gzip_file:
with open(output_path_extracted, "wb") as extracted_file:
shutil.copyfileobj(gzip_file, extracted_file)
elif is_zipfile(output_path): # put zip file to the last, b/c it is possible wrongly detected as zip
with ZipFile(output_path, "r") as zip_file:
zip_file.extractall(output_path_extracted)
zip_file.close()
elif is_xz(output_path):
os.rmdir(output_path_extracted)
with lzma.open(output_path) as compressed_file:
with open(output_path_extracted, "wb") as extracted_file:
shutil.copyfileobj(compressed_file, extracted_file)
elif is_rarfile(output_path):
if config.RARFILE_AVAILABLE:
import rarfile

rf = rarfile.RarFile(output_path)
rf.extractall(output_path_extracted)
rf.close()
else:
raise EnvironmentError("Please pip install rarfile")
elif ZstdExtractor.is_extractable(output_path):
os.rmdir(output_path_extracted)
ZstdExtractor.extract(output_path, output_path_extracted)
else:
raise EnvironmentError("Archive format of {} could not be identified".format(output_path))

return output_path_extracted
if download_config.extract_compressed_file:
output_path = ExtractManager(cache_dir=download_config.cache_dir).extract(
output_path, force_extract=download_config.force_extract
)

return output_path

Expand Down Expand Up @@ -693,67 +627,6 @@ def _resumable_file_manager():
return cache_path


def is_gzip(path: str) -> bool:
"""from https://stackoverflow.com/a/60634210"""
with gzip.open(path, "r") as fh:
try:
fh.read(1)
return True
except OSError:
return False


def is_xz(path: str) -> bool:
"""https://tukaani.org/xz/xz-file-format-1.0.4.txt"""
with open(path, "rb") as f:
try:
header_magic_bytes = f.read(6)
except OSError:
return False
if header_magic_bytes == b"\xfd7zXZ\x00":
return True
else:
return False


def is_rarfile(path: str) -> bool:
"""https://github.com/markokr/rarfile/blob/master/rarfile.py"""
RAR_ID = b"Rar!\x1a\x07\x00"
RAR5_ID = b"Rar!\x1a\x07\x01\x00"

with open(path, "rb", 1024) as fd:
buf = fd.read(len(RAR5_ID))
if buf.startswith(RAR_ID) or buf.startswith(RAR5_ID):
return True
else:
return False


class ZstdExtractor:
@staticmethod
def is_extractable(path: str) -> bool:
"""https://datatracker.ietf.org/doc/html/rfc8878

Magic_Number: 4 bytes, little-endian format. Value: 0xFD2FB528.
"""
with open(path, "rb") as f:
try:
magic_number = f.read(4)
except OSError:
return False
return True if magic_number == struct.pack("<I", 0xFD2FB528) else False

@staticmethod
def extract(input_path: str, output_path: str):
if not config.ZSTANDARD_AVAILABLE:
raise EnvironmentError("Please pip install zstandard")
import zstandard as zstd

dctx = zstd.ZstdDecompressor()
with open(input_path, "rb") as ifh, open(output_path, "wb") as ofh:
dctx.copy_stream(ifh, ofh)


def add_start_docstrings(*docstr):
def docstring_decorator(fn):
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
Expand Down
11 changes: 11 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,17 @@ def xz_file(tmp_path_factory):
return filename


@pytest.fixture(scope="session")
def gz_path(tmp_path_factory, text_path):
import gzip

path = str(tmp_path_factory.mktemp("data") / "file.gz")
data = bytes(FILE_CONTENT, "utf-8")
with gzip.open(path, "wb") as f:
f.write(data)
return path


@pytest.fixture(scope="session")
def xml_file(tmp_path_factory):
filename = tmp_path_factory.mktemp("data") / "file.xml"
Expand Down
Loading