Skip to content

Commit 27c04c7

Browse files
committed
fixes and regresison
1 parent 630f0e8 commit 27c04c7

File tree

9 files changed

+252
-19
lines changed

9 files changed

+252
-19
lines changed

python-client/river/client.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from typing_extensions import TypedDict
1515

16+
from river.session import SessionState
1617
from river.streams import Readable, Writable
1718
from river.transport import WebSocketClientTransport
1819
from river.types import (
@@ -259,8 +260,30 @@ def _handle_proc(
259260
if self._connect_on_invoke:
260261
transport.connect(to)
261262

262-
# Get the session and a send function
263+
# Get the session and a send function.
264+
# If connect() couldn't start (retry budget exhausted, transport
265+
# closing, etc.) the session will be in NO_CONNECTION with no
266+
# connect task in flight — fail immediately instead of hanging.
263267
session = transport._get_or_create_session(to)
268+
connect_task = transport._connect_tasks.get(to)
269+
has_active_connect = connect_task is not None and not connect_task.done()
270+
if session.state == SessionState.NO_CONNECTION and not has_active_connect:
271+
transport._delete_session(to, emit_closing=False)
272+
res_readable = Readable()
273+
res_readable._push_value(
274+
err_result(
275+
UNEXPECTED_DISCONNECT_CODE,
276+
f"{to} connection failed",
277+
)
278+
)
279+
res_readable._trigger_close()
280+
req_writable = Writable(write_cb=lambda _: None, close_cb=None)
281+
req_writable._closed = True
282+
return {
283+
"res_readable": res_readable,
284+
"req_writable": req_writable,
285+
}
286+
264287
session_id = session.id
265288
try:
266289
send_fn = transport.get_session_bound_send_fn(to, session_id)

python-client/river/codegen/emitter.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,17 @@ def render_service_client(svc: ServiceDef, ir: SchemaIR, import_prefix: str) ->
132132
has_upload = "upload" in proc_types
133133
has_subscription = "subscription" in proc_types
134134

135+
# Check if any annotation references Literal (e.g. const schemas)
136+
all_annotations = []
137+
for p in svc.procedures:
138+
all_annotations.append(p.init_type.annotation)
139+
all_annotations.append(p.output_type.annotation)
140+
if p.input_type:
141+
all_annotations.append(p.input_type.annotation)
142+
if p.error_type:
143+
all_annotations.append(p.error_type.annotation)
144+
needs_literal = any("Literal[" in a for a in all_annotations)
145+
135146
return _env.get_template("service_client.py.j2").render(
136147
service=svc,
137148
type_names=type_names,
@@ -140,6 +151,7 @@ def render_service_client(svc: ServiceDef, ir: SchemaIR, import_prefix: str) ->
140151
has_stream=has_stream,
141152
has_upload=has_upload,
142153
has_subscription=has_subscription,
154+
needs_literal=needs_literal,
143155
)
144156

145157

python-client/river/codegen/schema.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def __init__(self) -> None:
151151

152152
def convert(self, raw: dict) -> SchemaIR:
153153
"""Convert the top-level serialized schema dict to IR."""
154+
self._typedicts = []
154155
services: list[ServiceDef] = []
155156
for svc_name, svc_data in raw.get("services", {}).items():
156157
svc_def = self._convert_service(svc_name, svc_data)
@@ -247,7 +248,11 @@ def _schema_to_typeref(self, schema: dict, name_hint: str) -> TypeRef:
247248
if "const" in schema:
248249
val = schema["const"]
249250
if isinstance(val, str):
250-
return TypeRef(annotation=f'Literal["{val}"]')
251+
# Use repr to handle all escaping (quotes, backslashes,
252+
# control chars) then unwrap the outer quotes and re-wrap
253+
# with double quotes for Literal["..."] syntax.
254+
escaped = repr(val)[1:-1].replace('"', '\\"')
255+
return TypeRef(annotation=f'Literal["{escaped}"]')
251256
return TypeRef(annotation=f"Literal[{val!r}]")
252257

253258
# anyOf (union)

python-client/river/codegen/templates/service_client.py.j2

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
from __future__ import annotations
44

55
import asyncio
6+
{% if needs_literal %}
7+
from typing import Any, Literal
8+
{% else %}
69
from typing import Any
10+
{% endif %}
711

812
from river.client import (
913
ErrResult,

python-client/river/streams.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def is_readable(self) -> bool:
5555
return not self._locked and not self._broken
5656

5757
def is_closed(self) -> bool:
58-
"""Whether the stream has been closed."""
58+
"""Whether the stream is fully consumed (closed and queue drained)."""
5959
return self._closed and len(self._queue) == 0
6060

6161
def _has_values_in_queue(self) -> bool:

python-client/river/transport.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ def connect(self, to: str | None = None) -> None:
235235
session.state = SessionState.BACKING_OFF
236236

237237
async def _do_connect():
238+
ws = None
238239
try:
239240
if backoff_ms > 0:
240241
await asyncio.sleep(backoff_ms / 1000.0)
@@ -252,7 +253,9 @@ async def _do_connect():
252253
session.state = SessionState.HANDSHAKING
253254
await self._do_handshake(session, ws, to)
254255
except asyncio.CancelledError:
255-
pass
256+
# Clean up socket if we got cancelled mid-handshake
257+
if ws is not None and session._ws is not ws:
258+
await ws.close()
256259
except Exception as e:
257260
logger.debug("Connection attempt failed for %s: %s", to, e)
258261
if not session._destroyed:
@@ -484,9 +487,12 @@ def _on_connection_failed(self, to: str) -> None:
484487

485488
# Transition to NoConnection with grace period so the session
486489
# is eventually destroyed if reconnect doesn't succeed.
490+
# Only start the grace period if one isn't already running,
491+
# so repeated failures don't keep extending the deadline.
487492
loop = self._get_loop()
488493
session.state = SessionState.NO_CONNECTION
489-
session.start_grace_period(loop)
494+
if session._grace_period_task is None or session._grace_period_task.done():
495+
session.start_grace_period(loop)
490496

491497
if self._reconnect_on_connection_drop:
492498
self._try_reconnecting(to)

python-client/tests/conftest.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import os
1010
import re
11+
import selectors
1112
import signal
1213
import subprocess
1314
import sys
@@ -119,20 +120,44 @@ def _start_server(
119120
port = None
120121
deadline = time.monotonic() + 30
121122
assert proc.stdout is not None
122-
while time.monotonic() < deadline:
123-
line = proc.stdout.readline().decode("utf-8").strip()
124-
if not line:
125-
if proc.poll() is not None:
126-
stderr = proc.stderr.read().decode("utf-8") if proc.stderr else ""
127-
raise RuntimeError(
128-
f"{label} exited with code {proc.returncode}.\nstderr: {stderr}"
129-
)
130-
time.sleep(0.1)
131-
continue
132-
m = re.match(r"RIVER_PORT=(\d+)", line)
133-
if m:
134-
port = int(m.group(1))
135-
break
123+
sel = selectors.DefaultSelector()
124+
sel.register(proc.stdout, selectors.EVENT_READ)
125+
buf = b""
126+
try:
127+
while time.monotonic() < deadline:
128+
remaining = deadline - time.monotonic()
129+
if remaining <= 0:
130+
break
131+
ready = sel.select(timeout=min(remaining, 1.0))
132+
if not ready:
133+
if proc.poll() is not None:
134+
stderr = proc.stderr.read().decode("utf-8") if proc.stderr else ""
135+
raise RuntimeError(
136+
f"{label} exited with code {proc.returncode}.\nstderr: {stderr}"
137+
)
138+
continue
139+
chunk = proc.stdout.read1(4096) # type: ignore[union-attr]
140+
if not chunk:
141+
# EOF — child closed stdout (likely exited)
142+
if proc.poll() is not None:
143+
stderr = proc.stderr.read().decode("utf-8") if proc.stderr else ""
144+
raise RuntimeError(
145+
f"{label} exited with code {proc.returncode}.\nstderr: {stderr}"
146+
)
147+
continue
148+
buf += chunk
149+
while b"\n" in buf:
150+
line_bytes, buf = buf.split(b"\n", 1)
151+
line = line_bytes.decode("utf-8").strip()
152+
m = re.match(r"RIVER_PORT=(\d+)", line)
153+
if m:
154+
port = int(m.group(1))
155+
break
156+
if port is not None:
157+
break
158+
finally:
159+
sel.unregister(proc.stdout)
160+
sel.close()
136161

137162
if port is None:
138163
proc.kill()

python-client/tests/test_e2e.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1751,3 +1751,52 @@ def test_enable_transparent_reconnects_option(self):
17511751
options=opts,
17521752
)
17531753
assert transport.reconnect_on_connection_drop is False
1754+
1755+
def test_literal_const_escaping(self):
1756+
"""String consts with quotes/backslashes/control chars are escaped."""
1757+
from river.codegen.schema import SchemaConverter
1758+
1759+
converter = SchemaConverter()
1760+
schema = {"const": 'a"b'}
1761+
ref = converter._schema_to_typeref(schema, "Test")
1762+
assert ref.annotation == 'Literal["a\\"b"]'
1763+
1764+
schema2 = {"const": "a\\b"}
1765+
ref2 = converter._schema_to_typeref(schema2, "Test")
1766+
assert ref2.annotation == 'Literal["a\\\\b"]'
1767+
1768+
# Control characters must be escaped
1769+
schema3 = {"const": "line1\nline2"}
1770+
ref3 = converter._schema_to_typeref(schema3, "Test")
1771+
assert ref3.annotation == 'Literal["line1\\nline2"]'
1772+
1773+
schema4 = {"const": "a\tb"}
1774+
ref4 = converter._schema_to_typeref(schema4, "Test")
1775+
assert ref4.annotation == 'Literal["a\\tb"]'
1776+
1777+
def test_is_closed_with_buffered_data(self):
1778+
"""is_closed() is False when closed but queue has data."""
1779+
from river.streams import Readable
1780+
1781+
r: Readable = Readable()
1782+
r._push_value({"val": 1})
1783+
r._trigger_close()
1784+
# Closed but not fully consumed
1785+
assert r.is_closed() is False
1786+
assert r._closed is True
1787+
1788+
@pytest.mark.asyncio
1789+
async def test_close_cancels_inflight_connect(self, server_url: str):
1790+
"""close() during handshake doesn't leak the websocket."""
1791+
transport = WebSocketClientTransport(
1792+
ws_url=server_url,
1793+
client_id=None,
1794+
server_id="SERVER",
1795+
codec=NaiveJsonCodec(),
1796+
)
1797+
transport.connect("SERVER")
1798+
# Let connection start but don't wait for completion
1799+
await asyncio.sleep(0)
1800+
await transport.close()
1801+
# No leaked sessions
1802+
assert len(transport.sessions) == 0

python-client/tests/test_session.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,3 +390,112 @@ async def test_connect_on_invoke_false_no_reconnect(self, server_url: str):
390390
finally:
391391
# transport already closed above
392392
pass
393+
394+
395+
# =====================================================================
396+
# Regression: stale connect-task must not block fail-fast
397+
# =====================================================================
398+
399+
400+
class TestStaleConnectTask:
401+
@pytest.mark.asyncio
402+
async def test_done_connect_task_does_not_block_failfast(self):
403+
"""A completed (done) connect task in _connect_tasks must not
404+
prevent the fail-fast path from firing.
405+
406+
Regression: previously the check was `to not in _connect_tasks`,
407+
so a done task kept the entry alive and calls would hang instead
408+
of failing immediately when retries were exhausted.
409+
"""
410+
transport = WebSocketClientTransport(
411+
ws_url="ws://127.0.0.1:1", # unreachable
412+
client_id=None,
413+
server_id="STALE",
414+
codec=NaiveJsonCodec(),
415+
options=SessionOptions(
416+
connection_timeout_ms=100,
417+
handshake_timeout_ms=100,
418+
session_disconnect_grace_ms=200,
419+
),
420+
)
421+
transport.reconnect_on_connection_drop = False
422+
try:
423+
# Trigger a connect that will fail
424+
transport.connect("STALE")
425+
await wait_for_session_gone(transport, "STALE")
426+
427+
# The done task is still in _connect_tasks
428+
assert "STALE" in transport._connect_tasks
429+
assert transport._connect_tasks["STALE"].done()
430+
431+
# Exhaust the retry budget so connect() is a no-op
432+
transport._retry_budget.budget_consumed = (
433+
transport._retry_budget.attempt_budget_capacity
434+
)
435+
436+
# RPC must fail immediately, not hang
437+
client = RiverClient(
438+
transport, server_id="STALE", connect_on_invoke=True
439+
)
440+
result = await asyncio.wait_for(
441+
client.rpc("test", "add", {"n": 1}), timeout=1.0
442+
)
443+
assert result["ok"] is False
444+
assert result["payload"]["code"] == "UNEXPECTED_DISCONNECT"
445+
finally:
446+
await transport.close()
447+
448+
449+
# =====================================================================
450+
# Regression: grace period must not reset on each failed reconnect
451+
# =====================================================================
452+
453+
454+
class TestGracePeriodNotResetOnRetry:
455+
@pytest.mark.asyncio
456+
async def test_grace_period_not_extended_by_retries(self, server_url: str):
457+
"""Repeated connection failures must not restart the grace timer.
458+
459+
Regression: _on_connection_failed() unconditionally called
460+
start_grace_period(), which cancelled and restarted the timer
461+
on every retry, extending session lifetime far beyond
462+
session_disconnect_grace_ms.
463+
"""
464+
grace_ms = 400
465+
transport = WebSocketClientTransport(
466+
ws_url="ws://127.0.0.1:1", # unreachable
467+
client_id=None,
468+
server_id="GRACE",
469+
codec=NaiveJsonCodec(),
470+
options=SessionOptions(
471+
connection_timeout_ms=100,
472+
handshake_timeout_ms=100,
473+
session_disconnect_grace_ms=grace_ms,
474+
),
475+
)
476+
try:
477+
transport.connect("GRACE")
478+
479+
# Wait for at least one connection failure to set the grace period
480+
await wait_for(
481+
lambda: (
482+
(s := transport.sessions.get("GRACE")) is not None
483+
and s._grace_period_task is not None
484+
),
485+
timeout=2.0,
486+
)
487+
488+
session = transport.sessions["GRACE"]
489+
original_expiry = session._grace_expiry_time
490+
assert original_expiry is not None
491+
492+
# After further retries, the expiry time must not have moved forward
493+
await asyncio.sleep(0.2)
494+
session2 = transport.sessions.get("GRACE")
495+
if session2 is not None and session2._grace_expiry_time is not None:
496+
assert session2._grace_expiry_time <= original_expiry
497+
498+
# Session should be gone within grace_ms + generous margin
499+
await wait_for_session_gone(transport, "GRACE", timeout=3.0)
500+
finally:
501+
await transport.close()

0 commit comments

Comments
 (0)