diff --git a/src/datasets/config.py b/src/datasets/config.py index 5a4dde161c2..5894b2d1832 100644 --- a/src/datasets/config.py +++ b/src/datasets/config.py @@ -134,6 +134,9 @@ logger.info("Disabling rarfile because USE_RAR is set to False") +ZSTANDARD_AVAILABLE = importlib.util.find_spec("zstandard") is not None + + # Cache location DEFAULT_XDG_CACHE_HOME = "~/.cache" XDG_CACHE_HOME = os.getenv("XDG_CACHE_HOME", DEFAULT_XDG_CACHE_HOME) diff --git a/src/datasets/utils/file_utils.py b/src/datasets/utils/file_utils.py index 63eb2b6a328..1855bfb459e 100644 --- a/src/datasets/utils/file_utils.py +++ b/src/datasets/utils/file_utils.py @@ -11,6 +11,7 @@ import os import re import shutil +import struct import sys import tarfile import tempfile @@ -307,6 +308,7 @@ def cached_path( and not is_gzip(output_path) and not is_xz(output_path) and not is_rarfile(output_path) + and not ZstdExtractor.is_extractable(output_path) ): return output_path @@ -360,6 +362,9 @@ def cached_path( rf.close() else: raise EnvironmentError("Please pip install rarfile") + elif ZstdExtractor.is_extractable(output_path): + os.rmdir(output_path_extracted) + ZstdExtractor.extract(output_path, output_path_extracted) else: raise EnvironmentError("Archive format of {} could not be identified".format(output_path)) @@ -724,6 +729,31 @@ def is_rarfile(path: str) -> bool: return False +class ZstdExtractor: + @staticmethod + def is_extractable(path: str) -> bool: + """https://datatracker.ietf.org/doc/html/rfc8878 + + Magic_Number: 4 bytes, little-endian format. Value: 0xFD2FB528. + """ + with open(path, "rb") as f: + try: + magic_number = f.read(4) + except OSError: + return False + return True if magic_number == struct.pack("