Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/datasets/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@
logger.info("Disabling rarfile because USE_RAR is set to False")


ZSTANDARD_AVAILABLE = importlib.util.find_spec("zstandard") is not None


# Cache location
DEFAULT_XDG_CACHE_HOME = "~/.cache"
XDG_CACHE_HOME = os.getenv("XDG_CACHE_HOME", DEFAULT_XDG_CACHE_HOME)
Expand Down
30 changes: 30 additions & 0 deletions src/datasets/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os
import re
import shutil
import struct
import sys
import tarfile
import tempfile
Expand Down Expand Up @@ -307,6 +308,7 @@ def cached_path(
and not is_gzip(output_path)
and not is_xz(output_path)
and not is_rarfile(output_path)
and not ZstdExtractor.is_extractable(output_path)
):
return output_path

Expand Down Expand Up @@ -360,6 +362,9 @@ def cached_path(
rf.close()
else:
raise EnvironmentError("Please pip install rarfile")
elif ZstdExtractor.is_extractable(output_path):
os.rmdir(output_path_extracted)
ZstdExtractor.extract(output_path, output_path_extracted)
else:
raise EnvironmentError("Archive format of {} could not be identified".format(output_path))

Expand Down Expand Up @@ -724,6 +729,31 @@ def is_rarfile(path: str) -> bool:
return False


class ZstdExtractor:
@staticmethod
def is_extractable(path: str) -> bool:
"""https://datatracker.ietf.org/doc/html/rfc8878

Magic_Number: 4 bytes, little-endian format. Value: 0xFD2FB528.
"""
with open(path, "rb") as f:
try:
magic_number = f.read(4)
except OSError:
return False
return True if magic_number == struct.pack("<I", 0xFD2FB528) else False

@staticmethod
def extract(input_path: str, output_path: str):
if not config.ZSTANDARD_AVAILABLE:
raise EnvironmentError("Please pip install zstandard")
import zstandard as zstd

dctx = zstd.ZstdDecompressor()
with open(input_path, "rb") as ifh, open(output_path, "wb") as ofh:
dctx.copy_stream(ifh, ofh)


def add_start_docstrings(*docstr):
def docstring_decorator(fn):
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
Expand Down
38 changes: 34 additions & 4 deletions tests/test_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

import numpy as np
import pytest
import zstandard as zstd

from datasets.utils.file_utils import (
DownloadConfig,
OfflineModeIsEnabled,
ZstdExtractor,
cached_path,
ftp_get,
ftp_head,
Expand All @@ -20,6 +22,20 @@
from .utils import require_tf, require_torch


FILE_CONTENT = """\
Text data.
Second line of data."""


@pytest.fixture(scope="session")
def zstd_path(tmp_path_factory):
path = tmp_path_factory.mktemp("data") / "file.zstd"
data = bytes(FILE_CONTENT, "utf-8")
with zstd.open(path, "wb") as f:
f.write(data)
return path


class TempSeedTest(TestCase):
@require_tf
def test_tensorflow(self):
Expand Down Expand Up @@ -72,12 +88,26 @@ def gen_random_output():
self.assertGreater(np.abs(out1 - out3).sum(), 0)


def test_cached_path_extract(xz_file, tmp_path, text_file):
filename = xz_file
def test_zstd_extractor(zstd_path, tmp_path, text_file):
input_path = zstd_path
assert ZstdExtractor.is_extractable(input_path)
output_path = str(tmp_path / "extracted.txt")
ZstdExtractor.extract(input_path, output_path)
with open(output_path) as f:
extracted_file_content = f.read()
with open(text_file) as f:
expected_file_content = f.read()
assert extracted_file_content == expected_file_content


@pytest.mark.parametrize("compression_format", ["xz", "zstd"])
def test_cached_path_extract(compression_format, xz_file, zstd_path, tmp_path, text_file):
path = {"xz": xz_file, "zstd": zstd_path}
input_path = path[compression_format]
cache_dir = tmp_path / "cache"
download_config = DownloadConfig(cache_dir=cache_dir, extract_compressed_file=True)
extracted_filename = cached_path(filename, download_config=download_config)
with open(extracted_filename) as f:
extracted_path = cached_path(input_path, download_config=download_config)
with open(extracted_path) as f:
extracted_file_content = f.read()
with open(text_file) as f:
expected_file_content = f.read()
Expand Down