diff --git a/docs/source/dataset_streaming.rst b/docs/source/dataset_streaming.rst
index 07973a0c808..aab8e009fbd 100644
--- a/docs/source/dataset_streaming.rst
+++ b/docs/source/dataset_streaming.rst
@@ -164,3 +164,108 @@ It is possible to get a ``torch.utils.data.IterableDataset`` from a :class:`data
{'input_ids': tensor([[101, 11047, 10497, 7869, 2352...]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0...]]), 'attention_mask': tensor([[1, 1, 1, 1, 1...]])}
For now, only the PyTorch format is supported but support for TensorFlow and others will be added soon.
+
+
+How does dataset streaming work ?
+--------------------------------------------------
+
+The StreamingDownloadManager
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+The standard (i.e. non-streaming) way of loading a dataset has two steps:
+
+1. download and extract the raw data files of the dataset by using the :class:`datasets.DownloadManager`
+2. process the data files to generate the Arrow file used to load the :class:`datasets.Dataset` object.
+
+For example, in non-streaming mode a file is simply downloaded like this:
+
+.. code-block::
+
+ >>> from datasets import DownloadManager
+ >>> url = "https://huggingface.co/datasets/lhoestq/test/resolve/main/some_text.txt"
+ >>> filepath = DownloadManager().download(url) # the file is downloaded here
+ >>> print(filepath)
+ '/Users/user/.cache/huggingface/datasets/downloads/16b702620cad8d485bafea59b1d2ed69e796196e6f2c73f005dee935a413aa19.ab631f60c6cb31a079ecf1ad910005a7c009ef0f1e4905b69d489fb2bd162683'
+ >>> with open(filepath) as f:
+ ... print(f.read())
+
+When you load a dataset in streaming mode, the download manager that is used instead is the :class:`datasets.StreamingDownloadManager`.
+Instead of actually downloading and extracting all the data when you load the dataset, it is done lazily.
+The file starts to be downloaded and extracted only when ``open`` is called.
+This is made possible by extending ``open`` to support opening remote files via HTTP.
+In each dataset script, ``open`` is replaced by our function ``xopen`` that extends ``open`` to be able to stream data from remote files.
+
+Here is a sample code that shows what is done under the hood:
+
+.. code-block::
+
+ >>> from datasets.utils.streaming_download_manager import StreamingDownloadManager, xopen
+ >>> url = "https://huggingface.co/datasets/lhoestq/test/resolve/main/some_text.txt"
+ >>> urlpath = StreamingDownloadManager().download(url)
+ >>> print(urlpath)
+ 'https://huggingface.co/datasets/lhoestq/test/resolve/main/some_text.txt'
+ >>> with xopen(urlpath) as f:
+ ... print(f.read()) # the file is actually downloaded here
+
+As you can see, since it's possible to open remote files via an URL, the streaming download manager just returns the URL instead of the path to the local downloaded file.
+
+Then the file is downloaded in a streaming fashion: it is downloaded progessively as you iterate over the data file.
+This is made possible because it is based on ``fsspec``, a library that allows to open and iterate on remote files.
+You can find more information about ``fsspec`` in `its documentation `_
+
+Compressed files and archives
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+You may have noticed that the streaming download manager returns the exact same URL that was given as input for a text file.
+However if you use ``download_and_extract`` on a compressed file instead, then the output url will be a chained URL.
+Chained URLs are used by ``fsspec`` to navigate in remote compressed archives.
+
+Some examples of chained URL are:
+
+.. code-block::
+
+ >>> from datasets.utils.streaming_download_manager import xopen
+ >>> chained_url = "zip://combined/train.json::https://adversarialqa.github.io/data/aqa_v1.0.zip"
+ >>> with xopen(chained_url) as f:
+ ... print(f.read()[:100])
+ '{"data": [{"title": "Brain", "paragraphs": [{"context": "Another approach to brain function is to ex'
+ >>> chained_url2 = "gzip://mkqa.jsonl::https://github.com/apple/ml-mkqa/raw/master/dataset/mkqa.jsonl.gz"
+ >>> with xopen(chained_url2) as f:
+ ... print(f.readline()[:100])
+ '{"query": "how long did it take the twin towers to be built", "answers": {"en": [{"type": "number_wi'
+
+We also extended some functions from ``os.path`` to work with chained URLs.
+For example ``os.path.join`` is replaced by our function ``xjoin`` that extends ``os.path.join`` to work with chained URLs:
+
+.. code-block::
+
+ >>> from datasets.utils.streaming_download_manager import StreamingDownloadManager, xopen, xjoin
+ >>> url = "https://adversarialqa.github.io/data/aqa_v1.0.zip"
+ >>> archive_path = StreamingDownloadManager().download_and_extract(url)
+ >>> print(archive_path)
+ 'zip://::https://adversarialqa.github.io/data/aqa_v1.0.zip'
+ >>> filepath = xjoin(archive_path, "combined", "train.json")
+ >>> print(filepath)
+ 'zip://combined/train.json::https://adversarialqa.github.io/data/aqa_v1.0.zip'
+ >>> with xopen(filepath) as f:
+ ... print(f.read()[:100])
+ '{"data": [{"title": "Brain", "paragraphs": [{"context": "Another approach to brain function is to ex'
+
+You can also take a look at the ``fsspec`` documentation about URL chaining `here `_
+
+.. note::
+
+ Streaming data from TAR archives is currently highly inefficient and requires a lot of bandwidth. We are working on optimizing this to offer you the best performance, stay tuned !
+
+Dataset script compatibility
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Now that you are aware of how dataset streaming works, you can make sure your dataset script work in streaming mode:
+
+1. make sure you use ``open`` to open the data files: it is extended to work with remote files
+2. if you have to deal with archives like ZIP files, make sure you use ``os.path.join`` to navigate in the archive
+
+Currently a few python functions or classes are not supported for dataset streaming:
+
+- ``pathlib.Path`` and all its methods are not supported, please use ``os.path.join`` and string objects
+- ``os.walk``, ``os.listdir``, ``glob.glob`` are not supported yet
diff --git a/src/datasets/config.py b/src/datasets/config.py
index 3a19c8206e5..6b29c6ac8ef 100644
--- a/src/datasets/config.py
+++ b/src/datasets/config.py
@@ -120,21 +120,10 @@
logger.info("Disabling Apache Beam because USE_BEAM is set to False")
-USE_RAR = os.environ.get("USE_RAR", "AUTO").upper()
-RARFILE_VERSION = "N/A"
-RARFILE_AVAILABLE = False
-if USE_RAR in ("1", "ON", "YES", "AUTO"):
- try:
- RARFILE_VERSION = version.parse(importlib_metadata.version("rarfile"))
- RARFILE_AVAILABLE = True
- logger.info("rarfile available.")
- except importlib_metadata.PackageNotFoundError:
- pass
-else:
- logger.info("Disabling rarfile because USE_RAR is set to False")
-
-
+# Optional compression tools
+RARFILE_AVAILABLE = importlib.util.find_spec("rarfile") is not None
ZSTANDARD_AVAILABLE = importlib.util.find_spec("zstandard") is not None
+LZ4_AVAILABLE = importlib.util.find_spec("lz4") is not None
# Cache location
diff --git a/src/datasets/filesystems/__init__.py b/src/datasets/filesystems/__init__.py
index 1e925dd6d1a..a274a3cd1e0 100644
--- a/src/datasets/filesystems/__init__.py
+++ b/src/datasets/filesystems/__init__.py
@@ -1,4 +1,5 @@
import importlib
+from typing import List
import fsspec
@@ -10,9 +11,17 @@
if _has_s3fs:
from .s3filesystem import S3FileSystem # noqa: F401
+COMPRESSION_FILESYSTEMS: List[compression.BaseCompressedFileFileSystem] = [
+ compression.Bz2FileSystem,
+ compression.GzipFileSystem,
+ compression.Lz4FileSystem,
+ compression.XzFileSystem,
+ compression.ZstdFileSystem,
+]
# Register custom filesystems
-fsspec.register_implementation(compression.gzip.GZipFileSystem.protocol, compression.gzip.GZipFileSystem)
+for fs_class in COMPRESSION_FILESYSTEMS:
+ fsspec.register_implementation(fs_class.protocol, fs_class)
def extract_path_from_uri(dataset_path: str) -> str:
diff --git a/src/datasets/filesystems/compression.py b/src/datasets/filesystems/compression.py
new file mode 100644
index 00000000000..280d3215097
--- /dev/null
+++ b/src/datasets/filesystems/compression.py
@@ -0,0 +1,168 @@
+import os
+from typing import Optional
+
+import fsspec
+from fsspec.archive import AbstractArchiveFileSystem
+from fsspec.utils import DEFAULT_BLOCK_SIZE
+
+
+class BaseCompressedFileFileSystem(AbstractArchiveFileSystem):
+ """Read contents of compressed file as a filesystem with one file inside."""
+
+ root_marker = ""
+ protocol: str = (
+ None # protocol passed in prefix to the url. ex: "gzip", for gzip://file.txt::http://foo.bar/file.txt.gz
+ )
+ compression: str = None # compression type in fsspec. ex: "gzip"
+ extension: str = None # extension of the filename to strip. ex: "".gz" to get file.txt from file.txt.gz
+
+ def __init__(
+ self, fo: str = "", target_protocol: Optional[str] = None, target_options: Optional[dict] = None, **kwargs
+ ):
+ """
+ The compressed file system can be instantiated from any compressed file.
+ It reads the contents of compressed file as a filesystem with one file inside, as if it was an archive.
+
+ The single file inside the filesystem is named after the compresssed file,
+ without the compression extension at the end of the filename.
+
+ Args:
+ fo (:obj:``str``): Path to compressed file. Will fetch file using ``fsspec.open()``
+ mode (:obj:``str``): Currently, only 'rb' accepted
+ target_protocol(:obj:``str``, optional): To override the FS protocol inferred from a URL.
+ target_options (:obj:``dict``, optional): Kwargs passed when instantiating the target FS.
+ """
+ super().__init__(self, **kwargs)
+ # always open as "rb" since fsspec can then use the TextIOWrapper to make it work for "r" mode
+ self.file = fsspec.open(
+ fo, mode="rb", protocol=target_protocol, compression=self.compression, **(target_options or {})
+ )
+ self.info = self.file.fs.info(self.file.path)
+ self.compressed_name = os.path.basename(self.file.path.split("::")[0])
+ self.uncompressed_name = self.compressed_name[: self.compressed_name.rindex(".")]
+ self.dir_cache = None
+
+ @classmethod
+ def _strip_protocol(cls, path):
+ # compressed file paths are always relative to the archive root
+ return super()._strip_protocol(path).lstrip("/")
+
+ def _get_dirs(self):
+ if self.dir_cache is None:
+ f = {**self.info, "name": self.uncompressed_name}
+ self.dir_cache = {f["name"]: f}
+
+ def cat(self, path: str):
+ return self.file.open().read()
+
+ def _open(
+ self,
+ path: str,
+ mode: str = "rb",
+ block_size=None,
+ autocommit=True,
+ cache_options=None,
+ **kwargs,
+ ):
+ path = self._strip_protocol(path)
+ if mode != "rb":
+ raise ValueError(f"Tried to read with mode {mode} on file {self.file.path} opened with mode 'rb'")
+ if path != self.uncompressed_name:
+ raise FileNotFoundError(f"Expected file {self.uncompressed_name} but got {path}")
+ return self.file.open()
+
+
+class Bz2FileSystem(BaseCompressedFileFileSystem):
+ """Read contents of BZ2 file as a filesystem with one file inside."""
+
+ protocol = "bz2"
+ compression = "bz2"
+ extension = ".bz2"
+
+
+class GzipFileSystem(BaseCompressedFileFileSystem):
+ """Read contents of GZIP file as a filesystem with one file inside."""
+
+ protocol = "gzip"
+ compression = "gzip"
+ extension = ".gz"
+
+
+class Lz4FileSystem(BaseCompressedFileFileSystem):
+ """Read contents of LZ4 file as a filesystem with one file inside."""
+
+ protocol = "lz4"
+ compression = "lz4"
+ extension = ".lz4"
+
+
+class XzFileSystem(BaseCompressedFileFileSystem):
+ """Read contents of .xz (LZMA) file as a filesystem with one file inside."""
+
+ protocol = "xz"
+ compression = "xz"
+ extension = ".xz"
+
+
+class ZstdFileSystem(BaseCompressedFileFileSystem):
+ """
+ Read contents of zstd file as a filesystem with one file inside.
+
+ Note that reading in binary mode with fsspec isn't supported yet:
+ https://github.com/indygreg/python-zstandard/issues/136
+ """
+
+ protocol = "zstd"
+ compression = "zstd"
+ extension = ".zst"
+
+ def __init__(
+ self,
+ fo: str,
+ mode: str = "rb",
+ target_protocol: Optional[str] = None,
+ target_options: Optional[dict] = None,
+ block_size: int = DEFAULT_BLOCK_SIZE,
+ **kwargs,
+ ):
+ super().__init__(
+ fo=fo,
+ mode=mode,
+ target_protocol=target_protocol,
+ target_options=target_options,
+ block_size=block_size,
+ **kwargs,
+ )
+ # We need to wrap the zstd decompressor to avoid this error in fsspec==2021.7.0 and zstandard==0.15.2:
+ #
+ # File "/Users/user/.virtualenvs/hf-datasets/lib/python3.7/site-packages/fsspec/core.py", line 145, in open
+ # out.close = close
+ # AttributeError: 'zstd.ZstdDecompressionReader' object attribute 'close' is read-only
+ #
+ # see https://github.com/intake/filesystem_spec/issues/725
+ _enter = self.file.__enter__
+
+ class WrappedFile:
+ def __init__(self, file_):
+ self._file = file_
+
+ def __enter__(self):
+ self._file.__enter__()
+ return self
+
+ def __exit__(self, *args, **kwargs):
+ self._file.__exit__(*args, **kwargs)
+
+ def __iter__(self):
+ return iter(self._file)
+
+ def __next__(self):
+ return next(self._file)
+
+ def __getattr__(self, attr):
+ return getattr(self._file, attr)
+
+ def fixed_enter(*args, **kwargs):
+ return WrappedFile(_enter(*args, **kwargs))
+
+ self.file.__enter__ = fixed_enter
diff --git a/src/datasets/filesystems/compression/__init__.py b/src/datasets/filesystems/compression/__init__.py
deleted file mode 100644
index 1ea8c6ff2c3..00000000000
--- a/src/datasets/filesystems/compression/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from . import gzip # noqa: F401
diff --git a/src/datasets/filesystems/compression/gzip.py b/src/datasets/filesystems/compression/gzip.py
deleted file mode 100644
index 314d4133ad2..00000000000
--- a/src/datasets/filesystems/compression/gzip.py
+++ /dev/null
@@ -1,70 +0,0 @@
-import os
-from typing import Optional
-
-import fsspec
-from fsspec.archive import AbstractArchiveFileSystem
-from fsspec.utils import DEFAULT_BLOCK_SIZE
-
-
-class GZipFileSystem(AbstractArchiveFileSystem):
- """Read contents of GZIP archive as a file-system with one file inside."""
-
- root_marker = ""
- protocol = "gzip"
-
- def __init__(
- self,
- fo: str = "",
- mode: str = "rb",
- target_protocol: Optional[str] = None,
- target_options: Optional[dict] = None,
- block_size: int = DEFAULT_BLOCK_SIZE,
- **kwargs,
- ):
- """
- The GZipFileSystem can be instantiated from any gzip file.
- It read the contents of GZip archive as a file-system with one file inside.
- The single file inside the filesystem is named after the Gzip file, without ".gz" at the end.
-
- Args:
- fo (:obj:``str``): Path to file containing GZIP. Will fetch file using ``fsspec.open()``
- mode (:obj:``str``): Currently, only 'rb' accepted
- target_protocol(:obj:``str``, optional): To override the FS protocol inferred from a URL.
- target_options (:obj:``dict``, optional): Kwargs passed when instantiating the target FS.
- """
- super().__init__(self, **kwargs)
- if mode != "rb":
- raise ValueError("Only read from gzip files accepted")
- self.gzip = fsspec.open(fo, mode=mode, protocol=target_protocol, compression="gzip", **(target_options or {}))
- self.info = self.gzip.fs.info(self.gzip.path)
- self.compressed_name = os.path.basename(self.gzip.path.split("::")[0]).rstrip(".gz")
- self.uncompressed_name = self.compressed_name.rstrip(".gz")
- self.block_size = block_size
- self.dir_cache = None
-
- @classmethod
- def _strip_protocol(cls, path):
- # gzip file paths are always relative to the archive root
- return super()._strip_protocol(path).lstrip("/")
-
- def _get_dirs(self):
- if self.dir_cache is None:
- f = {**self.info, "name": self.uncompressed_name}
- self.dir_cache = {f["name"]: f}
-
- def cat(self, path: str):
- return self.gzip.open().read()
-
- def _open(
- self,
- path: str,
- mode: str = "rb",
- block_size: Optional[int] = None,
- autocommit: bool = True,
- cache_options: Optional[dict] = None,
- **kwargs,
- ):
- path = self._strip_protocol(path)
- if path != self.uncompressed_name:
- raise FileNotFoundError(f"Expected file {self.uncompressed_name} but got {path}")
- return self.gzip.open()
diff --git a/src/datasets/utils/streaming_download_manager.py b/src/datasets/utils/streaming_download_manager.py
index 2768cba7997..1a684a647fc 100644
--- a/src/datasets/utils/streaming_download_manager.py
+++ b/src/datasets/utils/streaming_download_manager.py
@@ -8,6 +8,7 @@
from aiohttp.client_exceptions import ClientError
from .. import config
+from ..filesystems import COMPRESSION_FILESYSTEMS
from .download_manager import DownloadConfig, map_nested
from .file_utils import get_authentication_headers_for_url, is_local_path, is_relative_path, url_or_path_join
from .logging import get_logger
@@ -15,7 +16,17 @@
logger = get_logger(__name__)
BASE_KNOWN_EXTENSIONS = ["txt", "csv", "json", "jsonl", "tsv", "conll", "conllu", "parquet", "pkl", "pickle", "xml"]
-COMPRESSION_KNOWN_EXTENSIONS = ["bz2", "gz", "lz4", "xz", "zip", "zst"]
+
+COMPRESSION_EXTENSION_TO_PROTOCOL = {
+ # single file compression
+ **{fs_class.extension.lstrip("."): fs_class.protocol for fs_class in COMPRESSION_FILESYSTEMS},
+ # archive compression
+ "zip": "zip",
+ "tar": "tar",
+ "tgz": "tar",
+}
+
+SINGLE_FILE_COMPRESSION_PROTOCOLS = {fs_class.protocol for fs_class in COMPRESSION_FILESYSTEMS}
def xjoin(a, *p):
@@ -34,19 +45,14 @@ def xjoin(a, *p):
Example::
- >>> xjoin("https://host.com/archive.zip", "folder1/file.txt")
+ >>> xjoin("zip://folder1::https://host.com/archive.zip", "file.txt")
zip://folder1/file.txt::https://host.com/archive.zip
"""
a, *b = a.split("::")
if is_local_path(a):
a = Path(a, *p).as_posix()
else:
- compression = fsspec.core.get_compression(a, "infer")
- if compression in ["zip"]:
- b = [a] + b
- a = posixpath.join(f"{compression}://", *p)
- else:
- a = posixpath.join(a, *p)
+ a = posixpath.join(a, *p)
return "::".join([a] + b)
@@ -69,22 +75,9 @@ def read_with_retries(*args, **kwargs):
return out
file_obj.read = read_with_retries
- return file_obj
-
-def _add_retries_to_fsspec_open_file(fsspec_open_file):
- open_ = fsspec_open_file.open
- def open_with_retries():
- file_obj = open_()
- _add_retries_to_file_obj_read_method(file_obj)
- return file_obj
-
- fsspec_open_file.open = open_with_retries
- return fsspec_open_file
-
-
-def xopen(file, mode="r", compression="infer", *args, **kwargs):
+def xopen(file, mode="r", *args, **kwargs):
"""
This function extends the builtin `open` function to support remote files using fsspec.
@@ -93,9 +86,9 @@ def xopen(file, mode="r", compression="infer", *args, **kwargs):
"""
if fsspec.get_fs_token_paths(file)[0].protocol == "https":
kwargs["headers"] = get_authentication_headers_for_url(file, use_auth_token=kwargs.pop("use_auth_token", None))
- fsspec_open_file = fsspec.open(file, mode=mode, compression=compression, *args, **kwargs)
- fsspec_open_file = _add_retries_to_fsspec_open_file(fsspec_open_file)
- return fsspec_open_file
+ file_obj = fsspec.open(file, mode=mode, *args, **kwargs).open()
+ _add_retries_to_file_obj_read_method(file_obj)
+ return file_obj
class StreamingDownloadManager(object):
@@ -126,30 +119,42 @@ def download(self, url_or_urls):
url_or_urls = map_nested(self._download, url_or_urls, map_tuple=True)
return url_or_urls
- def _download(self, url_or_filename):
- if is_relative_path(url_or_filename):
+ def _download(self, urlpath: str) -> str:
+ if is_relative_path(urlpath):
# append the relative path to the base_path
- url_or_filename = url_or_path_join(self._base_path, url_or_filename)
- return url_or_filename
+ urlpath = url_or_path_join(self._base_path, urlpath)
+ return urlpath
def extract(self, path_or_paths):
urlpaths = map_nested(self._extract, path_or_paths, map_tuple=True)
return urlpaths
- def _extract(self, urlpath):
+ def _extract(self, urlpath: str) -> str:
protocol = self._get_extraction_protocol(urlpath)
if protocol is None:
# no extraction
return urlpath
+ elif protocol in SINGLE_FILE_COMPRESSION_PROTOCOLS:
+ # there is one single file which is the uncompressed file
+ inner_file = os.path.basename(urlpath.split("::")[0])
+ inner_file = inner_file[: inner_file.rindex(".")]
+ # check for tar.gz, tar.bz2 etc.
+ if inner_file.endswith(".tar"):
+ return f"tar://::{urlpath}"
+ else:
+ return f"{protocol}://{inner_file}::{urlpath}"
else:
- return f"{protocol}://*::{urlpath}"
+ return f"{protocol}://::{urlpath}"
- def _get_extraction_protocol(self, urlpath) -> Optional[str]:
+ def _get_extraction_protocol(self, urlpath: str) -> Optional[str]:
path = urlpath.split("::")[0]
- if path.split(".")[-1] in BASE_KNOWN_EXTENSIONS + COMPRESSION_KNOWN_EXTENSIONS:
+ extension = path.split(".")[-1]
+ if extension in BASE_KNOWN_EXTENSIONS:
return None
- elif path.endswith(".tar"):
- return "tar"
+ elif path.endswith(".tar.gz"):
+ pass
+ elif extension in COMPRESSION_EXTENSION_TO_PROTOCOL:
+ return COMPRESSION_EXTENSION_TO_PROTOCOL[extension]
raise NotImplementedError(f"Extraction protocol for file at {urlpath} is not implemented yet")
def download_and_extract(self, url_or_urls):
diff --git a/tests/conftest.py b/tests/conftest.py
index 052d039b91c..3221c2b3f09 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -7,6 +7,7 @@
import pyarrow.parquet as pq
import pytest
+from datasets import config
from datasets.arrow_dataset import Dataset
from datasets.features import ClassLabel, Features, Sequence, Value
@@ -80,7 +81,7 @@ def text_file(tmp_path_factory):
@pytest.fixture(scope="session")
def xz_file(tmp_path_factory):
- filename = tmp_path_factory.mktemp("data") / "file.xz"
+ filename = tmp_path_factory.mktemp("data") / "file.txt.xz"
data = bytes(FILE_CONTENT, "utf-8")
with lzma.open(filename, "wb") as f:
f.write(data)
@@ -88,16 +89,51 @@ def xz_file(tmp_path_factory):
@pytest.fixture(scope="session")
-def gz_path(tmp_path_factory, text_path):
+def gz_file(tmp_path_factory):
import gzip
- path = str(tmp_path_factory.mktemp("data") / "file.gz")
+ path = str(tmp_path_factory.mktemp("data") / "file.txt.gz")
data = bytes(FILE_CONTENT, "utf-8")
with gzip.open(path, "wb") as f:
f.write(data)
return path
+@pytest.fixture(scope="session")
+def bz2_file(tmp_path_factory):
+ import bz2
+
+ path = tmp_path_factory.mktemp("data") / "file.txt.bz2"
+ data = bytes(FILE_CONTENT, "utf-8")
+ with bz2.open(path, "wb") as f:
+ f.write(data)
+ return path
+
+
+@pytest.fixture(scope="session")
+def zstd_file(tmp_path_factory):
+ if config.ZSTANDARD_AVAILABLE:
+ import zstandard as zstd
+
+ path = tmp_path_factory.mktemp("data") / "file.txt.zst"
+ data = bytes(FILE_CONTENT, "utf-8")
+ with zstd.open(path, "wb") as f:
+ f.write(data)
+ return path
+
+
+@pytest.fixture(scope="session")
+def lz4_file(tmp_path_factory):
+ if config.LZ4_AVAILABLE:
+ import lz4.frame
+
+ path = tmp_path_factory.mktemp("data") / "file.txt.lz4"
+ data = bytes(FILE_CONTENT, "utf-8")
+ with lz4.frame.open(path, "wb") as f:
+ f.write(data)
+ return path
+
+
@pytest.fixture(scope="session")
def xml_file(tmp_path_factory):
filename = tmp_path_factory.mktemp("data") / "file.xml"
diff --git a/tests/test_extract.py b/tests/test_extract.py
index 909ddf83895..0967f98a670 100644
--- a/tests/test_extract.py
+++ b/tests/test_extract.py
@@ -1,25 +1,13 @@
import pytest
-import zstandard as zstd
from datasets.utils.extract import Extractor, ZstdExtractor
+from .utils import require_zstandard
-FILE_CONTENT = """\
- Text data.
- Second line of data."""
-
-@pytest.fixture(scope="session")
-def zstd_path(tmp_path_factory):
- path = tmp_path_factory.mktemp("data") / "file.zstd"
- data = bytes(FILE_CONTENT, "utf-8")
- with zstd.open(path, "wb") as f:
- f.write(data)
- return path
-
-
-def test_zstd_extractor(zstd_path, tmp_path, text_file):
- input_path = zstd_path
+@require_zstandard
+def test_zstd_extractor(zstd_file, tmp_path, text_file):
+ input_path = zstd_file
assert ZstdExtractor.is_extractable(input_path)
output_path = str(tmp_path / "extracted.txt")
ZstdExtractor.extract(input_path, output_path)
@@ -30,21 +18,16 @@ def test_zstd_extractor(zstd_path, tmp_path, text_file):
assert extracted_file_content == expected_file_content
-@pytest.mark.parametrize(
- "compression_format, expected_text_path_name", [("gzip", "text_path"), ("xz", "text_file"), ("zstd", "text_file")]
-)
-def test_extractor(
- compression_format, expected_text_path_name, text_gz_path, xz_file, zstd_path, tmp_path, text_file, text_path
-):
- input_paths = {"gzip": text_gz_path, "xz": xz_file, "zstd": zstd_path}
+@require_zstandard
+@pytest.mark.parametrize("compression_format", ["gzip", "xz", "zstd"])
+def test_extractor(compression_format, gz_file, xz_file, zstd_file, tmp_path, text_file):
+ input_paths = {"gzip": gz_file, "xz": xz_file, "zstd": zstd_file}
input_path = str(input_paths[compression_format])
output_path = str(tmp_path / "extracted.txt")
assert Extractor.is_extractable(input_path)
Extractor.extract(input_path, output_path)
with open(output_path) as f:
extracted_file_content = f.read()
- expected_text_paths = {"text_file": text_file, "text_path": text_path}
- expected_text_path = str(expected_text_paths[expected_text_path_name])
- with open(expected_text_path) as f:
+ with open(text_file) as f:
expected_file_content = f.read()
assert extracted_file_content == expected_file_content
diff --git a/tests/test_file_utils.py b/tests/test_file_utils.py
index 283b8b5a3ed..d623e64e8bb 100644
--- a/tests/test_file_utils.py
+++ b/tests/test_file_utils.py
@@ -88,8 +88,8 @@ def gen_random_output():
@pytest.mark.parametrize("compression_format", ["gzip", "xz", "zstd"])
-def test_cached_path_extract(compression_format, gz_path, xz_file, zstd_path, tmp_path, text_file):
- input_paths = {"gzip": gz_path, "xz": xz_file, "zstd": zstd_path}
+def test_cached_path_extract(compression_format, gz_file, xz_file, zstd_path, tmp_path, text_file):
+ input_paths = {"gzip": gz_file, "xz": xz_file, "zstd": zstd_path}
input_path = str(input_paths[compression_format])
cache_dir = tmp_path / "cache"
download_config = DownloadConfig(cache_dir=cache_dir, extract_compressed_file=True)
diff --git a/tests/test_filesystem.py b/tests/test_filesystem.py
index 50c9e6fbe13..a71da75f516 100644
--- a/tests/test_filesystem.py
+++ b/tests/test_filesystem.py
@@ -5,8 +5,9 @@
import pytest
from moto import mock_s3
-from datasets.filesystems import S3FileSystem, extract_path_from_uri, is_remote_filesystem
-from datasets.filesystems.compression.gzip import GZipFileSystem
+from datasets.filesystems import COMPRESSION_FILESYSTEMS, S3FileSystem, extract_path_from_uri, is_remote_filesystem
+
+from .utils import require_lz4, require_zstandard
@pytest.fixture(scope="function")
@@ -53,10 +54,16 @@ def test_is_remote_filesystem():
assert is_remote is False
-def test_gzip_filesystem(text_gz_path, text_path):
- fs = fsspec.filesystem("gzip", fo=text_gz_path)
- assert isinstance(fs, GZipFileSystem)
- expected_filename = os.path.basename(text_gz_path).rstrip(".gz")
+@require_zstandard
+@require_lz4
+@pytest.mark.parametrize("compression_fs_class", COMPRESSION_FILESYSTEMS)
+def test_compression_filesystems(compression_fs_class, gz_file, bz2_file, lz4_file, zstd_file, xz_file, text_file):
+ input_paths = {"gzip": gz_file, "xz": xz_file, "zstd": zstd_file, "bz2": bz2_file, "lz4": lz4_file}
+ input_path = str(input_paths[compression_fs_class.protocol])
+ fs = fsspec.filesystem(compression_fs_class.protocol, fo=input_path)
+ assert isinstance(fs, compression_fs_class)
+ expected_filename = os.path.basename(input_path)
+ expected_filename = expected_filename[: expected_filename.rindex(".")]
assert fs.ls("/") == [expected_filename]
- with fs.open(expected_filename, "r", encoding="utf-8") as f, open(text_path, encoding="utf-8") as expected_file:
+ with fs.open(expected_filename, "r", encoding="utf-8") as f, open(text_file, encoding="utf-8") as expected_file:
assert f.read() == expected_file.read()
diff --git a/tests/test_load.py b/tests/test_load.py
index 701ed70be87..c61dd494cb2 100644
--- a/tests/test_load.py
+++ b/tests/test_load.py
@@ -254,6 +254,9 @@ def test_load_dataset_streaming_gz_json(jsonl_gz_path):
def test_load_dataset_streaming_compressed_files(path):
repo_id = "albertvillanova/datasets-tests-compression"
data_files = f"https://huggingface.co/datasets/{repo_id}/resolve/main/{path}"
+ if data_files[-3:] in ("zip", "tar"): # we need to glob "*" inside archives
+ data_files = data_files[-3:] + "://*::" + data_files
+ return # TODO(QL, albert): support re-add support for ZIP and TAR archives streaming
ds = load_dataset("json", split="train", data_files=data_files, streaming=True)
assert isinstance(ds, IterableDataset)
ds_item = next(iter(ds))
diff --git a/tests/test_streaming_download_manager.py b/tests/test_streaming_download_manager.py
index 2e46e712aa4..12d7e667850 100644
--- a/tests/test_streaming_download_manager.py
+++ b/tests/test_streaming_download_manager.py
@@ -1,6 +1,10 @@
+import os
+
import pytest
-from .utils import require_streaming
+from datasets.filesystems import COMPRESSION_FILESYSTEMS
+
+from .utils import require_lz4, require_streaming, require_zstandard
TEST_URL = "https://huggingface.co/datasets/lhoestq/test/raw/main/some_text.txt"
@@ -11,7 +15,12 @@
@pytest.mark.parametrize(
"input_path, paths_to_join, expected_path",
[
- ("https://host.com/archive.zip", ("file.txt",), "zip://file.txt::https://host.com/archive.zip"),
+ ("https://host.com/archive.zip", ("file.txt",), "https://host.com/archive.zip/file.txt"),
+ (
+ "zip://::https://host.com/archive.zip",
+ ("file.txt",),
+ "zip://file.txt::https://host.com/archive.zip",
+ ),
(
"zip://folder::https://host.com/archive.zip",
("file.txt",),
@@ -72,14 +81,16 @@ def test_streaming_dl_manager_download_and_extract_no_extraction(urlpath):
@require_streaming
-def test_streaming_dl_manager_extract(text_gz_path):
+def test_streaming_dl_manager_extract(text_gz_path, text_path):
from datasets.utils.streaming_download_manager import StreamingDownloadManager, xopen
dl_manager = StreamingDownloadManager()
output_path = dl_manager.extract(text_gz_path)
- assert output_path == text_gz_path
- fsspec_open_file = xopen(output_path)
- assert fsspec_open_file.compression == "gzip"
+ path = os.path.basename(text_gz_path).rstrip(".gz")
+ assert output_path == f"gzip://{path}::{text_gz_path}"
+ fsspec_open_file = xopen(output_path, encoding="utf-8")
+ with fsspec_open_file as f, open(text_path, encoding="utf-8") as expected_file:
+ assert f.read() == expected_file.read()
@require_streaming
@@ -88,9 +99,9 @@ def test_streaming_dl_manager_download_and_extract_with_extraction(text_gz_path,
dl_manager = StreamingDownloadManager()
output_path = dl_manager.download_and_extract(text_gz_path)
- assert output_path == text_gz_path
+ path = os.path.basename(text_gz_path).rstrip(".gz")
+ assert output_path == f"gzip://{path}::{text_gz_path}"
fsspec_open_file = xopen(output_path, encoding="utf-8")
- assert output_path == text_gz_path
with fsspec_open_file as f, open(text_path, encoding="utf-8") as expected_file:
assert f.read() == expected_file.read()
@@ -107,3 +118,24 @@ def test_streaming_dl_manager_download_and_extract_with_join(input_path, filenam
extracted_path = dl_manager.download_and_extract(input_path)
output_path = xjoin(extracted_path, filename)
assert output_path == expected_path
+
+
+@require_streaming
+@require_zstandard
+@require_lz4
+@pytest.mark.parametrize("compression_fs_class", COMPRESSION_FILESYSTEMS)
+def test_streaming_dl_manager_extract_all_supported_single_file_compression_types(
+ compression_fs_class, gz_file, xz_file, zstd_file, bz2_file, lz4_file, text_file
+):
+ from datasets.utils.streaming_download_manager import StreamingDownloadManager, xopen
+
+ input_paths = {"gzip": gz_file, "xz": xz_file, "zstd": zstd_file, "bz2": bz2_file, "lz4": lz4_file}
+ input_path = str(input_paths[compression_fs_class.protocol])
+ dl_manager = StreamingDownloadManager()
+ output_path = dl_manager.extract(input_path)
+ path = os.path.basename(input_path)
+ path = path[: path.rindex(".")]
+ assert output_path == f"{compression_fs_class.protocol}://{path}::{input_path}"
+ fsspec_open_file = xopen(output_path, encoding="utf-8")
+ with fsspec_open_file as f, open(text_file, encoding="utf-8") as expected_file:
+ assert f.read() == expected_file.read()
diff --git a/tests/utils.py b/tests/utils.py
index 010f6a2f93c..ee26201e986 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -137,6 +137,30 @@ def require_jax(test_case):
return test_case
+def require_zstandard(test_case):
+ """
+ Decorator marking a test that requires zstandard.
+
+ These tests are skipped when zstandard isn't installed.
+
+ """
+ if not config.ZSTANDARD_AVAILABLE:
+ test_case = unittest.skip("test requires zstandard")(test_case)
+ return test_case
+
+
+def require_lz4(test_case):
+ """
+ Decorator marking a test that requires lz4.
+
+ These tests are skipped when lz4 isn't installed.
+
+ """
+ if not config.LZ4_AVAILABLE:
+ test_case = unittest.skip("test requires lz4")(test_case)
+ return test_case
+
+
def require_transformers(test_case):
"""
Decorator marking a test that requires transformers.