diff --git a/src/datasets/utils/streaming_download_manager.py b/src/datasets/utils/streaming_download_manager.py index 616ca2250b0..2768cba7997 100644 --- a/src/datasets/utils/streaming_download_manager.py +++ b/src/datasets/utils/streaming_download_manager.py @@ -15,7 +15,7 @@ logger = get_logger(__name__) BASE_KNOWN_EXTENSIONS = ["txt", "csv", "json", "jsonl", "tsv", "conll", "conllu", "parquet", "pkl", "pickle", "xml"] -COMPRESSION_KNOWN_EXTENSIONS = ["bz2", "lz4", "xz", "zst"] +COMPRESSION_KNOWN_EXTENSIONS = ["bz2", "gz", "lz4", "xz", "zip", "zst"] def xjoin(a, *p): @@ -34,14 +34,19 @@ def xjoin(a, *p): Example:: - >>> xjoin("zip://folder1::https://host.com/archive.zip", "file.txt") + >>> xjoin("https://host.com/archive.zip", "folder1/file.txt") zip://folder1/file.txt::https://host.com/archive.zip """ a, *b = a.split("::") if is_local_path(a): a = Path(a, *p).as_posix() else: - a = posixpath.join(a, *p) + compression = fsspec.core.get_compression(a, "infer") + if compression in ["zip"]: + b = [a] + b + a = posixpath.join(f"{compression}://", *p) + else: + a = posixpath.join(a, *p) return "::".join([a] + b) @@ -79,7 +84,7 @@ def open_with_retries(): return fsspec_open_file -def xopen(file, mode="r", *args, **kwargs): +def xopen(file, mode="r", compression="infer", *args, **kwargs): """ This function extends the builtin `open` function to support remote files using fsspec. @@ -88,14 +93,9 @@ def xopen(file, mode="r", *args, **kwargs): """ if fsspec.get_fs_token_paths(file)[0].protocol == "https": kwargs["headers"] = get_authentication_headers_for_url(file, use_auth_token=kwargs.pop("use_auth_token", None)) - compression = fsspec.core.get_compression(file, "infer") - if not compression or compression in ["gzip", "zip"]: - file_obj = fsspec.open(file, mode=mode, *args, **kwargs).open() - file_obj = _add_retries_to_file_obj_read_method(file_obj) - else: - file_obj = fsspec.open(file, mode=mode, compression=compression, *args, **kwargs) - file_obj = _add_retries_to_fsspec_open_file(file_obj) - return file_obj + fsspec_open_file = fsspec.open(file, mode=mode, compression=compression, *args, **kwargs) + fsspec_open_file = _add_retries_to_fsspec_open_file(fsspec_open_file) + return fsspec_open_file class StreamingDownloadManager(object): @@ -141,9 +141,6 @@ def _extract(self, urlpath): if protocol is None: # no extraction return urlpath - elif protocol == "gzip": - # there is one single file which is the uncompressed gzip file - return f"{protocol}://{os.path.basename(urlpath.split('::')[0]).rstrip('.gz')}::{urlpath}" else: return f"{protocol}://*::{urlpath}" @@ -151,12 +148,8 @@ def _get_extraction_protocol(self, urlpath) -> Optional[str]: path = urlpath.split("::")[0] if path.split(".")[-1] in BASE_KNOWN_EXTENSIONS + COMPRESSION_KNOWN_EXTENSIONS: return None - elif path.endswith(".gz") and not path.endswith(".tar.gz"): - return "gzip" elif path.endswith(".tar"): return "tar" - elif path.endswith(".zip"): - return "zip" raise NotImplementedError(f"Extraction protocol for file at {urlpath} is not implemented yet") def download_and_extract(self, url_or_urls): diff --git a/tests/test_streaming_download_manager.py b/tests/test_streaming_download_manager.py index c3d9d62df0a..2e46e712aa4 100644 --- a/tests/test_streaming_download_manager.py +++ b/tests/test_streaming_download_manager.py @@ -1,5 +1,3 @@ -import os - import pytest from .utils import require_streaming @@ -9,6 +7,25 @@ TEST_URL_CONTENT = "foo\nbar\nfoobar" +@require_streaming +@pytest.mark.parametrize( + "input_path, paths_to_join, expected_path", + [ + ("https://host.com/archive.zip", ("file.txt",), "zip://file.txt::https://host.com/archive.zip"), + ( + "zip://folder::https://host.com/archive.zip", + ("file.txt",), + "zip://folder/file.txt::https://host.com/archive.zip", + ), + ], +) +def test_xjoin(input_path, paths_to_join, expected_path): + from datasets.utils.streaming_download_manager import xjoin + + output_path = xjoin(input_path, *paths_to_join) + assert output_path == expected_path + + @require_streaming def test_xopen_local(text_path): from datasets.utils.streaming_download_manager import xopen @@ -56,11 +73,13 @@ def test_streaming_dl_manager_download_and_extract_no_extraction(urlpath): @require_streaming def test_streaming_dl_manager_extract(text_gz_path): - from datasets.utils.streaming_download_manager import StreamingDownloadManager + from datasets.utils.streaming_download_manager import StreamingDownloadManager, xopen dl_manager = StreamingDownloadManager() - path = os.path.basename(text_gz_path).rstrip(".gz") - assert dl_manager.extract(text_gz_path) == f"gzip://{path}::{text_gz_path}" + output_path = dl_manager.extract(text_gz_path) + assert output_path == text_gz_path + fsspec_open_file = xopen(output_path) + assert fsspec_open_file.compression == "gzip" @require_streaming @@ -68,8 +87,23 @@ def test_streaming_dl_manager_download_and_extract_with_extraction(text_gz_path, from datasets.utils.streaming_download_manager import StreamingDownloadManager, xopen dl_manager = StreamingDownloadManager() - filename = os.path.basename(text_gz_path).rstrip(".gz") - out = dl_manager.download_and_extract(text_gz_path) - assert out == f"gzip://{filename}::{text_gz_path}" - with xopen(out, encoding="utf-8") as f, open(text_path, encoding="utf-8") as expected_file: + output_path = dl_manager.download_and_extract(text_gz_path) + assert output_path == text_gz_path + fsspec_open_file = xopen(output_path, encoding="utf-8") + assert output_path == text_gz_path + with fsspec_open_file as f, open(text_path, encoding="utf-8") as expected_file: assert f.read() == expected_file.read() + + +@require_streaming +@pytest.mark.parametrize( + "input_path, filename, expected_path", + [("https://domain.org/archive.zip", "filename.jsonl", "zip://filename.jsonl::https://domain.org/archive.zip")], +) +def test_streaming_dl_manager_download_and_extract_with_join(input_path, filename, expected_path): + from datasets.utils.streaming_download_manager import StreamingDownloadManager, xjoin + + dl_manager = StreamingDownloadManager() + extracted_path = dl_manager.download_and_extract(input_path) + output_path = xjoin(extracted_path, filename) + assert output_path == expected_path