Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 165 additions & 8 deletions crates/openfang-channels/src/matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, watch, RwLock};
use tracing::{info, warn};
use tracing::{debug, info, warn};
use zeroize::Zeroizing;

const SYNC_TIMEOUT_MS: u64 = 30000;
Expand All @@ -35,6 +35,8 @@ pub struct MatrixAdapter {
shutdown_rx: watch::Receiver<bool>,
/// Sync token for resuming /sync.
since_token: Arc<RwLock<Option<String>>>,
/// Whether to auto-accept room invites.
auto_accept_invites: bool,
}

impl MatrixAdapter {
Expand All @@ -44,6 +46,7 @@ impl MatrixAdapter {
user_id: String,
access_token: String,
allowed_rooms: Vec<String>,
auto_accept_invites: bool,
) -> Self {
let (shutdown_tx, shutdown_rx) = watch::channel(false);
Self {
Expand All @@ -55,6 +58,7 @@ impl MatrixAdapter {
shutdown_tx: Arc::new(shutdown_tx),
shutdown_rx,
since_token: Arc::new(RwLock::new(None)),
auto_accept_invites,
}
}

Expand Down Expand Up @@ -116,12 +120,84 @@ impl MatrixAdapter {
Ok(user_id)
}

#[allow(dead_code)]
#[cfg(test)]
fn is_allowed_room(&self, room_id: &str) -> bool {
self.allowed_rooms.is_empty() || self.allowed_rooms.iter().any(|r| r == room_id)
}
}

/// Accept a room invite by calling POST /_matrix/client/v3/rooms/{room_id}/join.
async fn accept_invite(
client: &reqwest::Client,
homeserver: &str,
access_token: &str,
room_id: &str,
) {
let url = format!("{homeserver}/_matrix/client/v3/rooms/{room_id}/join");
match client
.post(&url)
.bearer_auth(access_token)
.json(&serde_json::json!({}))
.send()
.await
{
Ok(resp) if resp.status().is_success() => {
info!("Matrix: auto-accepted invite to {room_id}");
}
Ok(resp) => {
let status = resp.status();
warn!("Matrix: failed to accept invite to {room_id}: {status}");
}
Err(e) => {
warn!("Matrix: error accepting invite to {room_id}: {e}");
}
}
}

/// Get the number of joined members in a room.
async fn get_room_member_count(
client: &reqwest::Client,
homeserver: &str,
access_token: &str,
room_id: &str,
) -> Option<usize> {
let url = format!("{homeserver}/_matrix/client/v3/rooms/{room_id}/joined_members");
let resp = client
.get(&url)
.bearer_auth(access_token)
.send()
.await
.ok()?;
if !resp.status().is_success() {
return None;
}
let body: serde_json::Value = resp.json().await.ok()?;
body["joined"].as_object().map(|m| m.len())
}

/// Do an initial /sync with timeout=0 to get the since token without processing events.
/// This prevents replaying old messages when the adapter first connects.
async fn initial_sync(
client: &reqwest::Client,
homeserver: &str,
access_token: &str,
) -> Option<String> {
let url = format!(
"{homeserver}/_matrix/client/v3/sync?timeout=0&filter={{\"room\":{{\"timeline\":{{\"limit\":0}}}}}}"
);
let resp = client
.get(&url)
.bearer_auth(access_token)
.send()
.await
.ok()?;
if !resp.status().is_success() {
return None;
}
let body: serde_json::Value = resp.json().await.ok()?;
body["next_batch"].as_str().map(String::from)
}

#[async_trait]
impl ChannelAdapter for MatrixAdapter {
fn name(&self) -> &str {
Expand All @@ -143,14 +219,32 @@ impl ChannelAdapter for MatrixAdapter {
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
let homeserver = self.homeserver_url.clone();
let access_token = self.access_token.clone();
let user_id = self.user_id.clone();
// Use the validated user ID from /whoami instead of the config value.
// Matrix server delegation or casing differences can cause self.user_id
// to not match the sender field in timeline events, making the bot
// process its own replies in an infinite loop (see #757).
let user_id = validated_user;
let allowed_rooms = self.allowed_rooms.clone();
let client = self.client.clone();
let since_token = Arc::clone(&self.since_token);
let mut shutdown_rx = self.shutdown_rx.clone();
let auto_accept = self.auto_accept_invites;

// FIX #4: Do an initial sync to get the since token, skipping old messages.
if since_token.read().await.is_none() {
if let Some(token) = initial_sync(&client, &homeserver, access_token.as_str()).await {
info!("Matrix: initial sync complete, skipping old messages");
*since_token.write().await = Some(token);
}
}

tokio::spawn(async move {
let mut backoff = Duration::from_secs(1);
// Track recently seen event IDs to prevent duplicate processing
// on sync token races or reconnects.
let mut seen_events: std::collections::HashSet<String> =
std::collections::HashSet::new();
const MAX_SEEN: usize = 500;

loop {
// Build /sync URL
Expand All @@ -168,7 +262,7 @@ impl ChannelAdapter for MatrixAdapter {
info!("Matrix adapter shutting down");
break;
}
result = client.get(&url).bearer_auth(&*access_token).send() => {
result = client.get(&url).bearer_auth(access_token.as_str()).send() => {
match result {
Ok(r) => r,
Err(e) => {
Expand Down Expand Up @@ -203,6 +297,24 @@ impl ChannelAdapter for MatrixAdapter {
*since_token.write().await = Some(next.to_string());
}

// FIX #1: Auto-accept room invites.
if auto_accept {
if let Some(invites) = body["rooms"]["invite"].as_object() {
for (room_id, _invite_data) in invites {
if !allowed_rooms.is_empty()
&& !allowed_rooms.iter().any(|r| r == room_id)
{
debug!(
"Matrix: ignoring invite to {room_id} (not in allowed_rooms)"
);
continue;
}
accept_invite(&client, &homeserver, access_token.as_str(), room_id)
.await;
}
}
}

// Process room events
if let Some(rooms) = body["rooms"]["join"].as_object() {
for (room_id, room_data) in rooms {
Expand All @@ -223,6 +335,21 @@ impl ChannelAdapter for MatrixAdapter {
continue; // Skip own messages
}

// Dedup: skip events we've already processed.
let event_id_str =
event["event_id"].as_str().unwrap_or("").to_string();
if !event_id_str.is_empty() {
if seen_events.contains(&event_id_str) {
debug!("Matrix: skipping duplicate event {event_id_str}");
continue;
}
seen_events.insert(event_id_str.clone());
// Prevent unbounded growth
if seen_events.len() > MAX_SEEN {
seen_events.clear();
}
}

let content = event["content"]["body"].as_str().unwrap_or("");
if content.is_empty() {
continue;
Expand All @@ -243,11 +370,38 @@ impl ChannelAdapter for MatrixAdapter {
ChannelContent::Text(content.to_string())
};

let event_id = event["event_id"].as_str().unwrap_or("").to_string();
// FIX #2: Detect @mentions in message text.
let mut metadata = HashMap::new();
if content.contains(&user_id) {
metadata.insert(
"was_mentioned".to_string(),
serde_json::json!(true),
);
}

// FIX #3: Determine if room is a DM (2 members) or group.
let is_group = get_room_member_count(
&client,
&homeserver,
access_token.as_str(),
room_id,
)
.await
.map(|count| count > 2)
.unwrap_or(true);

// For DMs, auto-set was_mentioned so dm_policy works.
if !is_group {
metadata.insert(
"was_mentioned".to_string(),
serde_json::json!(true),
);
metadata.insert("is_dm".to_string(), serde_json::json!(true));
}

let channel_msg = ChannelMessage {
channel: ChannelType::Matrix,
platform_message_id: event_id,
platform_message_id: event_id_str,
sender: ChannelUser {
platform_id: room_id.clone(),
display_name: sender.to_string(),
Expand All @@ -256,9 +410,9 @@ impl ChannelAdapter for MatrixAdapter {
content: msg_content,
target_agent: None,
timestamp: Utc::now(),
is_group: true,
is_group,
thread_id: None,
metadata: HashMap::new(),
metadata,
};

if tx.send(channel_msg).await.is_err() {
Expand Down Expand Up @@ -330,6 +484,7 @@ mod tests {
"@bot:matrix.org".to_string(),
"access_token".to_string(),
vec![],
false,
);
assert_eq!(adapter.name(), "matrix");
}
Expand All @@ -341,6 +496,7 @@ mod tests {
"@bot:matrix.org".to_string(),
"token".to_string(),
vec!["!room1:matrix.org".to_string()],
false,
);
assert!(adapter.is_allowed_room("!room1:matrix.org"));
assert!(!adapter.is_allowed_room("!room2:matrix.org"));
Expand All @@ -350,6 +506,7 @@ mod tests {
"@bot:matrix.org".to_string(),
"token".to_string(),
vec![],
false,
);
assert!(open.is_allowed_room("!any:matrix.org"));
}
Expand Down