Skip to content

Commit 630f0e8

Browse files
committed
more fixes
1 parent 41938fc commit 630f0e8

File tree

19 files changed

+369
-380
lines changed

19 files changed

+369
-380
lines changed

.github/workflows/ci.yaml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,24 @@ jobs:
6060

6161
- name: Install Python dependencies
6262
working-directory: python-client
63-
run: pip install -e ".[dev]"
63+
run: pip install -e ".[dev]" ty
6464

6565
- name: Python lint
6666
working-directory: python-client
6767
run: |
6868
ruff check .
6969
ruff format --check .
7070
71+
- name: Python type check
72+
working-directory: python-client
73+
run: ty check river/
74+
7175
- name: Python tests
7276
working-directory: python-client
7377
run: python -m pytest tests/ -v
78+
79+
- name: Python type check generated clients
80+
working-directory: python-client
81+
run: |
82+
ty check tests/generated/
83+
ty check tests/test_codegen.py

python-client/pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,14 @@ testpaths = ["tests"]
2929

3030
[tool.ruff]
3131
target-version = "py310"
32+
exclude = ["tests/generated"]
3233

3334
[tool.ruff.lint]
3435
select = ["E", "F", "I", "W"]
3536

37+
[tool.ty.environment]
38+
extra-paths = ["tests"]
39+
3640
[tool.setuptools.packages.find]
3741
include = ["river*"]
3842

python-client/river/__init__.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,28 @@
1-
"""River protocol v2.0 Python client implementation."""
1+
"""River protocol v2.0 Python client implementation.
22
3-
from river.client import RiverClient
3+
This client was generated with the assistance of AI (Claude).
4+
"""
5+
6+
from river.client import (
7+
ErrResult,
8+
OkResult,
9+
RiverClient,
10+
StreamResult,
11+
SubscriptionResult,
12+
UploadResult,
13+
)
414
from river.codec import BinaryCodec, NaiveJsonCodec
515
from river.streams import Readable, Writable
616
from river.transport import WebSocketClientTransport
717
from river.types import Err, Ok, TransportMessage
818

919
__all__ = [
1020
"RiverClient",
21+
"OkResult",
22+
"ErrResult",
23+
"StreamResult",
24+
"UploadResult",
25+
"SubscriptionResult",
1126
"WebSocketClientTransport",
1227
"NaiveJsonCodec",
1328
"BinaryCodec",

python-client/river/client.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
import asyncio
1010
import logging
1111
from dataclasses import dataclass
12-
from typing import Any, Callable
12+
from typing import Any, Callable, Generic, Literal, TypeVar
13+
14+
from typing_extensions import TypedDict
1315

1416
from river.streams import Readable, Writable
1517
from river.transport import WebSocketClientTransport
@@ -29,28 +31,43 @@
2931

3032
logger = logging.getLogger(__name__)
3133

34+
T = TypeVar("T")
35+
TPayload = TypeVar("TPayload")
3236

33-
@dataclass
34-
class RpcResult:
35-
"""Result of an RPC call."""
3637

37-
ok: bool
38-
payload: Any
38+
class OkResult(TypedDict, Generic[TPayload]):
39+
"""Successful result from a procedure call."""
40+
41+
ok: Literal[True]
42+
payload: TPayload
43+
44+
45+
class ErrResult(TypedDict, Generic[TPayload]):
46+
"""Error result from a procedure call."""
47+
48+
ok: Literal[False]
49+
payload: TPayload
3950

4051

4152
@dataclass
42-
class StreamResult:
43-
"""Result of opening a stream procedure."""
53+
class StreamResult(Generic[T]):
54+
"""Result of opening a stream procedure.
4455
45-
req_writable: Writable
56+
Generic over the input type ``T`` written to ``req_writable``.
57+
"""
58+
59+
req_writable: Writable[T]
4660
res_readable: Readable
4761

4862

4963
@dataclass
50-
class UploadResult:
51-
"""Result of opening an upload procedure."""
64+
class UploadResult(Generic[T]):
65+
"""Result of opening an upload procedure.
5266
53-
req_writable: Writable
67+
Generic over the input type ``T`` written to ``req_writable``.
68+
"""
69+
70+
req_writable: Writable[T]
5471
finalize: Callable[[], Any] # async callable returning RpcResult
5572

5673

@@ -113,7 +130,7 @@ async def rpc(
113130
procedure_name: str,
114131
init: Any,
115132
abort_signal: asyncio.Event | None = None,
116-
) -> dict[str, Any]:
133+
) -> Any:
117134
"""Invoke an RPC procedure.
118135
119136
Returns the result dict: {"ok": True/False, "payload": ...}

python-client/river/codec.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,18 +78,18 @@ class BinaryCodec(Codec):
7878
name = "binary"
7979

8080
def to_buffer(self, obj: dict[str, Any]) -> bytes:
81-
import msgpack # type: ignore[import-untyped]
81+
import msgpack
8282

8383
return msgpack.packb(obj, use_bin_type=True, default=self._ext_encode)
8484

8585
def from_buffer(self, buf: bytes) -> dict[str, Any]:
86-
import msgpack # type: ignore[import-untyped]
86+
import msgpack
8787

8888
return msgpack.unpackb(buf, raw=False, ext_hook=self._ext_decode)
8989

9090
@staticmethod
9191
def _ext_encode(obj: Any) -> Any:
92-
import msgpack # type: ignore[import-untyped]
92+
import msgpack
9393

9494
if isinstance(obj, int) and (obj > _MSGPACK_INT_MAX or obj < _MSGPACK_INT_MIN):
9595
# Encode as string in extension type 0 (matches TS BigInt ext)
@@ -99,7 +99,7 @@ def _ext_encode(obj: Any) -> Any:
9999

100100
@staticmethod
101101
def _ext_decode(code: int, data: bytes) -> Any:
102-
import msgpack # type: ignore[import-untyped]
102+
import msgpack
103103

104104
if code == _BIGINT_EXT_TYPE:
105105
val = msgpack.unpackb(data, raw=False)

python-client/river/codegen/emitter.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,19 @@
2929
_env.filters["pascal"] = _to_pascal_case
3030

3131

32+
def _result_type(proc) -> str: # noqa: ANN001
33+
"""Build the typed result annotation for a procedure."""
34+
ok = f"OkResult[{proc.output_type.annotation}]"
35+
if proc.error_type:
36+
err = f"ErrResult[{proc.error_type.annotation} | ProtocolError]"
37+
else:
38+
err = "ErrResult[ProtocolError]"
39+
return f"{ok} | {err}"
40+
41+
42+
_env.filters["result_type"] = _result_type
43+
44+
3245
# ---------------------------------------------------------------------------
3346
# Helpers
3447
# ---------------------------------------------------------------------------
@@ -51,6 +64,9 @@ def _collect_used_type_names(svc: ServiceDef, ir: SchemaIR) -> list[str]:
5164
_extract_names(proc.init_type.annotation, td_names, names)
5265
if proc.input_type:
5366
_extract_names(proc.input_type.annotation, td_names, names)
67+
_extract_names(proc.output_type.annotation, td_names, names)
68+
if proc.error_type:
69+
_extract_names(proc.error_type.annotation, td_names, names)
5470

5571
return sorted(names)
5672

@@ -110,22 +126,20 @@ def render_service_client(svc: ServiceDef, ir: SchemaIR, import_prefix: str) ->
110126
type_names = _collect_used_type_names(svc, ir)
111127
types_module = "._types" if import_prefix == "." else f"{import_prefix}_types"
112128

113-
needs_readable = any(
114-
p.proc_type in ("stream", "subscription") for p in svc.procedures
115-
)
116-
needs_writable = any(p.proc_type in ("stream", "upload") for p in svc.procedures)
117-
118-
wrappers = [
119-
p for p in svc.procedures if p.proc_type in ("stream", "upload", "subscription")
120-
]
129+
proc_types = {p.proc_type for p in svc.procedures}
130+
has_rpc = "rpc" in proc_types
131+
has_stream = "stream" in proc_types
132+
has_upload = "upload" in proc_types
133+
has_subscription = "subscription" in proc_types
121134

122135
return _env.get_template("service_client.py.j2").render(
123136
service=svc,
124137
type_names=type_names,
125138
types_module=types_module,
126-
needs_readable=needs_readable,
127-
needs_writable=needs_writable,
128-
wrappers=wrappers,
139+
has_rpc=has_rpc,
140+
has_stream=has_stream,
141+
has_upload=has_upload,
142+
has_subscription=has_subscription,
129143
)
130144

131145

python-client/river/codegen/schema.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,23 @@ def _to_snake_case(s: str) -> str:
118118

119119

120120
def _safe_field_name(name: str) -> str:
121-
"""Ensure a field name is a valid Python identifier."""
122-
name = _sanitize_identifier(name)
121+
"""Ensure a field name is a valid Python identifier.
122+
123+
Raises ValueError if the name requires sanitization that would
124+
change it from its wire representation, since TypedDict keys must
125+
match the dict keys sent on the wire.
126+
"""
127+
sanitized = _sanitize_identifier(name)
128+
if sanitized != name:
129+
raise ValueError(
130+
f"schema property {name!r} is not a valid Python identifier "
131+
f"and cannot be represented in a TypedDict"
132+
)
123133
if keyword.iskeyword(name):
124-
return name + "_"
134+
raise ValueError(
135+
f"schema property {name!r} is a Python keyword "
136+
f"and cannot be used as a TypedDict field"
137+
)
125138
return name
126139

127140

0 commit comments

Comments
 (0)