Skip to content

Commit b891219

Browse files
committed
test streaming gz, lz4, bz2, xz and zst
1 parent 8c866ed commit b891219

6 files changed

Lines changed: 100 additions & 45 deletions

File tree

src/datasets/config.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -120,21 +120,10 @@
120120
logger.info("Disabling Apache Beam because USE_BEAM is set to False")
121121

122122

123-
USE_RAR = os.environ.get("USE_RAR", "AUTO").upper()
124-
RARFILE_VERSION = "N/A"
125-
RARFILE_AVAILABLE = False
126-
if USE_RAR in ("1", "ON", "YES", "AUTO"):
127-
try:
128-
RARFILE_VERSION = version.parse(importlib_metadata.version("rarfile"))
129-
RARFILE_AVAILABLE = True
130-
logger.info("rarfile available.")
131-
except importlib_metadata.PackageNotFoundError:
132-
pass
133-
else:
134-
logger.info("Disabling rarfile because USE_RAR is set to False")
135-
136-
123+
# Optional compression tools
124+
RARFILE_AVAILABLE = importlib.util.find_spec("rarfile") is not None
137125
ZSTANDARD_AVAILABLE = importlib.util.find_spec("zstandard") is not None
126+
LZ4_AVAILABLE = importlib.util.find_spec("lz4") is not None
138127

139128

140129
# Cache location

src/datasets/filesystems/compression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(
3939
)
4040
self.info = self.file.fs.info(self.file.path)
4141
self.compressed_name = os.path.basename(self.file.path.split("::")[0])
42-
self.uncompressed_name = self.compressed_name.rstrip(self.extension)
42+
self.uncompressed_name = self.compressed_name[: self.compressed_name.rindex(".")]
4343
self.dir_cache = None
4444

4545
@classmethod

tests/conftest.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pyarrow.parquet as pq
88
import pytest
99

10+
from datasets import config
1011
from datasets.arrow_dataset import Dataset
1112
from datasets.features import ClassLabel, Features, Sequence, Value
1213

@@ -80,24 +81,59 @@ def text_file(tmp_path_factory):
8081

8182
@pytest.fixture(scope="session")
8283
def xz_file(tmp_path_factory):
83-
filename = tmp_path_factory.mktemp("data") / "file.xz"
84+
filename = tmp_path_factory.mktemp("data") / "file.txt.xz"
8485
data = bytes(FILE_CONTENT, "utf-8")
8586
with lzma.open(filename, "wb") as f:
8687
f.write(data)
8788
return filename
8889

8990

9091
@pytest.fixture(scope="session")
91-
def gz_path(tmp_path_factory, text_path):
92+
def gz_file(tmp_path_factory):
9293
import gzip
9394

94-
path = str(tmp_path_factory.mktemp("data") / "file.gz")
95+
path = str(tmp_path_factory.mktemp("data") / "file.txt.gz")
9596
data = bytes(FILE_CONTENT, "utf-8")
9697
with gzip.open(path, "wb") as f:
9798
f.write(data)
9899
return path
99100

100101

102+
@pytest.fixture(scope="session")
103+
def bz2_file(tmp_path_factory):
104+
import bz2
105+
106+
path = tmp_path_factory.mktemp("data") / "file.txt.bz2"
107+
data = bytes(FILE_CONTENT, "utf-8")
108+
with bz2.open(path, "wb") as f:
109+
f.write(data)
110+
return path
111+
112+
113+
@pytest.fixture(scope="session")
114+
def zstd_file(tmp_path_factory):
115+
if config.ZSTANDARD_AVAILABLE:
116+
import zstandard as zstd
117+
118+
path = tmp_path_factory.mktemp("data") / "file.txt.zst"
119+
data = bytes(FILE_CONTENT, "utf-8")
120+
with zstd.open(path, "wb") as f:
121+
f.write(data)
122+
return path
123+
124+
125+
@pytest.fixture(scope="session")
126+
def lz4_file(tmp_path_factory):
127+
if config.LZ4_AVAILABLE:
128+
import lz4.frame
129+
130+
path = tmp_path_factory.mktemp("data") / "file.txt.lz4"
131+
data = bytes(FILE_CONTENT, "utf-8")
132+
with lz4.frame.open(path, "wb") as f:
133+
f.write(data)
134+
return path
135+
136+
101137
@pytest.fixture(scope="session")
102138
def xml_file(tmp_path_factory):
103139
filename = tmp_path_factory.mktemp("data") / "file.xml"

tests/test_extract.py

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,13 @@
11
import pytest
2-
import zstandard as zstd
32

43
from datasets.utils.extract import Extractor, ZstdExtractor
54

5+
from .utils import require_zstandard
66

7-
FILE_CONTENT = """\
8-
Text data.
9-
Second line of data."""
107

11-
12-
@pytest.fixture(scope="session")
13-
def zstd_path(tmp_path_factory):
14-
path = tmp_path_factory.mktemp("data") / "file.zstd"
15-
data = bytes(FILE_CONTENT, "utf-8")
16-
with zstd.open(path, "wb") as f:
17-
f.write(data)
18-
return path
19-
20-
21-
def test_zstd_extractor(zstd_path, tmp_path, text_file):
22-
input_path = zstd_path
8+
@require_zstandard
9+
def test_zstd_extractor(zstd_file, tmp_path, text_file):
10+
input_path = zstd_file
2311
assert ZstdExtractor.is_extractable(input_path)
2412
output_path = str(tmp_path / "extracted.txt")
2513
ZstdExtractor.extract(input_path, output_path)
@@ -30,21 +18,16 @@ def test_zstd_extractor(zstd_path, tmp_path, text_file):
3018
assert extracted_file_content == expected_file_content
3119

3220

33-
@pytest.mark.parametrize(
34-
"compression_format, expected_text_path_name", [("gzip", "text_path"), ("xz", "text_file"), ("zstd", "text_file")]
35-
)
36-
def test_extractor(
37-
compression_format, expected_text_path_name, text_gz_path, xz_file, zstd_path, tmp_path, text_file, text_path
38-
):
39-
input_paths = {"gzip": text_gz_path, "xz": xz_file, "zstd": zstd_path}
21+
@require_zstandard
22+
@pytest.mark.parametrize("compression_format", ["gzip", "xz", "zstd"])
23+
def test_extractor(compression_format, gz_file, xz_file, zstd_file, tmp_path, text_file):
24+
input_paths = {"gzip": gz_file, "xz": xz_file, "zstd": zstd_file}
4025
input_path = str(input_paths[compression_format])
4126
output_path = str(tmp_path / "extracted.txt")
4227
assert Extractor.is_extractable(input_path)
4328
Extractor.extract(input_path, output_path)
4429
with open(output_path) as f:
4530
extracted_file_content = f.read()
46-
expected_text_paths = {"text_file": text_file, "text_path": text_path}
47-
expected_text_path = str(expected_text_paths[expected_text_path_name])
48-
with open(expected_text_path) as f:
31+
with open(text_file) as f:
4932
expected_file_content = f.read()
5033
assert extracted_file_content == expected_file_content

tests/test_streaming_download_manager.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
import pytest
44

5-
from .utils import require_streaming
5+
from datasets.filesystems import COMPRESSION_FILESYSTEMS
6+
7+
from .utils import require_lz4, require_streaming, require_zstandard
68

79

810
TEST_URL = "https://huggingface.co/datasets/lhoestq/test/raw/main/some_text.txt"
@@ -116,3 +118,24 @@ def test_streaming_dl_manager_download_and_extract_with_join(input_path, filenam
116118
extracted_path = dl_manager.download_and_extract(input_path)
117119
output_path = xjoin(extracted_path, filename)
118120
assert output_path == expected_path
121+
122+
123+
@require_streaming
124+
@require_zstandard
125+
@require_lz4
126+
@pytest.mark.parametrize("compression_fs_class", COMPRESSION_FILESYSTEMS)
127+
def test_streaming_dl_manager_extract_all_supported_single_file_compression_types(
128+
compression_fs_class, gz_file, xz_file, zstd_file, bz2_file, lz4_file, text_file
129+
):
130+
from datasets.utils.streaming_download_manager import StreamingDownloadManager, xopen
131+
132+
input_paths = {"gzip": gz_file, "xz": xz_file, "zstd": zstd_file, "bz2": bz2_file, "lz4": lz4_file}
133+
input_path = str(input_paths[compression_fs_class.protocol])
134+
dl_manager = StreamingDownloadManager()
135+
output_path = dl_manager.extract(input_path)
136+
path = os.path.basename(input_path)
137+
path = path[: path.rindex(".")]
138+
assert output_path == f"{compression_fs_class.protocol}://{path}::{input_path}"
139+
fsspec_open_file = xopen(output_path, encoding="utf-8")
140+
with fsspec_open_file as f, open(text_file, encoding="utf-8") as expected_file:
141+
assert f.read() == expected_file.read()

tests/utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,30 @@ def require_jax(test_case):
137137
return test_case
138138

139139

140+
def require_zstandard(test_case):
141+
"""
142+
Decorator marking a test that requires zstandard.
143+
144+
These tests are skipped when zstandard isn't installed.
145+
146+
"""
147+
if not config.ZSTANDARD_AVAILABLE:
148+
test_case = unittest.skip("test requires zstandard")(test_case)
149+
return test_case
150+
151+
152+
def require_lz4(test_case):
153+
"""
154+
Decorator marking a test that requires lz4.
155+
156+
These tests are skipped when lz4 isn't installed.
157+
158+
"""
159+
if not config.LZ4_AVAILABLE:
160+
test_case = unittest.skip("test requires lz4")(test_case)
161+
return test_case
162+
163+
140164
def require_transformers(test_case):
141165
"""
142166
Decorator marking a test that requires transformers.

0 commit comments

Comments
 (0)