Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 15 additions & 21 deletions src/datasets/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,7 +1150,7 @@ def _prepare_single_hop_path_and_storage_options(
urlpath = "hf://" + urlpath[len(config.HF_ENDPOINT) + 1 :].replace("/resolve/", "@", 1)
protocol = urlpath.split("://")[0] if "://" in urlpath else "file"
if download_config is not None and protocol in download_config.storage_options:
storage_options = download_config.storage_options[protocol]
storage_options = download_config.storage_options[protocol].copy()
elif download_config is not None and protocol not in download_config.storage_options:
storage_options = {
option_name: option_value
Expand All @@ -1159,40 +1159,34 @@ def _prepare_single_hop_path_and_storage_options(
}
else:
storage_options = {}
if storage_options:
storage_options = {protocol: storage_options}
if protocol in ["http", "https"]:
storage_options[protocol] = {
"headers": {
**get_authentication_headers_for_url(urlpath, token=token),
"user-agent": get_datasets_user_agent(),
},
"client_kwargs": {"trust_env": True}, # Enable reading proxy env variables.
**(storage_options.get(protocol, {})),
}
if protocol in {"http", "https"}:
client_kwargs = storage_options.pop("client_kwargs", {})
storage_options["client_kwargs"] = {"trust_env": True, **client_kwargs} # Enable reading proxy env variables
if "drive.google.com" in urlpath:
response = http_head(urlpath)
cookies = None
for k, v in response.cookies.items():
if k.startswith("download_warning"):
urlpath += "&confirm=" + v
cookies = response.cookies
storage_options[protocol] = {"cookies": cookies, **storage_options.get(protocol, {})}
# Fix Google Drive URL to avoid Virus scan warning
if "drive.google.com" in urlpath and "confirm=" not in urlpath:
urlpath += "&confirm=t"
storage_options = {"cookies": cookies, **storage_options}
# Fix Google Drive URL to avoid Virus scan warning
if "confirm=" not in urlpath:
urlpath += "&confirm=t"
if urlpath.startswith("https://raw.githubusercontent.com/"):
# Workaround for served data with gzip content-encoding: https://github.com/fsspec/filesystem_spec/issues/389
storage_options[protocol]["headers"]["Accept-Encoding"] = "identity"
headers = storage_options.pop("headers", {})
storage_options["headers"] = {"Accept-Encoding": "identity", **headers}
elif protocol == "hf":
storage_options[protocol] = {
storage_options = {
"token": token,
"endpoint": config.HF_ENDPOINT,
**storage_options.get(protocol, {}),
**storage_options,
}
# streaming with block_size=0 is only implemented in 0.21 (see https://github.com/huggingface/huggingface_hub/pull/1967)
if config.HF_HUB_VERSION < version.parse("0.21.0"):
storage_options[protocol]["block_size"] = "default"
storage_options["block_size"] = "default"
if storage_options:
storage_options = {protocol: storage_options}
return urlpath, storage_options


Expand Down
73 changes: 71 additions & 2 deletions tests/test_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from datasets.utils.file_utils import (
OfflineModeIsEnabled,
_get_extraction_protocol,
_prepare_single_hop_path_and_storage_options,
cached_path,
fsspec_get,
fsspec_head,
Expand Down Expand Up @@ -47,7 +48,7 @@

FILE_PATH = "file"

TEST_URL = "https://huggingface.co/datasets/hf-internal-testing/dataset_with_script/raw/main/some_text.txt"
TEST_URL = "https://huggingface.co/datasets/hf-internal-testing/dataset_with_script/resolve/main/some_text.txt"
TEST_URL_CONTENT = "foo\nbar\nfoobar"

TEST_GG_DRIVE_FILENAME = "train.tsv"
Expand Down Expand Up @@ -90,7 +91,6 @@ def test_cached_path_protocols(protocol, monkeypatch, tmp_path):
urls = {"hf": "hf://datasets/org-name/ds-name@main/filename.ext", "s3": "s3://bucket-name/filename.ext"}
url = urls[protocol]
_ = cached_path(url, download_config=download_config)
assert True
for mock in [mock_fsspec_head, mock_fsspec_get]:
assert mock.called
assert mock.call_count == 1
Expand Down Expand Up @@ -197,6 +197,75 @@ def test_fsspec_offline(tmp_path_factory):
fsspec_head("s3://huggingface.co")


@pytest.mark.parametrize(
"urlpath, download_config, expected_urlpath, expected_storage_options",
[
(
"https://huggingface.co/datasets/hf-internal-testing/dataset_with_script/resolve/main/some_text.txt",
DownloadConfig(),
"hf://datasets/hf-internal-testing/dataset_with_script@main/some_text.txt",
{"hf": {"endpoint": "https://huggingface.co", "token": None}},
),
(
"https://huggingface.co/datasets/hf-internal-testing/dataset_with_script/resolve/main/some_text.txt",
DownloadConfig(token="MY-TOKEN"),
"hf://datasets/hf-internal-testing/dataset_with_script@main/some_text.txt",
{"hf": {"endpoint": "https://huggingface.co", "token": "MY-TOKEN"}},
),
(
"https://huggingface.co/datasets/hf-internal-testing/dataset_with_script/resolve/main/some_text.txt",
DownloadConfig(token="MY-TOKEN", storage_options={"hf": {"on_error": "omit"}}),
"hf://datasets/hf-internal-testing/dataset_with_script@main/some_text.txt",
{"hf": {"endpoint": "https://huggingface.co", "token": "MY-TOKEN", "on_error": "omit"}},
),
(
"https://domain.org/data.txt",
DownloadConfig(),
"https://domain.org/data.txt",
{"https": {"client_kwargs": {"trust_env": True}}},
),
(
"https://domain.org/data.txt",
DownloadConfig(storage_options={"https": {"block_size": "omit"}}),
"https://domain.org/data.txt",
{"https": {"client_kwargs": {"trust_env": True}, "block_size": "omit"}},
),
(
"https://domain.org/data.txt",
DownloadConfig(storage_options={"https": {"client_kwargs": {"raise_for_status": True}}}),
"https://domain.org/data.txt",
{"https": {"client_kwargs": {"trust_env": True, "raise_for_status": True}}},
),
(
"https://domain.org/data.txt",
DownloadConfig(storage_options={"https": {"client_kwargs": {"trust_env": False}}}),
"https://domain.org/data.txt",
{"https": {"client_kwargs": {"trust_env": False}}},
),
(
"https://raw.githubusercontent.com/data.txt",
DownloadConfig(storage_options={"https": {"headers": {"x-test": "true"}}}),
"https://raw.githubusercontent.com/data.txt",
{
"https": {
"client_kwargs": {"trust_env": True},
"headers": {"x-test": "true", "Accept-Encoding": "identity"},
}
},
),
],
)
def test_prepare_single_hop_path_and_storage_options(
urlpath, download_config, expected_urlpath, expected_storage_options
):
original_download_config_storage_options = str(download_config.storage_options)
prepared_urlpath, storage_options = _prepare_single_hop_path_and_storage_options(urlpath, download_config)
assert prepared_urlpath == expected_urlpath
assert storage_options == expected_storage_options
# Check that DownloadConfig.storage_options are not modified:
assert str(download_config.storage_options) == original_download_config_storage_options


class DummyTestFS(AbstractFileSystem):
protocol = "mock"
_file_class = AbstractBufferedFile
Expand Down