Skip to content

Commit 474d46b

Browse files
Support Zstandard compressed files (#2578)
* Test Zstandard extractor * Implement Zstandard extractor * Test cached_path extracts Zstandard * Implement cached_path extracts Zstandard * Minor refactoring * Move zstd_path fixture to test_file_utils
1 parent b15b476 commit 474d46b

File tree

3 files changed

+67
-4
lines changed

3 files changed

+67
-4
lines changed

src/datasets/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,9 @@
134134
logger.info("Disabling rarfile because USE_RAR is set to False")
135135

136136

137+
ZSTANDARD_AVAILABLE = importlib.util.find_spec("zstandard") is not None
138+
139+
137140
# Cache location
138141
DEFAULT_XDG_CACHE_HOME = "~/.cache"
139142
XDG_CACHE_HOME = os.getenv("XDG_CACHE_HOME", DEFAULT_XDG_CACHE_HOME)

src/datasets/utils/file_utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import os
1212
import re
1313
import shutil
14+
import struct
1415
import sys
1516
import tarfile
1617
import tempfile
@@ -307,6 +308,7 @@ def cached_path(
307308
and not is_gzip(output_path)
308309
and not is_xz(output_path)
309310
and not is_rarfile(output_path)
311+
and not ZstdExtractor.is_extractable(output_path)
310312
):
311313
return output_path
312314

@@ -360,6 +362,9 @@ def cached_path(
360362
rf.close()
361363
else:
362364
raise EnvironmentError("Please pip install rarfile")
365+
elif ZstdExtractor.is_extractable(output_path):
366+
os.rmdir(output_path_extracted)
367+
ZstdExtractor.extract(output_path, output_path_extracted)
363368
else:
364369
raise EnvironmentError("Archive format of {} could not be identified".format(output_path))
365370

@@ -724,6 +729,31 @@ def is_rarfile(path: str) -> bool:
724729
return False
725730

726731

732+
class ZstdExtractor:
733+
@staticmethod
734+
def is_extractable(path: str) -> bool:
735+
"""https://datatracker.ietf.org/doc/html/rfc8878
736+
737+
Magic_Number: 4 bytes, little-endian format. Value: 0xFD2FB528.
738+
"""
739+
with open(path, "rb") as f:
740+
try:
741+
magic_number = f.read(4)
742+
except OSError:
743+
return False
744+
return True if magic_number == struct.pack("<I", 0xFD2FB528) else False
745+
746+
@staticmethod
747+
def extract(input_path: str, output_path: str):
748+
if not config.ZSTANDARD_AVAILABLE:
749+
raise EnvironmentError("Please pip install zstandard")
750+
import zstandard as zstd
751+
752+
dctx = zstd.ZstdDecompressor()
753+
with open(input_path, "rb") as ifh, open(output_path, "wb") as ofh:
754+
dctx.copy_stream(ifh, ofh)
755+
756+
727757
def add_start_docstrings(*docstr):
728758
def docstring_decorator(fn):
729759
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")

tests/test_file_utils.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55

66
import numpy as np
77
import pytest
8+
import zstandard as zstd
89

910
from datasets.utils.file_utils import (
1011
DownloadConfig,
1112
OfflineModeIsEnabled,
13+
ZstdExtractor,
1214
cached_path,
1315
ftp_get,
1416
ftp_head,
@@ -20,6 +22,20 @@
2022
from .utils import require_tf, require_torch
2123

2224

25+
FILE_CONTENT = """\
26+
Text data.
27+
Second line of data."""
28+
29+
30+
@pytest.fixture(scope="session")
31+
def zstd_path(tmp_path_factory):
32+
path = tmp_path_factory.mktemp("data") / "file.zstd"
33+
data = bytes(FILE_CONTENT, "utf-8")
34+
with zstd.open(path, "wb") as f:
35+
f.write(data)
36+
return path
37+
38+
2339
class TempSeedTest(TestCase):
2440
@require_tf
2541
def test_tensorflow(self):
@@ -72,12 +88,26 @@ def gen_random_output():
7288
self.assertGreater(np.abs(out1 - out3).sum(), 0)
7389

7490

75-
def test_cached_path_extract(xz_file, tmp_path, text_file):
76-
filename = xz_file
91+
def test_zstd_extractor(zstd_path, tmp_path, text_file):
92+
input_path = zstd_path
93+
assert ZstdExtractor.is_extractable(input_path)
94+
output_path = str(tmp_path / "extracted.txt")
95+
ZstdExtractor.extract(input_path, output_path)
96+
with open(output_path) as f:
97+
extracted_file_content = f.read()
98+
with open(text_file) as f:
99+
expected_file_content = f.read()
100+
assert extracted_file_content == expected_file_content
101+
102+
103+
@pytest.mark.parametrize("compression_format", ["xz", "zstd"])
104+
def test_cached_path_extract(compression_format, xz_file, zstd_path, tmp_path, text_file):
105+
path = {"xz": xz_file, "zstd": zstd_path}
106+
input_path = path[compression_format]
77107
cache_dir = tmp_path / "cache"
78108
download_config = DownloadConfig(cache_dir=cache_dir, extract_compressed_file=True)
79-
extracted_filename = cached_path(filename, download_config=download_config)
80-
with open(extracted_filename) as f:
109+
extracted_path = cached_path(input_path, download_config=download_config)
110+
with open(extracted_path) as f:
81111
extracted_file_content = f.read()
82112
with open(text_file) as f:
83113
expected_file_content = f.read()

0 commit comments

Comments
 (0)