Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/datasets/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .utils.patching import patch_submodule
from .utils.streaming_download_manager import (
xdirname,
xglob,
xjoin,
xopen,
xpandas_read_csv,
Expand Down Expand Up @@ -47,6 +48,7 @@ def extend_module_for_streaming(module_path, use_auth_token: Optional[Union[str,
patch_submodule(module, "open", partial(xopen, use_auth_token=use_auth_token)).start()
else:
patch_submodule(module, "open", xopen).start()
patch_submodule(module, "glob.glob", xglob).start()
# allow to navigate in remote zip files
patch_submodule(module, "os.path.join", xjoin).start()
patch_submodule(module, "os.path.dirname", xdirname).start()
Expand Down
23 changes: 23 additions & 0 deletions src/datasets/utils/streaming_download_manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import glob
import os
import re
import time
Expand Down Expand Up @@ -180,6 +181,28 @@ def xpathopen(path: Path, *args, **kwargs):
return xopen(_as_posix(path), *args, **kwargs)


def xglob(urlpath):
"""Extend `glob.glob` function to support remote files.

Args:
urlpath (:obj:`str`): URL path with shell-style wildcard patterns.

Returns:
:obj:`list` of :obj:`str`
"""
main_hop, *rest_hops = urlpath.split("::")
if is_local_path(main_hop):
return glob.glob(main_hop)
else:
fs, *_ = fsspec.get_fs_token_paths(urlpath)
# - If there's no "*" in the pattern, get_fs_token_paths() doesn't do any pattern matching
# so to be able to glob patterns like "[0-9]", we have to call `fs.glob`.
# - Also "*" in get_fs_token_paths() only matches files: we have to call `fs.glob` to match directories.
# - If there is "**" in the pattern, `fs.glob` must be called anyway.
globbed_paths = fs.glob(main_hop)
return ["::".join([f"{fs.protocol}://{globbed_path}"] + rest_hops) for globbed_path in globbed_paths]


def xpathglob(path, pattern):
"""Glob function for argument of type :obj:`~pathlib.Path` that supports both local paths end remote URLs.

Expand Down
41 changes: 41 additions & 0 deletions tests/test_streaming_download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
StreamingDownloadManager,
_as_posix,
_get_extraction_protocol,
xglob,
xjoin,
xopen,
xpathglob,
Expand Down Expand Up @@ -219,6 +220,46 @@ def test_xopen_remote():
assert list(f) == TEST_URL_CONTENT.splitlines(keepends=True)


@pytest.mark.parametrize(
"input_path, expected_paths",
[
("tmp_path/*.txt", ["file1.txt", "file2.txt"]),
("mock://*", ["mock://glob_test", "mock://misc", "mock://top_level"]),
("mock://top_*", ["mock://top_level"]),
(
"mock://top_level/second_level/date=2019-10-0[1-4]",
[
"mock://top_level/second_level/date=2019-10-01",
"mock://top_level/second_level/date=2019-10-02",
"mock://top_level/second_level/date=2019-10-04",
],
),
(
"mock://top_level/second_level/date=2019-10-0[1-4]/*",
[
"mock://top_level/second_level/date=2019-10-01/a.parquet",
"mock://top_level/second_level/date=2019-10-01/b.parquet",
"mock://top_level/second_level/date=2019-10-02/a.parquet",
"mock://top_level/second_level/date=2019-10-04/a.parquet",
],
),
],
)
def test_xglob(input_path, expected_paths, tmp_path):
if input_path.startswith("tmp_path"):
input_path = input_path.replace("/", os.sep).replace("tmp_path", str(tmp_path))
expected_paths = [str(tmp_path / file) for file in expected_paths]
for file in ["file1.txt", "file2.txt", "README.md"]:
(tmp_path / file).touch()
output_path = sorted(xglob(input_path))
else:
dummy_registry = datasets.utils.streaming_download_manager.fsspec.registry.target.copy()
dummy_registry["mock"] = DummyTestFS
with patch.dict(datasets.utils.streaming_download_manager.fsspec.registry.target, dummy_registry):
output_path = sorted(xglob(input_path))
assert output_path == expected_paths


@pytest.mark.parametrize(
"input_path, pattern, expected_paths",
[
Expand Down