From 402c9a046741562650a2ed421b777ff3cace02f2 Mon Sep 17 00:00:00 2001 From: chocy Date: Sat, 21 Mar 2026 18:30:30 +0100 Subject: [PATCH] fix(sync): stop prior long-poll on reconnection --- src/api/client/sync/v5.rs | 11 +++++----- src/service/sync/mod.rs | 42 ++++++++++++++++++++++++++------------- 2 files changed, 33 insertions(+), 20 deletions(-) diff --git a/src/api/client/sync/v5.rs b/src/api/client/sync/v5.rs index e1d46d26d..634c11c6b 100644 --- a/src/api/client/sync/v5.rs +++ b/src/api/client/sync/v5.rs @@ -101,7 +101,7 @@ pub(crate) async fn sync_events_v5_route( .unwrap_or(0); let conn_key = into_connection_key(sender_user, sender_device, request.conn_id.as_deref()); - let conn_val = services + let (conn_val, superseded) = services .sync .load_or_init_connection(&conn_key) .await; @@ -196,11 +196,10 @@ pub(crate) async fn sync_events_v5_route( if timeout == 0 || services.server.is_stopping() - || timeout_at(stop_at, watchers) - .boxed() - .await - .is_err() - { + || tokio::select! { + () = superseded.notified() => true, + result = timeout_at(stop_at, watchers).boxed() => result.is_err(), + } { response.pos = conn.next_batch.to_string().into(); trace!(conn.globalsince, conn.next_batch, "timeout; empty response {response:?}"); conn.store(&services.sync, &conn_key); diff --git a/src/service/sync/mod.rs b/src/service/sync/mod.rs index 13264a6b1..33c6fa9c9 100644 --- a/src/service/sync/mod.rs +++ b/src/service/sync/mod.rs @@ -14,7 +14,7 @@ use ruma::{ }, }; use serde::{Deserialize, Serialize}; -use tokio::sync::Mutex as TokioMutex; +use tokio::sync::{Mutex as TokioMutex, Notify}; use tuwunel_core::{Result, at, debug, err, implement, is_equal_to, utils::stream::TryIgnore}; use tuwunel_database::{Cbor, Deserialized, Map}; @@ -57,9 +57,10 @@ pub struct Room { pub roomsince: u64, } -type Connections = TokioMutex>; pub type ConnectionVal = Arc>; +pub type ConnectionEntry = (ConnectionVal, Arc); pub type ConnectionKey = (OwnedUserId, Option, Option); +type Connections = TokioMutex>; pub type Subscriptions = BTreeMap; pub type Lists = BTreeMap; @@ -102,22 +103,22 @@ pub async fn clear_connections( device_id: Option<&DeviceId>, conn_id: Option<&ConnectionId>, ) { - self.connections - .lock() - .await - .retain(|(conn_user_id, conn_device_id, conn_conn_id), _| { + self.connections.lock().await.retain( + |(conn_user_id, conn_device_id, conn_conn_id), (_, notify)| { let retain = user_id.is_none_or(is_equal_to!(conn_user_id)) && (device_id.is_none() || device_id == conn_device_id.as_deref()) && (conn_id.is_none() || conn_id == conn_conn_id.as_ref()); if !retain { + notify.notify_waiters(); self.db .userdeviceconnid_conn .del((conn_user_id, conn_device_id, conn_conn_id)); } retain - }); + }, + ); } #[implement(Service)] @@ -126,16 +127,24 @@ pub async fn drop_connection(&self, key: &ConnectionKey) { let mut cache = self.connections.lock().await; self.db.userdeviceconnid_conn.del(key); - cache.remove(key); + if let Some((_, notify)) = cache.remove(key) { + notify.notify_waiters(); + } } #[implement(Service)] #[tracing::instrument(level = "debug", skip(self))] -pub async fn load_or_init_connection(&self, key: &ConnectionKey) -> ConnectionVal { +pub async fn load_or_init_connection(&self, key: &ConnectionKey) -> ConnectionEntry { let mut cache = self.connections.lock().await; match cache.entry(key.clone()) { - | Entry::Occupied(val) => val.get().clone(), + | Entry::Occupied(mut entry) => { + entry.get().1.notify_waiters(); + let conn = entry.get().0.clone(); + let notify = Arc::new(Notify::new()); + entry.insert((conn.clone(), notify.clone())); + (conn, notify) + }, | Entry::Vacant(val) => { let conn = self .db @@ -149,7 +158,9 @@ pub async fn load_or_init_connection(&self, key: &ConnectionKey) -> ConnectionVa .map(Arc::new) .unwrap_or_default(); - val.insert(conn).clone() + let notify = Arc::new(Notify::new()); + let (conn, _) = val.insert((conn, notify.clone())); + (conn.clone(), notify) }, } } @@ -160,7 +171,7 @@ pub async fn load_connection(&self, key: &ConnectionKey) -> Result Ok(val.get().clone()), + | Entry::Occupied(val) => Ok(val.get().0.clone()), | Entry::Vacant(val) => self .db .userdeviceconnid_conn @@ -170,7 +181,10 @@ pub async fn load_connection(&self, key: &ConnectionKey) -> Result Result