diff --git a/aiocqhttp/__init__.py b/aiocqhttp/__init__.py index 94827d9..1a02801 100644 --- a/aiocqhttp/__init__.py +++ b/aiocqhttp/__init__.py @@ -130,6 +130,7 @@ async def handler(event): self._sync_api = None self._bus = EventBus() self._before_sending_funcs = set() + self._on_wsr_receive_func = None self._loop = None self._server_app = Quart(import_name, **(server_app_kwargs or {})) @@ -454,7 +455,8 @@ async def startup(): def on_websocket_connection(self, func: Callable) -> Callable: """ - 注册 WebSocket 连接元事件处理函数,等价于 ``on_meta_event('lifecycle.connect')``,例如: + 注册 WebSocket 连接元事件处理函数,等价于 ``on_meta_event('lifecycle.connect')`` + 注意:若 OneBot 端不可信,请使用 bot.server_app.before_websocket ```py @bot.on_websocket_connection @@ -465,6 +467,25 @@ async def handler(event): """ return self.on_meta_event('lifecycle.connect')(func) + def on_websocket_receive(self, func: Callable) -> Callable: + """ + 注册 WebSocket 调用 receive() 后对 payload 的后处理函数, + 可用于对 WebSocket 每次上报内容的鉴权/验证 + 暂时仅支持注册1个 + + 注:对于 HTTP 上报,你可以使用 bot.server_app.before_request; + 对于 WebSocket 的首次连接,你可以使用 bot.server_app.before_websocket 验证其 Header + + ```py + @bot.on_websocket_receive + async def handler(payload: bytes): + return payload + """ + if self._on_wsr_receive_func: + raise RuntimeError("`on_websocket_receive` can only register once.") + self._on_wsr_receive_func = func + return func + async def _handle_http_event(self) -> Response: if self._secret: if 'X-Signature' not in request.headers: @@ -517,8 +538,11 @@ async def _handle_wsr_event(self) -> None: self._add_wsr_event_client() try: while True: + payload = await websocket.receive() + if self._on_wsr_receive_func: + payload = await self._on_wsr_receive_func(payload) try: - payload = json.loads(await websocket.receive()) + payload = json.loads(payload) except ValueError: payload = None @@ -534,10 +558,14 @@ async def _handle_wsr_api(self) -> None: self._add_wsr_api_client() try: while True: + payload = await websocket.receive() + if self._on_wsr_receive_func: + payload = await self._on_wsr_receive_func(payload) try: - ResultStore.add(json.loads(await websocket.receive())) + payload = json.loads(payload) except ValueError: pass + ResultStore.add(payload) finally: self._remove_wsr_api_client() @@ -546,8 +574,11 @@ async def _handle_wsr_universal(self) -> None: self._add_wsr_event_client() try: while True: + payload = await websocket.receive() + if self._on_wsr_receive_func: + payload = await self._on_wsr_receive_func(payload) try: - payload = json.loads(await websocket.receive()) + payload = json.loads(payload) except ValueError: payload = None