Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 50 additions & 22 deletions src/datasets/download/streaming_download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,14 @@ def xisfile(path, use_auth_token: Optional[Union[str, bool]] = None) -> bool:
if is_local_path(main_hop):
return os.path.isfile(path)
else:
if rest_hops and fsspec.get_fs_token_paths(rest_hops[0])[0].protocol == "https":
storage_options = {
"https": {"headers": get_authentication_headers_for_url(rest_hops[0], use_auth_token=use_auth_token)}
}
if not rest_hops and (main_hop.startswith("http://") or main_hop.startswith("https://")):
main_hop, http_kwargs = _prepare_http_url_kwargs(main_hop, use_auth_token=use_auth_token)
storage_options = http_kwargs
elif rest_hops and (rest_hops[0].startswith("http://") or rest_hops[0].startswith("https://")):
url = rest_hops[0]
url, http_kwargs = _prepare_http_url_kwargs(url, use_auth_token=use_auth_token)
storage_options = {"https": http_kwargs}
path = "::".join([main_hop, url, *rest_hops[1:]])
else:
storage_options = None
fs, *_ = fsspec.get_fs_token_paths(path, storage_options=storage_options)
Expand All @@ -257,10 +261,14 @@ def xgetsize(path, use_auth_token: Optional[Union[str, bool]] = None) -> int:
if is_local_path(main_hop):
return os.path.getsize(path)
else:
if rest_hops and fsspec.get_fs_token_paths(rest_hops[0])[0].protocol == "https":
storage_options = {
"https": {"headers": get_authentication_headers_for_url(rest_hops[0], use_auth_token=use_auth_token)}
}
if not rest_hops and (main_hop.startswith("http://") or main_hop.startswith("https://")):
main_hop, http_kwargs = _prepare_http_url_kwargs(main_hop, use_auth_token=use_auth_token)
storage_options = http_kwargs
elif rest_hops and (rest_hops[0].startswith("http://") or rest_hops[0].startswith("https://")):
url = rest_hops[0]
url, http_kwargs = _prepare_http_url_kwargs(url, use_auth_token=use_auth_token)
storage_options = {"https": http_kwargs}
path = "::".join([main_hop, url, *rest_hops[1:]])
else:
storage_options = None
fs, *_ = fsspec.get_fs_token_paths(path, storage_options=storage_options)
Expand All @@ -285,14 +293,20 @@ def xisdir(path, use_auth_token: Optional[Union[str, bool]] = None) -> bool:
if is_local_path(main_hop):
return os.path.isdir(path)
else:
if rest_hops and fsspec.get_fs_token_paths(rest_hops[0])[0].protocol == "https":
storage_options = {
"https": {"headers": get_authentication_headers_for_url(rest_hops[0], use_auth_token=use_auth_token)}
}
if not rest_hops and (main_hop.startswith("http://") or main_hop.startswith("https://")):
raise NotImplementedError("os.path.isdir is not extended to support URLs in streaming mode")
elif rest_hops and (rest_hops[0].startswith("http://") or rest_hops[0].startswith("https://")):
url = rest_hops[0]
url, http_kwargs = _prepare_http_url_kwargs(url, use_auth_token=use_auth_token)
storage_options = {"https": http_kwargs}
path = "::".join([main_hop, url, *rest_hops[1:]])
else:
storage_options = None
fs, *_ = fsspec.get_fs_token_paths(path, storage_options=storage_options)
return fs.isdir(main_hop)
inner_path = main_hop.split("://")[1]
if not inner_path.strip("/"):
return True
return fs.isdir(inner_path)


def xrelpath(path, start=None):
Expand Down Expand Up @@ -463,14 +477,20 @@ def xlistdir(path: str, use_auth_token: Optional[Union[str, bool]] = None) -> Li
return os.listdir(path)
else:
# globbing inside a zip in a private repo requires authentication
if rest_hops and fsspec.get_fs_token_paths(rest_hops[0])[0].protocol == "https":
storage_options = {
"https": {"headers": get_authentication_headers_for_url(rest_hops[0], use_auth_token=use_auth_token)}
}
if not rest_hops and (main_hop.startswith("http://") or main_hop.startswith("https://")):
raise NotImplementedError("os.listdir is not extended to support URLs in streaming mode")
elif rest_hops and (rest_hops[0].startswith("http://") or rest_hops[0].startswith("https://")):
url = rest_hops[0]
url, http_kwargs = _prepare_http_url_kwargs(url, use_auth_token=use_auth_token)
storage_options = {"https": http_kwargs}
path = "::".join([main_hop, url, *rest_hops[1:]])
else:
storage_options = None
fs, *_ = fsspec.get_fs_token_paths(path, storage_options=storage_options)
objects = fs.listdir(main_hop.split("://")[1])
inner_path = main_hop.split("://")[1]
if inner_path.strip("/") and not fs.isdir(inner_path):
raise FileNotFoundError(f"Directory doesn't exist: {path}")
objects = fs.listdir(inner_path)
return [os.path.basename(obj["name"]) for obj in objects]


Expand All @@ -490,7 +510,9 @@ def xglob(urlpath, *, recursive=False, use_auth_token: Optional[Union[str, bool]
return glob.glob(main_hop, recursive=recursive)
else:
# globbing inside a zip in a private repo requires authentication
if rest_hops and (rest_hops[0].startswith("http://") or rest_hops[0].startswith("https://")):
if not rest_hops and (main_hop.startswith("http://") or main_hop.startswith("https://")):
raise NotImplementedError("glob.glob is not extended to support URLs in streaming mode")
elif rest_hops and (rest_hops[0].startswith("http://") or rest_hops[0].startswith("https://")):
url = rest_hops[0]
url, kwargs = _prepare_http_url_kwargs(url, use_auth_token=use_auth_token)
storage_options = {"https": kwargs}
Expand All @@ -502,7 +524,8 @@ def xglob(urlpath, *, recursive=False, use_auth_token: Optional[Union[str, bool]
# so to be able to glob patterns like "[0-9]", we have to call `fs.glob`.
# - Also "*" in get_fs_token_paths() only matches files: we have to call `fs.glob` to match directories.
# - If there is "**" in the pattern, `fs.glob` must be called anyway.
globbed_paths = fs.glob(main_hop)
inner_path = main_hop.split("://")[1]
globbed_paths = fs.glob(inner_path)
return ["::".join([f"{fs.protocol}://{globbed_path}"] + rest_hops) for globbed_path in globbed_paths]


Expand All @@ -522,15 +545,20 @@ def xwalk(urlpath, use_auth_token: Optional[Union[str, bool]] = None):
yield from os.walk(main_hop)
else:
# walking inside a zip in a private repo requires authentication
if rest_hops and (rest_hops[0].startswith("http://") or rest_hops[0].startswith("https://")):
if not rest_hops and (main_hop.startswith("http://") or main_hop.startswith("https://")):
raise NotImplementedError("os.walk is not extended to support URLs in streaming mode")
elif rest_hops and (rest_hops[0].startswith("http://") or rest_hops[0].startswith("https://")):
url = rest_hops[0]
url, kwargs = _prepare_http_url_kwargs(url, use_auth_token=use_auth_token)
storage_options = {"https": kwargs}
urlpath = "::".join([main_hop, url, *rest_hops[1:]])
else:
storage_options = None
fs, *_ = fsspec.get_fs_token_paths(urlpath, storage_options=storage_options)
for dirpath, dirnames, filenames in fs.walk(main_hop):
inner_path = main_hop.split("://")[1]
if inner_path.strip("/") and not fs.isdir(inner_path):
return []
for dirpath, dirnames, filenames in fs.walk(inner_path):
yield "::".join([f"{fs.protocol}://{dirpath}"] + rest_hops), dirnames, filenames


Expand Down
4 changes: 2 additions & 2 deletions tests/hub_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ def hf_private_dataset_repo_txt_data(hf_private_dataset_repo_txt_data_):


@pytest.fixture(scope="session")
def hf_private_dataset_repo_zipped_txt_data_(hf_api: HfApi, hf_token, zip_csv_path):
def hf_private_dataset_repo_zipped_txt_data_(hf_api: HfApi, hf_token, zip_csv_with_dir_path):
repo_name = f"repo_zipped_txt_data-{int(time.time() * 10e3)}"
create_repo(hf_api, repo_name, token=hf_token, organization=USER, repo_type="dataset", private=True)
repo_id = f"{USER}/{repo_name}"
hf_api.upload_file(
token=hf_token,
path_or_fileobj=str(zip_csv_path),
path_or_fileobj=str(zip_csv_with_dir_path),
path_in_repo="data.zip",
repo_id=repo_id,
repo_type="dataset",
Expand Down
78 changes: 78 additions & 0 deletions tests/test_streaming_download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
xrelpath,
xsplit,
xsplitext,
xwalk,
)
from datasets.filesystems import COMPRESSION_FILESYSTEMS
from datasets.utils.file_utils import hf_hub_url

from .utils import require_lz4, require_zstandard, slow

Expand Down Expand Up @@ -302,6 +304,16 @@ def test_xlistdir(input_path, expected_paths, tmp_path, mock_fsspec):
assert output_paths == expected_paths


def test_xlistdir_private(hf_private_dataset_repo_zipped_txt_data, hf_token):
root_url = hf_hub_url(hf_private_dataset_repo_zipped_txt_data, "data.zip")
assert len(xlistdir("zip://::" + root_url, use_auth_token=hf_token)) == 1
assert len(xlistdir("zip://main_dir::" + root_url, use_auth_token=hf_token)) == 2
with pytest.raises(FileNotFoundError):
xlistdir("zip://qwertyuiop::" + root_url, use_auth_token=hf_token)
with pytest.raises(NotImplementedError):
xlistdir(root_url, use_auth_token=hf_token)


@pytest.mark.parametrize(
"input_path, isdir",
[
Expand All @@ -319,6 +331,15 @@ def test_xisdir(input_path, isdir, tmp_path, mock_fsspec):
assert xisdir(input_path) == isdir


def test_xisdir_private(hf_private_dataset_repo_zipped_txt_data, hf_token):
root_url = hf_hub_url(hf_private_dataset_repo_zipped_txt_data, "data.zip")
assert xisdir("zip://::" + root_url, use_auth_token=hf_token) is True
assert xisdir("zip://main_dir::" + root_url, use_auth_token=hf_token) is True
assert xisdir("zip://qwertyuiop::" + root_url, use_auth_token=hf_token) is False
with pytest.raises(NotImplementedError):
xisdir(root_url, use_auth_token=hf_token)


@pytest.mark.parametrize(
"input_path, isfile",
[
Expand All @@ -335,6 +356,12 @@ def test_xisfile(input_path, isfile, tmp_path, mock_fsspec):
assert xisfile(input_path) == isfile


def test_xisfile_private(hf_private_dataset_repo_txt_data, hf_token):
root_url = hf_hub_url(hf_private_dataset_repo_txt_data, "")
assert xisfile(root_url + "data/text_data.txt", use_auth_token=hf_token) is True
assert xisfile(root_url + "qwertyuiop", use_auth_token=hf_token) is False


@pytest.mark.parametrize(
"input_path, size",
[
Expand All @@ -351,6 +378,13 @@ def test_xgetsize(input_path, size, tmp_path, mock_fsspec):
assert xgetsize(input_path) == size


def test_xgetsize_private(hf_private_dataset_repo_txt_data, hf_token):
root_url = hf_hub_url(hf_private_dataset_repo_txt_data, "")
assert xgetsize(root_url + "data/text_data.txt", use_auth_token=hf_token) == 39
with pytest.raises(FileNotFoundError):
xgetsize(root_url + "qwertyuiop", use_auth_token=hf_token)


@pytest.mark.parametrize(
"input_path, expected_paths",
[
Expand Down Expand Up @@ -386,6 +420,50 @@ def test_xglob(input_path, expected_paths, tmp_path, mock_fsspec):
assert output_paths == expected_paths


def test_xglob_private(hf_private_dataset_repo_zipped_txt_data, hf_token):
root_url = hf_hub_url(hf_private_dataset_repo_zipped_txt_data, "data.zip")
assert len(xglob("zip://**::" + root_url, use_auth_token=hf_token)) == 3
assert len(xglob("zip://qwertyuiop/*::" + root_url, use_auth_token=hf_token)) == 0


@pytest.mark.parametrize(
"input_path, expected_outputs",
[
("tmp_path", [("", [], ["file1.txt", "file2.txt", "README.md"])]),
(
"mock://top_level/second_level",
[
("mock://top_level/second_level", ["date=2019-10-01", "date=2019-10-02", "date=2019-10-04"], []),
("mock://top_level/second_level/date=2019-10-01", [], ["a.parquet", "b.parquet"]),
("mock://top_level/second_level/date=2019-10-02", [], ["a.parquet"]),
("mock://top_level/second_level/date=2019-10-04", [], ["a.parquet"]),
],
),
],
)
def test_xwalk(input_path, expected_outputs, tmp_path, mock_fsspec):
if input_path.startswith("tmp_path"):
input_path = input_path.replace("/", os.sep).replace("tmp_path", str(tmp_path))
expected_outputs = sorted(
[
(str(tmp_path / dirpath).rstrip("/"), sorted(dirnames), sorted(filenames))
for dirpath, dirnames, filenames in expected_outputs
]
)
for file in ["file1.txt", "file2.txt", "README.md"]:
(tmp_path / file).touch()
outputs = sorted(xwalk(input_path))
outputs = [(dirpath, sorted(dirnames), sorted(filenames)) for dirpath, dirnames, filenames in outputs]
assert outputs == expected_outputs


def test_xwalk_private(hf_private_dataset_repo_zipped_txt_data, hf_token):
root_url = hf_hub_url(hf_private_dataset_repo_zipped_txt_data, "data.zip")
assert len(list(xwalk("zip://::" + root_url, use_auth_token=hf_token))) == 2
assert len(list(xwalk("zip://main_dir::" + root_url, use_auth_token=hf_token))) == 1
assert len(list(xwalk("zip://qwertyuiop::" + root_url, use_auth_token=hf_token))) == 0


@pytest.mark.parametrize(
"input_path, start_path, expected_path",
[
Expand Down