|
1 | 1 | import os |
2 | 2 | import re |
| 3 | +from dataclasses import dataclass, field |
3 | 4 | from pathlib import Path |
4 | 5 | from unittest.mock import MagicMock, patch |
5 | 6 |
|
@@ -78,24 +79,85 @@ def tmpfs_file(tmpfs): |
78 | 79 | return FILE_PATH |
79 | 80 |
|
80 | 81 |
|
81 | | -@pytest.mark.parametrize("protocol", ["hf", "s3"]) |
82 | | -def test_cached_path_protocols(protocol, monkeypatch, tmp_path): |
| 82 | +@pytest.mark.parametrize( |
| 83 | + "protocol, download_config_storage_options, expected_fsspec_called", |
| 84 | + [ |
| 85 | + ("hf", {}, True), |
| 86 | + ("s3", {"s3": {"anon": True}}, True), |
| 87 | + # HTTP calls fsspec only if passed HTTP download_config.storage_options: |
| 88 | + ("https", {"https": {"block_size": "omit"}}, True), |
| 89 | + ("https", {}, False), |
| 90 | + ], |
| 91 | +) |
| 92 | +def test_cached_path_calls_fsspec_for_protocols( |
| 93 | + protocol, download_config_storage_options, expected_fsspec_called, monkeypatch, tmp_path |
| 94 | +): |
83 | 95 | # GH-6598: Test no TypeError: __init__() got an unexpected keyword argument 'hf' |
| 96 | + # fsspec_head/get: |
84 | 97 | mock_fsspec_head = MagicMock(return_value={}) |
85 | 98 | mock_fsspec_get = MagicMock(return_value=None) |
86 | 99 | monkeypatch.setattr("datasets.utils.file_utils.fsspec_head", mock_fsspec_head) |
87 | 100 | monkeypatch.setattr("datasets.utils.file_utils.fsspec_get", mock_fsspec_get) |
| 101 | + |
| 102 | + # http_head_get: |
| 103 | + @dataclass |
| 104 | + class Response: |
| 105 | + status_code: int |
| 106 | + headers: dict = field(default_factory=dict) |
| 107 | + cookies: dict = field(default_factory=dict) |
| 108 | + |
| 109 | + mock_http_head = MagicMock(return_value=Response(status_code=200)) |
| 110 | + mock_http_get = MagicMock(return_value=None) |
| 111 | + monkeypatch.setattr("datasets.utils.file_utils.http_head", mock_http_head) |
| 112 | + monkeypatch.setattr("datasets.utils.file_utils.http_get", mock_http_get) |
| 113 | + # Test: |
88 | 114 | cache_dir = tmp_path / "cache" |
89 | | - storage_options = {} if protocol == "hf" else {"s3": {"anon": True}} |
90 | | - download_config = DownloadConfig(cache_dir=cache_dir, storage_options=storage_options) |
91 | | - urls = {"hf": "hf://datasets/org-name/ds-name@main/filename.ext", "s3": "s3://bucket-name/filename.ext"} |
| 115 | + download_config = DownloadConfig(cache_dir=cache_dir, storage_options=download_config_storage_options) |
| 116 | + urls = { |
| 117 | + "hf": "hf://datasets/org-name/ds-name@main/filename.ext", |
| 118 | + "https": "https://doamin.org/filename.ext", |
| 119 | + "s3": "s3://bucket-name/filename.ext", |
| 120 | + } |
92 | 121 | url = urls[protocol] |
93 | 122 | _ = cached_path(url, download_config=download_config) |
94 | | - for mock in [mock_fsspec_head, mock_fsspec_get]: |
95 | | - assert mock.called |
96 | | - assert mock.call_count == 1 |
97 | | - assert mock.call_args.args[0] == url |
98 | | - assert list(mock.call_args.kwargs["storage_options"].keys()) == [protocol] |
| 123 | + if expected_fsspec_called: |
| 124 | + for mock in [mock_fsspec_head, mock_fsspec_get]: |
| 125 | + assert mock.called |
| 126 | + assert mock.call_count == 1 |
| 127 | + assert mock.call_args.args[0] == url |
| 128 | + assert list(mock.call_args.kwargs["storage_options"].keys()) == [protocol] |
| 129 | + for mock in [mock_http_head, mock_http_get]: |
| 130 | + assert not mock.called |
| 131 | + else: |
| 132 | + for mock in [mock_fsspec_head, mock_fsspec_get]: |
| 133 | + assert not mock.called |
| 134 | + for mock in [mock_http_head, mock_http_get]: |
| 135 | + assert mock.called |
| 136 | + assert mock.call_count == 1 |
| 137 | + assert mock.call_args.args[0] == url |
| 138 | + |
| 139 | + |
| 140 | +@pytest.mark.parametrize( |
| 141 | + "download_config_storage_options, expected_storage_options_passed_to_get_from_catch", |
| 142 | + [ |
| 143 | + ({}, {}), # No DownloadConfig.storage_options |
| 144 | + ({"https": {"block_size": "omit"}}, {"https": {"client_kwargs": {"trust_env": True}, "block_size": "omit"}}), |
| 145 | + ], |
| 146 | +) |
| 147 | +def test_cached_path_passes_http_storage_options_to_get_from_cache_only_if_present_in_download_config( |
| 148 | + download_config_storage_options, expected_storage_options_passed_to_get_from_catch, monkeypatch, tmp_path |
| 149 | +): |
| 150 | + # Test cached_path passes HTTP storage_options to get_from_cache only if passed HTTP download_config.storage_options |
| 151 | + mock_get_from_catch = MagicMock(return_value=None) |
| 152 | + monkeypatch.setattr("datasets.utils.file_utils.get_from_cache", mock_get_from_catch) |
| 153 | + url = "https://domain.org/data.txt" |
| 154 | + cache_dir = tmp_path / "cache" |
| 155 | + download_config = DownloadConfig(cache_dir=cache_dir, storage_options=download_config_storage_options) |
| 156 | + _ = cached_path(url, download_config=download_config) |
| 157 | + assert mock_get_from_catch.called |
| 158 | + assert mock_get_from_catch.call_count == 1 |
| 159 | + assert mock_get_from_catch.call_args.args[0] == url |
| 160 | + assert mock_get_from_catch.call_args.kwargs["storage_options"] == expected_storage_options_passed_to_get_from_catch |
99 | 161 |
|
100 | 162 |
|
101 | 163 | @pytest.mark.parametrize("compression_format", ["gzip", "xz", "zstd"]) |
|
0 commit comments