Skip to content

Commit 3443957

Browse files
committed
fixes
1 parent 6efd70a commit 3443957

File tree

4 files changed

+184
-12
lines changed

4 files changed

+184
-12
lines changed

python-client/river/session.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,8 @@ def send(self, partial: PartialTransportMessage) -> tuple[bool, str]:
141141
if self.state == SessionState.CONNECTED and self._ws is not None:
142142
ok, result = self._send_over_wire(msg)
143143
if not ok:
144-
# Roll back: remove the unsendable message from the buffer
145-
# and restore seq so subsequent messages don't have a gap.
146-
self.send_buffer = [m for m in self.send_buffer if m.id != msg.id]
147-
self.seq = msg.seq # restore to the seq we consumed
144+
# Send failure is fatal — the caller (transport)
145+
# is expected to destroy the session.
148146
return False, result
149147
return True, msg.id
150148

python-client/river/streams.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def close(self, value: T | None = None) -> None:
214214
if value is not None:
215215
self._write_cb(value)
216216
self._closed = True
217-
# Nullify callbacks after invocation to prevent reuse (matches TS)
217+
# Nullify callbacks after invocation to prevent reuse
218218
self._write_cb = lambda _: None # type: ignore[assignment]
219219
if self._close_cb:
220220
self._close_cb()

python-client/river/transport.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -280,9 +280,14 @@ async def _do_handshake(self, session: Session, ws: Any, to: str) -> None:
280280
hs_msg = session.create_handshake_request(metadata=self._handshake_metadata)
281281
ok, buf = self._codec_adapter.to_buffer(hs_msg)
282282
if not ok:
283+
# Handshake send failure is fatal — destroy session
283284
logger.error("Failed to encode handshake: %s", buf)
284285
await ws.close()
285-
self._on_connection_failed(to)
286+
self._events.dispatch(
287+
"protocolError",
288+
{"type": "message_send_failure", "message": buf},
289+
)
290+
self._delete_session(to)
286291
return
287292

288293
await ws.send(buf)
@@ -303,19 +308,21 @@ async def _do_handshake(self, session: Session, ws: Any, to: str) -> None:
303308

304309
ok, result = self._codec_adapter.from_buffer(response_bytes)
305310
if not ok:
311+
# Invalid handshake response is fatal
306312
logger.error("Failed to decode handshake response: %s", result)
307313
await ws.close()
308-
self._on_connection_failed(to)
314+
self._delete_session(to)
309315
return
310316

311317
response_msg: TransportMessage = result # type: ignore[assignment]
312318
payload = response_msg.payload
313319

314320
# Validate handshake response
315321
if not isinstance(payload, dict) or payload.get("type") != "HANDSHAKE_RESP":
322+
# Invalid handshake schema is fatal
316323
logger.error("Invalid handshake response payload")
317324
await ws.close()
318-
self._on_connection_failed(to)
325+
self._delete_session(to)
319326
return
320327

321328
status = payload.get("status", {})
@@ -411,10 +418,12 @@ def _on_message_data(self, session: Session, raw: bytes, to: str) -> None:
411418
"""Handle raw bytes received from the WebSocket."""
412419
ok, result = self._codec_adapter.from_buffer(raw)
413420
if not ok:
421+
# Invalid message is fatal — destroy the session
414422
self._events.dispatch(
415423
"protocolError",
416424
{"type": "invalid_message", "message": result},
417425
)
426+
self._delete_session(to)
418427
return
419428

420429
msg: TransportMessage = result # type: ignore[assignment]
@@ -506,6 +515,12 @@ def _send(msg: PartialTransportMessage) -> str:
506515

507516
ok, result = session.send(msg)
508517
if not ok:
518+
# Send failure is fatal — destroy session
519+
self._events.dispatch(
520+
"protocolError",
521+
{"type": "message_send_failure", "message": result},
522+
)
523+
self._delete_session(to)
509524
raise RuntimeError(f"Send failed: {result}")
510525
return result
511526

python-client/tests/test_e2e.py

Lines changed: 163 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1405,12 +1405,12 @@ async def test_finalize_after_explicit_close(self, server_url: str):
14051405

14061406

14071407
class TestProtocolConformance:
1408-
"""Tests verifying Python client matches TS protocol behavior."""
1408+
"""Tests verifying protocol-level conformance."""
14091409

14101410
def test_handshake_stream_id_is_random(self):
14111411
"""Handshake streamId should be a random ID, not a fixed string.
14121412
1413-
TS uses generateId() for handshake streamId; Python must match.
1413+
The protocol requires a random streamId for handshakes.
14141414
"""
14151415
from river.codec import CodecMessageAdapter, NaiveJsonCodec
14161416
from river.session import Session
@@ -1446,7 +1446,7 @@ def test_readable_push_after_break_is_noop(self):
14461446
def test_writable_close_nullifies_callbacks(self):
14471447
"""After close(), write/close callbacks should not be invocable.
14481448
1449-
TS nullifies callbacks after close to prevent reuse.
1449+
Callbacks should be nullified after close to prevent reuse.
14501450
"""
14511451
from river.streams import Writable
14521452

@@ -1467,7 +1467,7 @@ def test_writable_close_nullifies_callbacks(self):
14671467
assert close_count[0] == 1
14681468

14691469
def test_heartbeat_stream_id_is_fixed(self):
1470-
"""Heartbeat streamId should be 'heartbeat' (matching TS)."""
1470+
"""Heartbeat streamId should be the fixed string 'heartbeat'."""
14711471
from river.types import heartbeat_message
14721472

14731473
hb = heartbeat_message()
@@ -1501,3 +1501,162 @@ def test_handshake_payload_omits_metadata_when_none(self):
15011501
metadata=None,
15021502
)
15031503
assert "metadata" not in payload
1504+
1505+
1506+
class TestFatalErrorPaths:
1507+
"""Regression tests for fatal error paths that must destroy the session.
1508+
1509+
Certain errors are not retryable and must immediately destroy
1510+
the session.
1511+
"""
1512+
1513+
def test_failed_send_destroys_session(self):
1514+
"""Send failure on a connected session destroys it."""
1515+
from unittest.mock import AsyncMock
1516+
1517+
from river.codec import CodecMessageAdapter, NaiveJsonCodec
1518+
from river.session import Session, SessionState
1519+
from river.transport import WebSocketClientTransport
1520+
1521+
transport = WebSocketClientTransport(
1522+
ws_url="ws://127.0.0.1:1",
1523+
client_id="client",
1524+
server_id="server",
1525+
codec=NaiveJsonCodec(),
1526+
)
1527+
codec = CodecMessageAdapter(NaiveJsonCodec())
1528+
session = Session("s1", "client", "server", codec)
1529+
session.state = SessionState.CONNECTED
1530+
session._ws = AsyncMock()
1531+
transport.sessions["server"] = session
1532+
1533+
send_fn = transport.get_session_bound_send_fn("server", "s1")
1534+
1535+
# A payload that can't be serialized (set is not JSON-serializable)
1536+
from river.types import PartialTransportMessage
1537+
1538+
try:
1539+
send_fn(
1540+
PartialTransportMessage(
1541+
payload={"bad": {1, 2}},
1542+
stream_id="x",
1543+
control_flags=0,
1544+
)
1545+
)
1546+
except RuntimeError:
1547+
pass
1548+
1549+
# Session must be destroyed
1550+
assert transport.sessions.get("server") is None
1551+
1552+
def test_failed_send_seq_consumed(self):
1553+
"""Send failure does not roll back seq.
1554+
1555+
The seq is consumed and the session is destroyed instead.
1556+
"""
1557+
from unittest.mock import AsyncMock
1558+
1559+
from river.codec import CodecMessageAdapter, NaiveJsonCodec
1560+
from river.session import Session, SessionState
1561+
from river.types import PartialTransportMessage
1562+
1563+
codec = CodecMessageAdapter(NaiveJsonCodec())
1564+
session = Session("s1", "client", "server", codec)
1565+
session.state = SessionState.CONNECTED
1566+
session._ws = AsyncMock()
1567+
1568+
initial_seq = session.seq
1569+
1570+
ok, _ = session.send(
1571+
PartialTransportMessage(
1572+
payload={"bad": {1, 2}},
1573+
stream_id="x",
1574+
control_flags=0,
1575+
)
1576+
)
1577+
1578+
assert not ok
1579+
# seq was consumed (not rolled back)
1580+
assert session.seq == initial_seq + 1
1581+
1582+
def test_invalid_message_destroys_session(self):
1583+
"""Receiving a corrupt message destroys the session."""
1584+
from river.codec import CodecMessageAdapter, NaiveJsonCodec
1585+
from river.session import Session, SessionState
1586+
from river.transport import WebSocketClientTransport
1587+
1588+
transport = WebSocketClientTransport(
1589+
ws_url="ws://127.0.0.1:1",
1590+
client_id="client",
1591+
server_id="server",
1592+
codec=NaiveJsonCodec(),
1593+
)
1594+
codec = CodecMessageAdapter(NaiveJsonCodec())
1595+
session = Session("s1", "client", "server", codec)
1596+
session.state = SessionState.CONNECTED
1597+
transport.sessions["server"] = session
1598+
1599+
errors: list[dict] = []
1600+
transport.add_event_listener("protocolError", lambda e: errors.append(e))
1601+
1602+
# Feed garbage bytes
1603+
transport._on_message_data(session, b"not valid json", "server")
1604+
1605+
# Session must be destroyed
1606+
assert transport.sessions.get("server") is None
1607+
assert len(errors) == 1
1608+
assert errors[0]["type"] == "invalid_message"
1609+
1610+
def test_readable_broken_after_async_for_break(self):
1611+
"""Breaking out of async for marks readable as broken."""
1612+
from river.streams import Readable
1613+
1614+
r: Readable = Readable()
1615+
r._push_value({"ok": True, "payload": 1})
1616+
1617+
# Simulate what async for + break does: create iterator, get
1618+
# one value, then let the iterator be GC'd
1619+
it = r.__aiter__()
1620+
# The __del__ should mark broken
1621+
del it
1622+
1623+
assert r._broken
1624+
# Subsequent pushes should be no-ops
1625+
r._push_value({"ok": True, "payload": 2})
1626+
assert not r._has_values_in_queue()
1627+
1628+
def test_frozen_session_options(self):
1629+
"""SessionOptions is frozen — mutation raises."""
1630+
from river.session import SessionOptions
1631+
1632+
opts = SessionOptions()
1633+
try:
1634+
opts.heartbeat_interval_ms = 999 # type: ignore[misc]
1635+
raise AssertionError("should have raised FrozenInstanceError")
1636+
except AttributeError:
1637+
pass # frozen dataclass raises AttributeError on mutation
1638+
1639+
def test_json_codec_large_int_encoding(self):
1640+
"""Large ints beyond JS safe integer range are encoded as $b."""
1641+
from river.codec import NaiveJsonCodec
1642+
1643+
codec = NaiveJsonCodec()
1644+
large = 2**53 + 1
1645+
buf = codec.to_buffer({"n": large})
1646+
decoded = codec.from_buffer(buf)
1647+
assert decoded["n"] == large
1648+
1649+
# Normal ints should NOT be encoded as $b
1650+
buf2 = codec.to_buffer({"n": 42})
1651+
raw = buf2.decode("utf-8")
1652+
assert "$b" not in raw
1653+
1654+
def test_json_codec_negative_large_int(self):
1655+
"""Negative large ints are also encoded as $b."""
1656+
from river.codec import NaiveJsonCodec
1657+
1658+
codec = NaiveJsonCodec()
1659+
large_neg = -(2**53 + 1)
1660+
buf = codec.to_buffer({"n": large_neg})
1661+
decoded = codec.from_buffer(buf)
1662+
assert decoded["n"] == large_neg

0 commit comments

Comments
 (0)