diff --git a/src/datasets/streaming.py b/src/datasets/streaming.py index 50adc0a3981..e428b357ca7 100644 --- a/src/datasets/streaming.py +++ b/src/datasets/streaming.py @@ -5,7 +5,7 @@ from .utils.logging import get_logger from .utils.patching import patch_submodule -from .utils.streaming_download_manager import xdirname, xjoin, xopen, xpathjoin, xpathopen +from .utils.streaming_download_manager import xdirname, xjoin, xopen, xpathjoin, xpathopen, xpathstem, xpathsuffix logger = get_logger(__name__) @@ -43,3 +43,5 @@ def extend_module_for_streaming(module_path, use_auth_token: Optional[Union[str, patch.object(module.Path, "joinpath", xpathjoin).start() patch.object(module.Path, "__truediv__", xpathjoin).start() patch.object(module.Path, "open", xpathopen).start() + patch.object(module.Path, "stem", property(fget=xpathstem)).start() + patch.object(module.Path, "suffix", property(fget=xpathsuffix)).start() diff --git a/src/datasets/utils/streaming_download_manager.py b/src/datasets/utils/streaming_download_manager.py index 391179426db..bd5baf31b22 100644 --- a/src/datasets/utils/streaming_download_manager.py +++ b/src/datasets/utils/streaming_download_manager.py @@ -1,7 +1,7 @@ import os import re import time -from pathlib import Path +from pathlib import Path, PurePosixPath from typing import Optional, Tuple import fsspec @@ -180,6 +180,30 @@ def xpathopen(path: Path, **kwargs): return xopen(_as_posix(path), **kwargs) +def xpathstem(path: Path): + """Stem function for argument of type :obj:`~pathlib.Path` that supports both local paths end remote URLs. + + Args: + path (:obj:`~pathlib.Path`): Calling Path instance. + + Returns: + :obj:`str` + """ + return PurePosixPath(_as_posix(path).split("::")[0]).stem + + +def xpathsuffix(path: Path): + """Suffix function for argument of type :obj:`~pathlib.Path` that supports both local paths end remote URLs. + + Args: + path (:obj:`~pathlib.Path`): Calling Path instance. + + Returns: + :obj:`str` + """ + return PurePosixPath(_as_posix(path).split("::")[0]).suffix + + class StreamingDownloadManager(object): """ Download manager that uses the "::" separator to navigate through (possibly remote) compressed archives. diff --git a/tests/test_streaming_download_manager.py b/tests/test_streaming_download_manager.py index cbf494d6eb4..0a24ca52d05 100644 --- a/tests/test_streaming_download_manager.py +++ b/tests/test_streaming_download_manager.py @@ -13,6 +13,8 @@ xopen, xpathjoin, xpathopen, + xpathstem, + xpathsuffix, ) from .utils import require_lz4, require_zstandard @@ -124,6 +126,30 @@ def test_xopen_remote(): assert list(f) == TEST_URL_CONTENT.splitlines(keepends=True) +@pytest.mark.parametrize( + "input_path, expected", + [ + ("zip://file.txt::https://host.com/archive.zip", "file"), + ("file.txt", "file"), + ((Path().resolve() / "file.txt").as_posix(), "file"), + ], +) +def test_xpathstem(input_path, expected): + assert xpathstem(Path(input_path)) == expected + + +@pytest.mark.parametrize( + "input_path, expected", + [ + ("zip://file.txt::https://host.com/archive.zip", ".txt"), + ("file.txt", ".txt"), + ((Path().resolve() / "file.txt").as_posix(), ".txt"), + ], +) +def test_xpathsuffix(input_path, expected): + assert xpathsuffix(Path(input_path)) == expected + + @pytest.mark.parametrize("urlpath", [r"C:\\foo\bar.txt", "/foo/bar.txt", "https://f.oo/bar.txt"]) def test_streaming_dl_manager_download_dummy_path(urlpath): dl_manager = StreamingDownloadManager()