Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 10 additions & 17 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
import pyarrow.compute as pc
from huggingface_hub import CommitOperationAdd, CommitOperationDelete, DatasetCard, DatasetCardData, HfApi
from multiprocess import Pool
from requests import HTTPError

from . import config
from .arrow_reader import ArrowReader
Expand Down Expand Up @@ -113,7 +112,8 @@
from .utils import logging
from .utils import tqdm as hf_tqdm
from .utils.deprecation_utils import deprecated
from .utils.file_utils import _retry, estimate_dataset_size
from .utils.file_utils import estimate_dataset_size
from .utils.hub import preupload_lfs_files
from .utils.info_utils import is_small_dataset
from .utils.metadata import MetadataConfigs
from .utils.py_utils import (
Expand Down Expand Up @@ -5203,21 +5203,14 @@ def shards_with_embedded_external_files(shards):
shard.to_parquet(buffer)
uploaded_size += buffer.tell()
shard_addition = CommitOperationAdd(path_in_repo=shard_path_in_repo, path_or_fileobj=buffer)
_retry(
api.preupload_lfs_files,
func_kwargs={
"repo_id": repo_id,
"additions": [shard_addition],
"token": token,
"repo_type": "dataset",
"revision": revision,
"create_pr": create_pr,
},
exceptions=HTTPError,
status_codes=[504],
base_wait_time=2.0,
max_retries=5,
max_wait_time=20.0,
preupload_lfs_files(
api,
repo_id=repo_id,
additions=[shard_addition],
token=token,
repo_type="dataset",
revision=revision,
create_pr=create_pr,
)
additions.append(shard_addition)

Expand Down
1 change: 1 addition & 0 deletions src/datasets/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
FSSPEC_VERSION = version.parse(importlib.metadata.version("fsspec"))
PANDAS_VERSION = version.parse(importlib.metadata.version("pandas"))
PYARROW_VERSION = version.parse(importlib.metadata.version("pyarrow"))
HF_HUB_VERSION = version.parse(importlib.metadata.version("huggingface_hub"))

USE_TF = os.environ.get("USE_TF", "AUTO").upper()
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
Expand Down
8 changes: 4 additions & 4 deletions src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -1332,14 +1332,14 @@ def __init__(
increase_load_count(name, resource_type="dataset")

def download_loading_script(self) -> str:
file_path = hf_hub_url(repo_id=self.name, path=self.name.split("/")[-1] + ".py", revision=self.revision)
file_path = hf_hub_url(self.name, self.name.split("/")[-1] + ".py", revision=self.revision)
download_config = self.download_config.copy()
if download_config.download_desc is None:
download_config.download_desc = "Downloading builder script"
return cached_path(file_path, download_config=download_config)

def download_dataset_infos_file(self) -> str:
dataset_infos = hf_hub_url(repo_id=self.name, path=config.DATASETDICT_INFOS_FILENAME, revision=self.revision)
dataset_infos = hf_hub_url(self.name, config.DATASETDICT_INFOS_FILENAME, revision=self.revision)
# Download the dataset infos file if available
download_config = self.download_config.copy()
if download_config.download_desc is None:
Expand All @@ -1353,7 +1353,7 @@ def download_dataset_infos_file(self) -> str:
return None

def download_dataset_readme_file(self) -> str:
readme_url = hf_hub_url(repo_id=self.name, path="README.md", revision=self.revision)
readme_url = hf_hub_url(self.name, "README.md", revision=self.revision)
# Download the dataset infos file if available
download_config = self.download_config.copy()
if download_config.download_desc is None:
Expand Down Expand Up @@ -1382,7 +1382,7 @@ def get_module(self) -> DatasetModule:
imports = get_imports(local_path)
local_imports = _download_additional_modules(
name=self.name,
base_path=hf_hub_url(repo_id=self.name, path="", revision=self.revision),
base_path=hf_hub_url(self.name, "", revision=self.revision),
imports=imports,
download_config=self.download_config,
)
Expand Down
28 changes: 1 addition & 27 deletions src/datasets/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from contextlib import closing, contextmanager
from functools import partial
from pathlib import Path
from typing import List, Optional, Type, TypeVar, Union
from typing import Optional, TypeVar, Union
from urllib.parse import urljoin, urlparse

import fsspec
Expand Down Expand Up @@ -279,32 +279,6 @@ def _raise_if_offline_mode_is_enabled(msg: Optional[str] = None):
)


def _retry(
func,
func_args: Optional[tuple] = None,
func_kwargs: Optional[dict] = None,
exceptions: Type[requests.exceptions.RequestException] = requests.exceptions.RequestException,
status_codes: Optional[List[int]] = None,
max_retries: int = 0,
base_wait_time: float = 0.5,
max_wait_time: float = 2,
):
func_args = func_args or ()
func_kwargs = func_kwargs or {}
retry = 0
while True:
try:
return func(*func_args, **func_kwargs)
except exceptions as err:
if retry >= max_retries or (status_codes and err.response.status_code not in status_codes):
raise err
else:
sleep_time = min(max_wait_time, base_wait_time * 2**retry) # Exponential backoff
logger.info(f"{func} timed out, retrying in {sleep_time}s... [{retry/max_retries}]")
time.sleep(sleep_time)
retry += 1


def _request_with_retry(
method: str,
url: str,
Expand Down
51 changes: 43 additions & 8 deletions src/datasets/utils/hub.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,47 @@
from typing import Optional
from urllib.parse import quote
import time
from functools import partial

import huggingface_hub as hfh
from huggingface_hub import HfApi, hf_hub_url
from packaging import version
from requests import HTTPError

from .. import config
from . import logging

def hf_hub_url(repo_id: str, path: str, revision: Optional[str] = None) -> str:
if version.parse(hfh.__version__).release < version.parse("0.11.0").release:
# old versions of hfh don't url-encode the file path
path = quote(path)
return hfh.hf_hub_url(repo_id, path, repo_type="dataset", revision=revision)

logger = logging.get_logger(__name__)

# Retry `preupload_lfs_files` in `huggingface_hub<0.20.0` on the "500 (Internal Server Error)" and "504 (Gateway Timeout)" HTTP errors
if config.HF_HUB_VERSION < version.parse("0.20.0"):

def preupload_lfs_files(hf_api: HfApi, **kwargs):
max_retries = 5
base_wait_time = 1
max_wait_time = 8
status_codes = [500, 503, 504]
retry = 0
while True:
try:
hf_api.preupload_lfs_files(**kwargs)
except (RuntimeError, HTTPError) as err:
if isinstance(err, RuntimeError) and isinstance(err.__cause__, HTTPError):
err = err.__cause__
if retry >= max_retries or err.response.status_code not in status_codes:
raise err
else:
sleep_time = min(max_wait_time, base_wait_time * 2**retry) # Exponential backoff
logger.info(
f"{hf_api.preupload_lfs_files} timed out, retrying in {sleep_time}s... [{retry/max_retries}]"
)
time.sleep(sleep_time)
retry += 1
else:
break
else:

def preupload_lfs_files(hf_api: HfApi, **kwargs):
hf_api.preupload_lfs_files(**kwargs)


# bakckward compatibility
hf_hub_url = partial(hf_hub_url, repo_type="dataset")
8 changes: 4 additions & 4 deletions tests/test_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@


@pytest.mark.parametrize("repo_id", ["canonical_dataset_name", "org-name/dataset-name"])
@pytest.mark.parametrize("path", ["filename.csv", "filename with blanks.csv"])
@pytest.mark.parametrize("filename", ["filename.csv", "filename with blanks.csv"])
@pytest.mark.parametrize("revision", [None, "v2"])
def test_hf_hub_url(repo_id, path, revision):
url = hf_hub_url(repo_id=repo_id, path=path, revision=revision)
assert url == f"https://huggingface.co/datasets/{repo_id}/resolve/{revision or 'main'}/{quote(path)}"
def test_hf_hub_url(repo_id, filename, revision):
url = hf_hub_url(repo_id=repo_id, filename=filename, revision=revision)
assert url == f"https://huggingface.co/datasets/{repo_id}/resolve/{revision or 'main'}/{quote(filename)}"