diff --git a/rocketchat_API/APISections/chat.py b/rocketchat_API/APISections/chat.py index 1df28c0..9544f7a 100644 --- a/rocketchat_API/APISections/chat.py +++ b/rocketchat_API/APISections/chat.py @@ -5,21 +5,56 @@ class RocketChatChat(RocketChatBase): def chat_post_message(self, text, room_id=None, channel=None, **kwargs): """Posts a new chat message.""" - if room_id: - if text: - return self.call_api_post( - "chat.postMessage", roomId=room_id, text=text, kwargs=kwargs - ) - return self.call_api_post("chat.postMessage", roomId=room_id, kwargs=kwargs) + if not (channel or room_id): + raise RocketMissingParamException("roomId or channel required") + + text = "" if text is None else text + + payload = {"text": text, **kwargs} if channel: - if text: - return self.call_api_post( - "chat.postMessage", channel=channel, text=text, kwargs=kwargs - ) - return self.call_api_post( - "chat.postMessage", channel=channel, kwargs=kwargs - ) - raise RocketMissingParamException("roomId or channel required") + payload["channel"] = channel + if room_id: + payload["roomId"] = room_id + + self._sanitize_payload_text_fields(payload) + + return self.call_api_post("chat.postMessage", kwargs=payload) + + @staticmethod + def _sanitize_text(value: str) -> str: + """Return value with common double-escaped control sequences normalized.""" + if not isinstance(value, str): + return value + return ( + value.replace("\\n", "\n") + .replace("\\r", "\r") + .replace("\\t", "\t") + .replace("\\b", "\b") + .replace("\\f", "\f") + ) + + @staticmethod + def _sanitize_payload_text_fields(payload: dict) -> dict: + """ + In-place sanitize of typical text fields in chat payloads: + - payload["text"] + - payload["attachments"][i]["text"] + + Returns the mutated payload for convenience. + """ + if not isinstance(payload, dict): + return payload + + if "text" in payload and isinstance(payload["text"], str): + payload["text"] = RocketChatChat._sanitize_text(payload["text"]) + + attachments = payload.get("attachments") + if isinstance(attachments, list): + for att in attachments: + if isinstance(att, dict) and isinstance(att.get("text"), str): + att["text"] = RocketChatChat._sanitize_text(att["text"]) + + return payload def chat_send_message(self, message): if "rid" in message: diff --git a/tests/test_chat.py b/tests/test_chat.py index 1dc3684..eb39082 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -227,3 +227,41 @@ def test_chat_get_mentioned_messages(logged_rocket): assert "messages" in chat_get_mentioned_messages assert len(chat_get_mentioned_messages.get("messages")) > 0 assert chat_get_mentioned_messages.get("messages")[0].get("msg") == "hello @user1" + + +def test_chat_post_sanitizes_text_and_attachments(logged_rocket): + # message text contains double-escaped sequences + msg_text = "hello\\nworld\\t!" + att_text = "line1\\nline2" + + chat_post_message = logged_rocket.chat_post_message( + msg_text, + channel="GENERAL", + attachments=[{"color": "#00ff00", "text": att_text}], + ).json() + + assert chat_post_message.get("success") + mid = chat_post_message["message"]["_id"] + + # retrieve and verify that sequences became real control chars + got = logged_rocket.chat_get_message(msg_id=mid).json() + assert got.get("success") + + posted = got["message"] + assert posted["msg"] == "hello\nworld\t!" + assert posted["attachments"][0]["text"] == "line1\nline2" + + +def test_chat_update_sanitizes_text(logged_rocket): + # create a message first + mid = ( + logged_rocket.chat_post_message("seed", channel="GENERAL") + .json()["message"]["_id"] + ) + # update with escaped content + upd = logged_rocket.chat_update( + room_id="GENERAL", msg_id=mid, text="foo\\nbar" + ).json() + assert upd.get("success") + assert upd["message"]["msg"] == "foo\nbar" +