diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 10577836662..ac83f13447b 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -113,7 +113,7 @@ from .utils import tqdm as hf_tqdm from .utils.deprecation_utils import deprecated from .utils.file_utils import estimate_dataset_size -from .utils.hub import preupload_lfs_files +from .utils.hub import list_files_info, preupload_lfs_files from .utils.info_utils import is_small_dataset from .utils.metadata import MetadataConfigs from .utils.py_utils import ( @@ -5379,7 +5379,7 @@ def push_to_hub( deletions, deleted_size = [], 0 repo_splits = [] # use a list to keep the order of the splits repo_files_to_add = [addition.path_in_repo for addition in additions] - for repo_file in api.list_files_info(repo_id, revision=revision, repo_type="dataset", token=token): + for repo_file in list_files_info(api, repo_id=repo_id, revision=revision, repo_type="dataset", token=token): if repo_file.rfilename == "README.md": repo_with_dataset_card = True elif repo_file.rfilename == config.DATASETDICT_INFOS_FILENAME: diff --git a/src/datasets/dataset_dict.py b/src/datasets/dataset_dict.py index a94c5410a17..85f401958eb 100644 --- a/src/datasets/dataset_dict.py +++ b/src/datasets/dataset_dict.py @@ -33,6 +33,7 @@ from .utils import logging from .utils.deprecation_utils import deprecated from .utils.doc_utils import is_documented_by +from .utils.hub import list_files_info from .utils.metadata import MetadataConfigs from .utils.py_utils import asdict, glob_pattern_to_regex, string_to_dict from .utils.typing import PathLike @@ -1722,7 +1723,7 @@ def push_to_hub( repo_splits = [] # use a list to keep the order of the splits deletions = [] repo_files_to_add = [addition.path_in_repo for addition in additions] - for repo_file in api.list_files_info(repo_id, revision=revision, repo_type="dataset", token=token): + for repo_file in list_files_info(api, repo_id=repo_id, revision=revision, repo_type="dataset", token=token): if repo_file.rfilename == "README.md": repo_with_dataset_card = True elif repo_file.rfilename == config.DATASETDICT_INFOS_FILENAME: diff --git a/src/datasets/utils/hub.py b/src/datasets/utils/hub.py index 52b48cec0fe..90b9f6ea634 100644 --- a/src/datasets/utils/hub.py +++ b/src/datasets/utils/hub.py @@ -2,6 +2,7 @@ from functools import partial from huggingface_hub import HfApi, hf_hub_url +from huggingface_hub.hf_api import RepoFile from packaging import version from requests import ConnectionError, HTTPError @@ -45,5 +46,19 @@ def preupload_lfs_files(hf_api: HfApi, **kwargs): hf_api.preupload_lfs_files(**kwargs) +# `list_files_info` is deprecated in favor of `list_repo_tree` in `huggingface_hub>=0.20.0` +if config.HF_HUB_VERSION < version.parse("0.20.0"): + + def list_files_info(hf_api: HfApi, **kwargs): + yield from hf_api.list_files_info(**kwargs) +else: + + def list_files_info(hf_api: HfApi, **kwargs): + kwargs = {**kwargs, "recursive": True} + for repo_path in hf_api.list_repo_tree(**kwargs): + if isinstance(repo_path, RepoFile): + yield repo_path + + # bakckward compatibility hf_hub_url = partial(hf_hub_url, repo_type="dataset") diff --git a/tests/fixtures/hub.py b/tests/fixtures/hub.py index 9bd8a162da5..76a4ba358f3 100644 --- a/tests/fixtures/hub.py +++ b/tests/fixtures/hub.py @@ -1,12 +1,12 @@ +import os import time import uuid from contextlib import contextmanager -from pathlib import Path from typing import Optional import pytest import requests -from huggingface_hub.hf_api import HfApi, HfFolder, RepositoryNotFoundError +from huggingface_hub.hf_api import HfApi, RepositoryNotFoundError CI_HUB_USER = "__DUMMY_TRANSFORMERS_USER__" @@ -16,7 +16,6 @@ CI_HUB_ENDPOINT = "https://hub-ci.huggingface.co" CI_HUB_DATASETS_URL = CI_HUB_ENDPOINT + "/datasets/{repo_id}/resolve/{revision}/{path}" CI_HFH_HUGGINGFACE_CO_URL_TEMPLATE = CI_HUB_ENDPOINT + "/{repo_id}/resolve/{revision}/{filename}" -CI_HUB_TOKEN_PATH = Path("~/.huggingface/hub_ci_token").expanduser() @pytest.fixture @@ -33,15 +32,12 @@ def ci_hub_config(monkeypatch): @pytest.fixture -def ci_hub_token_path(monkeypatch): - monkeypatch.setattr("huggingface_hub.hf_api.HfFolder.path_token", CI_HUB_TOKEN_PATH) - - -@pytest.fixture -def set_ci_hub_access_token(ci_hub_config, ci_hub_token_path): - HfFolder.save_token(CI_HUB_USER_TOKEN) +def set_ci_hub_access_token(ci_hub_config): + old_environ = dict(os.environ) + os.environ["HF_TOKEN"] = CI_HUB_USER_TOKEN yield - HfFolder.delete_token() + os.environ.clear() + os.environ.update(old_environ) @pytest.fixture(scope="session")