Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/datasets/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .utils.patching import patch_submodule
from .utils.streaming_download_manager import (
xdirname,
xglob,
xjoin,
xopen,
xpandas_read_csv,
Expand Down Expand Up @@ -47,6 +48,7 @@ def extend_module_for_streaming(module_path, use_auth_token: Optional[Union[str,
patch_submodule(module, "open", partial(xopen, use_auth_token=use_auth_token)).start()
else:
patch_submodule(module, "open", xopen).start()
patch_submodule(module, "glob.glob", xglob).start()
# allow to navigate in remote zip files
patch_submodule(module, "os.path.join", xjoin).start()
patch_submodule(module, "os.path.dirname", xdirname).start()
Expand Down
25 changes: 25 additions & 0 deletions src/datasets/utils/streaming_download_manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import glob
import os
import re
import time
Expand Down Expand Up @@ -180,6 +181,30 @@ def xpathopen(path: Path, *args, **kwargs):
return xopen(_as_posix(path), *args, **kwargs)


def xglob(urlpath, *, recursive=False):
"""Extend `glob.glob` function to support remote files.

Args:
urlpath (:obj:`str`): URL path with shell-style wildcard patterns.
recursive (:obj:`bool`, default `False`): Whether to match the "**" pattern recursively to zero or more
directories or subdirectories.

Returns:
:obj:`list` of :obj:`str`
"""
main_hop, *rest_hops = urlpath.split("::")
if is_local_path(main_hop):
return glob.glob(main_hop, recursive=recursive)
else:
fs, *_ = fsspec.get_fs_token_paths(urlpath)
# - If there's no "*" in the pattern, get_fs_token_paths() doesn't do any pattern matching
# so to be able to glob patterns like "[0-9]", we have to call `fs.glob`.
# - Also "*" in get_fs_token_paths() only matches files: we have to call `fs.glob` to match directories.
# - If there is "**" in the pattern, `fs.glob` must be called anyway.
globbed_paths = fs.glob(main_hop)
return ["::".join([f"{fs.protocol}://{globbed_path}"] + rest_hops) for globbed_path in globbed_paths]


def xpathglob(path, pattern):
"""Glob function for argument of type :obj:`~pathlib.Path` that supports both local paths end remote URLs.

Expand Down
64 changes: 49 additions & 15 deletions tests/test_streaming_download_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import re
from pathlib import Path
from unittest.mock import patch

import pytest
from fsspec.spec import AbstractBufferedFile, AbstractFileSystem
Expand All @@ -12,6 +11,7 @@
StreamingDownloadManager,
_as_posix,
_get_extraction_protocol,
xglob,
xjoin,
xopen,
xpathglob,
Expand Down Expand Up @@ -117,6 +117,13 @@ def _open(
)


@pytest.fixture
def mock_fsspec(monkeypatch):
dummy_registry = datasets.utils.streaming_download_manager.fsspec.registry.target.copy()
dummy_registry["mock"] = DummyTestFS
monkeypatch.setattr("datasets.utils.streaming_download_manager.fsspec.registry.target", dummy_registry)


def _readd_double_slash_removed_by_path(path_as_posix: str) -> str:
"""Path(...) on an url path like zip://file.txt::http://host.com/data.zip
converts the :// to :/
Expand Down Expand Up @@ -219,6 +226,41 @@ def test_xopen_remote():
assert list(f) == TEST_URL_CONTENT.splitlines(keepends=True)


@pytest.mark.parametrize(
"input_path, expected_paths",
[
("tmp_path/*.txt", ["file1.txt", "file2.txt"]),
("mock://*", ["mock://glob_test", "mock://misc", "mock://top_level"]),
("mock://top_*", ["mock://top_level"]),
(
"mock://top_level/second_level/date=2019-10-0[1-4]",
[
"mock://top_level/second_level/date=2019-10-01",
"mock://top_level/second_level/date=2019-10-02",
"mock://top_level/second_level/date=2019-10-04",
],
),
(
"mock://top_level/second_level/date=2019-10-0[1-4]/*",
[
"mock://top_level/second_level/date=2019-10-01/a.parquet",
"mock://top_level/second_level/date=2019-10-01/b.parquet",
"mock://top_level/second_level/date=2019-10-02/a.parquet",
"mock://top_level/second_level/date=2019-10-04/a.parquet",
],
),
],
)
def test_xglob(input_path, expected_paths, tmp_path, mock_fsspec):
if input_path.startswith("tmp_path"):
input_path = input_path.replace("/", os.sep).replace("tmp_path", str(tmp_path))
expected_paths = [str(tmp_path / file) for file in expected_paths]
for file in ["file1.txt", "file2.txt", "README.md"]:
(tmp_path / file).touch()
output_paths = sorted(xglob(input_path))
assert output_paths == expected_paths


@pytest.mark.parametrize(
"input_path, pattern, expected_paths",
[
Expand Down Expand Up @@ -246,20 +288,16 @@ def test_xopen_remote():
),
],
)
def test_xpathglob(input_path, pattern, expected_paths, tmp_path):
def test_xpathglob(input_path, pattern, expected_paths, tmp_path, mock_fsspec):
if input_path == "tmp_path":
input_path = tmp_path
expected_paths = [tmp_path / file for file in expected_paths]
for file in ["file1.txt", "file2.txt", "README.md"]:
(tmp_path / file).touch()
output_path = sorted(xpathglob(input_path, pattern))
else:
dummy_registry = datasets.utils.streaming_download_manager.fsspec.registry.target.copy()
dummy_registry["mock"] = DummyTestFS
expected_paths = [Path(file) for file in expected_paths]
with patch.dict(datasets.utils.streaming_download_manager.fsspec.registry.target, dummy_registry):
output_path = sorted(xpathglob(Path(input_path), pattern))
assert output_path == expected_paths
output_paths = sorted(xpathglob(Path(input_path), pattern))
assert output_paths == expected_paths


@pytest.mark.parametrize(
Expand Down Expand Up @@ -306,22 +344,18 @@ def test_xpathglob(input_path, pattern, expected_paths, tmp_path):
),
],
)
def test_xpathrglob(input_path, pattern, expected_paths, tmp_path):
def test_xpathrglob(input_path, pattern, expected_paths, tmp_path, mock_fsspec):
if input_path == "tmp_path":
input_path = tmp_path
dir_path = tmp_path / "dir"
dir_path.mkdir()
expected_paths = [dir_path / file for file in expected_paths]
for file in ["file1.txt", "file2.txt", "README.md"]:
(dir_path / file).touch()
output_path = sorted(xpathrglob(input_path, pattern))
else:
dummy_registry = datasets.utils.streaming_download_manager.fsspec.registry.target.copy()
dummy_registry["mock"] = DummyTestFS
expected_paths = [Path(file) for file in expected_paths]
with patch.dict(datasets.utils.streaming_download_manager.fsspec.registry.target, dummy_registry):
output_path = sorted(xpathrglob(Path(input_path), pattern))
assert output_path == expected_paths
output_paths = sorted(xpathrglob(Path(input_path), pattern))
assert output_paths == expected_paths


@pytest.mark.parametrize(
Expand Down