From 0d95a7615c88a76428b12c5463a098a9ef3b938e Mon Sep 17 00:00:00 2001 From: Lucain Pouget Date: Wed, 15 Oct 2025 12:30:48 +0200 Subject: [PATCH 1/2] Fix hf_raise_for_status on async stream + tests --- src/huggingface_hub/utils/_http.py | 9 ++- tests/test_utils_http.py | 102 ++++++++++++++++++++++++++++- 2 files changed, 108 insertions(+), 3 deletions(-) diff --git a/src/huggingface_hub/utils/_http.py b/src/huggingface_hub/utils/_http.py index f5533316f8..4f943652a7 100644 --- a/src/huggingface_hub/utils/_http.py +++ b/src/huggingface_hub/utils/_http.py @@ -626,8 +626,13 @@ def _format(error_type: type[HfHubHTTPError], custom_message: str, response: htt try: data = response.json() except httpx.ResponseNotRead: - response.read() # In case of streaming response, we need to read the response first - data = response.json() + try: + response.read() # In case of streaming response, we need to read the response first + data = response.json() + except RuntimeError: + data = {} # In case of async streaming response, we can't read the response here => skip + # ^ TODO: find a better way to handle async streaming response. In practice, async stream must be read in an async content + # so either before hf_raise_for_status or by making an async version of hf_raise_for_status. error = data.get("error") if error is not None: diff --git a/tests/test_utils_http.py b/tests/test_utils_http.py index 0b20f78c75..4383bbd14b 100644 --- a/tests/test_utils_http.py +++ b/tests/test_utils_http.py @@ -2,9 +2,11 @@ import threading import time import unittest +from http.server import BaseHTTPRequestHandler, HTTPServer from multiprocessing import Process, Queue from typing import Generator, Optional from unittest.mock import Mock, call, patch +from urllib.parse import urlparse from uuid import UUID import httpx @@ -12,13 +14,14 @@ from httpx import ConnectTimeout, HTTPError from huggingface_hub.constants import ENDPOINT -from huggingface_hub.errors import OfflineModeIsEnabled +from huggingface_hub.errors import HfHubHTTPError, OfflineModeIsEnabled from huggingface_hub.utils._http import ( _adjust_range_header, default_client_factory, fix_hf_endpoint_in_url, get_async_session, get_session, + hf_raise_for_status, http_backoff, set_client_factory, ) @@ -378,3 +381,100 @@ async def test_async_client_get_request(): client = get_async_session() response = await client.get("https://huggingface.co") assert response.status_code == 200 + + +class FakeServerHandler(BaseHTTPRequestHandler): + """Fake server handler to test client behavior.""" + + def do_GET(self): + parsed = urlparse(self.path) + + # Health check endpoint (always succeeds) + if parsed.path == "/health": + self._send_response(200, b"OK") + return + + # Main endpoint (always fails with 500) + self._send_response(500, b"This is a 500 error") + + def _send_response(self, status_code, body): + self.send_response(status_code) + self.send_header("Content-Type", "text/plain") + self.send_header("Content-Length", str(len(body))) + self.end_headers() + self.wfile.write(body) + + +@pytest.fixture(scope="module", autouse=True) +def fake_server(): + # Find a free port + host, port = "127.0.0.1", 8000 + for port in range(port, 8100): + try: + server = HTTPServer((host, port), FakeServerHandler) + break + except OSError: + continue + else: + raise RuntimeError("Could not find a free port") + + url = f"http://{host}:{port}" + + # Start server in a separate thread and wait until it's ready + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + + for _ in range(1000): # up to 10 seconds + try: + if httpx.get(f"{url}/health", timeout=0.01).status_code == 200: + break + except httpx.HTTPError: + pass + time.sleep(0.01) + else: + server.shutdown() + raise RuntimeError("Fake server failed to start") + + yield url + server.shutdown() + + +def _assert_500_error(exc_info: pytest.ExceptionInfo, expect_message: bool = True): + """Common assertions for 500 error tests.""" + assert isinstance(exc_info.value, HfHubHTTPError) + assert exc_info.value.response.status_code == 500 + if expect_message: + assert "This is a 500 error" in str(exc_info.value) + else: + assert "This is a 500 error" not in str(exc_info.value) + + +def test_raise_on_status_sync_non_stream(fake_server: str): + response = get_session().get(fake_server) + with pytest.raises(HTTPError) as exc_info: + hf_raise_for_status(response) + _assert_500_error(exc_info) + + +def test_raise_on_status_sync_stream(fake_server: str): + with get_session().stream("GET", fake_server) as response: + with pytest.raises(HTTPError) as exc_info: + hf_raise_for_status(response) + _assert_500_error(exc_info) + + +@pytest.mark.asyncio +async def test_raise_on_status_async_non_stream(fake_server: str): + response = await get_async_session().get(fake_server) + with pytest.raises(HTTPError) as exc_info: + hf_raise_for_status(response) + _assert_500_error(exc_info) + + +@pytest.mark.asyncio +async def test_raise_on_status_async_stream(fake_server: str): + async with get_async_session().stream("GET", fake_server) as response: + with pytest.raises(HTTPError) as exc_info: + hf_raise_for_status(response) + # Async streaming response does not support reading the content + _assert_500_error(exc_info, expect_message=False) From ea58d141acd9b67be2d175a09955e3c906bd2f0c Mon Sep 17 00:00:00 2001 From: Lucain Pouget Date: Wed, 15 Oct 2025 13:38:37 +0200 Subject: [PATCH 2/2] use response hook --- src/huggingface_hub/utils/_http.py | 25 +++++++++++++++++++++---- tests/test_utils_http.py | 27 ++++++++------------------- 2 files changed, 29 insertions(+), 23 deletions(-) diff --git a/src/huggingface_hub/utils/_http.py b/src/huggingface_hub/utils/_http.py index 4f943652a7..7cc3951f27 100644 --- a/src/huggingface_hub/utils/_http.py +++ b/src/huggingface_hub/utils/_http.py @@ -109,6 +109,20 @@ async def async_hf_request_event_hook(request: httpx.Request) -> None: return hf_request_event_hook(request) +async def async_hf_response_event_hook(response: httpx.Response) -> None: + if response.status_code >= 400: + # If response will raise, read content from stream to have it available when raising the exception + # If content-length is not set or is too large, skip reading the content to avoid OOM + if "Content-length" in response.headers: + try: + length = int(response.headers["Content-length"]) + except ValueError: + return + + if length < 1_000_000: + await response.aread() + + def default_client_factory() -> httpx.Client: """ Factory function to create a `httpx.Client` with the default transport. @@ -125,7 +139,7 @@ def default_async_client_factory() -> httpx.AsyncClient: Factory function to create a `httpx.AsyncClient` with the default transport. """ return httpx.AsyncClient( - event_hooks={"request": [async_hf_request_event_hook]}, + event_hooks={"request": [async_hf_request_event_hook], "response": [async_hf_response_event_hook]}, follow_redirects=True, timeout=httpx.Timeout(constants.DEFAULT_REQUEST_TIMEOUT, write=60.0), ) @@ -630,9 +644,12 @@ def _format(error_type: type[HfHubHTTPError], custom_message: str, response: htt response.read() # In case of streaming response, we need to read the response first data = response.json() except RuntimeError: - data = {} # In case of async streaming response, we can't read the response here => skip - # ^ TODO: find a better way to handle async streaming response. In practice, async stream must be read in an async content - # so either before hf_raise_for_status or by making an async version of hf_raise_for_status. + # In case of async streaming response, we can't read the stream here. + # In practice if user is using the default async client from `get_async_client`, the stream will have + # already been read in the async event hook `async_hf_response_event_hook`. + # + # Here, we are skipping reading the response to avoid RuntimeError but it happens only if async + stream + used httpx.AsyncClient directly. + data = {} error = data.get("error") if error is not None: diff --git a/tests/test_utils_http.py b/tests/test_utils_http.py index 4383bbd14b..1fc04be802 100644 --- a/tests/test_utils_http.py +++ b/tests/test_utils_http.py @@ -439,42 +439,31 @@ def fake_server(): server.shutdown() -def _assert_500_error(exc_info: pytest.ExceptionInfo, expect_message: bool = True): +def _check_raise_status(response: httpx.Response): """Common assertions for 500 error tests.""" - assert isinstance(exc_info.value, HfHubHTTPError) + with pytest.raises(HfHubHTTPError) as exc_info: + hf_raise_for_status(response) assert exc_info.value.response.status_code == 500 - if expect_message: - assert "This is a 500 error" in str(exc_info.value) - else: - assert "This is a 500 error" not in str(exc_info.value) + assert "This is a 500 error" in str(exc_info.value) def test_raise_on_status_sync_non_stream(fake_server: str): response = get_session().get(fake_server) - with pytest.raises(HTTPError) as exc_info: - hf_raise_for_status(response) - _assert_500_error(exc_info) + _check_raise_status(response) def test_raise_on_status_sync_stream(fake_server: str): with get_session().stream("GET", fake_server) as response: - with pytest.raises(HTTPError) as exc_info: - hf_raise_for_status(response) - _assert_500_error(exc_info) + _check_raise_status(response) @pytest.mark.asyncio async def test_raise_on_status_async_non_stream(fake_server: str): response = await get_async_session().get(fake_server) - with pytest.raises(HTTPError) as exc_info: - hf_raise_for_status(response) - _assert_500_error(exc_info) + _check_raise_status(response) @pytest.mark.asyncio async def test_raise_on_status_async_stream(fake_server: str): async with get_async_session().stream("GET", fake_server) as response: - with pytest.raises(HTTPError) as exc_info: - hf_raise_for_status(response) - # Async streaming response does not support reading the content - _assert_500_error(exc_info, expect_message=False) + _check_raise_status(response)