Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/datasets/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@
logger.info("Disabling rarfile because USE_RAR is set to False")


ZSTANDARD_AVAILABLE = importlib.util.find_spec("zstandard") is not None


# Cache location
DEFAULT_XDG_CACHE_HOME = "~/.cache"
XDG_CACHE_HOME = os.getenv("XDG_CACHE_HOME", DEFAULT_XDG_CACHE_HOME)
Expand Down
30 changes: 30 additions & 0 deletions src/datasets/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os
import re
import shutil
import struct
import sys
import tarfile
import tempfile
Expand Down Expand Up @@ -307,6 +308,7 @@ def cached_path(
and not is_gzip(output_path)
and not is_xz(output_path)
and not is_rarfile(output_path)
and not ZstdExtractor.is_extractable(output_path)
):
return output_path

Expand Down Expand Up @@ -360,6 +362,9 @@ def cached_path(
rf.close()
else:
raise EnvironmentError("Please pip install rarfile")
elif ZstdExtractor.is_extractable(output_path):
os.rmdir(output_path_extracted)
ZstdExtractor.extract(output_path, output_path_extracted)
else:
raise EnvironmentError("Archive format of {} could not be identified".format(output_path))

Expand Down Expand Up @@ -724,6 +729,31 @@ def is_rarfile(path: str) -> bool:
return False


class ZstdExtractor:
@staticmethod
def is_extractable(path: str) -> bool:
"""https://datatracker.ietf.org/doc/html/rfc8878

Magic_Number: 4 bytes, little-endian format. Value: 0xFD2FB528.
"""
with open(path, "rb") as f:
try:
magic_number = f.read(4)
except OSError:
return False
return True if magic_number == struct.pack("<I", 0xFD2FB528) else False

@staticmethod
def extract(input_path: str, output_path: str):
if not config.ZSTANDARD_AVAILABLE:
raise EnvironmentError("Please pip install zstandard")
import zstandard as zstd

dctx = zstd.ZstdDecompressor()
with open(input_path, "rb") as ifh, open(output_path, "wb") as ofh:
dctx.copy_stream(ifh, ofh)


def add_start_docstrings(*docstr):
def docstring_decorator(fn):
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
Expand Down
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pyarrow as pa
import pyarrow.parquet as pq
import pytest
import zstandard as zstd

from datasets.arrow_dataset import Dataset
from datasets.features import ClassLabel, Features, Sequence, Value
Expand Down Expand Up @@ -87,6 +88,15 @@ def xz_file(tmp_path_factory):
return filename


@pytest.fixture(scope="session")
def zstd_path(tmp_path_factory):
path = tmp_path_factory.mktemp("data") / "file.zstd"
data = bytes(FILE_CONTENT, "utf-8")
with zstd.open(path, "wb") as f:
f.write(data)
return path


@pytest.fixture(scope="session")
def xml_file(tmp_path_factory):
filename = tmp_path_factory.mktemp("data") / "file.xml"
Expand Down
23 changes: 19 additions & 4 deletions tests/test_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from datasets.utils.file_utils import (
DownloadConfig,
OfflineModeIsEnabled,
ZstdExtractor,
cached_path,
ftp_get,
ftp_head,
Expand Down Expand Up @@ -72,12 +73,26 @@ def gen_random_output():
self.assertGreater(np.abs(out1 - out3).sum(), 0)


def test_cached_path_extract(xz_file, tmp_path, text_file):
filename = xz_file
def test_zstd_extractor(zstd_path, tmp_path, text_file):
input_path = zstd_path
assert ZstdExtractor.is_extractable(input_path)
output_path = str(tmp_path / "extracted.txt")
ZstdExtractor.extract(input_path, output_path)
with open(output_path) as f:
extracted_file_content = f.read()
with open(text_file) as f:
expected_file_content = f.read()
assert extracted_file_content == expected_file_content


@pytest.mark.parametrize("compression_format", ["xz", "zstd"])
def test_cached_path_extract(compression_format, xz_file, zstd_path, tmp_path, text_file):
path = {"xz": xz_file, "zstd": zstd_path}
input_path = path[compression_format]
cache_dir = tmp_path / "cache"
download_config = DownloadConfig(cache_dir=cache_dir, extract_compressed_file=True)
extracted_filename = cached_path(filename, download_config=download_config)
with open(extracted_filename) as f:
extracted_path = cached_path(input_path, download_config=download_config)
with open(extracted_path) as f:
extracted_file_content = f.read()
with open(text_file) as f:
expected_file_content = f.read()
Expand Down