Skip to content

Commit c2a06b5

Browse files
Test extractors for all compression formats (#4689)
* Add more compressed files for tests * Import zipfile at module level * Add more compression formats to test_extractor * Test base extractors for all compressions * Rename filename to path
1 parent dd8c636 commit c2a06b5

File tree

2 files changed

+120
-68
lines changed

2 files changed

+120
-68
lines changed

tests/conftest.py

Lines changed: 44 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import tarfile
66
import textwrap
7+
import zipfile
78

89
import pyarrow as pa
910
import pyarrow.parquet as pq
@@ -95,12 +96,14 @@ def text_file(tmp_path_factory):
9596

9697

9798
@pytest.fixture(scope="session")
98-
def xz_file(tmp_path_factory):
99-
filename = tmp_path_factory.mktemp("data") / "file.txt.xz"
99+
def bz2_file(tmp_path_factory):
100+
import bz2
101+
102+
path = tmp_path_factory.mktemp("data") / "file.txt.bz2"
100103
data = bytes(FILE_CONTENT, "utf-8")
101-
with lzma.open(filename, "wb") as f:
104+
with bz2.open(path, "wb") as f:
102105
f.write(data)
103-
return filename
106+
return path
104107

105108

106109
@pytest.fixture(scope="session")
@@ -114,29 +117,6 @@ def gz_file(tmp_path_factory):
114117
return path
115118

116119

117-
@pytest.fixture(scope="session")
118-
def bz2_file(tmp_path_factory):
119-
import bz2
120-
121-
path = tmp_path_factory.mktemp("data") / "file.txt.bz2"
122-
data = bytes(FILE_CONTENT, "utf-8")
123-
with bz2.open(path, "wb") as f:
124-
f.write(data)
125-
return path
126-
127-
128-
@pytest.fixture(scope="session")
129-
def zstd_file(tmp_path_factory):
130-
if config.ZSTANDARD_AVAILABLE:
131-
import zstandard as zstd
132-
133-
path = tmp_path_factory.mktemp("data") / "file.txt.zst"
134-
data = bytes(FILE_CONTENT, "utf-8")
135-
with zstd.open(path, "wb") as f:
136-
f.write(data)
137-
return path
138-
139-
140120
@pytest.fixture(scope="session")
141121
def lz4_file(tmp_path_factory):
142122
if config.LZ4_AVAILABLE:
@@ -160,6 +140,43 @@ def seven_zip_file(tmp_path_factory, text_file):
160140
return path
161141

162142

143+
@pytest.fixture(scope="session")
144+
def tar_file(tmp_path_factory, text_file):
145+
path = tmp_path_factory.mktemp("data") / "file.txt.tar"
146+
with tarfile.TarFile(path, "w") as f:
147+
f.add(text_file, arcname=os.path.basename(text_file))
148+
return path
149+
150+
151+
@pytest.fixture(scope="session")
152+
def xz_file(tmp_path_factory):
153+
path = tmp_path_factory.mktemp("data") / "file.txt.xz"
154+
data = bytes(FILE_CONTENT, "utf-8")
155+
with lzma.open(path, "wb") as f:
156+
f.write(data)
157+
return path
158+
159+
160+
@pytest.fixture(scope="session")
161+
def zip_file(tmp_path_factory, text_file):
162+
path = tmp_path_factory.mktemp("data") / "file.txt.zip"
163+
with zipfile.ZipFile(path, "w") as f:
164+
f.write(text_file, arcname=os.path.basename(text_file))
165+
return path
166+
167+
168+
@pytest.fixture(scope="session")
169+
def zstd_file(tmp_path_factory):
170+
if config.ZSTANDARD_AVAILABLE:
171+
import zstandard as zstd
172+
173+
path = tmp_path_factory.mktemp("data") / "file.txt.zst"
174+
data = bytes(FILE_CONTENT, "utf-8")
175+
with zstd.open(path, "wb") as f:
176+
f.write(data)
177+
return path
178+
179+
163180
@pytest.fixture(scope="session")
164181
def xml_file(tmp_path_factory):
165182
filename = tmp_path_factory.mktemp("data") / "file.xml"
@@ -276,8 +293,6 @@ def bz2_csv_path(csv_path, tmp_path_factory):
276293

277294
@pytest.fixture(scope="session")
278295
def zip_csv_path(csv_path, csv2_path, tmp_path_factory):
279-
import zipfile
280-
281296
path = tmp_path_factory.mktemp("data") / "dataset.csv.zip"
282297
with zipfile.ZipFile(path, "w") as f:
283298
f.write(csv_path, arcname=os.path.basename(csv_path))
@@ -287,8 +302,6 @@ def zip_csv_path(csv_path, csv2_path, tmp_path_factory):
287302

288303
@pytest.fixture(scope="session")
289304
def zip_csv_with_dir_path(csv_path, csv2_path, tmp_path_factory):
290-
import zipfile
291-
292305
path = tmp_path_factory.mktemp("data") / "dataset_with_dir.csv.zip"
293306
with zipfile.ZipFile(path, "w") as f:
294307
f.write(csv_path, arcname=os.path.join("main_dir", os.path.basename(csv_path)))
@@ -392,8 +405,6 @@ def jsonl_gz_path(tmp_path_factory, jsonl_path):
392405

393406
@pytest.fixture(scope="session")
394407
def zip_jsonl_path(jsonl_path, jsonl2_path, tmp_path_factory):
395-
import zipfile
396-
397408
path = tmp_path_factory.mktemp("data") / "dataset.jsonl.zip"
398409
with zipfile.ZipFile(path, "w") as f:
399410
f.write(jsonl_path, arcname=os.path.basename(jsonl_path))
@@ -403,8 +414,6 @@ def zip_jsonl_path(jsonl_path, jsonl2_path, tmp_path_factory):
403414

404415
@pytest.fixture(scope="session")
405416
def zip_jsonl_with_dir_path(jsonl_path, jsonl2_path, tmp_path_factory):
406-
import zipfile
407-
408417
path = tmp_path_factory.mktemp("data") / "dataset_with_dir.jsonl.zip"
409418
with zipfile.ZipFile(path, "w") as f:
410419
f.write(jsonl_path, arcname=os.path.join("main_dir", os.path.basename(jsonl_path)))
@@ -451,8 +460,6 @@ def text2_path(tmp_path_factory):
451460

452461
@pytest.fixture(scope="session")
453462
def zip_text_path(text_path, text2_path, tmp_path_factory):
454-
import zipfile
455-
456463
path = tmp_path_factory.mktemp("data") / "dataset.text.zip"
457464
with zipfile.ZipFile(path, "w") as f:
458465
f.write(text_path, arcname=os.path.basename(text_path))
@@ -462,8 +469,6 @@ def zip_text_path(text_path, text2_path, tmp_path_factory):
462469

463470
@pytest.fixture(scope="session")
464471
def zip_text_with_dir_path(text_path, text2_path, tmp_path_factory):
465-
import zipfile
466-
467472
path = tmp_path_factory.mktemp("data") / "dataset_with_dir.text.zip"
468473
with zipfile.ZipFile(path, "w") as f:
469474
f.write(text_path, arcname=os.path.join("main_dir", os.path.basename(text_path)))
@@ -487,8 +492,6 @@ def image_file():
487492

488493
@pytest.fixture(scope="session")
489494
def zip_image_path(image_file, tmp_path_factory):
490-
import zipfile
491-
492495
path = tmp_path_factory.mktemp("data") / "dataset.img.zip"
493496
with zipfile.ZipFile(path, "w") as f:
494497
f.write(image_file, arcname=os.path.basename(image_file))

tests/test_extract.py

Lines changed: 76 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,93 @@
11
import pytest
22

3-
from datasets.utils.extract import Extractor, SevenZipExtractor, ZstdExtractor
3+
from datasets.utils.extract import (
4+
Bzip2Extractor,
5+
Extractor,
6+
GzipExtractor,
7+
SevenZipExtractor,
8+
TarExtractor,
9+
XzExtractor,
10+
ZipExtractor,
11+
ZstdExtractor,
12+
)
413

514
from .utils import require_py7zr, require_zstandard
615

716

8-
@require_py7zr
9-
def test_seven_zip_extractor(seven_zip_file, tmp_path, text_file):
10-
input_path = seven_zip_file
11-
assert SevenZipExtractor.is_extractable(input_path)
12-
output_path = tmp_path / "extracted"
13-
SevenZipExtractor.extract(input_path, output_path)
14-
assert output_path.is_dir()
15-
for file_path in output_path.iterdir():
16-
assert file_path.name == text_file.name
17-
extracted_file_content = file_path.read_text(encoding="utf-8")
17+
@pytest.mark.parametrize(
18+
"compression_format, is_archive",
19+
[("7z", True), ("bz2", False), ("gzip", False), ("tar", True), ("xz", False), ("zip", True), ("zstd", False)],
20+
)
21+
def test_base_extractors(
22+
compression_format,
23+
is_archive,
24+
bz2_file,
25+
gz_file,
26+
seven_zip_file,
27+
tar_file,
28+
xz_file,
29+
zip_file,
30+
zstd_file,
31+
tmp_path,
32+
text_file,
33+
):
34+
input_paths_and_base_extractors = {
35+
"7z": (seven_zip_file, SevenZipExtractor),
36+
"bz2": (bz2_file, Bzip2Extractor),
37+
"gzip": (gz_file, GzipExtractor),
38+
"tar": (tar_file, TarExtractor),
39+
"xz": (xz_file, XzExtractor),
40+
"zip": (zip_file, ZipExtractor),
41+
"zstd": (zstd_file, ZstdExtractor),
42+
}
43+
input_path, base_extractor = input_paths_and_base_extractors[compression_format]
44+
if input_path is None:
45+
reason = f"for '{compression_format}' compression_format, "
46+
if compression_format == "7z":
47+
reason += require_py7zr.kwargs["reason"]
48+
elif compression_format == "zstd":
49+
reason += require_zstandard.kwargs["reason"]
50+
pytest.skip(reason)
51+
assert base_extractor.is_extractable(input_path)
52+
output_path = tmp_path / ("extracted" if is_archive else "extracted.txt")
53+
base_extractor.extract(input_path, output_path)
54+
if is_archive:
55+
assert output_path.is_dir()
56+
for file_path in output_path.iterdir():
57+
assert file_path.name == text_file.name
58+
extracted_file_content = file_path.read_text(encoding="utf-8")
59+
else:
60+
extracted_file_content = output_path.read_text(encoding="utf-8")
1861
expected_file_content = text_file.read_text(encoding="utf-8")
1962
assert extracted_file_content == expected_file_content
2063

2164

22-
@require_zstandard
23-
def test_zstd_extractor(zstd_file, tmp_path, text_file):
24-
input_path = zstd_file
25-
assert ZstdExtractor.is_extractable(input_path)
26-
output_path = str(tmp_path / "extracted.txt")
27-
ZstdExtractor.extract(input_path, output_path)
28-
with open(output_path) as f:
29-
extracted_file_content = f.read()
30-
with open(text_file) as f:
31-
expected_file_content = f.read()
32-
assert extracted_file_content == expected_file_content
33-
34-
3565
@pytest.mark.parametrize(
36-
"compression_format, is_archive", [("gzip", False), ("xz", False), ("zstd", False), ("bz2", False), ("7z", True)]
66+
"compression_format, is_archive",
67+
[("7z", True), ("bz2", False), ("gzip", False), ("tar", True), ("xz", False), ("zip", True), ("zstd", False)],
3768
)
3869
def test_extractor(
39-
compression_format, is_archive, gz_file, xz_file, zstd_file, bz2_file, seven_zip_file, tmp_path, text_file
70+
compression_format,
71+
is_archive,
72+
bz2_file,
73+
gz_file,
74+
seven_zip_file,
75+
tar_file,
76+
xz_file,
77+
zip_file,
78+
zstd_file,
79+
tmp_path,
80+
text_file,
4081
):
41-
input_paths = {"gzip": gz_file, "xz": xz_file, "zstd": zstd_file, "bz2": bz2_file, "7z": seven_zip_file}
82+
input_paths = {
83+
"7z": seven_zip_file,
84+
"bz2": bz2_file,
85+
"gzip": gz_file,
86+
"tar": tar_file,
87+
"xz": xz_file,
88+
"zip": zip_file,
89+
"zstd": zstd_file,
90+
}
4291
input_path = input_paths[compression_format]
4392
if input_path is None:
4493
reason = f"for '{compression_format}' compression_format, "

0 commit comments

Comments
 (0)