Skip to content

Commit 219dfe6

Browse files
Merge pull request #53 from mdsol/fix/support-sse-connections
Fix ASGI event handling for long-lived connections
2 parents 21579c9 + 09332f3 commit 219dfe6

File tree

4 files changed

+79
-10
lines changed

4 files changed

+79
-10
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# 1.6.6
2+
- Support long-lived connections in ASGI middleware
3+
14
# 1.6.5
25
- Resolved dependabot identified security issues
36
- Removed build status icon from travis (not used for CI any longer)

mauth_client/middlewares/asgi.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ async def __call__(
6262
scope_copy[ENV_APP_UUID] = signed.app_uuid
6363
scope_copy[ENV_AUTHENTIC] = True
6464
scope_copy[ENV_PROTOCOL_VERSION] = signed.protocol_version()
65-
await self.app(scope_copy, self._fake_receive(events), send)
65+
await self.app(scope_copy, self._fake_receive(events, receive), send)
6666
else:
6767
await self._send_response(send, status, message)
6868

@@ -100,18 +100,26 @@ async def _send_response(self, send: ASGISendCallable, status: int, msg: str) ->
100100
"body": json.dumps(body).encode("utf-8"),
101101
})
102102

103-
def _fake_receive(self, events: List[ASGIReceiveEvent]) -> ASGIReceiveCallable:
103+
def _fake_receive(self, events: List[ASGIReceiveEvent],
104+
original_receive: ASGIReceiveCallable) -> ASGIReceiveCallable:
104105
"""
105-
Create a fake, async receive function using an iterator of the events
106-
we've already read. This will be passed to downstream middlewares/apps
107-
instead of the usual receive fn, so that they can also "receive" the
108-
body events.
106+
Create a fake receive function that replays cached body events.
107+
108+
After the middleware consumes request body events for authentication,
109+
this allows downstream apps to also "receive" those events. Once all
110+
cached events are exhausted, delegates to the original receive to
111+
properly forward lifecycle events (like http.disconnect).
112+
113+
This is essential for long-lived connections (SSE, streaming responses)
114+
that need to detect client disconnects.
109115
"""
110116
events_iter = iter(events)
111117

112118
async def _receive() -> ASGIReceiveEvent:
113119
try:
114120
return next(events_iter)
115121
except StopIteration:
116-
pass
122+
# After body events are consumed, delegate to original receive
123+
# This allows proper handling of disconnects for SSE connections
124+
return await original_receive()
117125
return _receive

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "mauth-client"
3-
version = "1.6.5"
3+
version = "1.6.6"
44
description = "MAuth Client for Python"
55
repository = "https://github.com/mdsol/mauth-client-python"
66
authors = ["Medidata Solutions <[email protected]>"]

tests/middlewares/asgi_test.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import unittest
2-
from unittest.mock import patch
3-
42
from fastapi import FastAPI, Request
53
from fastapi.testclient import TestClient
64
from fastapi.websockets import WebSocket
5+
from unittest.mock import AsyncMock
6+
from unittest.mock import patch
77
from uuid import uuid4
88

99
from mauth_client.authenticator import LocalAuthenticator
@@ -220,3 +220,61 @@ def is_authentic_effect(self):
220220
self.client.get("/sub_app/path")
221221

222222
self.assertEqual(request_url, "/sub_app/path")
223+
224+
225+
class TestMAuthASGIMiddlewareInLongLivedConnections(unittest.IsolatedAsyncioTestCase):
226+
def setUp(self):
227+
self.app = FastAPI()
228+
Config.APP_UUID = str(uuid4())
229+
Config.MAUTH_URL = "https://mauth.com"
230+
Config.MAUTH_API_VERSION = "v1"
231+
Config.PRIVATE_KEY = "key"
232+
233+
@patch.object(LocalAuthenticator, "is_authentic")
234+
async def test_fake_receive_delegates_to_original_after_body_consumed(self, is_authentic_mock):
235+
"""Test that after body events are consumed, _fake_receive delegates to original receive"""
236+
is_authentic_mock.return_value = (True, 200, "")
237+
238+
# Track that original receive was called after body events exhausted
239+
call_order = []
240+
241+
async def mock_app(scope, receive, send):
242+
# First receive should get body event
243+
event1 = await receive()
244+
call_order.append(("body", event1["type"]))
245+
246+
# Second receive should delegate to original receive
247+
event2 = await receive()
248+
call_order.append(("disconnect", event2["type"]))
249+
250+
await send({"type": "http.response.start", "status": 200, "headers": []})
251+
await send({"type": "http.response.body", "body": b""})
252+
253+
middleware = MAuthASGIMiddleware(mock_app)
254+
255+
# Mock receive that returns body then disconnect
256+
receive_calls = 0
257+
258+
async def mock_receive():
259+
nonlocal receive_calls
260+
receive_calls += 1
261+
if receive_calls == 1:
262+
return {"type": "http.request", "body": b"test", "more_body": False}
263+
return {"type": "http.disconnect"}
264+
265+
send_mock = AsyncMock()
266+
scope = {
267+
"type": "http",
268+
"method": "POST",
269+
"path": "/test",
270+
"query_string": b"",
271+
"headers": []
272+
}
273+
274+
await middleware(scope, mock_receive, send_mock)
275+
276+
# Verify events were received in correct order
277+
self.assertEqual(len(call_order), 2)
278+
self.assertEqual(call_order[0], ("body", "http.request"))
279+
self.assertEqual(call_order[1], ("disconnect", "http.disconnect"))
280+
self.assertEqual(receive_calls, 2) # Called once for auth, once from app

0 commit comments

Comments
 (0)