Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions chatsky/conditions/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
"""

import asyncio
from typing import Pattern, Union, List, cast
from typing import Literal, Pattern, Sequence, Union, List, cast
import logging
import re
from functools import cached_property

from pydantic import Field, computed_field, field_validator
from pydantic import Field, computed_field, field_validator, model_validator

from chatsky.core import BaseCondition, Context
from chatsky.core.message import Message, MessageInitTypes, CallbackQuery
Expand All @@ -27,7 +27,7 @@ class ExactMatch(BaseCondition):
"""
Check if :py:attr:`~.Context.last_request` matches :py:attr:`.match`.

If :py:attr:`.skip_none`, will not compare ``None`` fields of :py:attr:`.match`.
If :py:attr:`.skip_fields`, will allow skip matching the fields of :py:attr:`.match`.
"""

match: MessageInitTypes
Expand All @@ -36,26 +36,37 @@ class ExactMatch(BaseCondition):

Is initialized according to :py:data:`~.MessageInitTypes`.
"""
skip_none: bool = True
skip_fields: Sequence[Union[Literal["text", "attachments", "annotations", "misc", "origin"], str]] = Field(
default=["origin"]
)
"""
Whether fields set to ``None`` in :py:attr:`.match` should not be compared.
Listed fields should not be compared in :py:attr:`.match`.
"""

@field_validator("match", mode="before")
@classmethod
def validate_match(cls, value):
return Message.model_validate(value)

@model_validator(mode="after")
def skip_fields_validator(self):
extra_fields = set(self.skip_fields) - set(self.match.__dict__.keys())
if extra_fields:
raise ValueError(extra_fields)
else:
return self

async def call(self, ctx: Context) -> bool:
match: Message = cast(Message, self.match)

request = ctx.last_request
for field in match.model_fields:
match_value = match.__getattribute__(field)
if self.skip_none and match_value is None:
for field in match.__dict__:
if field in self.skip_fields:
continue
if field in request.model_fields.keys():
if request.__getattribute__(field) != match.__getattribute__(field):
match_value = match.__getattribute__(field)
if field in request.__dict__:
if request.__getattribute__(field) != match_value:
logger.debug(f"Request and match don't match in {field}")
return False
else:
return False
Expand Down
29 changes: 16 additions & 13 deletions chatsky/messengers/telegram/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,13 @@ def extract_message_from_telegram(self, update: TelegramMessage) -> Message:
"""

message = Message()
message.attachments = list()
temp_attachments = list()

message.text = update.text or update.caption
if update.location is not None:
message.attachments += [Location(latitude=update.location.latitude, longitude=update.location.longitude)]
temp_attachments += [Location(latitude=update.location.latitude, longitude=update.location.longitude)]
if update.contact is not None:
message.attachments += [
temp_attachments += [
Contact(
phone_number=update.contact.phone_number,
first_name=update.contact.first_name,
Expand All @@ -193,7 +193,7 @@ def extract_message_from_telegram(self, update: TelegramMessage) -> Message:
)
]
if update.invoice is not None:
message.attachments += [
temp_attachments += [
Invoice(
title=update.invoice.title,
description=update.invoice.description,
Expand All @@ -202,7 +202,7 @@ def extract_message_from_telegram(self, update: TelegramMessage) -> Message:
)
]
if update.poll is not None:
message.attachments += [
temp_attachments += [
Poll(
question=update.poll.question,
options=[PollOption(text=option.text, votes=option.voter_count) for option in update.poll.options],
Expand All @@ -216,7 +216,7 @@ def extract_message_from_telegram(self, update: TelegramMessage) -> Message:
)
]
if update.sticker is not None:
message.attachments += [
temp_attachments += [
Sticker(
id=update.sticker.file_id,
is_animated=update.sticker.is_animated,
Expand All @@ -230,7 +230,7 @@ def extract_message_from_telegram(self, update: TelegramMessage) -> Message:
if update.audio.thumbnail is not None
else None
)
message.attachments += [
temp_attachments += [
Audio(
id=update.audio.file_id,
file_unique_id=update.audio.file_unique_id,
Expand All @@ -247,7 +247,7 @@ def extract_message_from_telegram(self, update: TelegramMessage) -> Message:
if update.video.thumbnail is not None
else None
)
message.attachments += [
temp_attachments += [
Video(
id=update.video.file_id,
file_unique_id=update.video.file_unique_id,
Expand All @@ -265,7 +265,7 @@ def extract_message_from_telegram(self, update: TelegramMessage) -> Message:
if update.animation.thumbnail is not None
else None
)
message.attachments += [
temp_attachments += [
Animation(
id=update.animation.file_id,
file_unique_id=update.animation.file_unique_id,
Expand All @@ -278,7 +278,7 @@ def extract_message_from_telegram(self, update: TelegramMessage) -> Message:
)
]
if len(update.photo) > 0:
message.attachments += [
temp_attachments += [
Image(
id=picture.file_id,
file_unique_id=picture.file_unique_id,
Expand All @@ -293,7 +293,7 @@ def extract_message_from_telegram(self, update: TelegramMessage) -> Message:
if update.document.thumbnail is not None
else None
)
message.attachments += [
temp_attachments += [
Document(
id=update.document.file_id,
file_unique_id=update.document.file_unique_id,
Expand All @@ -303,7 +303,7 @@ def extract_message_from_telegram(self, update: TelegramMessage) -> Message:
)
]
if update.voice is not None:
message.attachments += [
temp_attachments += [
VoiceMessage(
id=update.voice.file_id,
file_unique_id=update.voice.file_unique_id,
Expand All @@ -316,14 +316,17 @@ def extract_message_from_telegram(self, update: TelegramMessage) -> Message:
if update.video_note.thumbnail is not None
else None
)
message.attachments += [
temp_attachments += [
VideoMessage(
id=update.video_note.file_id,
file_unique_id=update.video_note.file_unique_id,
thumbnail=thumbnail,
)
]

if temp_attachments:
message.attachments = temp_attachments

return message

async def cast_message_to_telegram_and_send(self, bot: ExtBot, chat_id: int, message: Message) -> None:
Expand Down
9 changes: 3 additions & 6 deletions tests/core/test_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,10 @@ def request_based_ctx(context_factory):
"condition,result",
[
(cnd.ExactMatch(match=Message(text="text", misc={"key": "value"})), True),
(cnd.ExactMatch(match=Message(text="text"), skip_none=True), True),
(cnd.ExactMatch(match=Message(text="text"), skip_none=False), False),
(cnd.ExactMatch(match="text", skip_none=True), True),
(cnd.ExactMatch(match=Message(text="smth"), skip_fields=["text", "misc"]), True),
(cnd.ExactMatch(match=Message(text="")), False),
(cnd.ExactMatch(match=Message(text="text", misc={"key": None})), False),
(cnd.ExactMatch(match=Message(), skip_none=True), True),
(cnd.ExactMatch(match={}, skip_none=True), True),
(cnd.ExactMatch(match=Message(text="text", misc={"key": None}), skip_fields=["misc"]), True),
(cnd.ExactMatch(match={}), False),
(cnd.ExactMatch(match=SubclassMessage(text="text", misc={"key": "value"}, additional_field="")), False),
],
)
Expand Down
Loading
Loading