Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 8c6066e

Browse files
authored
feat: validate server message (#345)
* chore: introduce `ServerMessage` and `ClientMessage` types `ServerMessage` represents all the different message types that a server might send to a client `ClientMessage` represents all the different message types that a client might send to a server these are there to validate the payload types, enforcing the correctness of the payloads * fix: improve messages schema now all of messages should be valid * feat: introduce `ServerMessage` parsing now, instead of parsing and dispatching the messages inside the `AsyncRealtimeChannel._trigger` function, first parse the message using pydantic's `ServerMessageAdapter` type adapter, and then dispatch the callbacks inside `_handle_message`, based on which message was parsed. this required the removal of the `_on` function, which was used to both define callbacks to messages, but also to define callbacks to event refs, which was very confusing. this behavior was used precisely to define callbacks that are called on acks. now, AsyncPush'es are saved in AsyncRealtimeChannel.messages_waiting_for_ack, and later the callbacks registered by `AsyncPush.receive` (now saved directly in the push object) are called when a `phx_reply` type message is received. some logic was introduce to avoid this message buffer from growing forever indefinetly, by removing these messages from the dictionary once the callback is called, or once the timeout is reached.
1 parent 6e4154e commit 8c6066e

File tree

8 files changed

+574
-443
lines changed

8 files changed

+574
-443
lines changed

realtime/_async/channel.py

Lines changed: 172 additions & 223 deletions
Large diffs are not rendered by default.

realtime/_async/client.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
from urllib.parse import urlencode, urlparse, urlunparse
88

99
import websockets
10+
from pydantic import ValidationError
1011
from websockets import connect
1112
from websockets.asyncio.client import ClientConnection
1213

1314
from ..exceptions import NotConnectedError
14-
from ..message import Message
15+
from ..message import Message, ServerMessageAdapter
1516
from ..transformers import http_endpoint_url
1617
from ..types import (
1718
DEFAULT_HEARTBEAT_INTERVAL,
@@ -99,13 +100,14 @@ async def _listen(self) -> None:
99100
async for msg in self._ws_connection:
100101
logger.info(f"receive: {msg!r}")
101102

102-
message = Message.model_validate_json(msg)
103-
channel = self.channels.get(message.topic)
104-
105-
if channel:
106-
channel._trigger(
107-
message.event, dict(**message.payload), message.ref
108-
)
103+
try:
104+
message = ServerMessageAdapter.validate_json(msg)
105+
except ValidationError as e:
106+
logger.error(f"Unrecognized message format {msg!r}\n{e}")
107+
continue
108+
logger.info(f"parsed message as {message}")
109+
if channel := self.channels.get(message.topic):
110+
channel._handle_message(message)
109111
except websockets.exceptions.ConnectionClosedError as e:
110112
await self._on_connect_error(e)
111113

realtime/_async/presence.py

Lines changed: 37 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -21,26 +21,11 @@
2121

2222

2323
class AsyncRealtimePresence:
24-
def __init__(self, channel, opts: Optional[PresenceOpts] = None):
25-
self.channel = channel
24+
def __init__(self):
2625
self.state: RealtimePresenceState = {}
27-
self.pending_diffs: List[RawPresenceDiff] = []
28-
self.join_ref: Optional[str] = None
2926
self.on_join_callback: Optional[PresenceOnJoinCallback] = None
3027
self.on_leave_callback: Optional[PresenceOnLeaveCallback] = None
3128
self.on_sync_callback: Optional[Callable[[], None]] = None
32-
self.on_auth_success_callback: Optional[Callable[[], None]] = None
33-
self.on_auth_failure_callback: Optional[Callable[[], None]] = None
34-
# Initialize with default events if not provided
35-
events = (
36-
opts.events
37-
if opts
38-
else PresenceEvents(state="presence_state", diff="presence_diff")
39-
)
40-
# Set up event listeners for presence state and diff
41-
self.channel._on(events.state, callback=self._on_state_event)
42-
self.channel._on(events.diff, callback=self._on_diff_event)
43-
self.channel._on("phx_auth", callback=self._on_auth_event)
4429

4530
def on_join(self, callback: PresenceOnJoinCallback):
4631
self.on_join_callback = callback
@@ -51,67 +36,43 @@ def on_leave(self, callback: PresenceOnLeaveCallback):
5136
def on_sync(self, callback: Callable[[], None]):
5237
self.on_sync_callback = callback
5338

54-
def on_auth_success(self, callback: Callable[[], None]):
55-
self.on_auth_success_callback = callback
39+
def _on_state_event(self, payload: RawPresenceState):
40+
state = AsyncRealtimePresence._transform_state(payload)
41+
self.state = self._sync_state(state)
5642

57-
def on_auth_failure(self, callback: Callable[[], None]):
58-
self.on_auth_failure_callback = callback
59-
60-
def _on_state_event(self, payload: RawPresenceState, *args):
61-
self.join_ref = self.channel.join_ref
62-
self.state = self._sync_state(self.state, payload)
63-
64-
for diff in self.pending_diffs:
65-
self.state = self._sync_diff(self.state, diff)
66-
self.pending_diffs = []
6743
if self.on_sync_callback:
6844
self.on_sync_callback()
6945

70-
def _on_diff_event(self, payload: RawPresenceDiff, *args):
71-
if self.in_pending_sync_state():
72-
self.pending_diffs.append(payload)
73-
else:
74-
self.state = self._sync_diff(self.state, payload)
75-
if self.on_sync_callback:
76-
self.on_sync_callback()
77-
78-
def _on_auth_event(self, payload: Dict[str, Any], *args):
79-
if payload.get("status") == "ok":
80-
if self.on_auth_success_callback:
81-
self.on_auth_success_callback()
82-
else:
83-
if self.on_auth_failure_callback:
84-
self.on_auth_failure_callback()
46+
def _on_diff_event(self, payload: RawPresenceDiff):
47+
joins = AsyncRealtimePresence._transform_state(payload["joins"])
48+
leaves = AsyncRealtimePresence._transform_state(payload["leaves"])
49+
self.state = self._sync_diff(joins, leaves)
50+
if self.on_sync_callback:
51+
self.on_sync_callback()
8552

8653
def _sync_state(
8754
self,
88-
current_state: RealtimePresenceState,
89-
new_state: Union[RawPresenceState, RealtimePresenceState],
55+
new_state: RealtimePresenceState,
9056
) -> RealtimePresenceState:
91-
state = {key: list(value) for key, value in current_state.items()}
92-
transformed_state = AsyncRealtimePresence._transform_state(new_state)
93-
94-
joins: Dict[str, Any] = {}
95-
leaves: Dict[str, Any] = {
96-
k: v for k, v in state.items() if k not in transformed_state
97-
}
57+
joins = {}
58+
leaves = {k: v for k, v in self.state.items() if k not in new_state}
9859

99-
for key, value in transformed_state.items():
100-
current_presences = state.get(key, [])
60+
for key, value in new_state.items():
61+
current_presences = self.state.get(key, [])
10162

10263
if len(current_presences) > 0:
103-
new_presence_refs = {presence.get("presence_ref") for presence in value}
64+
new_presence_refs = {presence["presence_ref"] for presence in value}
10465
cur_presence_refs = {
105-
presence.get("presence_ref") for presence in current_presences
66+
presence["presence_ref"] for presence in current_presences
10667
}
10768

10869
joined_presences = [
109-
p for p in value if p.get("presence_ref") not in cur_presence_refs
70+
p for p in value if p["presence_ref"] not in cur_presence_refs
11071
]
11172
left_presences = [
11273
p
11374
for p in current_presences
114-
if p.get("presence_ref") not in new_presence_refs
75+
if p["presence_ref"] not in new_presence_refs
11576
]
11677

11778
if joined_presences:
@@ -121,19 +82,14 @@ def _sync_state(
12182
else:
12283
joins[key] = value
12384

124-
return self._sync_diff(state, {"joins": joins, "leaves": leaves})
85+
return self._sync_diff(joins, leaves)
12586

12687
def _sync_diff(
127-
self,
128-
state: RealtimePresenceState,
129-
diff: Union[RawPresenceDiff, PresenceDiff],
88+
self, joins: RealtimePresenceState, leaves: RealtimePresenceState
13089
) -> RealtimePresenceState:
131-
joins = AsyncRealtimePresence._transform_state(diff.get("joins", {}))
132-
leaves = AsyncRealtimePresence._transform_state(diff.get("leaves", {}))
133-
13490
for key, new_presences in joins.items():
135-
current_presences = state.get(key, [])
136-
state[key] = new_presences
91+
current_presences = self.state.get(key, [])
92+
self.state[key] = new_presences
13793

13894
if len(current_presences) > 0:
13995
joined_presence_refs = {
@@ -144,13 +100,13 @@ def _sync_diff(
144100
for presence in current_presences
145101
if presence.get("presence_ref") not in joined_presence_refs
146102
)
147-
state[key] = cur_presences + state[key]
103+
self.state[key] = cur_presences + self.state[key]
148104

149105
if self.on_join_callback:
150106
self.on_join_callback(key, current_presences, new_presences)
151107

152108
for key, left_presences in leaves.items():
153-
current_presences = state.get(key, [])
109+
current_presences = self.state.get(key, [])
154110

155111
if len(current_presences) == 0:
156112
break
@@ -163,22 +119,19 @@ def _sync_diff(
163119
for presence in current_presences
164120
if presence.get("presence_ref") not in presence_refs_to_remove
165121
]
166-
state[key] = current_presences
122+
self.state[key] = current_presences
167123

168124
if self.on_leave_callback:
169125
self.on_leave_callback(key, current_presences, left_presences)
170126

171127
if len(current_presences) == 0:
172-
del state[key]
173-
174-
return state
128+
del self.state[key]
175129

176-
def in_pending_sync_state(self) -> bool:
177-
return self.join_ref is None or self.join_ref != self.channel.join_ref
130+
return self.state
178131

179132
@staticmethod
180133
def _transform_state(
181-
state: Union[RawPresenceState, RealtimePresenceState],
134+
state: RawPresenceState,
182135
) -> RealtimePresenceState:
183136
"""
184137
Transform the raw presence state into a standardized RealtimePresenceState format.
@@ -216,16 +169,12 @@ def _transform_state(
216169
"""
217170
new_state: RealtimePresenceState = {}
218171
for key, presences in state.items():
219-
if isinstance(presences, dict) and "metas" in presences:
220-
new_state[key] = []
221-
222-
for presence in presences["metas"]:
223-
if "phx_ref_prev" in presence:
224-
del presence["phx_ref_prev"]
225-
new_presence: Presence = {"presence_ref": presence.pop("phx_ref")}
226-
new_presence.update(presence)
227-
new_state[key].append(new_presence)
228-
229-
else:
230-
new_state[key] = presences
172+
new_state[key] = []
173+
174+
for presence in presences["metas"]:
175+
if "phx_ref_prev" in presence:
176+
del presence["phx_ref_prev"]
177+
new_presence: Presence = {"presence_ref": presence.pop("phx_ref")}
178+
new_presence.update(presence)
179+
new_state[key].append(new_presence)
231180
return new_state

0 commit comments

Comments
 (0)