diff --git a/docs/source/dataset_streaming.rst b/docs/source/dataset_streaming.rst index 07973a0c808..aab8e009fbd 100644 --- a/docs/source/dataset_streaming.rst +++ b/docs/source/dataset_streaming.rst @@ -164,3 +164,108 @@ It is possible to get a ``torch.utils.data.IterableDataset`` from a :class:`data {'input_ids': tensor([[101, 11047, 10497, 7869, 2352...]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0...]]), 'attention_mask': tensor([[1, 1, 1, 1, 1...]])} For now, only the PyTorch format is supported but support for TensorFlow and others will be added soon. + + +How does dataset streaming work ? +-------------------------------------------------- + +The StreamingDownloadManager +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The standard (i.e. non-streaming) way of loading a dataset has two steps: + +1. download and extract the raw data files of the dataset by using the :class:`datasets.DownloadManager` +2. process the data files to generate the Arrow file used to load the :class:`datasets.Dataset` object. + +For example, in non-streaming mode a file is simply downloaded like this: + +.. code-block:: + + >>> from datasets import DownloadManager + >>> url = "https://huggingface.co/datasets/lhoestq/test/resolve/main/some_text.txt" + >>> filepath = DownloadManager().download(url) # the file is downloaded here + >>> print(filepath) + '/Users/user/.cache/huggingface/datasets/downloads/16b702620cad8d485bafea59b1d2ed69e796196e6f2c73f005dee935a413aa19.ab631f60c6cb31a079ecf1ad910005a7c009ef0f1e4905b69d489fb2bd162683' + >>> with open(filepath) as f: + ... print(f.read()) + +When you load a dataset in streaming mode, the download manager that is used instead is the :class:`datasets.StreamingDownloadManager`. +Instead of actually downloading and extracting all the data when you load the dataset, it is done lazily. +The file starts to be downloaded and extracted only when ``open`` is called. +This is made possible by extending ``open`` to support opening remote files via HTTP. +In each dataset script, ``open`` is replaced by our function ``xopen`` that extends ``open`` to be able to stream data from remote files. + +Here is a sample code that shows what is done under the hood: + +.. code-block:: + + >>> from datasets.utils.streaming_download_manager import StreamingDownloadManager, xopen + >>> url = "https://huggingface.co/datasets/lhoestq/test/resolve/main/some_text.txt" + >>> urlpath = StreamingDownloadManager().download(url) + >>> print(urlpath) + 'https://huggingface.co/datasets/lhoestq/test/resolve/main/some_text.txt' + >>> with xopen(urlpath) as f: + ... print(f.read()) # the file is actually downloaded here + +As you can see, since it's possible to open remote files via an URL, the streaming download manager just returns the URL instead of the path to the local downloaded file. + +Then the file is downloaded in a streaming fashion: it is downloaded progessively as you iterate over the data file. +This is made possible because it is based on ``fsspec``, a library that allows to open and iterate on remote files. +You can find more information about ``fsspec`` in `its documentation `_ + +Compressed files and archives +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +You may have noticed that the streaming download manager returns the exact same URL that was given as input for a text file. +However if you use ``download_and_extract`` on a compressed file instead, then the output url will be a chained URL. +Chained URLs are used by ``fsspec`` to navigate in remote compressed archives. + +Some examples of chained URL are: + +.. code-block:: + + >>> from datasets.utils.streaming_download_manager import xopen + >>> chained_url = "zip://combined/train.json::https://adversarialqa.github.io/data/aqa_v1.0.zip" + >>> with xopen(chained_url) as f: + ... print(f.read()[:100]) + '{"data": [{"title": "Brain", "paragraphs": [{"context": "Another approach to brain function is to ex' + >>> chained_url2 = "gzip://mkqa.jsonl::https://github.com/apple/ml-mkqa/raw/master/dataset/mkqa.jsonl.gz" + >>> with xopen(chained_url2) as f: + ... print(f.readline()[:100]) + '{"query": "how long did it take the twin towers to be built", "answers": {"en": [{"type": "number_wi' + +We also extended some functions from ``os.path`` to work with chained URLs. +For example ``os.path.join`` is replaced by our function ``xjoin`` that extends ``os.path.join`` to work with chained URLs: + +.. code-block:: + + >>> from datasets.utils.streaming_download_manager import StreamingDownloadManager, xopen, xjoin + >>> url = "https://adversarialqa.github.io/data/aqa_v1.0.zip" + >>> archive_path = StreamingDownloadManager().download_and_extract(url) + >>> print(archive_path) + 'zip://::https://adversarialqa.github.io/data/aqa_v1.0.zip' + >>> filepath = xjoin(archive_path, "combined", "train.json") + >>> print(filepath) + 'zip://combined/train.json::https://adversarialqa.github.io/data/aqa_v1.0.zip' + >>> with xopen(filepath) as f: + ... print(f.read()[:100]) + '{"data": [{"title": "Brain", "paragraphs": [{"context": "Another approach to brain function is to ex' + +You can also take a look at the ``fsspec`` documentation about URL chaining `here `_ + +.. note:: + + Streaming data from TAR archives is currently highly inefficient and requires a lot of bandwidth. We are working on optimizing this to offer you the best performance, stay tuned ! + +Dataset script compatibility +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Now that you are aware of how dataset streaming works, you can make sure your dataset script work in streaming mode: + +1. make sure you use ``open`` to open the data files: it is extended to work with remote files +2. if you have to deal with archives like ZIP files, make sure you use ``os.path.join`` to navigate in the archive + +Currently a few python functions or classes are not supported for dataset streaming: + +- ``pathlib.Path`` and all its methods are not supported, please use ``os.path.join`` and string objects +- ``os.walk``, ``os.listdir``, ``glob.glob`` are not supported yet diff --git a/src/datasets/config.py b/src/datasets/config.py index 3a19c8206e5..6b29c6ac8ef 100644 --- a/src/datasets/config.py +++ b/src/datasets/config.py @@ -120,21 +120,10 @@ logger.info("Disabling Apache Beam because USE_BEAM is set to False") -USE_RAR = os.environ.get("USE_RAR", "AUTO").upper() -RARFILE_VERSION = "N/A" -RARFILE_AVAILABLE = False -if USE_RAR in ("1", "ON", "YES", "AUTO"): - try: - RARFILE_VERSION = version.parse(importlib_metadata.version("rarfile")) - RARFILE_AVAILABLE = True - logger.info("rarfile available.") - except importlib_metadata.PackageNotFoundError: - pass -else: - logger.info("Disabling rarfile because USE_RAR is set to False") - - +# Optional compression tools +RARFILE_AVAILABLE = importlib.util.find_spec("rarfile") is not None ZSTANDARD_AVAILABLE = importlib.util.find_spec("zstandard") is not None +LZ4_AVAILABLE = importlib.util.find_spec("lz4") is not None # Cache location diff --git a/src/datasets/filesystems/__init__.py b/src/datasets/filesystems/__init__.py index 1e925dd6d1a..a274a3cd1e0 100644 --- a/src/datasets/filesystems/__init__.py +++ b/src/datasets/filesystems/__init__.py @@ -1,4 +1,5 @@ import importlib +from typing import List import fsspec @@ -10,9 +11,17 @@ if _has_s3fs: from .s3filesystem import S3FileSystem # noqa: F401 +COMPRESSION_FILESYSTEMS: List[compression.BaseCompressedFileFileSystem] = [ + compression.Bz2FileSystem, + compression.GzipFileSystem, + compression.Lz4FileSystem, + compression.XzFileSystem, + compression.ZstdFileSystem, +] # Register custom filesystems -fsspec.register_implementation(compression.gzip.GZipFileSystem.protocol, compression.gzip.GZipFileSystem) +for fs_class in COMPRESSION_FILESYSTEMS: + fsspec.register_implementation(fs_class.protocol, fs_class) def extract_path_from_uri(dataset_path: str) -> str: diff --git a/src/datasets/filesystems/compression.py b/src/datasets/filesystems/compression.py new file mode 100644 index 00000000000..280d3215097 --- /dev/null +++ b/src/datasets/filesystems/compression.py @@ -0,0 +1,168 @@ +import os +from typing import Optional + +import fsspec +from fsspec.archive import AbstractArchiveFileSystem +from fsspec.utils import DEFAULT_BLOCK_SIZE + + +class BaseCompressedFileFileSystem(AbstractArchiveFileSystem): + """Read contents of compressed file as a filesystem with one file inside.""" + + root_marker = "" + protocol: str = ( + None # protocol passed in prefix to the url. ex: "gzip", for gzip://file.txt::http://foo.bar/file.txt.gz + ) + compression: str = None # compression type in fsspec. ex: "gzip" + extension: str = None # extension of the filename to strip. ex: "".gz" to get file.txt from file.txt.gz + + def __init__( + self, fo: str = "", target_protocol: Optional[str] = None, target_options: Optional[dict] = None, **kwargs + ): + """ + The compressed file system can be instantiated from any compressed file. + It reads the contents of compressed file as a filesystem with one file inside, as if it was an archive. + + The single file inside the filesystem is named after the compresssed file, + without the compression extension at the end of the filename. + + Args: + fo (:obj:``str``): Path to compressed file. Will fetch file using ``fsspec.open()`` + mode (:obj:``str``): Currently, only 'rb' accepted + target_protocol(:obj:``str``, optional): To override the FS protocol inferred from a URL. + target_options (:obj:``dict``, optional): Kwargs passed when instantiating the target FS. + """ + super().__init__(self, **kwargs) + # always open as "rb" since fsspec can then use the TextIOWrapper to make it work for "r" mode + self.file = fsspec.open( + fo, mode="rb", protocol=target_protocol, compression=self.compression, **(target_options or {}) + ) + self.info = self.file.fs.info(self.file.path) + self.compressed_name = os.path.basename(self.file.path.split("::")[0]) + self.uncompressed_name = self.compressed_name[: self.compressed_name.rindex(".")] + self.dir_cache = None + + @classmethod + def _strip_protocol(cls, path): + # compressed file paths are always relative to the archive root + return super()._strip_protocol(path).lstrip("/") + + def _get_dirs(self): + if self.dir_cache is None: + f = {**self.info, "name": self.uncompressed_name} + self.dir_cache = {f["name"]: f} + + def cat(self, path: str): + return self.file.open().read() + + def _open( + self, + path: str, + mode: str = "rb", + block_size=None, + autocommit=True, + cache_options=None, + **kwargs, + ): + path = self._strip_protocol(path) + if mode != "rb": + raise ValueError(f"Tried to read with mode {mode} on file {self.file.path} opened with mode 'rb'") + if path != self.uncompressed_name: + raise FileNotFoundError(f"Expected file {self.uncompressed_name} but got {path}") + return self.file.open() + + +class Bz2FileSystem(BaseCompressedFileFileSystem): + """Read contents of BZ2 file as a filesystem with one file inside.""" + + protocol = "bz2" + compression = "bz2" + extension = ".bz2" + + +class GzipFileSystem(BaseCompressedFileFileSystem): + """Read contents of GZIP file as a filesystem with one file inside.""" + + protocol = "gzip" + compression = "gzip" + extension = ".gz" + + +class Lz4FileSystem(BaseCompressedFileFileSystem): + """Read contents of LZ4 file as a filesystem with one file inside.""" + + protocol = "lz4" + compression = "lz4" + extension = ".lz4" + + +class XzFileSystem(BaseCompressedFileFileSystem): + """Read contents of .xz (LZMA) file as a filesystem with one file inside.""" + + protocol = "xz" + compression = "xz" + extension = ".xz" + + +class ZstdFileSystem(BaseCompressedFileFileSystem): + """ + Read contents of zstd file as a filesystem with one file inside. + + Note that reading in binary mode with fsspec isn't supported yet: + https://github.com/indygreg/python-zstandard/issues/136 + """ + + protocol = "zstd" + compression = "zstd" + extension = ".zst" + + def __init__( + self, + fo: str, + mode: str = "rb", + target_protocol: Optional[str] = None, + target_options: Optional[dict] = None, + block_size: int = DEFAULT_BLOCK_SIZE, + **kwargs, + ): + super().__init__( + fo=fo, + mode=mode, + target_protocol=target_protocol, + target_options=target_options, + block_size=block_size, + **kwargs, + ) + # We need to wrap the zstd decompressor to avoid this error in fsspec==2021.7.0 and zstandard==0.15.2: + # + # File "/Users/user/.virtualenvs/hf-datasets/lib/python3.7/site-packages/fsspec/core.py", line 145, in open + # out.close = close + # AttributeError: 'zstd.ZstdDecompressionReader' object attribute 'close' is read-only + # + # see https://github.com/intake/filesystem_spec/issues/725 + _enter = self.file.__enter__ + + class WrappedFile: + def __init__(self, file_): + self._file = file_ + + def __enter__(self): + self._file.__enter__() + return self + + def __exit__(self, *args, **kwargs): + self._file.__exit__(*args, **kwargs) + + def __iter__(self): + return iter(self._file) + + def __next__(self): + return next(self._file) + + def __getattr__(self, attr): + return getattr(self._file, attr) + + def fixed_enter(*args, **kwargs): + return WrappedFile(_enter(*args, **kwargs)) + + self.file.__enter__ = fixed_enter diff --git a/src/datasets/filesystems/compression/__init__.py b/src/datasets/filesystems/compression/__init__.py deleted file mode 100644 index 1ea8c6ff2c3..00000000000 --- a/src/datasets/filesystems/compression/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from . import gzip # noqa: F401 diff --git a/src/datasets/filesystems/compression/gzip.py b/src/datasets/filesystems/compression/gzip.py deleted file mode 100644 index 314d4133ad2..00000000000 --- a/src/datasets/filesystems/compression/gzip.py +++ /dev/null @@ -1,70 +0,0 @@ -import os -from typing import Optional - -import fsspec -from fsspec.archive import AbstractArchiveFileSystem -from fsspec.utils import DEFAULT_BLOCK_SIZE - - -class GZipFileSystem(AbstractArchiveFileSystem): - """Read contents of GZIP archive as a file-system with one file inside.""" - - root_marker = "" - protocol = "gzip" - - def __init__( - self, - fo: str = "", - mode: str = "rb", - target_protocol: Optional[str] = None, - target_options: Optional[dict] = None, - block_size: int = DEFAULT_BLOCK_SIZE, - **kwargs, - ): - """ - The GZipFileSystem can be instantiated from any gzip file. - It read the contents of GZip archive as a file-system with one file inside. - The single file inside the filesystem is named after the Gzip file, without ".gz" at the end. - - Args: - fo (:obj:``str``): Path to file containing GZIP. Will fetch file using ``fsspec.open()`` - mode (:obj:``str``): Currently, only 'rb' accepted - target_protocol(:obj:``str``, optional): To override the FS protocol inferred from a URL. - target_options (:obj:``dict``, optional): Kwargs passed when instantiating the target FS. - """ - super().__init__(self, **kwargs) - if mode != "rb": - raise ValueError("Only read from gzip files accepted") - self.gzip = fsspec.open(fo, mode=mode, protocol=target_protocol, compression="gzip", **(target_options or {})) - self.info = self.gzip.fs.info(self.gzip.path) - self.compressed_name = os.path.basename(self.gzip.path.split("::")[0]).rstrip(".gz") - self.uncompressed_name = self.compressed_name.rstrip(".gz") - self.block_size = block_size - self.dir_cache = None - - @classmethod - def _strip_protocol(cls, path): - # gzip file paths are always relative to the archive root - return super()._strip_protocol(path).lstrip("/") - - def _get_dirs(self): - if self.dir_cache is None: - f = {**self.info, "name": self.uncompressed_name} - self.dir_cache = {f["name"]: f} - - def cat(self, path: str): - return self.gzip.open().read() - - def _open( - self, - path: str, - mode: str = "rb", - block_size: Optional[int] = None, - autocommit: bool = True, - cache_options: Optional[dict] = None, - **kwargs, - ): - path = self._strip_protocol(path) - if path != self.uncompressed_name: - raise FileNotFoundError(f"Expected file {self.uncompressed_name} but got {path}") - return self.gzip.open() diff --git a/src/datasets/utils/streaming_download_manager.py b/src/datasets/utils/streaming_download_manager.py index 2768cba7997..1a684a647fc 100644 --- a/src/datasets/utils/streaming_download_manager.py +++ b/src/datasets/utils/streaming_download_manager.py @@ -8,6 +8,7 @@ from aiohttp.client_exceptions import ClientError from .. import config +from ..filesystems import COMPRESSION_FILESYSTEMS from .download_manager import DownloadConfig, map_nested from .file_utils import get_authentication_headers_for_url, is_local_path, is_relative_path, url_or_path_join from .logging import get_logger @@ -15,7 +16,17 @@ logger = get_logger(__name__) BASE_KNOWN_EXTENSIONS = ["txt", "csv", "json", "jsonl", "tsv", "conll", "conllu", "parquet", "pkl", "pickle", "xml"] -COMPRESSION_KNOWN_EXTENSIONS = ["bz2", "gz", "lz4", "xz", "zip", "zst"] + +COMPRESSION_EXTENSION_TO_PROTOCOL = { + # single file compression + **{fs_class.extension.lstrip("."): fs_class.protocol for fs_class in COMPRESSION_FILESYSTEMS}, + # archive compression + "zip": "zip", + "tar": "tar", + "tgz": "tar", +} + +SINGLE_FILE_COMPRESSION_PROTOCOLS = {fs_class.protocol for fs_class in COMPRESSION_FILESYSTEMS} def xjoin(a, *p): @@ -34,19 +45,14 @@ def xjoin(a, *p): Example:: - >>> xjoin("https://host.com/archive.zip", "folder1/file.txt") + >>> xjoin("zip://folder1::https://host.com/archive.zip", "file.txt") zip://folder1/file.txt::https://host.com/archive.zip """ a, *b = a.split("::") if is_local_path(a): a = Path(a, *p).as_posix() else: - compression = fsspec.core.get_compression(a, "infer") - if compression in ["zip"]: - b = [a] + b - a = posixpath.join(f"{compression}://", *p) - else: - a = posixpath.join(a, *p) + a = posixpath.join(a, *p) return "::".join([a] + b) @@ -69,22 +75,9 @@ def read_with_retries(*args, **kwargs): return out file_obj.read = read_with_retries - return file_obj - -def _add_retries_to_fsspec_open_file(fsspec_open_file): - open_ = fsspec_open_file.open - def open_with_retries(): - file_obj = open_() - _add_retries_to_file_obj_read_method(file_obj) - return file_obj - - fsspec_open_file.open = open_with_retries - return fsspec_open_file - - -def xopen(file, mode="r", compression="infer", *args, **kwargs): +def xopen(file, mode="r", *args, **kwargs): """ This function extends the builtin `open` function to support remote files using fsspec. @@ -93,9 +86,9 @@ def xopen(file, mode="r", compression="infer", *args, **kwargs): """ if fsspec.get_fs_token_paths(file)[0].protocol == "https": kwargs["headers"] = get_authentication_headers_for_url(file, use_auth_token=kwargs.pop("use_auth_token", None)) - fsspec_open_file = fsspec.open(file, mode=mode, compression=compression, *args, **kwargs) - fsspec_open_file = _add_retries_to_fsspec_open_file(fsspec_open_file) - return fsspec_open_file + file_obj = fsspec.open(file, mode=mode, *args, **kwargs).open() + _add_retries_to_file_obj_read_method(file_obj) + return file_obj class StreamingDownloadManager(object): @@ -126,30 +119,42 @@ def download(self, url_or_urls): url_or_urls = map_nested(self._download, url_or_urls, map_tuple=True) return url_or_urls - def _download(self, url_or_filename): - if is_relative_path(url_or_filename): + def _download(self, urlpath: str) -> str: + if is_relative_path(urlpath): # append the relative path to the base_path - url_or_filename = url_or_path_join(self._base_path, url_or_filename) - return url_or_filename + urlpath = url_or_path_join(self._base_path, urlpath) + return urlpath def extract(self, path_or_paths): urlpaths = map_nested(self._extract, path_or_paths, map_tuple=True) return urlpaths - def _extract(self, urlpath): + def _extract(self, urlpath: str) -> str: protocol = self._get_extraction_protocol(urlpath) if protocol is None: # no extraction return urlpath + elif protocol in SINGLE_FILE_COMPRESSION_PROTOCOLS: + # there is one single file which is the uncompressed file + inner_file = os.path.basename(urlpath.split("::")[0]) + inner_file = inner_file[: inner_file.rindex(".")] + # check for tar.gz, tar.bz2 etc. + if inner_file.endswith(".tar"): + return f"tar://::{urlpath}" + else: + return f"{protocol}://{inner_file}::{urlpath}" else: - return f"{protocol}://*::{urlpath}" + return f"{protocol}://::{urlpath}" - def _get_extraction_protocol(self, urlpath) -> Optional[str]: + def _get_extraction_protocol(self, urlpath: str) -> Optional[str]: path = urlpath.split("::")[0] - if path.split(".")[-1] in BASE_KNOWN_EXTENSIONS + COMPRESSION_KNOWN_EXTENSIONS: + extension = path.split(".")[-1] + if extension in BASE_KNOWN_EXTENSIONS: return None - elif path.endswith(".tar"): - return "tar" + elif path.endswith(".tar.gz"): + pass + elif extension in COMPRESSION_EXTENSION_TO_PROTOCOL: + return COMPRESSION_EXTENSION_TO_PROTOCOL[extension] raise NotImplementedError(f"Extraction protocol for file at {urlpath} is not implemented yet") def download_and_extract(self, url_or_urls): diff --git a/tests/conftest.py b/tests/conftest.py index 052d039b91c..3221c2b3f09 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,7 @@ import pyarrow.parquet as pq import pytest +from datasets import config from datasets.arrow_dataset import Dataset from datasets.features import ClassLabel, Features, Sequence, Value @@ -80,7 +81,7 @@ def text_file(tmp_path_factory): @pytest.fixture(scope="session") def xz_file(tmp_path_factory): - filename = tmp_path_factory.mktemp("data") / "file.xz" + filename = tmp_path_factory.mktemp("data") / "file.txt.xz" data = bytes(FILE_CONTENT, "utf-8") with lzma.open(filename, "wb") as f: f.write(data) @@ -88,16 +89,51 @@ def xz_file(tmp_path_factory): @pytest.fixture(scope="session") -def gz_path(tmp_path_factory, text_path): +def gz_file(tmp_path_factory): import gzip - path = str(tmp_path_factory.mktemp("data") / "file.gz") + path = str(tmp_path_factory.mktemp("data") / "file.txt.gz") data = bytes(FILE_CONTENT, "utf-8") with gzip.open(path, "wb") as f: f.write(data) return path +@pytest.fixture(scope="session") +def bz2_file(tmp_path_factory): + import bz2 + + path = tmp_path_factory.mktemp("data") / "file.txt.bz2" + data = bytes(FILE_CONTENT, "utf-8") + with bz2.open(path, "wb") as f: + f.write(data) + return path + + +@pytest.fixture(scope="session") +def zstd_file(tmp_path_factory): + if config.ZSTANDARD_AVAILABLE: + import zstandard as zstd + + path = tmp_path_factory.mktemp("data") / "file.txt.zst" + data = bytes(FILE_CONTENT, "utf-8") + with zstd.open(path, "wb") as f: + f.write(data) + return path + + +@pytest.fixture(scope="session") +def lz4_file(tmp_path_factory): + if config.LZ4_AVAILABLE: + import lz4.frame + + path = tmp_path_factory.mktemp("data") / "file.txt.lz4" + data = bytes(FILE_CONTENT, "utf-8") + with lz4.frame.open(path, "wb") as f: + f.write(data) + return path + + @pytest.fixture(scope="session") def xml_file(tmp_path_factory): filename = tmp_path_factory.mktemp("data") / "file.xml" diff --git a/tests/test_extract.py b/tests/test_extract.py index 909ddf83895..0967f98a670 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -1,25 +1,13 @@ import pytest -import zstandard as zstd from datasets.utils.extract import Extractor, ZstdExtractor +from .utils import require_zstandard -FILE_CONTENT = """\ - Text data. - Second line of data.""" - -@pytest.fixture(scope="session") -def zstd_path(tmp_path_factory): - path = tmp_path_factory.mktemp("data") / "file.zstd" - data = bytes(FILE_CONTENT, "utf-8") - with zstd.open(path, "wb") as f: - f.write(data) - return path - - -def test_zstd_extractor(zstd_path, tmp_path, text_file): - input_path = zstd_path +@require_zstandard +def test_zstd_extractor(zstd_file, tmp_path, text_file): + input_path = zstd_file assert ZstdExtractor.is_extractable(input_path) output_path = str(tmp_path / "extracted.txt") ZstdExtractor.extract(input_path, output_path) @@ -30,21 +18,16 @@ def test_zstd_extractor(zstd_path, tmp_path, text_file): assert extracted_file_content == expected_file_content -@pytest.mark.parametrize( - "compression_format, expected_text_path_name", [("gzip", "text_path"), ("xz", "text_file"), ("zstd", "text_file")] -) -def test_extractor( - compression_format, expected_text_path_name, text_gz_path, xz_file, zstd_path, tmp_path, text_file, text_path -): - input_paths = {"gzip": text_gz_path, "xz": xz_file, "zstd": zstd_path} +@require_zstandard +@pytest.mark.parametrize("compression_format", ["gzip", "xz", "zstd"]) +def test_extractor(compression_format, gz_file, xz_file, zstd_file, tmp_path, text_file): + input_paths = {"gzip": gz_file, "xz": xz_file, "zstd": zstd_file} input_path = str(input_paths[compression_format]) output_path = str(tmp_path / "extracted.txt") assert Extractor.is_extractable(input_path) Extractor.extract(input_path, output_path) with open(output_path) as f: extracted_file_content = f.read() - expected_text_paths = {"text_file": text_file, "text_path": text_path} - expected_text_path = str(expected_text_paths[expected_text_path_name]) - with open(expected_text_path) as f: + with open(text_file) as f: expected_file_content = f.read() assert extracted_file_content == expected_file_content diff --git a/tests/test_file_utils.py b/tests/test_file_utils.py index 283b8b5a3ed..d623e64e8bb 100644 --- a/tests/test_file_utils.py +++ b/tests/test_file_utils.py @@ -88,8 +88,8 @@ def gen_random_output(): @pytest.mark.parametrize("compression_format", ["gzip", "xz", "zstd"]) -def test_cached_path_extract(compression_format, gz_path, xz_file, zstd_path, tmp_path, text_file): - input_paths = {"gzip": gz_path, "xz": xz_file, "zstd": zstd_path} +def test_cached_path_extract(compression_format, gz_file, xz_file, zstd_path, tmp_path, text_file): + input_paths = {"gzip": gz_file, "xz": xz_file, "zstd": zstd_path} input_path = str(input_paths[compression_format]) cache_dir = tmp_path / "cache" download_config = DownloadConfig(cache_dir=cache_dir, extract_compressed_file=True) diff --git a/tests/test_filesystem.py b/tests/test_filesystem.py index 50c9e6fbe13..a71da75f516 100644 --- a/tests/test_filesystem.py +++ b/tests/test_filesystem.py @@ -5,8 +5,9 @@ import pytest from moto import mock_s3 -from datasets.filesystems import S3FileSystem, extract_path_from_uri, is_remote_filesystem -from datasets.filesystems.compression.gzip import GZipFileSystem +from datasets.filesystems import COMPRESSION_FILESYSTEMS, S3FileSystem, extract_path_from_uri, is_remote_filesystem + +from .utils import require_lz4, require_zstandard @pytest.fixture(scope="function") @@ -53,10 +54,16 @@ def test_is_remote_filesystem(): assert is_remote is False -def test_gzip_filesystem(text_gz_path, text_path): - fs = fsspec.filesystem("gzip", fo=text_gz_path) - assert isinstance(fs, GZipFileSystem) - expected_filename = os.path.basename(text_gz_path).rstrip(".gz") +@require_zstandard +@require_lz4 +@pytest.mark.parametrize("compression_fs_class", COMPRESSION_FILESYSTEMS) +def test_compression_filesystems(compression_fs_class, gz_file, bz2_file, lz4_file, zstd_file, xz_file, text_file): + input_paths = {"gzip": gz_file, "xz": xz_file, "zstd": zstd_file, "bz2": bz2_file, "lz4": lz4_file} + input_path = str(input_paths[compression_fs_class.protocol]) + fs = fsspec.filesystem(compression_fs_class.protocol, fo=input_path) + assert isinstance(fs, compression_fs_class) + expected_filename = os.path.basename(input_path) + expected_filename = expected_filename[: expected_filename.rindex(".")] assert fs.ls("/") == [expected_filename] - with fs.open(expected_filename, "r", encoding="utf-8") as f, open(text_path, encoding="utf-8") as expected_file: + with fs.open(expected_filename, "r", encoding="utf-8") as f, open(text_file, encoding="utf-8") as expected_file: assert f.read() == expected_file.read() diff --git a/tests/test_load.py b/tests/test_load.py index 701ed70be87..c61dd494cb2 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -254,6 +254,9 @@ def test_load_dataset_streaming_gz_json(jsonl_gz_path): def test_load_dataset_streaming_compressed_files(path): repo_id = "albertvillanova/datasets-tests-compression" data_files = f"https://huggingface.co/datasets/{repo_id}/resolve/main/{path}" + if data_files[-3:] in ("zip", "tar"): # we need to glob "*" inside archives + data_files = data_files[-3:] + "://*::" + data_files + return # TODO(QL, albert): support re-add support for ZIP and TAR archives streaming ds = load_dataset("json", split="train", data_files=data_files, streaming=True) assert isinstance(ds, IterableDataset) ds_item = next(iter(ds)) diff --git a/tests/test_streaming_download_manager.py b/tests/test_streaming_download_manager.py index 2e46e712aa4..12d7e667850 100644 --- a/tests/test_streaming_download_manager.py +++ b/tests/test_streaming_download_manager.py @@ -1,6 +1,10 @@ +import os + import pytest -from .utils import require_streaming +from datasets.filesystems import COMPRESSION_FILESYSTEMS + +from .utils import require_lz4, require_streaming, require_zstandard TEST_URL = "https://huggingface.co/datasets/lhoestq/test/raw/main/some_text.txt" @@ -11,7 +15,12 @@ @pytest.mark.parametrize( "input_path, paths_to_join, expected_path", [ - ("https://host.com/archive.zip", ("file.txt",), "zip://file.txt::https://host.com/archive.zip"), + ("https://host.com/archive.zip", ("file.txt",), "https://host.com/archive.zip/file.txt"), + ( + "zip://::https://host.com/archive.zip", + ("file.txt",), + "zip://file.txt::https://host.com/archive.zip", + ), ( "zip://folder::https://host.com/archive.zip", ("file.txt",), @@ -72,14 +81,16 @@ def test_streaming_dl_manager_download_and_extract_no_extraction(urlpath): @require_streaming -def test_streaming_dl_manager_extract(text_gz_path): +def test_streaming_dl_manager_extract(text_gz_path, text_path): from datasets.utils.streaming_download_manager import StreamingDownloadManager, xopen dl_manager = StreamingDownloadManager() output_path = dl_manager.extract(text_gz_path) - assert output_path == text_gz_path - fsspec_open_file = xopen(output_path) - assert fsspec_open_file.compression == "gzip" + path = os.path.basename(text_gz_path).rstrip(".gz") + assert output_path == f"gzip://{path}::{text_gz_path}" + fsspec_open_file = xopen(output_path, encoding="utf-8") + with fsspec_open_file as f, open(text_path, encoding="utf-8") as expected_file: + assert f.read() == expected_file.read() @require_streaming @@ -88,9 +99,9 @@ def test_streaming_dl_manager_download_and_extract_with_extraction(text_gz_path, dl_manager = StreamingDownloadManager() output_path = dl_manager.download_and_extract(text_gz_path) - assert output_path == text_gz_path + path = os.path.basename(text_gz_path).rstrip(".gz") + assert output_path == f"gzip://{path}::{text_gz_path}" fsspec_open_file = xopen(output_path, encoding="utf-8") - assert output_path == text_gz_path with fsspec_open_file as f, open(text_path, encoding="utf-8") as expected_file: assert f.read() == expected_file.read() @@ -107,3 +118,24 @@ def test_streaming_dl_manager_download_and_extract_with_join(input_path, filenam extracted_path = dl_manager.download_and_extract(input_path) output_path = xjoin(extracted_path, filename) assert output_path == expected_path + + +@require_streaming +@require_zstandard +@require_lz4 +@pytest.mark.parametrize("compression_fs_class", COMPRESSION_FILESYSTEMS) +def test_streaming_dl_manager_extract_all_supported_single_file_compression_types( + compression_fs_class, gz_file, xz_file, zstd_file, bz2_file, lz4_file, text_file +): + from datasets.utils.streaming_download_manager import StreamingDownloadManager, xopen + + input_paths = {"gzip": gz_file, "xz": xz_file, "zstd": zstd_file, "bz2": bz2_file, "lz4": lz4_file} + input_path = str(input_paths[compression_fs_class.protocol]) + dl_manager = StreamingDownloadManager() + output_path = dl_manager.extract(input_path) + path = os.path.basename(input_path) + path = path[: path.rindex(".")] + assert output_path == f"{compression_fs_class.protocol}://{path}::{input_path}" + fsspec_open_file = xopen(output_path, encoding="utf-8") + with fsspec_open_file as f, open(text_file, encoding="utf-8") as expected_file: + assert f.read() == expected_file.read() diff --git a/tests/utils.py b/tests/utils.py index 010f6a2f93c..ee26201e986 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -137,6 +137,30 @@ def require_jax(test_case): return test_case +def require_zstandard(test_case): + """ + Decorator marking a test that requires zstandard. + + These tests are skipped when zstandard isn't installed. + + """ + if not config.ZSTANDARD_AVAILABLE: + test_case = unittest.skip("test requires zstandard")(test_case) + return test_case + + +def require_lz4(test_case): + """ + Decorator marking a test that requires lz4. + + These tests are skipped when lz4 isn't installed. + + """ + if not config.LZ4_AVAILABLE: + test_case = unittest.skip("test requires lz4")(test_case) + return test_case + + def require_transformers(test_case): """ Decorator marking a test that requires transformers.