Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions aiocqhttp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ async def handler(event):
self._sync_api = None
self._bus = EventBus()
self._before_sending_funcs = set()
self._on_wsr_receive_func = None
self._loop = None

self._server_app = Quart(import_name, **(server_app_kwargs or {}))
Expand Down Expand Up @@ -454,7 +455,8 @@ async def startup():

def on_websocket_connection(self, func: Callable) -> Callable:
"""
注册 WebSocket 连接元事件处理函数,等价于 ``on_meta_event('lifecycle.connect')``,例如:
注册 WebSocket 连接元事件处理函数,等价于 ``on_meta_event('lifecycle.connect')``
注意:若 OneBot 端不可信,请使用 bot.server_app.before_websocket

```py
@bot.on_websocket_connection
Expand All @@ -465,6 +467,25 @@ async def handler(event):
"""
return self.on_meta_event('lifecycle.connect')(func)

def on_websocket_receive(self, func: Callable) -> Callable:
"""
注册 WebSocket 调用 receive() 后对 payload 的后处理函数,
可用于对 WebSocket 每次上报内容的鉴权/验证
暂时仅支持注册1个

注:对于 HTTP 上报,你可以使用 bot.server_app.before_request;
对于 WebSocket 的首次连接,你可以使用 bot.server_app.before_websocket 验证其 Header

```py
@bot.on_websocket_receive
async def handler(payload: bytes):
return payload
"""
if self._on_wsr_receive_func:
raise RuntimeError("`on_websocket_receive` can only register once.")
self._on_wsr_receive_func = func
return func

async def _handle_http_event(self) -> Response:
if self._secret:
if 'X-Signature' not in request.headers:
Expand Down Expand Up @@ -517,8 +538,11 @@ async def _handle_wsr_event(self) -> None:
self._add_wsr_event_client()
try:
while True:
payload = await websocket.receive()
if self._on_wsr_receive_func:
payload = await self._on_wsr_receive_func(payload)
try:
payload = json.loads(await websocket.receive())
payload = json.loads(payload)
except ValueError:
payload = None

Expand All @@ -534,10 +558,14 @@ async def _handle_wsr_api(self) -> None:
self._add_wsr_api_client()
try:
while True:
payload = await websocket.receive()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

突然一下有点迷惑了,这里返回的是 str 还是 bytes,下面用的是 json.loads,但 _on_wsr_receive_func 的参数却是 bytes

另外就是,_on_wsr_receive_func 如何表达“payload 不合法”呢?是不是 563 行调用的时候接个异常、或检查是否为 None 会比较好;或者干脆让 _on_wsr_receive_func 不能修改 payload,只需要返回 bool 就行了(为什么会需要修改 payload 呢)。

on_websocket_receive 注释里应该写 _on_wsr_receive_func 中判断到不合法 payload 报错的例子,毕竟这才是它真正的用处,而不是直接原样返回的例子。

Copy link
Member

@stdrc stdrc Sep 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对了,docs/changelog.md 也需要改下,小标题先用 master

if self._on_wsr_receive_func:
payload = await self._on_wsr_receive_func(payload)
try:
ResultStore.add(json.loads(await websocket.receive()))
payload = json.loads(payload)
except ValueError:
pass
ResultStore.add(payload)
finally:
self._remove_wsr_api_client()

Expand All @@ -546,8 +574,11 @@ async def _handle_wsr_universal(self) -> None:
self._add_wsr_event_client()
try:
while True:
payload = await websocket.receive()
if self._on_wsr_receive_func:
payload = await self._on_wsr_receive_func(payload)
try:
payload = json.loads(await websocket.receive())
payload = json.loads(payload)
except ValueError:
payload = None

Expand Down