Skip to content

Commit 0d041d4

Browse files
committed
add test
1 parent b429ca2 commit 0d041d4

File tree

3 files changed

+29
-6
lines changed

3 files changed

+29
-6
lines changed

src/datasets/utils/file_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -330,24 +330,24 @@ def _request_with_retry(
330330

331331
def fsspec_head(url, timeout=10.0):
332332
_raise_if_offline_mode_is_enabled(f"Tried to reach {url}")
333-
fs, _, paths = fsspec.get_fs_token_paths(url)
333+
fs, _, paths = fsspec.get_fs_token_paths(url, storage_options={"requests_timeout": timeout})
334334
if len(paths) > 1:
335-
raise ValueError("HEAD can be called with at most one path but was called with {paths}")
335+
raise ValueError(f"HEAD can be called with at most one path but was called with {paths}")
336336
return fs.info(paths[0], timeout=timeout)
337337

338338

339339
def fsspec_get(url, temp_file, timeout=10.0, desc=None):
340340
_raise_if_offline_mode_is_enabled(f"Tried to reach {url}")
341-
fs, _, paths = fsspec.get_fs_token_paths(url)
341+
fs, _, paths = fsspec.get_fs_token_paths(url, storage_options={"requests_timeout": timeout})
342342
if len(paths) > 1:
343-
raise ValueError("GET can be called with at most one path but was called with {paths}")
343+
raise ValueError(f"GET can be called with at most one path but was called with {paths}")
344344
callback = fsspec.callbacks.TqdmCallback(
345345
tqdm_kwargs={
346346
"desc": desc or "Downloading",
347347
"disable": logging.is_progress_bar_enabled(),
348348
}
349349
)
350-
fs.get(paths[0], temp_file, timeout=timeout, callback=callback)
350+
fs.get_file(paths[0], temp_file.name, timeout=timeout, callback=callback)
351351

352352

353353
def ftp_head(url, timeout=10.0):

tests/fixtures/fsspec.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ def info(self, path, *args, **kwargs):
4040
out["name"] = out["name"][len(self.local_root_dir) :]
4141
return out
4242

43+
def get_file(self, rpath, lpath, *args, **kwargs):
44+
rpath = posixpath.join(self.local_root_dir, self._strip_protocol(rpath))
45+
return self._fs.get_file(rpath, lpath, *args, **kwargs)
46+
4347
def cp_file(self, path1, path2, *args, **kwargs):
4448
path1 = posixpath.join(self.local_root_dir, self._strip_protocol(path1))
4549
path2 = posixpath.join(self.local_root_dir, self._strip_protocol(path2))

tests/test_file_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
fsspec_head,
1414
ftp_get,
1515
ftp_head,
16+
get_from_cache,
1617
http_get,
1718
http_head,
1819
)
@@ -22,16 +23,25 @@
2223
Text data.
2324
Second line of data."""
2425

26+
FILE_PATH = "file"
27+
2528

2629
@pytest.fixture(scope="session")
2730
def zstd_path(tmp_path_factory):
28-
path = tmp_path_factory.mktemp("data") / "file.zstd"
31+
path = tmp_path_factory.mktemp("data") / FILE_PATH
2932
data = bytes(FILE_CONTENT, "utf-8")
3033
with zstd.open(path, "wb") as f:
3134
f.write(data)
3235
return path
3336

3437

38+
@pytest.fixture
39+
def mockfs_file(mockfs):
40+
with open(os.path.join(mockfs.local_root_dir, FILE_PATH), "w") as f:
41+
f.write(FILE_CONTENT)
42+
return mockfs
43+
44+
3545
@pytest.mark.parametrize("compression_format", ["gzip", "xz", "zstd"])
3646
def test_cached_path_extract(compression_format, gz_file, xz_file, zstd_path, tmp_path, text_file):
3747
input_paths = {"gzip": gz_file, "xz": xz_file, "zstd": zstd_path}
@@ -89,6 +99,15 @@ def test_cached_path_missing_local(tmp_path):
8999
cached_path(missing_file)
90100

91101

102+
def test_get_from_cache_fsspec(mockfs_file):
103+
with patch("datasets.utils.file_utils.fsspec.get_fs_token_paths") as mock_get_fs_token_paths:
104+
mock_get_fs_token_paths.return_value = (mockfs_file, "", [FILE_PATH])
105+
output_path = get_from_cache("mock://huggingface.co")
106+
with open(output_path) as f:
107+
output_file_content = f.read()
108+
assert output_file_content == FILE_CONTENT
109+
110+
92111
@patch("datasets.config.HF_DATASETS_OFFLINE", True)
93112
def test_cached_path_offline():
94113
with pytest.raises(OfflineModeIsEnabled):

0 commit comments

Comments
 (0)