Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 12 additions & 19 deletions src/datasets/utils/streaming_download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

logger = get_logger(__name__)
BASE_KNOWN_EXTENSIONS = ["txt", "csv", "json", "jsonl", "tsv", "conll", "conllu", "parquet", "pkl", "pickle", "xml"]
COMPRESSION_KNOWN_EXTENSIONS = ["bz2", "lz4", "xz", "zst"]
COMPRESSION_KNOWN_EXTENSIONS = ["bz2", "gz", "lz4", "xz", "zip", "zst"]


def xjoin(a, *p):
Expand All @@ -34,14 +34,19 @@ def xjoin(a, *p):

Example::

>>> xjoin("zip://folder1::https://host.com/archive.zip", "file.txt")
>>> xjoin("https://host.com/archive.zip", "folder1/file.txt")
zip://folder1/file.txt::https://host.com/archive.zip
"""
a, *b = a.split("::")
if is_local_path(a):
a = Path(a, *p).as_posix()
else:
a = posixpath.join(a, *p)
compression = fsspec.core.get_compression(a, "infer")
if compression in ["zip"]:
b = [a] + b
a = posixpath.join(f"{compression}://", *p)
else:
a = posixpath.join(a, *p)
return "::".join([a] + b)


Expand Down Expand Up @@ -79,7 +84,7 @@ def open_with_retries():
return fsspec_open_file


def xopen(file, mode="r", *args, **kwargs):
def xopen(file, mode="r", compression="infer", *args, **kwargs):
"""
This function extends the builtin `open` function to support remote files using fsspec.

Expand All @@ -88,14 +93,9 @@ def xopen(file, mode="r", *args, **kwargs):
"""
if fsspec.get_fs_token_paths(file)[0].protocol == "https":
kwargs["headers"] = get_authentication_headers_for_url(file, use_auth_token=kwargs.pop("use_auth_token", None))
compression = fsspec.core.get_compression(file, "infer")
if not compression or compression in ["gzip", "zip"]:
file_obj = fsspec.open(file, mode=mode, *args, **kwargs).open()
file_obj = _add_retries_to_file_obj_read_method(file_obj)
else:
file_obj = fsspec.open(file, mode=mode, compression=compression, *args, **kwargs)
file_obj = _add_retries_to_fsspec_open_file(file_obj)
return file_obj
fsspec_open_file = fsspec.open(file, mode=mode, compression=compression, *args, **kwargs)
fsspec_open_file = _add_retries_to_fsspec_open_file(fsspec_open_file)
return fsspec_open_file


class StreamingDownloadManager(object):
Expand Down Expand Up @@ -141,22 +141,15 @@ def _extract(self, urlpath):
if protocol is None:
# no extraction
return urlpath
elif protocol == "gzip":
# there is one single file which is the uncompressed gzip file
return f"{protocol}://{os.path.basename(urlpath.split('::')[0]).rstrip('.gz')}::{urlpath}"
else:
return f"{protocol}://*::{urlpath}"

def _get_extraction_protocol(self, urlpath) -> Optional[str]:
path = urlpath.split("::")[0]
if path.split(".")[-1] in BASE_KNOWN_EXTENSIONS + COMPRESSION_KNOWN_EXTENSIONS:
return None
elif path.endswith(".gz") and not path.endswith(".tar.gz"):
return "gzip"
elif path.endswith(".tar"):
return "tar"
elif path.endswith(".zip"):
return "zip"
raise NotImplementedError(f"Extraction protocol for file at {urlpath} is not implemented yet")

def download_and_extract(self, url_or_urls):
Expand Down
52 changes: 43 additions & 9 deletions tests/test_streaming_download_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import os

import pytest

from .utils import require_streaming
Expand All @@ -9,6 +7,25 @@
TEST_URL_CONTENT = "foo\nbar\nfoobar"


@require_streaming
@pytest.mark.parametrize(
"input_path, paths_to_join, expected_path",
[
("https://host.com/archive.zip", ("file.txt",), "zip://file.txt::https://host.com/archive.zip"),
(
"zip://folder::https://host.com/archive.zip",
("file.txt",),
"zip://folder/file.txt::https://host.com/archive.zip",
),
],
)
def test_xjoin(input_path, paths_to_join, expected_path):
from datasets.utils.streaming_download_manager import xjoin

output_path = xjoin(input_path, *paths_to_join)
assert output_path == expected_path


@require_streaming
def test_xopen_local(text_path):
from datasets.utils.streaming_download_manager import xopen
Expand Down Expand Up @@ -56,20 +73,37 @@ def test_streaming_dl_manager_download_and_extract_no_extraction(urlpath):

@require_streaming
def test_streaming_dl_manager_extract(text_gz_path):
from datasets.utils.streaming_download_manager import StreamingDownloadManager
from datasets.utils.streaming_download_manager import StreamingDownloadManager, xopen

dl_manager = StreamingDownloadManager()
path = os.path.basename(text_gz_path).rstrip(".gz")
assert dl_manager.extract(text_gz_path) == f"gzip://{path}::{text_gz_path}"
output_path = dl_manager.extract(text_gz_path)
assert output_path == text_gz_path
fsspec_open_file = xopen(output_path)
assert fsspec_open_file.compression == "gzip"


@require_streaming
def test_streaming_dl_manager_download_and_extract_with_extraction(text_gz_path, text_path):
from datasets.utils.streaming_download_manager import StreamingDownloadManager, xopen

dl_manager = StreamingDownloadManager()
filename = os.path.basename(text_gz_path).rstrip(".gz")
out = dl_manager.download_and_extract(text_gz_path)
assert out == f"gzip://{filename}::{text_gz_path}"
with xopen(out, encoding="utf-8") as f, open(text_path, encoding="utf-8") as expected_file:
output_path = dl_manager.download_and_extract(text_gz_path)
assert output_path == text_gz_path
fsspec_open_file = xopen(output_path, encoding="utf-8")
assert output_path == text_gz_path
with fsspec_open_file as f, open(text_path, encoding="utf-8") as expected_file:
assert f.read() == expected_file.read()


@require_streaming
@pytest.mark.parametrize(
"input_path, filename, expected_path",
[("https://domain.org/archive.zip", "filename.jsonl", "zip://filename.jsonl::https://domain.org/archive.zip")],
)
def test_streaming_dl_manager_download_and_extract_with_join(input_path, filename, expected_path):
from datasets.utils.streaming_download_manager import StreamingDownloadManager, xjoin

dl_manager = StreamingDownloadManager()
extracted_path = dl_manager.download_and_extract(input_path)
output_path = xjoin(extracted_path, filename)
assert output_path == expected_path