diff --git a/src/datasets/download/streaming_download_manager.py b/src/datasets/download/streaming_download_manager.py index f0aa67dae73..697f5122b44 100644 --- a/src/datasets/download/streaming_download_manager.py +++ b/src/datasets/download/streaming_download_manager.py @@ -234,10 +234,14 @@ def xisfile(path, use_auth_token: Optional[Union[str, bool]] = None) -> bool: if is_local_path(main_hop): return os.path.isfile(path) else: - if rest_hops and fsspec.get_fs_token_paths(rest_hops[0])[0].protocol == "https": - storage_options = { - "https": {"headers": get_authentication_headers_for_url(rest_hops[0], use_auth_token=use_auth_token)} - } + if not rest_hops and (main_hop.startswith("http://") or main_hop.startswith("https://")): + main_hop, http_kwargs = _prepare_http_url_kwargs(main_hop, use_auth_token=use_auth_token) + storage_options = http_kwargs + elif rest_hops and (rest_hops[0].startswith("http://") or rest_hops[0].startswith("https://")): + url = rest_hops[0] + url, http_kwargs = _prepare_http_url_kwargs(url, use_auth_token=use_auth_token) + storage_options = {"https": http_kwargs} + path = "::".join([main_hop, url, *rest_hops[1:]]) else: storage_options = None fs, *_ = fsspec.get_fs_token_paths(path, storage_options=storage_options) @@ -257,10 +261,14 @@ def xgetsize(path, use_auth_token: Optional[Union[str, bool]] = None) -> int: if is_local_path(main_hop): return os.path.getsize(path) else: - if rest_hops and fsspec.get_fs_token_paths(rest_hops[0])[0].protocol == "https": - storage_options = { - "https": {"headers": get_authentication_headers_for_url(rest_hops[0], use_auth_token=use_auth_token)} - } + if not rest_hops and (main_hop.startswith("http://") or main_hop.startswith("https://")): + main_hop, http_kwargs = _prepare_http_url_kwargs(main_hop, use_auth_token=use_auth_token) + storage_options = http_kwargs + elif rest_hops and (rest_hops[0].startswith("http://") or rest_hops[0].startswith("https://")): + url = rest_hops[0] + url, http_kwargs = _prepare_http_url_kwargs(url, use_auth_token=use_auth_token) + storage_options = {"https": http_kwargs} + path = "::".join([main_hop, url, *rest_hops[1:]]) else: storage_options = None fs, *_ = fsspec.get_fs_token_paths(path, storage_options=storage_options) @@ -285,14 +293,20 @@ def xisdir(path, use_auth_token: Optional[Union[str, bool]] = None) -> bool: if is_local_path(main_hop): return os.path.isdir(path) else: - if rest_hops and fsspec.get_fs_token_paths(rest_hops[0])[0].protocol == "https": - storage_options = { - "https": {"headers": get_authentication_headers_for_url(rest_hops[0], use_auth_token=use_auth_token)} - } + if not rest_hops and (main_hop.startswith("http://") or main_hop.startswith("https://")): + raise NotImplementedError("os.path.isdir is not extended to support URLs in streaming mode") + elif rest_hops and (rest_hops[0].startswith("http://") or rest_hops[0].startswith("https://")): + url = rest_hops[0] + url, http_kwargs = _prepare_http_url_kwargs(url, use_auth_token=use_auth_token) + storage_options = {"https": http_kwargs} + path = "::".join([main_hop, url, *rest_hops[1:]]) else: storage_options = None fs, *_ = fsspec.get_fs_token_paths(path, storage_options=storage_options) - return fs.isdir(main_hop) + inner_path = main_hop.split("://")[1] + if not inner_path.strip("/"): + return True + return fs.isdir(inner_path) def xrelpath(path, start=None): @@ -463,14 +477,20 @@ def xlistdir(path: str, use_auth_token: Optional[Union[str, bool]] = None) -> Li return os.listdir(path) else: # globbing inside a zip in a private repo requires authentication - if rest_hops and fsspec.get_fs_token_paths(rest_hops[0])[0].protocol == "https": - storage_options = { - "https": {"headers": get_authentication_headers_for_url(rest_hops[0], use_auth_token=use_auth_token)} - } + if not rest_hops and (main_hop.startswith("http://") or main_hop.startswith("https://")): + raise NotImplementedError("os.listdir is not extended to support URLs in streaming mode") + elif rest_hops and (rest_hops[0].startswith("http://") or rest_hops[0].startswith("https://")): + url = rest_hops[0] + url, http_kwargs = _prepare_http_url_kwargs(url, use_auth_token=use_auth_token) + storage_options = {"https": http_kwargs} + path = "::".join([main_hop, url, *rest_hops[1:]]) else: storage_options = None fs, *_ = fsspec.get_fs_token_paths(path, storage_options=storage_options) - objects = fs.listdir(main_hop.split("://")[1]) + inner_path = main_hop.split("://")[1] + if inner_path.strip("/") and not fs.isdir(inner_path): + raise FileNotFoundError(f"Directory doesn't exist: {path}") + objects = fs.listdir(inner_path) return [os.path.basename(obj["name"]) for obj in objects] @@ -490,7 +510,9 @@ def xglob(urlpath, *, recursive=False, use_auth_token: Optional[Union[str, bool] return glob.glob(main_hop, recursive=recursive) else: # globbing inside a zip in a private repo requires authentication - if rest_hops and (rest_hops[0].startswith("http://") or rest_hops[0].startswith("https://")): + if not rest_hops and (main_hop.startswith("http://") or main_hop.startswith("https://")): + raise NotImplementedError("glob.glob is not extended to support URLs in streaming mode") + elif rest_hops and (rest_hops[0].startswith("http://") or rest_hops[0].startswith("https://")): url = rest_hops[0] url, kwargs = _prepare_http_url_kwargs(url, use_auth_token=use_auth_token) storage_options = {"https": kwargs} @@ -502,7 +524,8 @@ def xglob(urlpath, *, recursive=False, use_auth_token: Optional[Union[str, bool] # so to be able to glob patterns like "[0-9]", we have to call `fs.glob`. # - Also "*" in get_fs_token_paths() only matches files: we have to call `fs.glob` to match directories. # - If there is "**" in the pattern, `fs.glob` must be called anyway. - globbed_paths = fs.glob(main_hop) + inner_path = main_hop.split("://")[1] + globbed_paths = fs.glob(inner_path) return ["::".join([f"{fs.protocol}://{globbed_path}"] + rest_hops) for globbed_path in globbed_paths] @@ -522,7 +545,9 @@ def xwalk(urlpath, use_auth_token: Optional[Union[str, bool]] = None): yield from os.walk(main_hop) else: # walking inside a zip in a private repo requires authentication - if rest_hops and (rest_hops[0].startswith("http://") or rest_hops[0].startswith("https://")): + if not rest_hops and (main_hop.startswith("http://") or main_hop.startswith("https://")): + raise NotImplementedError("os.walk is not extended to support URLs in streaming mode") + elif rest_hops and (rest_hops[0].startswith("http://") or rest_hops[0].startswith("https://")): url = rest_hops[0] url, kwargs = _prepare_http_url_kwargs(url, use_auth_token=use_auth_token) storage_options = {"https": kwargs} @@ -530,7 +555,10 @@ def xwalk(urlpath, use_auth_token: Optional[Union[str, bool]] = None): else: storage_options = None fs, *_ = fsspec.get_fs_token_paths(urlpath, storage_options=storage_options) - for dirpath, dirnames, filenames in fs.walk(main_hop): + inner_path = main_hop.split("://")[1] + if inner_path.strip("/") and not fs.isdir(inner_path): + return [] + for dirpath, dirnames, filenames in fs.walk(inner_path): yield "::".join([f"{fs.protocol}://{dirpath}"] + rest_hops), dirnames, filenames diff --git a/tests/hub_fixtures.py b/tests/hub_fixtures.py index 6bbc4f27e18..d710a410a44 100644 --- a/tests/hub_fixtures.py +++ b/tests/hub_fixtures.py @@ -60,13 +60,13 @@ def hf_private_dataset_repo_txt_data(hf_private_dataset_repo_txt_data_): @pytest.fixture(scope="session") -def hf_private_dataset_repo_zipped_txt_data_(hf_api: HfApi, hf_token, zip_csv_path): +def hf_private_dataset_repo_zipped_txt_data_(hf_api: HfApi, hf_token, zip_csv_with_dir_path): repo_name = f"repo_zipped_txt_data-{int(time.time() * 10e3)}" create_repo(hf_api, repo_name, token=hf_token, organization=USER, repo_type="dataset", private=True) repo_id = f"{USER}/{repo_name}" hf_api.upload_file( token=hf_token, - path_or_fileobj=str(zip_csv_path), + path_or_fileobj=str(zip_csv_with_dir_path), path_in_repo="data.zip", repo_id=repo_id, repo_type="dataset", diff --git a/tests/test_streaming_download_manager.py b/tests/test_streaming_download_manager.py index f57f5d47e3d..5da2298323a 100644 --- a/tests/test_streaming_download_manager.py +++ b/tests/test_streaming_download_manager.py @@ -23,8 +23,10 @@ xrelpath, xsplit, xsplitext, + xwalk, ) from datasets.filesystems import COMPRESSION_FILESYSTEMS +from datasets.utils.file_utils import hf_hub_url from .utils import require_lz4, require_zstandard, slow @@ -302,6 +304,16 @@ def test_xlistdir(input_path, expected_paths, tmp_path, mock_fsspec): assert output_paths == expected_paths +def test_xlistdir_private(hf_private_dataset_repo_zipped_txt_data, hf_token): + root_url = hf_hub_url(hf_private_dataset_repo_zipped_txt_data, "data.zip") + assert len(xlistdir("zip://::" + root_url, use_auth_token=hf_token)) == 1 + assert len(xlistdir("zip://main_dir::" + root_url, use_auth_token=hf_token)) == 2 + with pytest.raises(FileNotFoundError): + xlistdir("zip://qwertyuiop::" + root_url, use_auth_token=hf_token) + with pytest.raises(NotImplementedError): + xlistdir(root_url, use_auth_token=hf_token) + + @pytest.mark.parametrize( "input_path, isdir", [ @@ -319,6 +331,15 @@ def test_xisdir(input_path, isdir, tmp_path, mock_fsspec): assert xisdir(input_path) == isdir +def test_xisdir_private(hf_private_dataset_repo_zipped_txt_data, hf_token): + root_url = hf_hub_url(hf_private_dataset_repo_zipped_txt_data, "data.zip") + assert xisdir("zip://::" + root_url, use_auth_token=hf_token) is True + assert xisdir("zip://main_dir::" + root_url, use_auth_token=hf_token) is True + assert xisdir("zip://qwertyuiop::" + root_url, use_auth_token=hf_token) is False + with pytest.raises(NotImplementedError): + xisdir(root_url, use_auth_token=hf_token) + + @pytest.mark.parametrize( "input_path, isfile", [ @@ -335,6 +356,12 @@ def test_xisfile(input_path, isfile, tmp_path, mock_fsspec): assert xisfile(input_path) == isfile +def test_xisfile_private(hf_private_dataset_repo_txt_data, hf_token): + root_url = hf_hub_url(hf_private_dataset_repo_txt_data, "") + assert xisfile(root_url + "data/text_data.txt", use_auth_token=hf_token) is True + assert xisfile(root_url + "qwertyuiop", use_auth_token=hf_token) is False + + @pytest.mark.parametrize( "input_path, size", [ @@ -351,6 +378,13 @@ def test_xgetsize(input_path, size, tmp_path, mock_fsspec): assert xgetsize(input_path) == size +def test_xgetsize_private(hf_private_dataset_repo_txt_data, hf_token): + root_url = hf_hub_url(hf_private_dataset_repo_txt_data, "") + assert xgetsize(root_url + "data/text_data.txt", use_auth_token=hf_token) == 39 + with pytest.raises(FileNotFoundError): + xgetsize(root_url + "qwertyuiop", use_auth_token=hf_token) + + @pytest.mark.parametrize( "input_path, expected_paths", [ @@ -386,6 +420,50 @@ def test_xglob(input_path, expected_paths, tmp_path, mock_fsspec): assert output_paths == expected_paths +def test_xglob_private(hf_private_dataset_repo_zipped_txt_data, hf_token): + root_url = hf_hub_url(hf_private_dataset_repo_zipped_txt_data, "data.zip") + assert len(xglob("zip://**::" + root_url, use_auth_token=hf_token)) == 3 + assert len(xglob("zip://qwertyuiop/*::" + root_url, use_auth_token=hf_token)) == 0 + + +@pytest.mark.parametrize( + "input_path, expected_outputs", + [ + ("tmp_path", [("", [], ["file1.txt", "file2.txt", "README.md"])]), + ( + "mock://top_level/second_level", + [ + ("mock://top_level/second_level", ["date=2019-10-01", "date=2019-10-02", "date=2019-10-04"], []), + ("mock://top_level/second_level/date=2019-10-01", [], ["a.parquet", "b.parquet"]), + ("mock://top_level/second_level/date=2019-10-02", [], ["a.parquet"]), + ("mock://top_level/second_level/date=2019-10-04", [], ["a.parquet"]), + ], + ), + ], +) +def test_xwalk(input_path, expected_outputs, tmp_path, mock_fsspec): + if input_path.startswith("tmp_path"): + input_path = input_path.replace("/", os.sep).replace("tmp_path", str(tmp_path)) + expected_outputs = sorted( + [ + (str(tmp_path / dirpath).rstrip("/"), sorted(dirnames), sorted(filenames)) + for dirpath, dirnames, filenames in expected_outputs + ] + ) + for file in ["file1.txt", "file2.txt", "README.md"]: + (tmp_path / file).touch() + outputs = sorted(xwalk(input_path)) + outputs = [(dirpath, sorted(dirnames), sorted(filenames)) for dirpath, dirnames, filenames in outputs] + assert outputs == expected_outputs + + +def test_xwalk_private(hf_private_dataset_repo_zipped_txt_data, hf_token): + root_url = hf_hub_url(hf_private_dataset_repo_zipped_txt_data, "data.zip") + assert len(list(xwalk("zip://::" + root_url, use_auth_token=hf_token))) == 2 + assert len(list(xwalk("zip://main_dir::" + root_url, use_auth_token=hf_token))) == 1 + assert len(list(xwalk("zip://qwertyuiop::" + root_url, use_auth_token=hf_token))) == 0 + + @pytest.mark.parametrize( "input_path, start_path, expected_path", [