Skip to content
27 changes: 21 additions & 6 deletions src/datasets/streaming.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,33 @@
import importlib
from functools import partial
from typing import Optional, Union
from unittest.mock import patch

from .utils.logging import get_logger
from .utils.patching import patch_submodule
from .utils.streaming_download_manager import xjoin, xopen
from .utils.streaming_download_manager import xjoin, xopen, xpathjoin, xpathopen


logger = get_logger(__name__)


def extend_module_for_streaming(module_path, use_auth_token: Optional[Union[str, bool]] = None):
"""
Extend the `open` and `os.path.join` functions of the module to support data streaming.
They rare replaced by `xopen` and `xjoin` defined to work with the StreamingDownloadManager.
"""Extend the module to support streaming.

We patch some functions in the module to use `fsspec` to support data streaming:
- We use `fsspec.open` to open and read remote files. We patch the module function:
- `open`
- We use the "::" hop separator to join paths and navigate remote compressed/archive files. We patch the module
functions:
- `os.path.join`
- `pathlib.Path.joinpath` and `pathlib.Path.__truediv__` (called when using the "/" operator)

The patched functions are replaced with custom functions defined to work with the
:class:`~utils.streaming_download_manager.StreamingDownloadManager`.

We use fsspec to extend `open` to be able to read remote files.
To join paths and navigate in remote compressed archives, we use the "::" separator.
Args:
module_path: Path to the module to be extended.
use_auth_token: Whether to use authentication token.
"""

module = importlib.import_module(module_path)
Expand All @@ -27,3 +38,7 @@ def extend_module_for_streaming(module_path, use_auth_token: Optional[Union[str,
patch_submodule(module, "open", xopen).start()
# allow to navigate in remote zip files
patch_submodule(module, "os.path.join", xjoin).start()
if hasattr(module, "Path"):
patch.object(module.Path, "joinpath", xpathjoin).start()
patch.object(module.Path, "__truediv__", xpathjoin).start()
patch.object(module.Path, "open", xpathopen).start()
45 changes: 42 additions & 3 deletions src/datasets/utils/streaming_download_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
import re
import time
from pathlib import Path
from typing import Optional
from typing import Optional, Tuple

import fsspec
import posixpath
Expand All @@ -15,8 +16,8 @@


logger = get_logger(__name__)
BASE_KNOWN_EXTENSIONS = ["txt", "csv", "json", "jsonl", "tsv", "conll", "conllu", "parquet", "pkl", "pickle", "xml"]

BASE_KNOWN_EXTENSIONS = ["txt", "csv", "json", "jsonl", "tsv", "conll", "conllu", "parquet", "pkl", "pickle", "xml"]
COMPRESSION_EXTENSION_TO_PROTOCOL = {
# single file compression
**{fs_class.extension.lstrip("."): fs_class.protocol for fs_class in COMPRESSION_FILESYSTEMS},
Expand All @@ -25,8 +26,8 @@
"tar": "tar",
"tgz": "tar",
}

SINGLE_FILE_COMPRESSION_PROTOCOLS = {fs_class.protocol for fs_class in COMPRESSION_FILESYSTEMS}
SINGLE_SLASH_AFTER_PROTOCOL_PATTERN = re.compile(r"(?<!:):/")


def xjoin(a, *p):
Expand Down Expand Up @@ -56,6 +57,31 @@ def xjoin(a, *p):
return "::".join([a] + b)


def _as_posix(path: Path):
"""Extend :meth:`pathlib.PurePath.as_posix` to fix missing slash after protocol.

Args:
path (:obj:`~pathlib.Path`): Calling Path instance.

Returns:
obj:`str`
"""
return SINGLE_SLASH_AFTER_PROTOCOL_PATTERN.sub("://", path.as_posix())


def xpathjoin(a: Path, *p: Tuple[str, ...]):
"""Extend :func:`xjoin` to support argument of type :obj:`~pathlib.Path`.

Args:
a (:obj:`~pathlib.Path`): Calling Path instance.
*p (:obj:`tuple` of :obj:`str`): Other path components.

Returns:
obj:`str`
"""
return type(a)(xjoin(_as_posix(a), *p))


def _add_retries_to_file_obj_read_method(file_obj):
read = file_obj.read
max_retries = config.STREAMING_READ_MAX_RETRIES
Expand Down Expand Up @@ -110,6 +136,19 @@ def xopen(file, mode="r", *args, **kwargs):
return file_obj


def xpathopen(path: Path, **kwargs):
"""Extend :func:`xopen` to support argument of type :obj:`~pathlib.Path`.

Args:
path (:obj:`~pathlib.Path`): Calling Path instance.
**kwargs: Keyword arguments passed to :func:`fsspec.open`.

Returns:
:obj:`io.FileIO`: File-like object.
"""
return xopen(_as_posix(path), **kwargs)


class StreamingDownloadManager(object):
"""
Download manager that uses the "::" separator to navigate through (possibly remote) compressed archives.
Expand Down
58 changes: 34 additions & 24 deletions tests/test_streaming_download_manager.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
import os
from pathlib import Path

import pytest

from datasets.filesystems import COMPRESSION_FILESYSTEMS
from datasets.utils.streaming_download_manager import xopen
from datasets.utils.streaming_download_manager import (
StreamingDownloadManager,
_as_posix,
_get_extraction_protocol,
xjoin,
xopen,
xpathjoin,
xpathopen,
)

from .utils import require_lz4, require_zstandard

Expand All @@ -12,6 +21,14 @@
TEST_URL_CONTENT = "foo\nbar\nfoobar"


@pytest.mark.parametrize(
"input_path, expected_path",
[("zip:/test.txt::/Users/username/bar.zip", "zip://test.txt::/Users/username/bar.zip")],
)
def test_as_posix(input_path, expected_path):
assert _as_posix(Path(input_path)) == expected_path


@pytest.mark.parametrize(
"input_path, paths_to_join, expected_path",
[
Expand All @@ -26,39 +43,46 @@
("file.txt",),
"zip://folder/file.txt::https://host.com/archive.zip",
),
(
".",
("file.txt",),
"file.txt",
),
(
Path().resolve().as_posix(),
("file.txt",),
(Path().resolve() / "file.txt").as_posix(),
),
],
)
def test_xjoin(input_path, paths_to_join, expected_path):
from datasets.utils.streaming_download_manager import xjoin

output_path = xjoin(input_path, *paths_to_join)
assert output_path == expected_path
output_path = xpathjoin(Path(input_path), *paths_to_join)
assert output_path == Path(expected_path)


def test_xopen_local(text_path):

with xopen(text_path, encoding="utf-8") as f, open(text_path, encoding="utf-8") as expected_file:
assert list(f) == list(expected_file)
with xpathopen(Path(text_path), encoding="utf-8") as f, open(text_path, encoding="utf-8") as expected_file:
assert list(f) == list(expected_file)


def test_xopen_remote():
from datasets.utils.streaming_download_manager import xopen

with xopen(TEST_URL, encoding="utf-8") as f:
assert list(f) == TEST_URL_CONTENT.splitlines(keepends=True)
with xpathopen(Path(TEST_URL), encoding="utf-8") as f:
assert list(f) == TEST_URL_CONTENT.splitlines(keepends=True)


@pytest.mark.parametrize("urlpath", [r"C:\\foo\bar.txt", "/foo/bar.txt", "https://f.oo/bar.txt"])
def test_streaming_dl_manager_download_dummy_path(urlpath):
from datasets.utils.streaming_download_manager import StreamingDownloadManager

dl_manager = StreamingDownloadManager()
assert dl_manager.download(urlpath) == urlpath


def test_streaming_dl_manager_download(text_path):
from datasets.utils.streaming_download_manager import StreamingDownloadManager, xopen

dl_manager = StreamingDownloadManager()
out = dl_manager.download(text_path)
assert out == text_path
Expand All @@ -68,15 +92,11 @@ def test_streaming_dl_manager_download(text_path):

@pytest.mark.parametrize("urlpath", [r"C:\\foo\bar.txt", "/foo/bar.txt", "https://f.oo/bar.txt"])
def test_streaming_dl_manager_download_and_extract_no_extraction(urlpath):
from datasets.utils.streaming_download_manager import StreamingDownloadManager

dl_manager = StreamingDownloadManager()
assert dl_manager.download_and_extract(urlpath) == urlpath


def test_streaming_dl_manager_extract(text_gz_path, text_path):
from datasets.utils.streaming_download_manager import StreamingDownloadManager, xopen

dl_manager = StreamingDownloadManager()
output_path = dl_manager.extract(text_gz_path)
path = os.path.basename(text_gz_path).rstrip(".gz")
Expand All @@ -87,8 +107,6 @@ def test_streaming_dl_manager_extract(text_gz_path, text_path):


def test_streaming_dl_manager_download_and_extract_with_extraction(text_gz_path, text_path):
from datasets.utils.streaming_download_manager import StreamingDownloadManager, xopen

dl_manager = StreamingDownloadManager()
output_path = dl_manager.download_and_extract(text_gz_path)
path = os.path.basename(text_gz_path).rstrip(".gz")
Expand All @@ -103,8 +121,6 @@ def test_streaming_dl_manager_download_and_extract_with_extraction(text_gz_path,
[("https://domain.org/archive.zip", "filename.jsonl", "zip://filename.jsonl::https://domain.org/archive.zip")],
)
def test_streaming_dl_manager_download_and_extract_with_join(input_path, filename, expected_path):
from datasets.utils.streaming_download_manager import StreamingDownloadManager, xjoin

dl_manager = StreamingDownloadManager()
extracted_path = dl_manager.download_and_extract(input_path)
output_path = xjoin(extracted_path, filename)
Expand All @@ -117,8 +133,6 @@ def test_streaming_dl_manager_download_and_extract_with_join(input_path, filenam
def test_streaming_dl_manager_extract_all_supported_single_file_compression_types(
compression_fs_class, gz_file, xz_file, zstd_file, bz2_file, lz4_file, text_file
):
from datasets.utils.streaming_download_manager import StreamingDownloadManager, xopen

input_paths = {"gzip": gz_file, "xz": xz_file, "zstd": zstd_file, "bz2": bz2_file, "lz4": lz4_file}
input_path = str(input_paths[compression_fs_class.protocol])
dl_manager = StreamingDownloadManager()
Expand All @@ -140,8 +154,6 @@ def test_streaming_dl_manager_extract_all_supported_single_file_compression_type
],
)
def test_streaming_dl_manager_get_extraction_protocol(urlpath, expected_protocol):
from datasets.utils.streaming_download_manager import _get_extraction_protocol

assert _get_extraction_protocol(urlpath) == expected_protocol


Expand All @@ -155,6 +167,4 @@ def test_streaming_dl_manager_get_extraction_protocol(urlpath, expected_protocol
)
@pytest.mark.xfail(raises=NotImplementedError)
def test_streaming_dl_manager_get_extraction_protocol_throws(urlpath):
from datasets.utils.streaming_download_manager import _get_extraction_protocol

_get_extraction_protocol(urlpath)