|
10 | 10 | from wsproto import ConnectionType, events |
11 | 11 | from wsproto.connection import ConnectionState |
12 | 12 | from wsproto.extensions import Extension, PerMessageDeflate |
13 | | -from wsproto.utilities import RemoteProtocolError |
| 13 | +from wsproto.utilities import LocalProtocolError, RemoteProtocolError |
14 | 14 |
|
15 | 15 | from uvicorn._types import ( |
16 | 16 | ASGISendEvent, |
|
25 | 25 | from uvicorn.config import Config |
26 | 26 | from uvicorn.logging import TRACE_LOG_LEVEL |
27 | 27 | from uvicorn.protocols.utils import ( |
| 28 | + ClientDisconnected, |
28 | 29 | get_local_addr, |
29 | 30 | get_path_with_query_string, |
30 | 31 | get_remote_addr, |
@@ -236,6 +237,8 @@ def send_500_response(self) -> None: |
236 | 237 | async def run_asgi(self) -> None: |
237 | 238 | try: |
238 | 239 | result = await self.app(self.scope, self.receive, self.send) |
| 240 | + except ClientDisconnected: |
| 241 | + self.transport.close() |
239 | 242 | except BaseException: |
240 | 243 | self.logger.exception("Exception in ASGI application\n") |
241 | 244 | self.send_500_response() |
@@ -325,36 +328,39 @@ async def send(self, message: ASGISendEvent) -> None: |
325 | 328 | raise RuntimeError(msg % message_type) |
326 | 329 |
|
327 | 330 | elif not self.close_sent and not self.response_started: |
328 | | - if message_type == "websocket.send": |
329 | | - message = typing.cast(WebSocketSendEvent, message) |
330 | | - bytes_data = message.get("bytes") |
331 | | - text_data = message.get("text") |
332 | | - data = text_data if bytes_data is None else bytes_data |
333 | | - output = self.conn.send( |
334 | | - wsproto.events.Message(data=data) # type: ignore[type-var] |
335 | | - ) |
336 | | - if not self.transport.is_closing(): |
337 | | - self.transport.write(output) |
338 | | - |
339 | | - elif message_type == "websocket.close": |
340 | | - message = typing.cast(WebSocketCloseEvent, message) |
341 | | - self.close_sent = True |
342 | | - code = message.get("code", 1000) |
343 | | - reason = message.get("reason", "") or "" |
344 | | - self.queue.put_nowait({"type": "websocket.disconnect", "code": code}) |
345 | | - output = self.conn.send( |
346 | | - wsproto.events.CloseConnection(code=code, reason=reason) |
347 | | - ) |
348 | | - if not self.transport.is_closing(): |
349 | | - self.transport.write(output) |
350 | | - self.transport.close() |
351 | | - |
352 | | - else: |
353 | | - msg = ( |
354 | | - "Expected ASGI message 'websocket.send' or 'websocket.close'," |
355 | | - " but got '%s'." |
356 | | - ) |
357 | | - raise RuntimeError(msg % message_type) |
| 331 | + try: |
| 332 | + if message_type == "websocket.send": |
| 333 | + message = typing.cast(WebSocketSendEvent, message) |
| 334 | + bytes_data = message.get("bytes") |
| 335 | + text_data = message.get("text") |
| 336 | + data = text_data if bytes_data is None else bytes_data |
| 337 | + output = self.conn.send(wsproto.events.Message(data=data)) # type: ignore |
| 338 | + if not self.transport.is_closing(): |
| 339 | + self.transport.write(output) |
| 340 | + |
| 341 | + elif message_type == "websocket.close": |
| 342 | + message = typing.cast(WebSocketCloseEvent, message) |
| 343 | + self.close_sent = True |
| 344 | + code = message.get("code", 1000) |
| 345 | + reason = message.get("reason", "") or "" |
| 346 | + self.queue.put_nowait( |
| 347 | + {"type": "websocket.disconnect", "code": code} |
| 348 | + ) |
| 349 | + output = self.conn.send( |
| 350 | + wsproto.events.CloseConnection(code=code, reason=reason) |
| 351 | + ) |
| 352 | + if not self.transport.is_closing(): |
| 353 | + self.transport.write(output) |
| 354 | + self.transport.close() |
| 355 | + |
| 356 | + else: |
| 357 | + msg = ( |
| 358 | + "Expected ASGI message 'websocket.send' or 'websocket.close'," |
| 359 | + " but got '%s'." |
| 360 | + ) |
| 361 | + raise RuntimeError(msg % message_type) |
| 362 | + except LocalProtocolError as exc: |
| 363 | + raise ClientDisconnected from exc |
358 | 364 | elif self.response_started: |
359 | 365 | if message_type == "websocket.http.response.body": |
360 | 366 | message = typing.cast("WebSocketResponseBodyEvent", message) |
|
0 commit comments