Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 44 additions & 41 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import tarfile
import textwrap
import zipfile

import pyarrow as pa
import pyarrow.parquet as pq
Expand Down Expand Up @@ -95,12 +96,14 @@ def text_file(tmp_path_factory):


@pytest.fixture(scope="session")
def xz_file(tmp_path_factory):
filename = tmp_path_factory.mktemp("data") / "file.txt.xz"
def bz2_file(tmp_path_factory):
import bz2

path = tmp_path_factory.mktemp("data") / "file.txt.bz2"
data = bytes(FILE_CONTENT, "utf-8")
with lzma.open(filename, "wb") as f:
with bz2.open(path, "wb") as f:
f.write(data)
return filename
return path


@pytest.fixture(scope="session")
Expand All @@ -114,29 +117,6 @@ def gz_file(tmp_path_factory):
return path


@pytest.fixture(scope="session")
def bz2_file(tmp_path_factory):
import bz2

path = tmp_path_factory.mktemp("data") / "file.txt.bz2"
data = bytes(FILE_CONTENT, "utf-8")
with bz2.open(path, "wb") as f:
f.write(data)
return path


@pytest.fixture(scope="session")
def zstd_file(tmp_path_factory):
if config.ZSTANDARD_AVAILABLE:
import zstandard as zstd

path = tmp_path_factory.mktemp("data") / "file.txt.zst"
data = bytes(FILE_CONTENT, "utf-8")
with zstd.open(path, "wb") as f:
f.write(data)
return path


@pytest.fixture(scope="session")
def lz4_file(tmp_path_factory):
if config.LZ4_AVAILABLE:
Expand All @@ -160,6 +140,43 @@ def seven_zip_file(tmp_path_factory, text_file):
return path


@pytest.fixture(scope="session")
def tar_file(tmp_path_factory, text_file):
path = tmp_path_factory.mktemp("data") / "file.txt.tar"
with tarfile.TarFile(path, "w") as f:
f.add(text_file, arcname=os.path.basename(text_file))
return path


@pytest.fixture(scope="session")
def xz_file(tmp_path_factory):
path = tmp_path_factory.mktemp("data") / "file.txt.xz"
data = bytes(FILE_CONTENT, "utf-8")
with lzma.open(path, "wb") as f:
f.write(data)
return path


@pytest.fixture(scope="session")
def zip_file(tmp_path_factory, text_file):
path = tmp_path_factory.mktemp("data") / "file.txt.zip"
with zipfile.ZipFile(path, "w") as f:
f.write(text_file, arcname=os.path.basename(text_file))
return path


@pytest.fixture(scope="session")
def zstd_file(tmp_path_factory):
if config.ZSTANDARD_AVAILABLE:
import zstandard as zstd

path = tmp_path_factory.mktemp("data") / "file.txt.zst"
data = bytes(FILE_CONTENT, "utf-8")
with zstd.open(path, "wb") as f:
f.write(data)
return path


@pytest.fixture(scope="session")
def xml_file(tmp_path_factory):
filename = tmp_path_factory.mktemp("data") / "file.xml"
Expand Down Expand Up @@ -276,8 +293,6 @@ def bz2_csv_path(csv_path, tmp_path_factory):

@pytest.fixture(scope="session")
def zip_csv_path(csv_path, csv2_path, tmp_path_factory):
import zipfile

path = tmp_path_factory.mktemp("data") / "dataset.csv.zip"
with zipfile.ZipFile(path, "w") as f:
f.write(csv_path, arcname=os.path.basename(csv_path))
Expand All @@ -287,8 +302,6 @@ def zip_csv_path(csv_path, csv2_path, tmp_path_factory):

@pytest.fixture(scope="session")
def zip_csv_with_dir_path(csv_path, csv2_path, tmp_path_factory):
import zipfile

path = tmp_path_factory.mktemp("data") / "dataset_with_dir.csv.zip"
with zipfile.ZipFile(path, "w") as f:
f.write(csv_path, arcname=os.path.join("main_dir", os.path.basename(csv_path)))
Expand Down Expand Up @@ -392,8 +405,6 @@ def jsonl_gz_path(tmp_path_factory, jsonl_path):

@pytest.fixture(scope="session")
def zip_jsonl_path(jsonl_path, jsonl2_path, tmp_path_factory):
import zipfile

path = tmp_path_factory.mktemp("data") / "dataset.jsonl.zip"
with zipfile.ZipFile(path, "w") as f:
f.write(jsonl_path, arcname=os.path.basename(jsonl_path))
Expand All @@ -403,8 +414,6 @@ def zip_jsonl_path(jsonl_path, jsonl2_path, tmp_path_factory):

@pytest.fixture(scope="session")
def zip_jsonl_with_dir_path(jsonl_path, jsonl2_path, tmp_path_factory):
import zipfile

path = tmp_path_factory.mktemp("data") / "dataset_with_dir.jsonl.zip"
with zipfile.ZipFile(path, "w") as f:
f.write(jsonl_path, arcname=os.path.join("main_dir", os.path.basename(jsonl_path)))
Expand Down Expand Up @@ -451,8 +460,6 @@ def text2_path(tmp_path_factory):

@pytest.fixture(scope="session")
def zip_text_path(text_path, text2_path, tmp_path_factory):
import zipfile

path = tmp_path_factory.mktemp("data") / "dataset.text.zip"
with zipfile.ZipFile(path, "w") as f:
f.write(text_path, arcname=os.path.basename(text_path))
Expand All @@ -462,8 +469,6 @@ def zip_text_path(text_path, text2_path, tmp_path_factory):

@pytest.fixture(scope="session")
def zip_text_with_dir_path(text_path, text2_path, tmp_path_factory):
import zipfile

path = tmp_path_factory.mktemp("data") / "dataset_with_dir.text.zip"
with zipfile.ZipFile(path, "w") as f:
f.write(text_path, arcname=os.path.join("main_dir", os.path.basename(text_path)))
Expand All @@ -487,8 +492,6 @@ def image_file():

@pytest.fixture(scope="session")
def zip_image_path(image_file, tmp_path_factory):
import zipfile

path = tmp_path_factory.mktemp("data") / "dataset.img.zip"
with zipfile.ZipFile(path, "w") as f:
f.write(image_file, arcname=os.path.basename(image_file))
Expand Down
103 changes: 76 additions & 27 deletions tests/test_extract.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,93 @@
import pytest

from datasets.utils.extract import Extractor, SevenZipExtractor, ZstdExtractor
from datasets.utils.extract import (
Bzip2Extractor,
Extractor,
GzipExtractor,
SevenZipExtractor,
TarExtractor,
XzExtractor,
ZipExtractor,
ZstdExtractor,
)

from .utils import require_py7zr, require_zstandard


@require_py7zr
def test_seven_zip_extractor(seven_zip_file, tmp_path, text_file):
input_path = seven_zip_file
assert SevenZipExtractor.is_extractable(input_path)
output_path = tmp_path / "extracted"
SevenZipExtractor.extract(input_path, output_path)
assert output_path.is_dir()
for file_path in output_path.iterdir():
assert file_path.name == text_file.name
extracted_file_content = file_path.read_text(encoding="utf-8")
@pytest.mark.parametrize(
"compression_format, is_archive",
[("7z", True), ("bz2", False), ("gzip", False), ("tar", True), ("xz", False), ("zip", True), ("zstd", False)],
)
def test_base_extractors(
compression_format,
is_archive,
bz2_file,
gz_file,
seven_zip_file,
tar_file,
xz_file,
zip_file,
zstd_file,
tmp_path,
text_file,
):
input_paths_and_base_extractors = {
"7z": (seven_zip_file, SevenZipExtractor),
"bz2": (bz2_file, Bzip2Extractor),
"gzip": (gz_file, GzipExtractor),
"tar": (tar_file, TarExtractor),
"xz": (xz_file, XzExtractor),
"zip": (zip_file, ZipExtractor),
"zstd": (zstd_file, ZstdExtractor),
}
input_path, base_extractor = input_paths_and_base_extractors[compression_format]
if input_path is None:
reason = f"for '{compression_format}' compression_format, "
if compression_format == "7z":
reason += require_py7zr.kwargs["reason"]
elif compression_format == "zstd":
reason += require_zstandard.kwargs["reason"]
pytest.skip(reason)
assert base_extractor.is_extractable(input_path)
output_path = tmp_path / ("extracted" if is_archive else "extracted.txt")
base_extractor.extract(input_path, output_path)
if is_archive:
assert output_path.is_dir()
for file_path in output_path.iterdir():
assert file_path.name == text_file.name
extracted_file_content = file_path.read_text(encoding="utf-8")
else:
extracted_file_content = output_path.read_text(encoding="utf-8")
expected_file_content = text_file.read_text(encoding="utf-8")
assert extracted_file_content == expected_file_content


@require_zstandard
def test_zstd_extractor(zstd_file, tmp_path, text_file):
input_path = zstd_file
assert ZstdExtractor.is_extractable(input_path)
output_path = str(tmp_path / "extracted.txt")
ZstdExtractor.extract(input_path, output_path)
with open(output_path) as f:
extracted_file_content = f.read()
with open(text_file) as f:
expected_file_content = f.read()
assert extracted_file_content == expected_file_content


@pytest.mark.parametrize(
"compression_format, is_archive", [("gzip", False), ("xz", False), ("zstd", False), ("bz2", False), ("7z", True)]
"compression_format, is_archive",
[("7z", True), ("bz2", False), ("gzip", False), ("tar", True), ("xz", False), ("zip", True), ("zstd", False)],
)
def test_extractor(
compression_format, is_archive, gz_file, xz_file, zstd_file, bz2_file, seven_zip_file, tmp_path, text_file
compression_format,
is_archive,
bz2_file,
gz_file,
seven_zip_file,
tar_file,
xz_file,
zip_file,
zstd_file,
tmp_path,
text_file,
):
input_paths = {"gzip": gz_file, "xz": xz_file, "zstd": zstd_file, "bz2": bz2_file, "7z": seven_zip_file}
input_paths = {
"7z": seven_zip_file,
"bz2": bz2_file,
"gzip": gz_file,
"tar": tar_file,
"xz": xz_file,
"zip": zip_file,
"zstd": zstd_file,
}
input_path = input_paths[compression_format]
if input_path is None:
reason = f"for '{compression_format}' compression_format, "
Expand Down