diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7fe210d99fe..b66b12cecd0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -72,10 +72,16 @@ jobs: run: uv pip install --system "datasets[tests] @ ." - name: Install dependencies (latest versions) if: ${{ matrix.deps_versions == 'deps-latest' }} - run: uv pip install --system --upgrade pyarrow huggingface-hub "dill<0.3.9" + run: | + uv pip install --system --upgrade pyarrow huggingface-hub "dill<0.3.9" + # TODO: remove once transformers v5 / huggingface_hub v1 are released officially + uv pip uninstall --system transformers huggingface_hub + uv pip install --system --prerelease=allow git+https://github.com/huggingface/transformers.git - name: Install dependencies (minimum versions) if: ${{ matrix.deps_versions != 'deps-latest' }} - run: uv pip install --system pyarrow==21.0.0 huggingface-hub==0.24.7 transformers dill==0.3.1.1 + run: uv pip install --system pyarrow==21.0.0 huggingface-hub==0.25.0 transformers dill==0.3.1.1 + - name: Print dependencies + run: uv pip list - name: Test with pytest run: | python -m pytest -rfExX -m ${{ matrix.test }} -n 2 --dist loadfile -sv ./tests/ @@ -119,6 +125,8 @@ jobs: run: pip install --upgrade uv - name: Install dependencies run: uv pip install --system "datasets[tests] @ ." + - name: Print dependencies + run: uv pip list - name: Test with pytest run: | python -m pytest -rfExX -m ${{ matrix.test }} -n 2 --dist loadfile -sv ./tests/ @@ -161,7 +169,14 @@ jobs: - name: Install uv run: pip install --upgrade uv - name: Install dependencies - run: uv pip install --system "datasets[tests_numpy2] @ ." + run: | + uv pip install --system "datasets[tests_numpy2] @ ." + # TODO: remove once transformers v5 / huggingface_hub v1 are released officially + uv pip uninstall --system transformers huggingface_hub + uv pip install --system --prerelease=allow git+https://github.com/huggingface/transformers.git + - name: Print dependencies + run: pip list + - name: Test with pytest run: | python -m pytest -rfExX -m ${{ matrix.test }} -n 2 --dist loadfile -sv ./tests/ diff --git a/setup.py b/setup.py index dbf6796698b..0dee50b1f42 100644 --- a/setup.py +++ b/setup.py @@ -118,6 +118,7 @@ "pandas", # for downloading datasets over HTTPS "requests>=2.32.2", + "httpx<1.0.0", # progress bars in downloads and data operations "tqdm>=4.66.3", # for fast hashing @@ -128,7 +129,7 @@ # minimum 2023.1.0 to support protocol=kwargs in fsspec's `open`, `get_fs_token_paths`, etc.: see https://github.com/fsspec/filesystem_spec/pull/1143 "fsspec[http]>=2023.1.0,<=2025.9.0", # To get datasets from the Datasets Hub on huggingface.co - "huggingface-hub>=0.24.0", + "huggingface-hub>=0.25.0,<2.0", # Utilities from PyPA to e.g., compare versions "packaging", # To parse YAML metadata from dataset cards diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 695bf310562..6063a31f1e3 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -68,9 +68,9 @@ DatasetCardData, HfApi, ) -from huggingface_hub.hf_api import HfHubHTTPError, RepoFile, RepositoryNotFoundError +from huggingface_hub.hf_api import RepoFile +from huggingface_hub.utils import HfHubHTTPError, RepositoryNotFoundError from multiprocess import Pool -from requests import HTTPError from tqdm.contrib.concurrent import thread_map from . import config @@ -5990,7 +5990,7 @@ def get_deletions_and_dataset_card() -> tuple[str, list[CommitOperationDelete], except HfHubHTTPError as err: if ( err.__context__ - and isinstance(err.__context__, HTTPError) + and isinstance(err.__context__, HfHubHTTPError) and err.__context__.response.status_code == 409 ): # 409 is Conflict (another commit is in progress) @@ -6040,7 +6040,7 @@ def get_deletions_and_dataset_card() -> tuple[str, list[CommitOperationDelete], except HfHubHTTPError as err: if ( err.__context__ - and isinstance(err.__context__, HTTPError) + and isinstance(err.__context__, HfHubHTTPError) and err.__context__.response.status_code in (412, 409) ): # 412 is Precondition failed (parent_commit isn't satisfied) diff --git a/src/datasets/data_files.py b/src/datasets/data_files.py index 087e037a186..9fefd4a4c69 100644 --- a/src/datasets/data_files.py +++ b/src/datasets/data_files.py @@ -352,7 +352,7 @@ def resolve_pattern( protocol = fs.protocol if isinstance(fs.protocol, str) else fs.protocol[0] protocol_prefix = protocol + "://" if protocol != "file" else "" glob_kwargs = {} - if protocol == "hf" and config.HF_HUB_VERSION >= version.parse("0.20.0"): + if protocol == "hf": # 10 times faster glob with detail=True (ignores costly info like lastCommit) glob_kwargs["expand_info"] = False matched_paths = [ diff --git a/src/datasets/dataset_dict.py b/src/datasets/dataset_dict.py index 733b96d0069..63a93429c45 100644 --- a/src/datasets/dataset_dict.py +++ b/src/datasets/dataset_dict.py @@ -26,7 +26,6 @@ ) from huggingface_hub.hf_api import RepoFile from huggingface_hub.utils import HfHubHTTPError, RepositoryNotFoundError -from requests import HTTPError from . import config from .arrow_dataset import ( @@ -1917,7 +1916,7 @@ def get_deletions_and_dataset_card() -> tuple[str, list[CommitOperationDelete], except HfHubHTTPError as err: if ( err.__context__ - and isinstance(err.__context__, HTTPError) + and isinstance(err.__context__, HfHubHTTPError) and err.__context__.response.status_code == 409 ): # 409 is Conflict (another commit is in progress) @@ -1967,7 +1966,7 @@ def get_deletions_and_dataset_card() -> tuple[str, list[CommitOperationDelete], except HfHubHTTPError as err: if ( err.__context__ - and isinstance(err.__context__, HTTPError) + and isinstance(err.__context__, HfHubHTTPError) and err.__context__.response.status_code in (412, 409) ): # 412 is Precondition failed (parent_commit isn't satisfied) @@ -2786,7 +2785,7 @@ def get_deletions_and_dataset_card() -> tuple[str, list[CommitOperationDelete], except HfHubHTTPError as err: if ( err.__context__ - and isinstance(err.__context__, HTTPError) + and isinstance(err.__context__, HfHubHTTPError) and err.__context__.response.status_code == 409 ): # 409 is Conflict (another commit is in progress) @@ -2836,7 +2835,7 @@ def get_deletions_and_dataset_card() -> tuple[str, list[CommitOperationDelete], except HfHubHTTPError as err: if ( err.__context__ - and isinstance(err.__context__, HTTPError) + and isinstance(err.__context__, HfHubHTTPError) and err.__context__.response.status_code in (412, 409) ): # 412 is Precondition failed (parent_commit isn't satisfied) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index d5e8c6e0a91..2578309bd78 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -30,7 +30,6 @@ from huggingface_hub.hf_api import RepoFile from huggingface_hub.utils import HfHubHTTPError, RepositoryNotFoundError from multiprocess import Pool -from requests import HTTPError from . import config from .arrow_dataset import PUSH_TO_HUB_WITHOUT_METADATA_CONFIGS_SPLIT_PATTERN_SHARDED, Dataset, DatasetInfoMixin @@ -4332,7 +4331,7 @@ def get_deletions_and_dataset_card() -> tuple[str, list[CommitOperationDelete], except HfHubHTTPError as err: if ( err.__context__ - and isinstance(err.__context__, HTTPError) + and isinstance(err.__context__, HfHubHTTPError) and err.__context__.response.status_code == 409 ): # 409 is Conflict (another commit is in progress) @@ -4382,7 +4381,7 @@ def get_deletions_and_dataset_card() -> tuple[str, list[CommitOperationDelete], except HfHubHTTPError as err: if ( err.__context__ - and isinstance(err.__context__, HTTPError) + and isinstance(err.__context__, HfHubHTTPError) and err.__context__.response.status_code in (412, 409) ): # 412 is Precondition failed (parent_commit isn't satisfied) diff --git a/src/datasets/load.py b/src/datasets/load.py index bc2b0e679b6..ae3b9825970 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -28,6 +28,7 @@ from typing import Any, Optional, Union import fsspec +import httpx import requests import yaml from fsspec.core import url_to_fs @@ -948,6 +949,8 @@ def dataset_module_factory( OfflineModeIsEnabled, requests.exceptions.Timeout, requests.exceptions.ConnectionError, + httpx.ConnectError, + httpx.TimeoutException, ), ): raise ConnectionError(f"Couldn't reach '{path}' on the Hub ({e.__class__.__name__})") from e @@ -963,6 +966,8 @@ def dataset_module_factory( OfflineModeIsEnabled, requests.exceptions.Timeout, requests.exceptions.ConnectionError, + httpx.ConnectError, + httpx.TimeoutException, ) as e: raise ConnectionError(f"Couldn't reach '{path}' on the Hub ({e.__class__.__name__})") from e except GatedRepoError as e: diff --git a/src/datasets/utils/file_utils.py b/src/datasets/utils/file_utils.py index 81be4f295c4..7a07f8cd267 100644 --- a/src/datasets/utils/file_utils.py +++ b/src/datasets/utils/file_utils.py @@ -27,12 +27,13 @@ from xml.etree import ElementTree as ET import fsspec +import httpx import huggingface_hub import huggingface_hub.errors import requests from fsspec.core import strip_protocol, url_to_fs from fsspec.utils import can_be_local -from huggingface_hub.utils import EntryNotFoundError, get_session, insecure_hashlib +from huggingface_hub.utils import get_session, insecure_hashlib from packaging import version from .. import __version__, config @@ -140,7 +141,7 @@ def cached_path( ConnectionError: in case of unreachable url and no cache on disk ValueError: if it couldn't parse the url or filename correctly - requests.exceptions.ConnectionError: in case of internet connection issue + httpx.NetworkError or requests.exceptions.ConnectionError: in case of internet connection issue """ if download_config is None: download_config = DownloadConfig(**download_kwargs) @@ -246,7 +247,7 @@ def cached_path( def get_datasets_user_agent(user_agent: Optional[Union[str, dict]] = None) -> str: ua = f"datasets/{__version__}" ua += f"; python/{config.PY_VERSION}" - ua += f"; huggingface_hub/{huggingface_hub.__version__}" + ua += f"; hf_hub/{huggingface_hub.__version__}" ua += f"; pyarrow/{config.PYARROW_VERSION}" if config.TORCH_AVAILABLE: ua += f"; torch/{config.TORCH_VERSION}" @@ -753,7 +754,7 @@ def xgetsize(path, download_config: Optional[DownloadConfig] = None) -> int: fs, *_ = fs, *_ = url_to_fs(path, **storage_options) try: size = fs.size(main_hop) - except EntryNotFoundError: + except huggingface_hub.utils.EntryNotFoundError: raise FileNotFoundError(f"No such file: {path}") if size is None: # use xopen instead of fs.open to make data fetching more robust @@ -817,6 +818,7 @@ def read_with_retries(*args, **kwargs): asyncio.TimeoutError, requests.exceptions.ConnectionError, requests.exceptions.Timeout, + httpx.RequestError, ) as err: disconnect_err = err logger.warning( @@ -897,9 +899,6 @@ def _prepare_single_hop_path_and_storage_options( "endpoint": config.HF_ENDPOINT, **storage_options, } - # streaming with block_size=0 is only implemented in 0.21 (see https://github.com/huggingface/huggingface_hub/pull/1967) - if config.HF_HUB_VERSION < version.parse("0.21.0"): - storage_options["block_size"] = "default" if storage_options: storage_options = {protocol: storage_options} return urlpath, storage_options diff --git a/src/datasets/utils/hub.py b/src/datasets/utils/hub.py index 555157afd52..6d784333b23 100644 --- a/src/datasets/utils/hub.py +++ b/src/datasets/utils/hub.py @@ -1,14 +1,6 @@ from functools import partial from huggingface_hub import hf_hub_url -from huggingface_hub.utils import get_session, hf_raise_for_status hf_dataset_url = partial(hf_hub_url, repo_type="dataset") - - -def check_auth(hf_api, repo_id, token=None): - headers = hf_api._build_hf_headers(token=token) - path = f"{hf_api.endpoint}/api/datasets/{repo_id}/auth-check" - r = get_session().get(path, headers=headers) - hf_raise_for_status(r) diff --git a/tests/fixtures/hub.py b/tests/fixtures/hub.py index c4baa1d733c..a6ba8472f21 100644 --- a/tests/fixtures/hub.py +++ b/tests/fixtures/hub.py @@ -5,9 +5,8 @@ from typing import Optional import pytest -import requests -from huggingface_hub.hf_api import HfApi, RepositoryNotFoundError -from huggingface_hub.utils import hf_raise_for_status +from huggingface_hub.hf_api import HfApi +from huggingface_hub.utils import HfHubHTTPError, RepositoryNotFoundError from huggingface_hub.utils._headers import _http_user_agent @@ -24,9 +23,14 @@ def ci_hub_config(monkeypatch): monkeypatch.setattr("datasets.config.HF_ENDPOINT", CI_HUB_ENDPOINT) monkeypatch.setattr("datasets.config.HUB_DATASETS_URL", CI_HUB_DATASETS_URL) - monkeypatch.setattr( - "huggingface_hub.file_download.HUGGINGFACE_CO_URL_TEMPLATE", CI_HFH_HUGGINGFACE_CO_URL_TEMPLATE - ) + monkeypatch.setattr("huggingface_hub.constants.HUGGINGFACE_CO_URL_TEMPLATE", CI_HFH_HUGGINGFACE_CO_URL_TEMPLATE) + try: + # for backward compatibility with huggingface_hub 0.x + monkeypatch.setattr( + "huggingface_hub.file_download.HUGGINGFACE_CO_URL_TEMPLATE", CI_HFH_HUGGINGFACE_CO_URL_TEMPLATE + ) + except AttributeError: + pass old_environ = dict(os.environ) os.environ["HF_ENDPOINT"] = CI_HUB_ENDPOINT yield @@ -107,18 +111,11 @@ def _hf_gated_dataset_repo_txt_data(hf_api: HfApi, hf_token, text_file_content): repo_id=repo_id, repo_type="dataset", ) - path = f"{hf_api.endpoint}/api/datasets/{repo_id}/settings" - repo_settings = {"gated": "auto"} - r = requests.put( - path, - headers={"authorization": f"Bearer {hf_token}"}, - json=repo_settings, - ) - hf_raise_for_status(r) + hf_api.update_repo_settings(repo_id, token=hf_token, repo_type="dataset", gated="auto") yield repo_id try: hf_api.delete_repo(repo_id, token=hf_token, repo_type="dataset") - except (requests.exceptions.HTTPError, ValueError): # catch http error and token invalid error + except (HfHubHTTPError, ValueError): # catch http error and token invalid error pass @@ -142,7 +139,7 @@ def hf_private_dataset_repo_txt_data_(hf_api: HfApi, hf_token, text_file_content yield repo_id try: hf_api.delete_repo(repo_id, token=hf_token, repo_type="dataset") - except (requests.exceptions.HTTPError, ValueError): # catch http error and token invalid error + except (HfHubHTTPError, ValueError): # catch http error and token invalid error pass @@ -166,7 +163,7 @@ def hf_private_dataset_repo_zipped_txt_data_(hf_api: HfApi, hf_token, zip_csv_wi yield repo_id try: hf_api.delete_repo(repo_id, token=hf_token, repo_type="dataset") - except (requests.exceptions.HTTPError, ValueError): # catch http error and token invalid error + except (HfHubHTTPError, ValueError): # catch http error and token invalid error pass @@ -190,7 +187,7 @@ def hf_private_dataset_repo_zipped_img_data_(hf_api: HfApi, hf_token, zip_image_ yield repo_id try: hf_api.delete_repo(repo_id, token=hf_token, repo_type="dataset") - except (requests.exceptions.HTTPError, ValueError): # catch http error and token invalid error + except (HfHubHTTPError, ValueError): # catch http error and token invalid error pass diff --git a/tests/test_load.py b/tests/test_load.py index ed28a506f89..422e6cd3180 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -11,7 +11,6 @@ import dill import pyarrow as pa import pytest -import requests import datasets from datasets import config, load_dataset @@ -767,10 +766,7 @@ def test_load_dataset_from_hub(self): def test_load_dataset_namespace(self): with self.assertRaises(DatasetNotFoundError) as context: datasets.load_dataset("hf-internal-testing/_dummy") - self.assertIn( - "hf-internal-testing/_dummy", - str(context.exception), - ) + self.assertIn("hf-internal-testing/_dummy", str(context.exception)) for offline_simulation_mode in list(OfflineSimulationMode): with offline(offline_simulation_mode): with self.assertRaises(ConnectionError) as context: @@ -1050,19 +1046,16 @@ def test_load_dataset_with_unsupported_extensions(text_dir_with_unsupported_exte @pytest.mark.integration def test_loading_from_the_datasets_hub_with_token(): - true_request = requests.Session().request - - def assert_auth(method, url, *args, headers, **kwargs): - assert headers["authorization"] == "Bearer foo" - return true_request(method, url, *args, headers=headers, **kwargs) + class CustomException(Exception): + pass - with patch("requests.Session.request") as mock_request: - mock_request.side_effect = assert_auth + with patch("huggingface_hub.file_download._get_metadata_or_catch_error") as mock_request: + mock_request.side_effect = CustomException() with tempfile.TemporaryDirectory() as tmp_dir: - with offline(): - with pytest.raises((ConnectionError, requests.exceptions.ConnectionError)): - load_dataset(SAMPLE_NOT_EXISTING_DATASET_IDENTIFIER, cache_dir=tmp_dir, token="foo") - mock_request.assert_called() + with pytest.raises(CustomException): + load_dataset(SAMPLE_NOT_EXISTING_DATASET_IDENTIFIER, cache_dir=tmp_dir, token="foo") + mock_request.assert_called_once() + assert mock_request.call_args_list[0][1]["headers"]["authorization"] == "Bearer foo" @pytest.mark.integration diff --git a/tests/test_offline_util.py b/tests/test_offline_util.py index ed8ff49b815..c51f3b0659f 100644 --- a/tests/test_offline_util.py +++ b/tests/test_offline_util.py @@ -1,38 +1,52 @@ from tempfile import NamedTemporaryFile +import httpx import pytest import requests +from huggingface_hub import get_session +from huggingface_hub.errors import OfflineModeIsEnabled from datasets.utils.file_utils import fsspec_get, fsspec_head -from .utils import OfflineSimulationMode, RequestWouldHangIndefinitelyError, offline, require_not_windows +from .utils import ( + IS_HF_HUB_1_x, + OfflineSimulationMode, + RequestWouldHangIndefinitelyError, + offline, + require_not_windows, +) @pytest.mark.integration @require_not_windows # fsspec get keeps a file handle on windows that raises PermissionError def test_offline_with_timeout(): + expected_exception = httpx.ReadTimeout if IS_HF_HUB_1_x else requests.ConnectTimeout with offline(OfflineSimulationMode.CONNECTION_TIMES_OUT): with pytest.raises(RequestWouldHangIndefinitelyError): - requests.request("GET", "https://huggingface.co") - with pytest.raises(requests.exceptions.Timeout): - requests.request("GET", "https://huggingface.co", timeout=1.0) - with pytest.raises(requests.exceptions.Timeout), NamedTemporaryFile() as temp_file: + get_session().request("GET", "https://huggingface.co") + + with pytest.raises(expected_exception): + get_session().request("GET", "https://huggingface.co", timeout=1.0) + + with pytest.raises(expected_exception), NamedTemporaryFile() as temp_file: fsspec_get("hf://dummy", temp_file=temp_file) @pytest.mark.integration @require_not_windows # fsspec get keeps a file handle on windows that raises PermissionError def test_offline_with_connection_error(): + expected_exception = httpx.ConnectError if IS_HF_HUB_1_x else requests.ConnectionError with offline(OfflineSimulationMode.CONNECTION_FAILS): - with pytest.raises(requests.exceptions.ConnectionError): - requests.request("GET", "https://huggingface.co") - with pytest.raises(requests.exceptions.ConnectionError), NamedTemporaryFile() as temp_file: + with pytest.raises(expected_exception): + get_session().request("GET", "https://huggingface.co") + + with pytest.raises(expected_exception), NamedTemporaryFile() as temp_file: fsspec_get("hf://dummy", temp_file=temp_file) def test_offline_with_datasets_offline_mode_enabled(): with offline(OfflineSimulationMode.HF_HUB_OFFLINE_SET_TO_1): - with pytest.raises(ConnectionError): + with pytest.raises(OfflineModeIsEnabled): fsspec_head("hf://dummy") - with pytest.raises(ConnectionError), NamedTemporaryFile() as temp_file: + with pytest.raises(OfflineModeIsEnabled), NamedTemporaryFile() as temp_file: fsspec_get("hf://dummy", temp_file=temp_file) diff --git a/tests/utils.py b/tests/utils.py index 0e411e8734b..166bd4789c2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -11,8 +11,9 @@ from enum import Enum from importlib.util import find_spec from pathlib import Path -from unittest.mock import patch +from unittest.mock import Mock, patch +import httpx import pyarrow as pa import pytest import requests @@ -67,6 +68,8 @@ def parse_flag_from_env(key, default=False): reason="test requires numpy < 2.0 on windows", ) +IS_HF_HUB_1_x = config.HF_HUB_VERSION >= version.parse("0.99") # clunky but works with pre-releases + def require_regex(test_case): """ @@ -368,55 +371,57 @@ class OfflineSimulationMode(Enum): @contextmanager -def offline(mode=OfflineSimulationMode.CONNECTION_FAILS, timeout=1e-16): +def offline(mode: OfflineSimulationMode): """ Simulate offline mode. - There are three offline simulatiom modes: + There are three offline simulation modes: CONNECTION_FAILS (default mode): a ConnectionError is raised for each network call. - Connection errors are created by mocking socket.socket - CONNECTION_TIMES_OUT: the connection hangs until it times out. - The default timeout value is low (1e-16) to speed up the tests. - Timeout errors are created by mocking requests.request - HF_HUB_OFFLINE_SET_TO_1: the HF_HUB_OFFLINE environment variable is set to 1. - This makes the http/ftp calls of the library instantly fail and raise an OfflineModeEmabled error. - """ - online_request = requests.Session().request - - def timeout_request(session, method, url, **kwargs): - # Change the url to an invalid url so that the connection hangs - invalid_url = "https://10.255.255.1" - if kwargs.get("timeout") is None: - raise RequestWouldHangIndefinitelyError( - f"Tried a call to {url} in offline mode with no timeout set. Please set a timeout." - ) - kwargs["timeout"] = timeout - try: - return online_request(method, invalid_url, **kwargs) - except Exception as e: - # The following changes in the error are just here to make the offline timeout error prettier - e.request.url = url - max_retry_error = e.args[0] - max_retry_error.args = (max_retry_error.args[0].replace("10.255.255.1", f"OfflineMock[{url}]"),) - e.args = (max_retry_error,) - raise - - def raise_connection_error(session, prepared_request, **kwargs): - raise requests.ConnectionError("Offline mode is enabled.", request=prepared_request) - - if mode is OfflineSimulationMode.CONNECTION_FAILS: - with patch("requests.Session.send", raise_connection_error): - yield - elif mode is OfflineSimulationMode.CONNECTION_TIMES_OUT: - # inspired from https://stackoverflow.com/a/904609 - with patch("requests.Session.request", timeout_request): - yield - elif mode is OfflineSimulationMode.HF_HUB_OFFLINE_SET_TO_1: + CONNECTION_TIMES_OUT: a ReadTimeout or ConnectTimeout is raised for each network call. + HF_HUB_OFFLINE_SET_TO_1: the HF_HUB_OFFLINE_SET_TO_1 environment variable is set to 1. + This makes the http/ftp calls of the library instantly fail and raise an OfflineModeEnabled error. + + The raised exceptions are either from the `requests` library (if `huggingface_hub<1.0.0`) + or from the `httpx` library (if `huggingface_hub>=1.0.0`). + """ + # Enable offline mode + if mode is OfflineSimulationMode.HF_HUB_OFFLINE_SET_TO_1: with patch("datasets.config.HF_HUB_OFFLINE", True): yield - else: - raise ValueError("Please use a value from the OfflineSimulationMode enum.") + return + + # Determine which exception to raise based on mode + + def error_response(*args, **kwargs): + if mode is OfflineSimulationMode.CONNECTION_FAILS: + exc = httpx.ConnectError if IS_HF_HUB_1_x else requests.ConnectionError + elif mode is OfflineSimulationMode.CONNECTION_TIMES_OUT: + if kwargs.get("timeout") is None: + raise RequestWouldHangIndefinitelyError( + "Tried an HTTP call in offline mode with no timeout set. Please set a timeout." + ) + exc = httpx.ReadTimeout if IS_HF_HUB_1_x else requests.ConnectTimeout + else: + raise ValueError("Please use a value from the OfflineSimulationMode enum.") + raise exc(f"Offline mode {mode}") + + # Patch all client methods to raise the appropriate error + client_mock = Mock() + for method in ["head", "get", "post", "put", "delete", "request", "stream"]: + setattr(client_mock, method, Mock(side_effect=error_response)) + + # Patching is slightly different depending on hfh internals + patch_target = ( + {"target": "huggingface_hub.utils._http._GLOBAL_CLIENT", "new": client_mock} + if IS_HF_HUB_1_x + else { + "target": "huggingface_hub.utils._http._get_session_from_cache", + "return_value": client_mock, + } + ) + with patch(**patch_target): + yield @contextmanager @@ -456,12 +461,11 @@ def is_rng_equal(rng1, rng2): def xfail_if_500_502_http_error(func): import decorator - from requests.exceptions import HTTPError def _wrapper(func, *args, **kwargs): try: return func(*args, **kwargs) - except HTTPError as err: + except (requests.HTTPError, httpx.HTTPError) as err: if str(err).startswith("500") or str(err).startswith("502"): pytest.xfail(str(err)) raise err