diff --git a/src/datasets/utils/extract.py b/src/datasets/utils/extract.py index bf6f5963033..06c9e7c09f4 100644 --- a/src/datasets/utils/extract.py +++ b/src/datasets/utils/extract.py @@ -1,3 +1,4 @@ +import bz2 import gzip import lzma import os @@ -157,9 +158,29 @@ def extract(input_path: str, output_path: str): dctx.copy_stream(ifh, ofh) +class Bzip2Extractor: + @staticmethod + def is_extractable(path: str) -> bool: + with open(path, "rb") as f: + try: + header_magic_bytes = f.read(3) + except OSError: + return False + if header_magic_bytes == b"BZh": + return True + else: + return False + + @staticmethod + def extract(input_path, output_path): + with bz2.open(input_path, "rb") as compressed_file: + with open(output_path, "wb") as extracted_file: + shutil.copyfileobj(compressed_file, extracted_file) + + class Extractor: # Put zip file to the last, b/c it is possible wrongly detected as zip (I guess it means: as tar or gzip) - extractors = [TarExtractor, GzipExtractor, ZipExtractor, XzExtractor, RarExtractor, ZstdExtractor] + extractors = [TarExtractor, GzipExtractor, ZipExtractor, XzExtractor, RarExtractor, ZstdExtractor, Bzip2Extractor] @classmethod def is_extractable(cls, path, return_extractor=False): diff --git a/tests/test_extract.py b/tests/test_extract.py index 0967f98a670..8533084930e 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -19,9 +19,9 @@ def test_zstd_extractor(zstd_file, tmp_path, text_file): @require_zstandard -@pytest.mark.parametrize("compression_format", ["gzip", "xz", "zstd"]) -def test_extractor(compression_format, gz_file, xz_file, zstd_file, tmp_path, text_file): - input_paths = {"gzip": gz_file, "xz": xz_file, "zstd": zstd_file} +@pytest.mark.parametrize("compression_format", ["gzip", "xz", "zstd", "bz2"]) +def test_extractor(compression_format, gz_file, xz_file, zstd_file, bz2_file, tmp_path, text_file): + input_paths = {"gzip": gz_file, "xz": xz_file, "zstd": zstd_file, "bz2": bz2_file} input_path = str(input_paths[compression_format]) output_path = str(tmp_path / "extracted.txt") assert Extractor.is_extractable(input_path)