Skip to content
27 changes: 21 additions & 6 deletions src/datasets/streaming.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,33 @@
import importlib
from functools import partial
from typing import Optional, Union
from unittest.mock import patch

from .utils.logging import get_logger
from .utils.patching import patch_submodule
from .utils.streaming_download_manager import xjoin, xopen
from .utils.streaming_download_manager import xjoin, xopen, xpathjoin, xpathopen


logger = get_logger(__name__)


def extend_module_for_streaming(module_path, use_auth_token: Optional[Union[str, bool]] = None):
"""
Extend the `open` and `os.path.join` functions of the module to support data streaming.
They rare replaced by `xopen` and `xjoin` defined to work with the StreamingDownloadManager.
"""Extend the module to support streaming.

We patch some functions in the module to use `fsspec` to support data streaming:
- We use `fsspec.open` to open and read remote files. We patch the module function:
- `open`
- We use the "::" hop separator to join paths and navigate remote compressed/archive files. We patch the module
functions:
- `os.path.join`
- `pathlib.Path.joinpath` and `pathlib.Path.__truediv__` (called when using the "/" operator)

The patched functions are replaced with custom functions defined to work with the
:class:`~utils.streaming_download_manager.StreamingDownloadManager`.

We use fsspec to extend `open` to be able to read remote files.
To join paths and navigate in remote compressed archives, we use the "::" separator.
Args:
module_path: Path to the module to be extended.
use_auth_token: Whether to use authentication token.
"""

module = importlib.import_module(module_path)
Expand All @@ -27,3 +38,7 @@ def extend_module_for_streaming(module_path, use_auth_token: Optional[Union[str,
patch_submodule(module, "open", xopen).start()
# allow to navigate in remote zip files
patch_submodule(module, "os.path.join", xjoin).start()
if hasattr(module, "Path"):
patch.object(module.Path, "joinpath", xpathjoin).start()
patch.object(module.Path, "__truediv__", xpathjoin).start()
patch.object(module.Path, "open", xpathopen).start()
12 changes: 12 additions & 0 deletions src/datasets/utils/streaming_download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ def xjoin(a, *p):
return "::".join([a] + b)


def _as_posix(path):
return path.as_posix().replace(":/", "://")


def xpathjoin(a, *p):
return type(a)(xjoin(_as_posix(a), *p))


def _add_retries_to_file_obj_read_method(file_obj):
read = file_obj.read
max_retries = config.STREAMING_READ_MAX_RETRIES
Expand Down Expand Up @@ -110,6 +118,10 @@ def xopen(file, mode="r", *args, **kwargs):
return file_obj


def xpathopen(path, **kwargs):
return xopen(_as_posix(path), **kwargs)


class StreamingDownloadManager(object):
"""
Download manager that uses the "::" separator to navigate through (possibly remote) compressed archives.
Expand Down
39 changes: 15 additions & 24 deletions tests/test_streaming_download_manager.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
import os
from pathlib import Path

import pytest

from datasets.filesystems import COMPRESSION_FILESYSTEMS
from datasets.utils.streaming_download_manager import xopen
from datasets.utils.streaming_download_manager import (
StreamingDownloadManager,
_get_extraction_protocol,
xjoin,
xopen,
xpathjoin,
xpathopen,
)

from .utils import require_lz4, require_zstandard

Expand All @@ -29,36 +37,33 @@
],
)
def test_xjoin(input_path, paths_to_join, expected_path):
from datasets.utils.streaming_download_manager import xjoin

output_path = xjoin(input_path, *paths_to_join)
assert output_path == expected_path
output_path = xpathjoin(Path(input_path), *paths_to_join)
assert output_path == Path(expected_path)


def test_xopen_local(text_path):

with xopen(text_path, encoding="utf-8") as f, open(text_path, encoding="utf-8") as expected_file:
assert list(f) == list(expected_file)
with xpathopen(Path(text_path), encoding="utf-8") as f, open(text_path, encoding="utf-8") as expected_file:
assert list(f) == list(expected_file)


def test_xopen_remote():
from datasets.utils.streaming_download_manager import xopen

with xopen(TEST_URL, encoding="utf-8") as f:
assert list(f) == TEST_URL_CONTENT.splitlines(keepends=True)
with xpathopen(Path(TEST_URL), encoding="utf-8") as f:
assert list(f) == TEST_URL_CONTENT.splitlines(keepends=True)


@pytest.mark.parametrize("urlpath", [r"C:\\foo\bar.txt", "/foo/bar.txt", "https://f.oo/bar.txt"])
def test_streaming_dl_manager_download_dummy_path(urlpath):
from datasets.utils.streaming_download_manager import StreamingDownloadManager

dl_manager = StreamingDownloadManager()
assert dl_manager.download(urlpath) == urlpath


def test_streaming_dl_manager_download(text_path):
from datasets.utils.streaming_download_manager import StreamingDownloadManager, xopen

dl_manager = StreamingDownloadManager()
out = dl_manager.download(text_path)
assert out == text_path
Expand All @@ -68,15 +73,11 @@ def test_streaming_dl_manager_download(text_path):

@pytest.mark.parametrize("urlpath", [r"C:\\foo\bar.txt", "/foo/bar.txt", "https://f.oo/bar.txt"])
def test_streaming_dl_manager_download_and_extract_no_extraction(urlpath):
from datasets.utils.streaming_download_manager import StreamingDownloadManager

dl_manager = StreamingDownloadManager()
assert dl_manager.download_and_extract(urlpath) == urlpath


def test_streaming_dl_manager_extract(text_gz_path, text_path):
from datasets.utils.streaming_download_manager import StreamingDownloadManager, xopen

dl_manager = StreamingDownloadManager()
output_path = dl_manager.extract(text_gz_path)
path = os.path.basename(text_gz_path).rstrip(".gz")
Expand All @@ -87,8 +88,6 @@ def test_streaming_dl_manager_extract(text_gz_path, text_path):


def test_streaming_dl_manager_download_and_extract_with_extraction(text_gz_path, text_path):
from datasets.utils.streaming_download_manager import StreamingDownloadManager, xopen

dl_manager = StreamingDownloadManager()
output_path = dl_manager.download_and_extract(text_gz_path)
path = os.path.basename(text_gz_path).rstrip(".gz")
Expand All @@ -103,8 +102,6 @@ def test_streaming_dl_manager_download_and_extract_with_extraction(text_gz_path,
[("https://domain.org/archive.zip", "filename.jsonl", "zip://filename.jsonl::https://domain.org/archive.zip")],
)
def test_streaming_dl_manager_download_and_extract_with_join(input_path, filename, expected_path):
from datasets.utils.streaming_download_manager import StreamingDownloadManager, xjoin

dl_manager = StreamingDownloadManager()
extracted_path = dl_manager.download_and_extract(input_path)
output_path = xjoin(extracted_path, filename)
Expand All @@ -117,8 +114,6 @@ def test_streaming_dl_manager_download_and_extract_with_join(input_path, filenam
def test_streaming_dl_manager_extract_all_supported_single_file_compression_types(
compression_fs_class, gz_file, xz_file, zstd_file, bz2_file, lz4_file, text_file
):
from datasets.utils.streaming_download_manager import StreamingDownloadManager, xopen

input_paths = {"gzip": gz_file, "xz": xz_file, "zstd": zstd_file, "bz2": bz2_file, "lz4": lz4_file}
input_path = str(input_paths[compression_fs_class.protocol])
dl_manager = StreamingDownloadManager()
Expand All @@ -140,8 +135,6 @@ def test_streaming_dl_manager_extract_all_supported_single_file_compression_type
],
)
def test_streaming_dl_manager_get_extraction_protocol(urlpath, expected_protocol):
from datasets.utils.streaming_download_manager import _get_extraction_protocol

assert _get_extraction_protocol(urlpath) == expected_protocol


Expand All @@ -155,6 +148,4 @@ def test_streaming_dl_manager_get_extraction_protocol(urlpath, expected_protocol
)
@pytest.mark.xfail(raises=NotImplementedError)
def test_streaming_dl_manager_get_extraction_protocol_throws(urlpath):
from datasets.utils.streaming_download_manager import _get_extraction_protocol

_get_extraction_protocol(urlpath)