diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 472a09c9a6e..5ee1afae51e 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -61,7 +61,6 @@ import pyarrow.compute as pc from huggingface_hub import CommitOperationAdd, CommitOperationDelete, DatasetCard, DatasetCardData, HfApi from multiprocess import Pool -from requests import HTTPError from . import config from .arrow_reader import ArrowReader @@ -113,7 +112,8 @@ from .utils import logging from .utils import tqdm as hf_tqdm from .utils.deprecation_utils import deprecated -from .utils.file_utils import _retry, estimate_dataset_size +from .utils.file_utils import estimate_dataset_size +from .utils.hub import preupload_lfs_files from .utils.info_utils import is_small_dataset from .utils.metadata import MetadataConfigs from .utils.py_utils import ( @@ -5203,21 +5203,14 @@ def shards_with_embedded_external_files(shards): shard.to_parquet(buffer) uploaded_size += buffer.tell() shard_addition = CommitOperationAdd(path_in_repo=shard_path_in_repo, path_or_fileobj=buffer) - _retry( - api.preupload_lfs_files, - func_kwargs={ - "repo_id": repo_id, - "additions": [shard_addition], - "token": token, - "repo_type": "dataset", - "revision": revision, - "create_pr": create_pr, - }, - exceptions=HTTPError, - status_codes=[504], - base_wait_time=2.0, - max_retries=5, - max_wait_time=20.0, + preupload_lfs_files( + api, + repo_id=repo_id, + additions=[shard_addition], + token=token, + repo_type="dataset", + revision=revision, + create_pr=create_pr, ) additions.append(shard_addition) diff --git a/src/datasets/config.py b/src/datasets/config.py index 851eb7d899f..a8a90b9dbc9 100644 --- a/src/datasets/config.py +++ b/src/datasets/config.py @@ -41,6 +41,7 @@ FSSPEC_VERSION = version.parse(importlib.metadata.version("fsspec")) PANDAS_VERSION = version.parse(importlib.metadata.version("pandas")) PYARROW_VERSION = version.parse(importlib.metadata.version("pyarrow")) +HF_HUB_VERSION = version.parse(importlib.metadata.version("huggingface_hub")) USE_TF = os.environ.get("USE_TF", "AUTO").upper() USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() diff --git a/src/datasets/load.py b/src/datasets/load.py index 6893e94242a..7bd78dc46cb 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -1332,14 +1332,14 @@ def __init__( increase_load_count(name, resource_type="dataset") def download_loading_script(self) -> str: - file_path = hf_hub_url(repo_id=self.name, path=self.name.split("/")[-1] + ".py", revision=self.revision) + file_path = hf_hub_url(self.name, self.name.split("/")[-1] + ".py", revision=self.revision) download_config = self.download_config.copy() if download_config.download_desc is None: download_config.download_desc = "Downloading builder script" return cached_path(file_path, download_config=download_config) def download_dataset_infos_file(self) -> str: - dataset_infos = hf_hub_url(repo_id=self.name, path=config.DATASETDICT_INFOS_FILENAME, revision=self.revision) + dataset_infos = hf_hub_url(self.name, config.DATASETDICT_INFOS_FILENAME, revision=self.revision) # Download the dataset infos file if available download_config = self.download_config.copy() if download_config.download_desc is None: @@ -1353,7 +1353,7 @@ def download_dataset_infos_file(self) -> str: return None def download_dataset_readme_file(self) -> str: - readme_url = hf_hub_url(repo_id=self.name, path="README.md", revision=self.revision) + readme_url = hf_hub_url(self.name, "README.md", revision=self.revision) # Download the dataset infos file if available download_config = self.download_config.copy() if download_config.download_desc is None: @@ -1382,7 +1382,7 @@ def get_module(self) -> DatasetModule: imports = get_imports(local_path) local_imports = _download_additional_modules( name=self.name, - base_path=hf_hub_url(repo_id=self.name, path="", revision=self.revision), + base_path=hf_hub_url(self.name, "", revision=self.revision), imports=imports, download_config=self.download_config, ) diff --git a/src/datasets/utils/file_utils.py b/src/datasets/utils/file_utils.py index 03cd91d0bbe..e96a0b9db02 100644 --- a/src/datasets/utils/file_utils.py +++ b/src/datasets/utils/file_utils.py @@ -18,7 +18,7 @@ from contextlib import closing, contextmanager from functools import partial from pathlib import Path -from typing import List, Optional, Type, TypeVar, Union +from typing import Optional, TypeVar, Union from urllib.parse import urljoin, urlparse import fsspec @@ -279,32 +279,6 @@ def _raise_if_offline_mode_is_enabled(msg: Optional[str] = None): ) -def _retry( - func, - func_args: Optional[tuple] = None, - func_kwargs: Optional[dict] = None, - exceptions: Type[requests.exceptions.RequestException] = requests.exceptions.RequestException, - status_codes: Optional[List[int]] = None, - max_retries: int = 0, - base_wait_time: float = 0.5, - max_wait_time: float = 2, -): - func_args = func_args or () - func_kwargs = func_kwargs or {} - retry = 0 - while True: - try: - return func(*func_args, **func_kwargs) - except exceptions as err: - if retry >= max_retries or (status_codes and err.response.status_code not in status_codes): - raise err - else: - sleep_time = min(max_wait_time, base_wait_time * 2**retry) # Exponential backoff - logger.info(f"{func} timed out, retrying in {sleep_time}s... [{retry/max_retries}]") - time.sleep(sleep_time) - retry += 1 - - def _request_with_retry( method: str, url: str, diff --git a/src/datasets/utils/hub.py b/src/datasets/utils/hub.py index 46402accfbb..0925d075ee0 100644 --- a/src/datasets/utils/hub.py +++ b/src/datasets/utils/hub.py @@ -1,12 +1,47 @@ -from typing import Optional -from urllib.parse import quote +import time +from functools import partial -import huggingface_hub as hfh +from huggingface_hub import HfApi, hf_hub_url from packaging import version +from requests import HTTPError +from .. import config +from . import logging -def hf_hub_url(repo_id: str, path: str, revision: Optional[str] = None) -> str: - if version.parse(hfh.__version__).release < version.parse("0.11.0").release: - # old versions of hfh don't url-encode the file path - path = quote(path) - return hfh.hf_hub_url(repo_id, path, repo_type="dataset", revision=revision) + +logger = logging.get_logger(__name__) + +# Retry `preupload_lfs_files` in `huggingface_hub<0.20.0` on the "500 (Internal Server Error)" and "503 (Service Unavailable)" HTTP errors +if config.HF_HUB_VERSION < version.parse("0.20.0"): + + def preupload_lfs_files(hf_api: HfApi, **kwargs): + max_retries = 5 + base_wait_time = 1 + max_wait_time = 8 + status_codes = [500, 503] + retry = 0 + while True: + try: + hf_api.preupload_lfs_files(**kwargs) + except (RuntimeError, HTTPError) as err: + if isinstance(err, RuntimeError) and isinstance(err.__cause__, HTTPError): + err = err.__cause__ + if retry >= max_retries or err.response.status_code not in status_codes: + raise err + else: + sleep_time = min(max_wait_time, base_wait_time * 2**retry) # Exponential backoff + logger.info( + f"{hf_api.preupload_lfs_files} timed out, retrying in {sleep_time}s... [{retry/max_retries}]" + ) + time.sleep(sleep_time) + retry += 1 + else: + break +else: + + def preupload_lfs_files(hf_api: HfApi, **kwargs): + hf_api.preupload_lfs_files(**kwargs) + + +# bakckward compatibility +hf_hub_url = partial(hf_hub_url, repo_type="dataset") diff --git a/tests/test_hub.py b/tests/test_hub.py index d7cdbd2843c..e940d7b8b29 100644 --- a/tests/test_hub.py +++ b/tests/test_hub.py @@ -6,8 +6,8 @@ @pytest.mark.parametrize("repo_id", ["canonical_dataset_name", "org-name/dataset-name"]) -@pytest.mark.parametrize("path", ["filename.csv", "filename with blanks.csv"]) +@pytest.mark.parametrize("filename", ["filename.csv", "filename with blanks.csv"]) @pytest.mark.parametrize("revision", [None, "v2"]) -def test_hf_hub_url(repo_id, path, revision): - url = hf_hub_url(repo_id=repo_id, path=path, revision=revision) - assert url == f"https://huggingface.co/datasets/{repo_id}/resolve/{revision or 'main'}/{quote(path)}" +def test_hf_hub_url(repo_id, filename, revision): + url = hf_hub_url(repo_id=repo_id, filename=filename, revision=revision) + assert url == f"https://huggingface.co/datasets/{repo_id}/resolve/{revision or 'main'}/{quote(filename)}"