Skip to content

Commit afed732

Browse files
authored
Raise Disconnect on send() when client disconnected (#2218)
* Raise `Disconnect` on `send()` when client disconnected * Remove unnucessary variable * Remove unnucessary sleep * Undo transport close changes * Rename Disconnect to ClientDisconnect
1 parent baf4ea4 commit afed732

File tree

5 files changed

+110
-52
lines changed

5 files changed

+110
-52
lines changed

tests/conftest.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,8 +253,12 @@ def unused_tcp_port() -> int:
253253
marks=pytest.mark.skipif(
254254
not importlib.util.find_spec("wsproto"), reason="wsproto not installed."
255255
),
256+
id="wsproto",
257+
),
258+
pytest.param(
259+
"uvicorn.protocols.websockets.websockets_impl:WebSocketProtocol",
260+
id="websockets",
256261
),
257-
"uvicorn.protocols.websockets.websockets_impl:WebSocketProtocol",
258262
]
259263
)
260264
def ws_protocol_cls(request: pytest.FixtureRequest):
@@ -269,8 +273,9 @@ def ws_protocol_cls(request: pytest.FixtureRequest):
269273
not importlib.util.find_spec("httptools"),
270274
reason="httptools not installed.",
271275
),
276+
id="httptools",
272277
),
273-
"uvicorn.protocols.http.h11_impl:H11Protocol",
278+
pytest.param("uvicorn.protocols.http.h11_impl:H11Protocol", id="h11"),
274279
]
275280
)
276281
def http_protocol_cls(request: pytest.FixtureRequest):

tests/protocols/test_websocket.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -762,6 +762,42 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
762762
assert got_disconnect_event_before_shutdown is True
763763

764764

765+
@pytest.mark.anyio
766+
async def test_client_connection_lost_on_send(
767+
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",
768+
http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]",
769+
unused_tcp_port: int,
770+
):
771+
disconnect = asyncio.Event()
772+
got_disconnect_event = False
773+
774+
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
775+
nonlocal got_disconnect_event
776+
message = await receive()
777+
if message["type"] == "websocket.connect":
778+
await send({"type": "websocket.accept"})
779+
try:
780+
await disconnect.wait()
781+
await send({"type": "websocket.send", "text": "123"})
782+
except IOError:
783+
got_disconnect_event = True
784+
785+
config = Config(
786+
app=app,
787+
ws=ws_protocol_cls,
788+
http=http_protocol_cls,
789+
lifespan="off",
790+
port=unused_tcp_port,
791+
)
792+
async with run_server(config):
793+
url = f"ws://127.0.0.1:{unused_tcp_port}"
794+
async with websockets.client.connect(url):
795+
await asyncio.sleep(0.1)
796+
disconnect.set()
797+
798+
assert got_disconnect_event is True
799+
800+
765801
@pytest.mark.anyio
766802
async def test_connection_lost_before_handshake_complete(
767803
ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]",

uvicorn/protocols/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
from uvicorn._types import WWWScope
77

88

9+
class ClientDisconnected(IOError):
10+
...
11+
12+
913
def get_remote_addr(transport: asyncio.Transport) -> tuple[str, int] | None:
1014
socket_info = transport.get_extra_info("socket")
1115
if socket_info is not None:

uvicorn/protocols/websockets/websockets_impl.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from uvicorn.config import Config
3939
from uvicorn.logging import TRACE_LOG_LEVEL
4040
from uvicorn.protocols.utils import (
41+
ClientDisconnected,
4142
get_local_addr,
4243
get_path_with_query_string,
4344
get_remote_addr,
@@ -252,6 +253,9 @@ async def run_asgi(self) -> None:
252253
"""
253254
try:
254255
result = await self.app(self.scope, self.asgi_receive, self.asgi_send)
256+
except ClientDisconnected:
257+
self.closed_event.set()
258+
self.transport.close()
255259
except BaseException as exc:
256260
self.closed_event.set()
257261
msg = "Exception in ASGI application\n"
@@ -336,26 +340,29 @@ async def asgi_send(self, message: "ASGISendEvent") -> None:
336340
elif not self.closed_event.is_set() and self.initial_response is None:
337341
await self.handshake_completed_event.wait()
338342

339-
if message_type == "websocket.send":
340-
message = cast("WebSocketSendEvent", message)
341-
bytes_data = message.get("bytes")
342-
text_data = message.get("text")
343-
data = text_data if bytes_data is None else bytes_data
344-
await self.send(data) # type: ignore[arg-type]
345-
346-
elif message_type == "websocket.close":
347-
message = cast("WebSocketCloseEvent", message)
348-
code = message.get("code", 1000)
349-
reason = message.get("reason", "") or ""
350-
await self.close(code, reason)
351-
self.closed_event.set()
343+
try:
344+
if message_type == "websocket.send":
345+
message = cast("WebSocketSendEvent", message)
346+
bytes_data = message.get("bytes")
347+
text_data = message.get("text")
348+
data = text_data if bytes_data is None else bytes_data
349+
await self.send(data) # type: ignore[arg-type]
350+
351+
elif message_type == "websocket.close":
352+
message = cast("WebSocketCloseEvent", message)
353+
code = message.get("code", 1000)
354+
reason = message.get("reason", "") or ""
355+
await self.close(code, reason)
356+
self.closed_event.set()
352357

353-
else:
354-
msg = (
355-
"Expected ASGI message 'websocket.send' or 'websocket.close',"
356-
" but got '%s'."
357-
)
358-
raise RuntimeError(msg % message_type)
358+
else:
359+
msg = (
360+
"Expected ASGI message 'websocket.send' or 'websocket.close',"
361+
" but got '%s'."
362+
)
363+
raise RuntimeError(msg % message_type)
364+
except ConnectionClosed as exc:
365+
raise ClientDisconnected from exc
359366

360367
elif self.initial_response is not None:
361368
if message_type == "websocket.http.response.body":

uvicorn/protocols/websockets/wsproto_impl.py

Lines changed: 37 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from wsproto import ConnectionType, events
1111
from wsproto.connection import ConnectionState
1212
from wsproto.extensions import Extension, PerMessageDeflate
13-
from wsproto.utilities import RemoteProtocolError
13+
from wsproto.utilities import LocalProtocolError, RemoteProtocolError
1414

1515
from uvicorn._types import (
1616
ASGISendEvent,
@@ -25,6 +25,7 @@
2525
from uvicorn.config import Config
2626
from uvicorn.logging import TRACE_LOG_LEVEL
2727
from uvicorn.protocols.utils import (
28+
ClientDisconnected,
2829
get_local_addr,
2930
get_path_with_query_string,
3031
get_remote_addr,
@@ -236,6 +237,8 @@ def send_500_response(self) -> None:
236237
async def run_asgi(self) -> None:
237238
try:
238239
result = await self.app(self.scope, self.receive, self.send)
240+
except ClientDisconnected:
241+
self.transport.close()
239242
except BaseException:
240243
self.logger.exception("Exception in ASGI application\n")
241244
self.send_500_response()
@@ -325,36 +328,39 @@ async def send(self, message: ASGISendEvent) -> None:
325328
raise RuntimeError(msg % message_type)
326329

327330
elif not self.close_sent and not self.response_started:
328-
if message_type == "websocket.send":
329-
message = typing.cast(WebSocketSendEvent, message)
330-
bytes_data = message.get("bytes")
331-
text_data = message.get("text")
332-
data = text_data if bytes_data is None else bytes_data
333-
output = self.conn.send(
334-
wsproto.events.Message(data=data) # type: ignore[type-var]
335-
)
336-
if not self.transport.is_closing():
337-
self.transport.write(output)
338-
339-
elif message_type == "websocket.close":
340-
message = typing.cast(WebSocketCloseEvent, message)
341-
self.close_sent = True
342-
code = message.get("code", 1000)
343-
reason = message.get("reason", "") or ""
344-
self.queue.put_nowait({"type": "websocket.disconnect", "code": code})
345-
output = self.conn.send(
346-
wsproto.events.CloseConnection(code=code, reason=reason)
347-
)
348-
if not self.transport.is_closing():
349-
self.transport.write(output)
350-
self.transport.close()
351-
352-
else:
353-
msg = (
354-
"Expected ASGI message 'websocket.send' or 'websocket.close',"
355-
" but got '%s'."
356-
)
357-
raise RuntimeError(msg % message_type)
331+
try:
332+
if message_type == "websocket.send":
333+
message = typing.cast(WebSocketSendEvent, message)
334+
bytes_data = message.get("bytes")
335+
text_data = message.get("text")
336+
data = text_data if bytes_data is None else bytes_data
337+
output = self.conn.send(wsproto.events.Message(data=data)) # type: ignore
338+
if not self.transport.is_closing():
339+
self.transport.write(output)
340+
341+
elif message_type == "websocket.close":
342+
message = typing.cast(WebSocketCloseEvent, message)
343+
self.close_sent = True
344+
code = message.get("code", 1000)
345+
reason = message.get("reason", "") or ""
346+
self.queue.put_nowait(
347+
{"type": "websocket.disconnect", "code": code}
348+
)
349+
output = self.conn.send(
350+
wsproto.events.CloseConnection(code=code, reason=reason)
351+
)
352+
if not self.transport.is_closing():
353+
self.transport.write(output)
354+
self.transport.close()
355+
356+
else:
357+
msg = (
358+
"Expected ASGI message 'websocket.send' or 'websocket.close',"
359+
" but got '%s'."
360+
)
361+
raise RuntimeError(msg % message_type)
362+
except LocalProtocolError as exc:
363+
raise ClientDisconnected from exc
358364
elif self.response_started:
359365
if message_type == "websocket.http.response.body":
360366
message = typing.cast("WebSocketResponseBodyEvent", message)

0 commit comments

Comments
 (0)