Skip to content

Commit eb81144

Browse files
Fix prepare_single_hop_path_and_storage_options (#7068)
* Transform all HF HTTP URLs to HF protocol * Fix test URL * Remove HF headers for non-HF HTTP URLs * Fix for HTTP storage_options without 'headers' * Remove unused cookies * Refactor * Refactor list to set to check membership * Refactor to add protocol key to storage_options only at the end * Fix overwriting storage_options nested values * Add tests * Revert "Transform all HF HTTP URLs to HF protocol" This reverts commit a337212. * Test that DownloadConfig.storage_options are not modified * Fix so DownloadConfig.storage_options are not modified * Refactor fix * Test also GitHub URL * Fix DownloadConfig.storage_options for GitHub URL
1 parent 7666c77 commit eb81144

File tree

2 files changed

+86
-23
lines changed

2 files changed

+86
-23
lines changed

src/datasets/utils/file_utils.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1160,7 +1160,7 @@ def _prepare_single_hop_path_and_storage_options(
11601160
urlpath = "hf://" + urlpath[len(config.HF_ENDPOINT) + 1 :].replace("/resolve/", "@", 1)
11611161
protocol = urlpath.split("://")[0] if "://" in urlpath else "file"
11621162
if download_config is not None and protocol in download_config.storage_options:
1163-
storage_options = download_config.storage_options[protocol]
1163+
storage_options = download_config.storage_options[protocol].copy()
11641164
elif download_config is not None and protocol not in download_config.storage_options:
11651165
storage_options = {
11661166
option_name: option_value
@@ -1169,40 +1169,34 @@ def _prepare_single_hop_path_and_storage_options(
11691169
}
11701170
else:
11711171
storage_options = {}
1172-
if storage_options:
1173-
storage_options = {protocol: storage_options}
1174-
if protocol in ["http", "https"]:
1175-
storage_options[protocol] = {
1176-
"headers": {
1177-
**get_authentication_headers_for_url(urlpath, token=token),
1178-
"user-agent": get_datasets_user_agent(),
1179-
},
1180-
"client_kwargs": {"trust_env": True}, # Enable reading proxy env variables.
1181-
**(storage_options.get(protocol, {})),
1182-
}
1172+
if protocol in {"http", "https"}:
1173+
client_kwargs = storage_options.pop("client_kwargs", {})
1174+
storage_options["client_kwargs"] = {"trust_env": True, **client_kwargs} # Enable reading proxy env variables
11831175
if "drive.google.com" in urlpath:
11841176
response = http_head(urlpath)
1185-
cookies = None
11861177
for k, v in response.cookies.items():
11871178
if k.startswith("download_warning"):
11881179
urlpath += "&confirm=" + v
11891180
cookies = response.cookies
1190-
storage_options[protocol] = {"cookies": cookies, **storage_options.get(protocol, {})}
1191-
# Fix Google Drive URL to avoid Virus scan warning
1192-
if "drive.google.com" in urlpath and "confirm=" not in urlpath:
1193-
urlpath += "&confirm=t"
1181+
storage_options = {"cookies": cookies, **storage_options}
1182+
# Fix Google Drive URL to avoid Virus scan warning
1183+
if "confirm=" not in urlpath:
1184+
urlpath += "&confirm=t"
11941185
if urlpath.startswith("https://raw.githubusercontent.com/"):
11951186
# Workaround for served data with gzip content-encoding: https://github.com/fsspec/filesystem_spec/issues/389
1196-
storage_options[protocol]["headers"]["Accept-Encoding"] = "identity"
1187+
headers = storage_options.pop("headers", {})
1188+
storage_options["headers"] = {"Accept-Encoding": "identity", **headers}
11971189
elif protocol == "hf":
1198-
storage_options[protocol] = {
1190+
storage_options = {
11991191
"token": token,
12001192
"endpoint": config.HF_ENDPOINT,
1201-
**storage_options.get(protocol, {}),
1193+
**storage_options,
12021194
}
12031195
# streaming with block_size=0 is only implemented in 0.21 (see https://github.com/huggingface/huggingface_hub/pull/1967)
12041196
if config.HF_HUB_VERSION < version.parse("0.21.0"):
1205-
storage_options[protocol]["block_size"] = "default"
1197+
storage_options["block_size"] = "default"
1198+
if storage_options:
1199+
storage_options = {protocol: storage_options}
12061200
return urlpath, storage_options
12071201

12081202

tests/test_file_utils.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from datasets.utils.file_utils import (
1313
OfflineModeIsEnabled,
1414
_get_extraction_protocol,
15+
_prepare_single_hop_path_and_storage_options,
1516
cached_path,
1617
fsspec_get,
1718
fsspec_head,
@@ -47,7 +48,7 @@
4748

4849
FILE_PATH = "file"
4950

50-
TEST_URL = "https://huggingface.co/datasets/hf-internal-testing/dataset_with_script/raw/main/some_text.txt"
51+
TEST_URL = "https://huggingface.co/datasets/hf-internal-testing/dataset_with_script/resolve/main/some_text.txt"
5152
TEST_URL_CONTENT = "foo\nbar\nfoobar"
5253

5354
TEST_GG_DRIVE_FILENAME = "train.tsv"
@@ -90,7 +91,6 @@ def test_cached_path_protocols(protocol, monkeypatch, tmp_path):
9091
urls = {"hf": "hf://datasets/org-name/ds-name@main/filename.ext", "s3": "s3://bucket-name/filename.ext"}
9192
url = urls[protocol]
9293
_ = cached_path(url, download_config=download_config)
93-
assert True
9494
for mock in [mock_fsspec_head, mock_fsspec_get]:
9595
assert mock.called
9696
assert mock.call_count == 1
@@ -197,6 +197,75 @@ def test_fsspec_offline(tmp_path_factory):
197197
fsspec_head("s3://huggingface.co")
198198

199199

200+
@pytest.mark.parametrize(
201+
"urlpath, download_config, expected_urlpath, expected_storage_options",
202+
[
203+
(
204+
"https://huggingface.co/datasets/hf-internal-testing/dataset_with_script/resolve/main/some_text.txt",
205+
DownloadConfig(),
206+
"hf://datasets/hf-internal-testing/dataset_with_script@main/some_text.txt",
207+
{"hf": {"endpoint": "https://huggingface.co", "token": None}},
208+
),
209+
(
210+
"https://huggingface.co/datasets/hf-internal-testing/dataset_with_script/resolve/main/some_text.txt",
211+
DownloadConfig(token="MY-TOKEN"),
212+
"hf://datasets/hf-internal-testing/dataset_with_script@main/some_text.txt",
213+
{"hf": {"endpoint": "https://huggingface.co", "token": "MY-TOKEN"}},
214+
),
215+
(
216+
"https://huggingface.co/datasets/hf-internal-testing/dataset_with_script/resolve/main/some_text.txt",
217+
DownloadConfig(token="MY-TOKEN", storage_options={"hf": {"on_error": "omit"}}),
218+
"hf://datasets/hf-internal-testing/dataset_with_script@main/some_text.txt",
219+
{"hf": {"endpoint": "https://huggingface.co", "token": "MY-TOKEN", "on_error": "omit"}},
220+
),
221+
(
222+
"https://domain.org/data.txt",
223+
DownloadConfig(),
224+
"https://domain.org/data.txt",
225+
{"https": {"client_kwargs": {"trust_env": True}}},
226+
),
227+
(
228+
"https://domain.org/data.txt",
229+
DownloadConfig(storage_options={"https": {"block_size": "omit"}}),
230+
"https://domain.org/data.txt",
231+
{"https": {"client_kwargs": {"trust_env": True}, "block_size": "omit"}},
232+
),
233+
(
234+
"https://domain.org/data.txt",
235+
DownloadConfig(storage_options={"https": {"client_kwargs": {"raise_for_status": True}}}),
236+
"https://domain.org/data.txt",
237+
{"https": {"client_kwargs": {"trust_env": True, "raise_for_status": True}}},
238+
),
239+
(
240+
"https://domain.org/data.txt",
241+
DownloadConfig(storage_options={"https": {"client_kwargs": {"trust_env": False}}}),
242+
"https://domain.org/data.txt",
243+
{"https": {"client_kwargs": {"trust_env": False}}},
244+
),
245+
(
246+
"https://raw.githubusercontent.com/data.txt",
247+
DownloadConfig(storage_options={"https": {"headers": {"x-test": "true"}}}),
248+
"https://raw.githubusercontent.com/data.txt",
249+
{
250+
"https": {
251+
"client_kwargs": {"trust_env": True},
252+
"headers": {"x-test": "true", "Accept-Encoding": "identity"},
253+
}
254+
},
255+
),
256+
],
257+
)
258+
def test_prepare_single_hop_path_and_storage_options(
259+
urlpath, download_config, expected_urlpath, expected_storage_options
260+
):
261+
original_download_config_storage_options = str(download_config.storage_options)
262+
prepared_urlpath, storage_options = _prepare_single_hop_path_and_storage_options(urlpath, download_config)
263+
assert prepared_urlpath == expected_urlpath
264+
assert storage_options == expected_storage_options
265+
# Check that DownloadConfig.storage_options are not modified:
266+
assert str(download_config.storage_options) == original_download_config_storage_options
267+
268+
200269
class DummyTestFS(AbstractFileSystem):
201270
protocol = "mock"
202271
_file_class = AbstractBufferedFile

0 commit comments

Comments
 (0)