diff --git a/CHANGES/11091.feature.rst b/CHANGES/11091.feature.rst new file mode 100644 index 00000000000..a4db2ddced5 --- /dev/null +++ b/CHANGES/11091.feature.rst @@ -0,0 +1 @@ +Added ``ssl_shutdown_timeout`` parameter to :py:class:`~aiohttp.ClientSession` and :py:class:`~aiohttp.TCPConnector` to control the grace period for SSL shutdown handshake on TLS connections. This helps prevent "connection reset" errors on the server side while avoiding excessive delays during connector cleanup. Note: This parameter only takes effect on Python 3.11+ -- by :user:`bdraco`. diff --git a/CHANGES/11094.feature.rst b/CHANGES/11094.feature.rst new file mode 120000 index 00000000000..a21761406a1 --- /dev/null +++ b/CHANGES/11094.feature.rst @@ -0,0 +1 @@ +11091.feature.rst \ No newline at end of file diff --git a/aiohttp/client.py b/aiohttp/client.py index 3b2cd2796cc..6457248d5ea 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -303,6 +303,7 @@ def __init__( max_field_size: int = 8190, fallback_charset_resolver: _CharsetResolver = lambda r, b: "utf-8", middlewares: Sequence[ClientMiddlewareType] = (), + ssl_shutdown_timeout: Optional[float] = 0.1, ) -> None: # We initialise _connector to None immediately, as it's referenced in __del__() # and could cause issues if an exception occurs during initialisation. @@ -361,7 +362,7 @@ def __init__( ) if connector is None: - connector = TCPConnector(loop=loop) + connector = TCPConnector(ssl_shutdown_timeout=ssl_shutdown_timeout) if connector._loop is not loop: raise RuntimeError("Session and connector has to use same event loop") diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 926a62684f6..6fa75d31a98 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -879,6 +879,12 @@ class TCPConnector(BaseConnector): socket_factory - A SocketFactoryType function that, if supplied, will be used to create sockets given an AddrInfoType. + ssl_shutdown_timeout - Grace period for SSL shutdown handshake on TLS + connections. Default is 0.1 seconds. This usually + allows for a clean SSL shutdown by notifying the + remote peer of connection closure, while avoiding + excessive delays during connector cleanup. + Note: Only takes effect on Python 3.11+. """ allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"}) @@ -905,6 +911,7 @@ def __init__( happy_eyeballs_delay: Optional[float] = 0.25, interleave: Optional[int] = None, socket_factory: Optional[SocketFactoryType] = None, + ssl_shutdown_timeout: Optional[float] = 0.1, ): super().__init__( keepalive_timeout=keepalive_timeout, @@ -932,6 +939,7 @@ def __init__( self._interleave = interleave self._resolve_host_tasks: Set["asyncio.Task[List[ResolveResult]]"] = set() self._socket_factory = socket_factory + self._ssl_shutdown_timeout = ssl_shutdown_timeout def _close(self) -> List[Awaitable[object]]: """Close all ongoing DNS calls.""" @@ -1176,6 +1184,13 @@ async def _wrap_create_connection( loop=self._loop, socket_factory=self._socket_factory, ) + # Add ssl_shutdown_timeout for Python 3.11+ when SSL is used + if ( + kwargs.get("ssl") + and self._ssl_shutdown_timeout is not None + and sys.version_info >= (3, 11) + ): + kwargs["ssl_shutdown_timeout"] = self._ssl_shutdown_timeout return await self._loop.create_connection(*args, **kwargs, sock=sock) except cert_errors as exc: raise ClientConnectorCertificateError(req.connection_key, exc) from exc @@ -1314,13 +1329,27 @@ async def _start_tls_connection( timeout.sock_connect, ceil_threshold=timeout.ceil_threshold ): try: - tls_transport = await self._loop.start_tls( - underlying_transport, - tls_proto, - sslcontext, - server_hostname=req.server_hostname or req.host, - ssl_handshake_timeout=timeout.total, - ) + # ssl_shutdown_timeout is only available in Python 3.11+ + if ( + sys.version_info >= (3, 11) + and self._ssl_shutdown_timeout is not None + ): + tls_transport = await self._loop.start_tls( + underlying_transport, + tls_proto, + sslcontext, + server_hostname=req.server_hostname or req.host, + ssl_handshake_timeout=timeout.total, + ssl_shutdown_timeout=self._ssl_shutdown_timeout, + ) + else: + tls_transport = await self._loop.start_tls( + underlying_transport, + tls_proto, + sslcontext, + server_hostname=req.server_hostname or req.host, + ssl_handshake_timeout=timeout.total, + ) except BaseException: # We need to close the underlying transport since # `start_tls()` probably failed before it had a diff --git a/docs/client_reference.rst b/docs/client_reference.rst index 40fd7cdb276..07839686039 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -57,7 +57,8 @@ The client session supports the context manager protocol for self closing. read_bufsize=2**16, \ max_line_size=8190, \ max_field_size=8190, \ - fallback_charset_resolver=lambda r, b: "utf-8") + fallback_charset_resolver=lambda r, b: "utf-8", \ + ssl_shutdown_timeout=0.1) The class for creating client sessions and making requests. @@ -256,6 +257,16 @@ The client session supports the context manager protocol for self closing. .. versionadded:: 3.8.6 + :param float ssl_shutdown_timeout: Grace period for SSL shutdown handshake on TLS + connections (``0.1`` seconds by default). This usually provides sufficient time + to notify the remote peer of connection closure, helping prevent broken + connections on the server side, while minimizing delays during connector + cleanup. This timeout is passed to the underlying :class:`TCPConnector` + when one is created automatically. Note: This parameter only takes effect + on Python 3.11+. + + .. versionadded:: 3.12.5 + .. attribute:: closed ``True`` if the session has been closed, ``False`` otherwise. @@ -1185,7 +1196,7 @@ is controlled by *force_close* constructor's parameter). force_close=False, limit=100, limit_per_host=0, \ enable_cleanup_closed=False, timeout_ceil_threshold=5, \ happy_eyeballs_delay=0.25, interleave=None, loop=None, \ - socket_factory=None) + socket_factory=None, ssl_shutdown_timeout=0.1) Connector for working with *HTTP* and *HTTPS* via *TCP* sockets. @@ -1312,6 +1323,16 @@ is controlled by *force_close* constructor's parameter). .. versionadded:: 3.12 + :param float ssl_shutdown_timeout: Grace period for SSL shutdown on TLS + connections (``0.1`` seconds by default). This parameter balances two + important considerations: usually providing sufficient time to notify + the remote server (which helps prevent "connection reset" errors), + while avoiding unnecessary delays during connector cleanup. + The default value provides a reasonable compromise for most use cases. + Note: This parameter only takes effect on Python 3.11+. + + .. versionadded:: 3.12.5 + .. attribute:: family *TCP* socket family e.g. :data:`socket.AF_INET` or diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index cb4edd3d1e1..1d91956c4a3 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -12,6 +12,7 @@ import tarfile import time import zipfile +from contextlib import suppress from typing import ( Any, AsyncIterator, @@ -685,6 +686,70 @@ async def handler(request): assert txt == "Test message" +@pytest.mark.skipif( + sys.version_info < (3, 11), reason="ssl_shutdown_timeout requires Python 3.11+" +) +async def test_ssl_client_shutdown_timeout( + aiohttp_server: AiohttpServer, + ssl_ctx: ssl.SSLContext, + aiohttp_client: AiohttpClient, + client_ssl_ctx: ssl.SSLContext, +) -> None: + # Test that ssl_shutdown_timeout is properly used during connection closure + + connector = aiohttp.TCPConnector(ssl=client_ssl_ctx, ssl_shutdown_timeout=0.1) + + async def streaming_handler(request: web.Request) -> NoReturn: + # Create a streaming response that continuously sends data + response = web.StreamResponse() + await response.prepare(request) + + # Keep sending data until connection is closed + while True: + await response.write(b"data chunk\n") + await asyncio.sleep(0.01) # Small delay between chunks + + assert False, "not reached" + + app = web.Application() + app.router.add_route("GET", "/stream", streaming_handler) + server = await aiohttp_server(app, ssl=ssl_ctx) + client = await aiohttp_client(server, connector=connector) + + # Verify the connector has the correct timeout + assert connector._ssl_shutdown_timeout == 0.1 + + # Start a streaming request to establish SSL connection with active data transfer + resp = await client.get("/stream") + assert resp.status == 200 + + # Create a background task that continuously reads data + async def read_loop() -> None: + while True: + # Read "data chunk\n" + await resp.content.read(11) + + read_task = asyncio.create_task(read_loop()) + await asyncio.sleep(0) # Yield control to ensure read_task starts + + # Record the time before closing + start_time = time.monotonic() + + # Now close the connector while the stream is still active + # This will test the ssl_shutdown_timeout during an active connection + await connector.close() + + # Verify the connection was closed within a reasonable time + # Should be close to ssl_shutdown_timeout (0.1s) but allow some margin + elapsed = time.monotonic() - start_time + assert elapsed < 0.3, f"Connection closure took too long: {elapsed}s" + + read_task.cancel() + with suppress(asyncio.CancelledError): + await read_task + assert read_task.done(), "Read task should be cancelled after connection closure" + + async def test_ssl_client_alpn( aiohttp_server: AiohttpServer, aiohttp_client: AiohttpClient, diff --git a/tests/test_client_session.py b/tests/test_client_session.py index 56c7a5c0c13..0fdfaee6761 100644 --- a/tests/test_client_session.py +++ b/tests/test_client_session.py @@ -310,7 +310,35 @@ async def test_create_connector(create_session, loop, mocker) -> None: assert connector.close.called -def test_connector_loop(loop) -> None: +async def test_ssl_shutdown_timeout_passed_to_connector() -> None: + # Test default value + async with ClientSession() as session: + assert isinstance(session.connector, TCPConnector) + assert session.connector._ssl_shutdown_timeout == 0.1 + + # Test custom value + async with ClientSession(ssl_shutdown_timeout=1.0) as session: + assert isinstance(session.connector, TCPConnector) + assert session.connector._ssl_shutdown_timeout == 1.0 + + # Test None value + async with ClientSession(ssl_shutdown_timeout=None) as session: + assert isinstance(session.connector, TCPConnector) + assert session.connector._ssl_shutdown_timeout is None + + # Test that it doesn't affect when custom connector is provided + custom_conn = TCPConnector(ssl_shutdown_timeout=2.0) + async with ClientSession( + connector=custom_conn, ssl_shutdown_timeout=1.0 + ) as session: + assert session.connector is not None + assert isinstance(session.connector, TCPConnector) + assert ( + session.connector._ssl_shutdown_timeout == 2.0 + ) # Should use connector's value + + +def test_connector_loop(loop: asyncio.AbstractEventLoop) -> None: with contextlib.ExitStack() as stack: another_loop = asyncio.new_event_loop() stack.enter_context(contextlib.closing(another_loop)) diff --git a/tests/test_connector.py b/tests/test_connector.py index f17ded6d960..3b2d28ea46c 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -2002,6 +2002,104 @@ async def test_tcp_connector_ctor() -> None: await conn.close() +async def test_tcp_connector_ssl_shutdown_timeout( + loop: asyncio.AbstractEventLoop, +) -> None: + # Test default value + conn = aiohttp.TCPConnector() + assert conn._ssl_shutdown_timeout == 0.1 + await conn.close() + + # Test custom value + conn = aiohttp.TCPConnector(ssl_shutdown_timeout=1.0) + assert conn._ssl_shutdown_timeout == 1.0 + await conn.close() + + # Test None value + conn = aiohttp.TCPConnector(ssl_shutdown_timeout=None) + assert conn._ssl_shutdown_timeout is None + await conn.close() + + +@pytest.mark.skipif( + sys.version_info < (3, 11), reason="ssl_shutdown_timeout requires Python 3.11+" +) +async def test_tcp_connector_ssl_shutdown_timeout_passed_to_create_connection( + loop: asyncio.AbstractEventLoop, start_connection: mock.AsyncMock +) -> None: + # Test that ssl_shutdown_timeout is passed to create_connection for SSL connections + conn = aiohttp.TCPConnector(ssl_shutdown_timeout=2.5) + + with mock.patch.object( + conn._loop, "create_connection", autospec=True, spec_set=True + ) as create_connection: + create_connection.return_value = mock.Mock(), mock.Mock() + + req = ClientRequest("GET", URL("https://example.com"), loop=loop) + + with closing(await conn.connect(req, [], ClientTimeout())): + assert create_connection.call_args.kwargs["ssl_shutdown_timeout"] == 2.5 + + await conn.close() + + # Test with None value + conn = aiohttp.TCPConnector(ssl_shutdown_timeout=None) + + with mock.patch.object( + conn._loop, "create_connection", autospec=True, spec_set=True + ) as create_connection: + create_connection.return_value = mock.Mock(), mock.Mock() + + req = ClientRequest("GET", URL("https://example.com"), loop=loop) + + with closing(await conn.connect(req, [], ClientTimeout())): + # When ssl_shutdown_timeout is None, it should not be in kwargs + assert "ssl_shutdown_timeout" not in create_connection.call_args.kwargs + + await conn.close() + + # Test that ssl_shutdown_timeout is NOT passed for non-SSL connections + conn = aiohttp.TCPConnector(ssl_shutdown_timeout=2.5) + + with mock.patch.object( + conn._loop, "create_connection", autospec=True, spec_set=True + ) as create_connection: + create_connection.return_value = mock.Mock(), mock.Mock() + + req = ClientRequest("GET", URL("http://example.com"), loop=loop) + + with closing(await conn.connect(req, [], ClientTimeout())): + # For non-SSL connections, ssl_shutdown_timeout should not be passed + assert "ssl_shutdown_timeout" not in create_connection.call_args.kwargs + + await conn.close() + + +@pytest.mark.skipif(sys.version_info >= (3, 11), reason="Test for Python < 3.11") +async def test_tcp_connector_ssl_shutdown_timeout_not_passed_pre_311( + loop: asyncio.AbstractEventLoop, start_connection: mock.AsyncMock +) -> None: + # Test that ssl_shutdown_timeout is NOT passed to create_connection on Python < 3.11 + conn = aiohttp.TCPConnector(ssl_shutdown_timeout=2.5) + + with mock.patch.object( + conn._loop, "create_connection", autospec=True, spec_set=True + ) as create_connection: + create_connection.return_value = mock.Mock(), mock.Mock() + + # Test with HTTPS + req = ClientRequest("GET", URL("https://example.com"), loop=loop) + with closing(await conn.connect(req, [], ClientTimeout())): + assert "ssl_shutdown_timeout" not in create_connection.call_args.kwargs + + # Test with HTTP + req = ClientRequest("GET", URL("http://example.com"), loop=loop) + with closing(await conn.connect(req, [], ClientTimeout())): + assert "ssl_shutdown_timeout" not in create_connection.call_args.kwargs + + await conn.close() + + async def test_tcp_connector_allowed_protocols(loop: asyncio.AbstractEventLoop) -> None: conn = aiohttp.TCPConnector() assert conn.allowed_protocol_schema_set == {"", "tcp", "http", "https", "ws", "wss"} diff --git a/tests/test_proxy.py b/tests/test_proxy.py index 0e73210f58b..f5ebf6adc4f 100644 --- a/tests/test_proxy.py +++ b/tests/test_proxy.py @@ -936,13 +936,23 @@ async def make_conn(): connector._create_connection(req, None, aiohttp.ClientTimeout()) ) - self.loop.start_tls.assert_called_with( - mock.ANY, - mock.ANY, - _SSL_CONTEXT_VERIFIED, - server_hostname="www.python.org", - ssl_handshake_timeout=mock.ANY, - ) + if sys.version_info >= (3, 11): + self.loop.start_tls.assert_called_with( + mock.ANY, + mock.ANY, + _SSL_CONTEXT_VERIFIED, + server_hostname="www.python.org", + ssl_handshake_timeout=mock.ANY, + ssl_shutdown_timeout=0.1, + ) + else: + self.loop.start_tls.assert_called_with( + mock.ANY, + mock.ANY, + _SSL_CONTEXT_VERIFIED, + server_hostname="www.python.org", + ssl_handshake_timeout=mock.ANY, + ) self.assertEqual(req.url.path, "/") self.assertEqual(proxy_req.method, "CONNECT")