Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
287 changes: 284 additions & 3 deletions crates/telegram/src/outbound.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ impl TelegramOutbound {
{
Ok(message) => Ok(message.id),
Err(e) => {
let plain_chunk = telegram_html_to_plain_text(chunk);
warn!(
account_id,
chat_id = to,
Expand All @@ -120,7 +121,7 @@ impl TelegramOutbound {
);
let message = self
.run_telegram_request_with_retry(account_id, to, "send message (plain)", || {
let mut plain_req = bot.send_message(chat_id, chunk);
let mut plain_req = bot.send_message(chat_id, &plain_chunk);
if silent {
plain_req = plain_req.disable_notification(true);
}
Expand Down Expand Up @@ -159,6 +160,7 @@ impl TelegramOutbound {
if is_message_not_modified_error(&e) {
return Ok(());
}
let plain_chunk = telegram_html_to_plain_text(chunk);
warn!(
account_id,
chat_id = to,
Expand All @@ -167,7 +169,7 @@ impl TelegramOutbound {
);
match self
.run_telegram_request_with_retry(account_id, to, "edit message (plain)", || {
let plain_req = bot.edit_message_text(chat_id, message_id, chunk);
let plain_req = bot.edit_message_text(chat_id, message_id, &plain_chunk);
async move { plain_req.await }
})
.await
Expand Down Expand Up @@ -255,6 +257,106 @@ fn is_message_not_modified_error(error: &RequestError) -> bool {
matches!(error, RequestError::Api(ApiError::MessageNotModified))
}

fn telegram_html_to_plain_text(html: &str) -> String {
let mut plain = String::with_capacity(html.len());
let mut chars = html.chars().peekable();

while let Some(ch) = chars.next() {
if ch == '<' {
let mut tag = String::from(ch);
let mut closed = false;
for next in chars.by_ref() {
tag.push(next);
if next == '>' {
closed = true;
break;
}
}

if !closed {
plain.push_str(&tag);
break;
}

let normalized = tag
.trim_start_matches('<')
.trim_end_matches('>')
.trim()
.to_ascii_lowercase();
if is_plain_text_line_break_tag(&normalized) && !plain.ends_with('\n') {
plain.push('\n');
}
continue;
}

if ch == '&' {
let mut entity = String::from(ch);
let mut terminated = false;
while let Some(&next) = chars.peek() {
entity.push(next);
chars.next();
if next == ';' {
terminated = true;
break;
}
if entity.len() > 12 {
break;
}
}

if terminated && let Some(decoded) = decode_html_entity(&entity) {
plain.push_str(&decoded);
continue;
}

plain.push_str(&entity);
continue;
}

plain.push(ch);
}

plain.trim().to_string()
}

fn is_plain_text_line_break_tag(tag: &str) -> bool {
let tag_name = tag
.trim_start_matches('/')
.trim_end_matches('/')
.split_whitespace()
.next()
.unwrap_or("");

matches!(tag_name, "blockquote" | "br" | "div" | "li" | "p" | "pre")
}

fn decode_html_entity(entity: &str) -> Option<String> {
match entity {
"&amp;" => Some("&".to_string()),
"&lt;" => Some("<".to_string()),
"&gt;" => Some(">".to_string()),
"&quot;" => Some("\"".to_string()),
"&apos;" | "&#39;" => Some("'".to_string()),
"&nbsp;" | "&#160;" => Some(" ".to_string()),
_ => decode_numeric_html_entity(entity),
}
}

fn decode_numeric_html_entity(entity: &str) -> Option<String> {
let value = entity
.strip_prefix("&#x")
.and_then(|hex| hex.strip_suffix(';'))
.and_then(|hex| u32::from_str_radix(hex, 16).ok())
.or_else(|| {
entity
.strip_prefix("&#")
.and_then(|decimal| decimal.strip_suffix(';'))
.and_then(|decimal| decimal.parse::<u32>().ok())
})?;

char::from_u32(value).map(|ch| ch.to_string())
}

trait RequestResultExt<T> {
fn channel_context(self, context: &'static str) -> Result<T>;
}
Expand Down Expand Up @@ -969,9 +1071,93 @@ impl ChannelStreamOutbound for TelegramOutbound {
mod tests {
use {
super::*,
std::{collections::HashMap, sync::Arc, time::Duration},
axum::{Json, Router, extract::State, http::StatusCode, routing::post},
moltis_channels::gating::DmPolicy,
secrecy::Secret,
serde::{Deserialize, Serialize},
std::{
collections::HashMap,
sync::{Arc, Mutex},
time::Duration,
},
tokio::sync::oneshot,
tokio_util::sync::CancellationToken,
};

use crate::{config::TelegramAccountConfig, otp::OtpState, state::AccountState};

#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
struct SendMessageRequest {
chat_id: i64,
text: String,
#[serde(default)]
parse_mode: Option<String>,
}

#[derive(Debug, Serialize)]
struct TelegramApiResponse {
ok: bool,
result: TelegramMessageResult,
}

#[derive(Debug, Serialize)]
struct TelegramMessageResult {
message_id: i64,
date: i64,
chat: TelegramChat,
text: String,
}

#[derive(Debug, Serialize)]
struct TelegramChat {
id: i64,
#[serde(rename = "type")]
chat_type: String,
}

#[derive(Clone)]
struct MockTelegramApi {
requests: Arc<Mutex<Vec<SendMessageRequest>>>,
}

async fn send_message_handler(
State(state): State<MockTelegramApi>,
Json(body): Json<SendMessageRequest>,
) -> (StatusCode, Json<serde_json::Value>) {
state
.requests
.lock()
.expect("lock requests")
.push(body.clone());

if body.parse_mode.as_deref() == Some("HTML") {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({
"ok": false,
"error_code": 400,
"description": "Bad Request: can't parse entities: unsupported start tag"
})),
);
}

(
StatusCode::OK,
Json(serde_json::json!(TelegramApiResponse {
ok: true,
result: TelegramMessageResult {
message_id: 1,
date: 0,
chat: TelegramChat {
id: body.chat_id,
chat_type: "private".to_string(),
},
text: body.text,
},
})),
)
}

#[tokio::test]
async fn send_location_unknown_account_returns_error() {
let accounts: AccountStateMap = Arc::new(std::sync::RwLock::new(HashMap::new()));
Expand Down Expand Up @@ -1004,6 +1190,22 @@ mod tests {
assert_eq!(retry_after_duration(&err), None);
}

#[test]
fn telegram_html_to_plain_text_strips_tags_and_decodes_entities() {
let plain = telegram_html_to_plain_text(
"<b>Hello</b> &amp; <i>world</i><br><code>&lt;ok&gt;</code>",
);

assert_eq!(plain, "Hello & world\n<ok>");
}

#[test]
fn telegram_html_to_plain_text_decodes_numeric_entities() {
let plain = telegram_html_to_plain_text("it&#39;s &#x1F642;");

assert_eq!(plain, "it's 🙂");
}

#[test]
fn is_message_not_modified_error_detects_variant() {
let err = RequestError::Api(ApiError::MessageNotModified);
Expand Down Expand Up @@ -1037,6 +1239,85 @@ mod tests {
));
}

#[tokio::test]
async fn send_html_fallback_sends_plain_text_without_raw_tags() {
let recorded_requests = Arc::new(Mutex::new(Vec::<SendMessageRequest>::new()));
let mock_api = MockTelegramApi {
requests: Arc::clone(&recorded_requests),
};
let app = Router::new()
.route("/{*path}", post(send_message_handler))
.with_state(mock_api);

let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.expect("bind test listener");
let addr = listener.local_addr().expect("local addr");
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
let server = tokio::spawn(async move {
axum::serve(listener, app)
.with_graceful_shutdown(async {
let _ = shutdown_rx.await;
})
.await
.expect("serve mock telegram api");
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;

let api_url = reqwest::Url::parse(&format!("http://{addr}/")).expect("parse api url");
let bot = Bot::new("test-token").set_api_url(api_url);

let accounts: AccountStateMap = Arc::new(std::sync::RwLock::new(HashMap::new()));
let outbound = Arc::new(TelegramOutbound {
accounts: Arc::clone(&accounts),
});
let account_id = "test-account";

{
let mut map = accounts.write().expect("accounts write lock");
map.insert(account_id.to_string(), AccountState {
bot: bot.clone(),
bot_username: Some("test_bot".to_string()),
account_id: account_id.to_string(),
config: TelegramAccountConfig {
token: Secret::new("test-token".to_string()),
dm_policy: DmPolicy::Open,
..Default::default()
},
outbound: Arc::clone(&outbound),
cancel: CancellationToken::new(),
message_log: None,
event_sink: None,
otp: Mutex::new(OtpState::new(300)),
});
}

outbound
.send_html(
account_id,
"42",
"<b>Hello</b> &amp; <i>world</i><br><code>&lt;ok&gt;</code>",
None,
)
.await
.expect("send html");

{
let requests = recorded_requests.lock().expect("requests lock");
assert_eq!(requests.len(), 2, "expected HTML send plus plain fallback");
assert_eq!(requests[0].parse_mode.as_deref(), Some("HTML"));
assert_eq!(
requests[0].text,
"<b>Hello</b> &amp; <i>world</i><br><code>&lt;ok&gt;</code>"
);
assert_eq!(requests[1].parse_mode, None);
assert_eq!(requests[1].text, "Hello & world\n<ok>");
}

let _ = shutdown_tx.send(());
server.await.expect("server join");
}

#[test]
fn stream_completion_notification_skips_when_already_notified_by_chunks() {
assert!(!should_send_stream_completion_notification(
Expand Down
Loading