Skip to content

Commit 145e041

Browse files
Support HTTP authentication in non-streaming mode (#7082)
* Refactor cached_path * Fix for empty storage_options * Allow passing HTTP storage options * Test HTTP storage_options passed by cached_path to get_from_cache Test cached_path passes HTTP storage options to get_from_cache only if passed in DownloadConfig * Test HTTP fsspec is called only if passed HTTP storage_options
1 parent b288977 commit 145e041

File tree

2 files changed

+83
-12
lines changed

2 files changed

+83
-12
lines changed

src/datasets/utils/file_utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,13 @@ def cached_path(
201201
url_or_filename, storage_options = _prepare_path_and_storage_options(
202202
url_or_filename, download_config=download_config
203203
)
204+
# Pass HTTP storage_options to get_from_cache only if passed HTTP download_config.storage_options
205+
if (
206+
storage_options
207+
and storage_options.keys() < {"http", "https"}
208+
and not (download_config.storage_options and download_config.storage_options.keys() < {"http", "https"})
209+
):
210+
storage_options = {}
204211
output_path = get_from_cache(
205212
url_or_filename,
206213
cache_dir=cache_dir,
@@ -525,6 +532,8 @@ def get_from_cache(
525532
ConnectionError: in case of unreachable url
526533
and no cache on disk
527534
"""
535+
if storage_options is None:
536+
storage_options = {}
528537
if use_auth_token != "deprecated":
529538
warnings.warn(
530539
"'use_auth_token' was deprecated in favor of 'token' in version 2.14.0 and will be removed in 3.0.0.\n"
@@ -570,7 +579,7 @@ def get_from_cache(
570579
scheme = urlparse(url).scheme
571580
if scheme == "ftp":
572581
connected = ftp_head(url)
573-
elif scheme not in ("http", "https"):
582+
elif scheme not in {"http", "https"} or storage_options.get(scheme):
574583
response = fsspec_head(url, storage_options=storage_options)
575584
# s3fs uses "ETag", gcsfs uses "etag"
576585
etag = (response.get("ETag", None) or response.get("etag", None)) if use_etag else None
@@ -676,7 +685,7 @@ def temp_file_manager(mode="w+b"):
676685
# GET file object
677686
if scheme == "ftp":
678687
ftp_get(url, temp_file)
679-
elif scheme not in ("http", "https"):
688+
elif scheme not in {"http", "https"} or storage_options.get(scheme):
680689
fsspec_get(
681690
url, temp_file, storage_options=storage_options, desc=download_desc, disable_tqdm=disable_tqdm
682691
)

tests/test_file_utils.py

Lines changed: 72 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import re
3+
from dataclasses import dataclass, field
34
from pathlib import Path
45
from unittest.mock import MagicMock, patch
56

@@ -78,24 +79,85 @@ def tmpfs_file(tmpfs):
7879
return FILE_PATH
7980

8081

81-
@pytest.mark.parametrize("protocol", ["hf", "s3"])
82-
def test_cached_path_protocols(protocol, monkeypatch, tmp_path):
82+
@pytest.mark.parametrize(
83+
"protocol, download_config_storage_options, expected_fsspec_called",
84+
[
85+
("hf", {}, True),
86+
("s3", {"s3": {"anon": True}}, True),
87+
# HTTP calls fsspec only if passed HTTP download_config.storage_options:
88+
("https", {"https": {"block_size": "omit"}}, True),
89+
("https", {}, False),
90+
],
91+
)
92+
def test_cached_path_calls_fsspec_for_protocols(
93+
protocol, download_config_storage_options, expected_fsspec_called, monkeypatch, tmp_path
94+
):
8395
# GH-6598: Test no TypeError: __init__() got an unexpected keyword argument 'hf'
96+
# fsspec_head/get:
8497
mock_fsspec_head = MagicMock(return_value={})
8598
mock_fsspec_get = MagicMock(return_value=None)
8699
monkeypatch.setattr("datasets.utils.file_utils.fsspec_head", mock_fsspec_head)
87100
monkeypatch.setattr("datasets.utils.file_utils.fsspec_get", mock_fsspec_get)
101+
102+
# http_head_get:
103+
@dataclass
104+
class Response:
105+
status_code: int
106+
headers: dict = field(default_factory=dict)
107+
cookies: dict = field(default_factory=dict)
108+
109+
mock_http_head = MagicMock(return_value=Response(status_code=200))
110+
mock_http_get = MagicMock(return_value=None)
111+
monkeypatch.setattr("datasets.utils.file_utils.http_head", mock_http_head)
112+
monkeypatch.setattr("datasets.utils.file_utils.http_get", mock_http_get)
113+
# Test:
88114
cache_dir = tmp_path / "cache"
89-
storage_options = {} if protocol == "hf" else {"s3": {"anon": True}}
90-
download_config = DownloadConfig(cache_dir=cache_dir, storage_options=storage_options)
91-
urls = {"hf": "hf://datasets/org-name/ds-name@main/filename.ext", "s3": "s3://bucket-name/filename.ext"}
115+
download_config = DownloadConfig(cache_dir=cache_dir, storage_options=download_config_storage_options)
116+
urls = {
117+
"hf": "hf://datasets/org-name/ds-name@main/filename.ext",
118+
"https": "https://doamin.org/filename.ext",
119+
"s3": "s3://bucket-name/filename.ext",
120+
}
92121
url = urls[protocol]
93122
_ = cached_path(url, download_config=download_config)
94-
for mock in [mock_fsspec_head, mock_fsspec_get]:
95-
assert mock.called
96-
assert mock.call_count == 1
97-
assert mock.call_args.args[0] == url
98-
assert list(mock.call_args.kwargs["storage_options"].keys()) == [protocol]
123+
if expected_fsspec_called:
124+
for mock in [mock_fsspec_head, mock_fsspec_get]:
125+
assert mock.called
126+
assert mock.call_count == 1
127+
assert mock.call_args.args[0] == url
128+
assert list(mock.call_args.kwargs["storage_options"].keys()) == [protocol]
129+
for mock in [mock_http_head, mock_http_get]:
130+
assert not mock.called
131+
else:
132+
for mock in [mock_fsspec_head, mock_fsspec_get]:
133+
assert not mock.called
134+
for mock in [mock_http_head, mock_http_get]:
135+
assert mock.called
136+
assert mock.call_count == 1
137+
assert mock.call_args.args[0] == url
138+
139+
140+
@pytest.mark.parametrize(
141+
"download_config_storage_options, expected_storage_options_passed_to_get_from_catch",
142+
[
143+
({}, {}), # No DownloadConfig.storage_options
144+
({"https": {"block_size": "omit"}}, {"https": {"client_kwargs": {"trust_env": True}, "block_size": "omit"}}),
145+
],
146+
)
147+
def test_cached_path_passes_http_storage_options_to_get_from_cache_only_if_present_in_download_config(
148+
download_config_storage_options, expected_storage_options_passed_to_get_from_catch, monkeypatch, tmp_path
149+
):
150+
# Test cached_path passes HTTP storage_options to get_from_cache only if passed HTTP download_config.storage_options
151+
mock_get_from_catch = MagicMock(return_value=None)
152+
monkeypatch.setattr("datasets.utils.file_utils.get_from_cache", mock_get_from_catch)
153+
url = "https://domain.org/data.txt"
154+
cache_dir = tmp_path / "cache"
155+
download_config = DownloadConfig(cache_dir=cache_dir, storage_options=download_config_storage_options)
156+
_ = cached_path(url, download_config=download_config)
157+
assert mock_get_from_catch.called
158+
assert mock_get_from_catch.call_count == 1
159+
assert mock_get_from_catch.call_args.args[0] == url
160+
assert mock_get_from_catch.call_args.kwargs["storage_options"] == expected_storage_options_passed_to_get_from_catch
99161

100162

101163
@pytest.mark.parametrize("compression_format", ["gzip", "xz", "zstd"])

0 commit comments

Comments
 (0)