Skip to content

Commit 50d8f45

Browse files
committed
fix
1 parent 208c047 commit 50d8f45

File tree

9 files changed

+258
-40
lines changed

9 files changed

+258
-40
lines changed

.github/release-drafter-python.yml

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
name-template: 'river-client-py/v$RESOLVED_VERSION'
2+
tag-template: 'river-client-py/v$RESOLVED_VERSION'
3+
filter-by-commitish: true
4+
include-paths:
5+
- 'python-client/'
6+
categories:
7+
- title: '🚀 Features'
8+
labels:
9+
- 'feature'
10+
- 'enhancement'
11+
- 'python'
12+
- title: '🐛 Bug Fixes'
13+
labels:
14+
- 'fix'
15+
- 'bugfix'
16+
- 'bug'
17+
- title: '🧰 Maintenance'
18+
label: 'chore'
19+
- title: '🤖 Dependencies'
20+
label: 'dependencies'
21+
change-template: '- $TITLE @$AUTHOR (#$NUMBER)'
22+
change-title-escapes: '\<*_&'
23+
version-resolver:
24+
major:
25+
labels:
26+
- 'major'
27+
minor:
28+
labels:
29+
- 'minor'
30+
patch:
31+
labels:
32+
- 'patch'
33+
default: patch
34+
template: |
35+
## Changes
36+
37+
$CHANGES
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
name: Build and Publish Python Package
2+
3+
on:
4+
release:
5+
types: [published]
6+
7+
jobs:
8+
build-and-publish:
9+
# Only run for Python releases
10+
if: startsWith(github.event.release.tag_name, 'river-client-py/')
11+
runs-on: ubuntu-latest
12+
steps:
13+
- uses: actions/checkout@v4
14+
15+
- name: Install uv
16+
uses: astral-sh/setup-uv@v3
17+
with:
18+
enable-cache: true
19+
20+
- name: Set up Python
21+
uses: actions/setup-python@v5
22+
with:
23+
python-version: '3.12'
24+
25+
- name: Check if version already published
26+
working-directory: python-client
27+
id: check
28+
run: |
29+
version=$(python -c "
30+
import tomllib
31+
with open('pyproject.toml', 'rb') as f:
32+
print(tomllib.load(f)['project']['version'])
33+
")
34+
echo "version=$version" >> "$GITHUB_OUTPUT"
35+
if uv pip install --dry-run "river-client==$version" 2>/dev/null; then
36+
echo "skip=true" >> "$GITHUB_OUTPUT"
37+
else
38+
echo "skip=false" >> "$GITHUB_OUTPUT"
39+
fi
40+
41+
- name: Build and publish
42+
if: steps.check.outputs.skip == 'false'
43+
working-directory: python-client
44+
run: |
45+
uv build
46+
UV_PUBLISH_TOKEN="${{ secrets.PYPI_TOKEN }}" \
47+
uv publish
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
name: Release Drafter (Python)
2+
3+
on:
4+
workflow_dispatch: {}
5+
push:
6+
branches:
7+
- main
8+
paths:
9+
- 'python-client/**'
10+
pull_request:
11+
types: [opened, reopened, synchronize]
12+
paths:
13+
- 'python-client/**'
14+
pull_request_target:
15+
types: [opened, reopened, synchronize]
16+
paths:
17+
- 'python-client/**'
18+
19+
permissions:
20+
contents: read
21+
22+
jobs:
23+
update_release_draft:
24+
permissions:
25+
contents: write
26+
pull-requests: write
27+
runs-on: ubuntu-latest
28+
steps:
29+
- uses: release-drafter/release-drafter@v5
30+
with:
31+
config-name: release-drafter-python.yml
32+
env:
33+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

python-client/river/codec.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -84,35 +84,16 @@ def from_buffer(self, buf: bytes) -> tuple[bool, TransportMessage | str]:
8484
"""Deserialize bytes to a TransportMessage.
8585
8686
Returns (True, TransportMessage) on success, (False, error_reason) on failure.
87+
Validation of required fields and types is handled by
88+
:meth:`TransportMessage.from_dict`.
8789
"""
8890
try:
8991
raw = self._codec.from_buffer(buf)
9092
if not isinstance(raw, dict):
9193
return False, f"Expected dict, got {type(raw).__name__}"
92-
# Validate required fields
93-
required = ("id", "from", "to", "seq", "ack", "payload", "streamId")
94-
for f in required:
95-
if f not in raw:
96-
return False, f"Missing required field: {f}"
97-
# Validate field types to prevent downstream crashes
98-
if not isinstance(raw["seq"], int):
99-
return False, (
100-
f"Field 'seq' must be int, got {type(raw['seq']).__name__}"
101-
)
102-
if not isinstance(raw["ack"], int):
103-
return False, (
104-
f"Field 'ack' must be int, got {type(raw['ack']).__name__}"
105-
)
106-
if not isinstance(raw["id"], str):
107-
return False, (
108-
f"Field 'id' must be str, got {type(raw['id']).__name__}"
109-
)
110-
if not isinstance(raw["streamId"], str):
111-
return False, (
112-
f"Field 'streamId' must be str, "
113-
f"got {type(raw['streamId']).__name__}"
114-
)
11594
msg = TransportMessage.from_dict(raw)
11695
return True, msg
96+
except (KeyError, TypeError) as e:
97+
return False, str(e)
11798
except Exception as e:
11899
return False, f"Failed to deserialize message: {e}"

python-client/river/codegen/schema.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,16 @@ def convert(self, raw: dict) -> SchemaIR:
163163
self._typedicts = []
164164
self._id_to_name = {}
165165
services: list[ServiceDef] = []
166+
seen_modules: dict[str, str] = {} # sanitized name → wire name
166167
for svc_name, svc_data in raw.get("services", {}).items():
168+
module_name = _sanitize_identifier(svc_name)
169+
if module_name in seen_modules:
170+
raise ValueError(
171+
f"services {seen_modules[module_name]!r} and "
172+
f"{svc_name!r} both map to Python module "
173+
f"{module_name!r}_client.py"
174+
)
175+
seen_modules[module_name] = svc_name
167176
svc_def = self._convert_service(svc_name, svc_data)
168177
services.append(svc_def)
169178

@@ -172,8 +181,17 @@ def convert(self, raw: dict) -> SchemaIR:
172181
def _convert_service(self, name: str, data: dict) -> ServiceDef:
173182
class_name = _to_pascal_case(name)
174183
procedures: list[ProcedureDef] = []
184+
seen_py_names: dict[str, str] = {} # py_name → wire name
175185
for proc_name, proc_data in data.get("procedures", {}).items():
176186
proc_def = self._convert_procedure(class_name, proc_name, proc_data)
187+
if proc_def.py_name in seen_py_names:
188+
raise ValueError(
189+
f"service {name!r}: procedures "
190+
f"{seen_py_names[proc_def.py_name]!r} and "
191+
f"{proc_name!r} both map to Python method "
192+
f"{proc_def.py_name!r}"
193+
)
194+
seen_py_names[proc_def.py_name] = proc_name
177195
procedures.append(proc_def)
178196
return ServiceDef(
179197
name=name,

python-client/river/transport.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -157,17 +157,7 @@ def get_status(self) -> str:
157157

158158
def _get_loop(self) -> asyncio.AbstractEventLoop:
159159
if self._loop is None:
160-
try:
161-
self._loop = asyncio.get_running_loop()
162-
except RuntimeError:
163-
try:
164-
loop = asyncio.get_event_loop()
165-
if loop.is_closed():
166-
raise RuntimeError("closed")
167-
self._loop = loop
168-
except RuntimeError:
169-
self._loop = asyncio.new_event_loop()
170-
asyncio.set_event_loop(self._loop)
160+
self._loop = asyncio.get_running_loop()
171161
return self._loop
172162

173163
# --- Event API ---

python-client/river/types.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,42 @@ def to_dict(self) -> dict[str, Any]:
8888

8989
@classmethod
9090
def from_dict(cls, d: dict[str, Any]) -> TransportMessage:
91-
"""Deserialize from a wire format dict."""
91+
"""Deserialize from a wire format dict.
92+
93+
Raises ``TypeError`` if required fields have wrong types.
94+
"""
95+
required_str = {"id": "id", "from": "from", "streamId": "streamId"}
96+
for wire_key, label in required_str.items():
97+
if wire_key not in d:
98+
raise KeyError(f"Missing required field: {label}")
99+
if not isinstance(d[wire_key], str):
100+
raise TypeError(
101+
f"Field '{label}' must be str, "
102+
f"got {type(d[wire_key]).__name__}"
103+
)
104+
105+
required_int = {"seq": "seq", "ack": "ack"}
106+
for wire_key, label in required_int.items():
107+
if wire_key not in d:
108+
raise KeyError(f"Missing required field: {label}")
109+
if not isinstance(d[wire_key], int):
110+
raise TypeError(
111+
f"Field '{label}' must be int, "
112+
f"got {type(d[wire_key]).__name__}"
113+
)
114+
115+
if "to" not in d:
116+
raise KeyError("Missing required field: to")
117+
if "payload" not in d:
118+
raise KeyError("Missing required field: payload")
119+
120+
control_flags = d.get("controlFlags", 0)
121+
if not isinstance(control_flags, int):
122+
raise TypeError(
123+
f"Field 'controlFlags' must be int, "
124+
f"got {type(control_flags).__name__}"
125+
)
126+
92127
return cls(
93128
id=d["id"],
94129
from_=d["from"],
@@ -97,7 +132,7 @@ def from_dict(cls, d: dict[str, Any]) -> TransportMessage:
97132
ack=d["ack"],
98133
payload=d["payload"],
99134
stream_id=d["streamId"],
100-
control_flags=d.get("controlFlags", 0),
135+
control_flags=control_flags,
101136
service_name=d.get("serviceName"),
102137
procedure_name=d.get("procedureName"),
103138
tracing=d.get("tracing"),

python-client/tests/test_codegen.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,83 @@ def test_valid_schema_passes(self):
508508
assert [f.name for f in td.fields] == ["userId", "count"]
509509

510510

511+
class TestNameCollisions:
512+
"""Codegen detects and rejects name collisions."""
513+
514+
def test_procedure_name_collision_raises(self):
515+
"""Two procedures that map to the same snake_case name are rejected."""
516+
from river.codegen.schema import SchemaConverter
517+
518+
raw = {
519+
"services": {
520+
"svc": {
521+
"procedures": {
522+
"fooBar": {
523+
"type": "rpc",
524+
"init": {"type": "object", "properties": {}},
525+
"output": {"type": "object", "properties": {}},
526+
},
527+
"foo_bar": {
528+
"type": "rpc",
529+
"init": {"type": "object", "properties": {}},
530+
"output": {"type": "object", "properties": {}},
531+
},
532+
}
533+
}
534+
}
535+
}
536+
converter = SchemaConverter()
537+
with pytest.raises(ValueError, match="foo_bar"):
538+
converter.convert(raw)
539+
540+
def test_service_module_collision_raises(self):
541+
"""Two services that map to the same module name are rejected."""
542+
from river.codegen.schema import SchemaConverter
543+
544+
raw = {
545+
"services": {
546+
"foo-bar": {
547+
"procedures": {},
548+
},
549+
"foo_bar": {
550+
"procedures": {},
551+
},
552+
}
553+
}
554+
converter = SchemaConverter()
555+
with pytest.raises(ValueError, match="foo_bar"):
556+
converter.convert(raw)
557+
558+
def test_no_collision_passes(self):
559+
"""Distinct names that don't collide work fine."""
560+
from river.codegen.schema import SchemaConverter
561+
562+
raw = {
563+
"services": {
564+
"alpha": {
565+
"procedures": {
566+
"doX": {
567+
"type": "rpc",
568+
"init": {"type": "object", "properties": {}},
569+
"output": {"type": "object", "properties": {}},
570+
},
571+
"doY": {
572+
"type": "rpc",
573+
"init": {"type": "object", "properties": {}},
574+
"output": {"type": "object", "properties": {}},
575+
},
576+
}
577+
},
578+
"beta": {
579+
"procedures": {},
580+
},
581+
}
582+
}
583+
converter = SchemaConverter()
584+
ir = converter.convert(raw)
585+
assert len(ir.services) == 2
586+
587+
511588
# ---------------------------------------------------------------------------
512589
# Complex type tests
513590
# ---------------------------------------------------------------------------

python-client/tests/test_e2e.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1860,16 +1860,16 @@ def fake_inject(carrier):
18601860

18611861

18621862
class TestEagerConnectSync:
1863-
def test_eager_connect_does_not_raise_outside_loop(self):
1863+
def test_eager_connect_raises_outside_loop(self):
18641864
"""Constructing with eagerly_connect=True outside an event loop
1865-
should not raise RuntimeError."""
1865+
raises RuntimeError rather than silently binding to a dead loop."""
18661866
transport = WebSocketClientTransport(
18671867
ws_url="ws://127.0.0.1:1",
18681868
server_id="SERVER",
18691869
codec=BinaryCodec(),
18701870
)
1871-
# This used to raise "no running event loop"
1872-
RiverClient(transport, server_id="SERVER", eagerly_connect=True)
1871+
with pytest.raises(RuntimeError, match="no running event loop"):
1872+
RiverClient(transport, server_id="SERVER", eagerly_connect=True)
18731873

18741874

18751875
# =====================================================================

0 commit comments

Comments
 (0)