Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions src/api/client/sync/v5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ pub(crate) async fn sync_events_v5_route(
.unwrap_or(0);

let conn_key = into_connection_key(sender_user, sender_device, request.conn_id.as_deref());
let conn_val = services
let (conn_val, superseded) = services
.sync
.load_or_init_connection(&conn_key)
.await;
Expand Down Expand Up @@ -196,11 +196,10 @@ pub(crate) async fn sync_events_v5_route(

if timeout == 0
|| services.server.is_stopping()
|| timeout_at(stop_at, watchers)
.boxed()
.await
.is_err()
{
|| tokio::select! {
() = superseded.notified() => true,
result = timeout_at(stop_at, watchers).boxed() => result.is_err(),
} {
response.pos = conn.next_batch.to_string().into();
trace!(conn.globalsince, conn.next_batch, "timeout; empty response {response:?}");
conn.store(&services.sync, &conn_key);
Expand Down
42 changes: 28 additions & 14 deletions src/service/sync/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use ruma::{
},
};
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex as TokioMutex;
use tokio::sync::{Mutex as TokioMutex, Notify};
use tuwunel_core::{Result, at, debug, err, implement, is_equal_to, utils::stream::TryIgnore};
use tuwunel_database::{Cbor, Deserialized, Map};

Expand Down Expand Up @@ -57,9 +57,10 @@ pub struct Room {
pub roomsince: u64,
}

type Connections = TokioMutex<BTreeMap<ConnectionKey, ConnectionVal>>;
pub type ConnectionVal = Arc<TokioMutex<Connection>>;
pub type ConnectionEntry = (ConnectionVal, Arc<Notify>);
pub type ConnectionKey = (OwnedUserId, Option<OwnedDeviceId>, Option<ConnectionId>);
type Connections = TokioMutex<BTreeMap<ConnectionKey, ConnectionEntry>>;

pub type Subscriptions = BTreeMap<OwnedRoomId, request::ListConfig>;
pub type Lists = BTreeMap<ListId, request::List>;
Expand Down Expand Up @@ -102,22 +103,22 @@ pub async fn clear_connections(
device_id: Option<&DeviceId>,
conn_id: Option<&ConnectionId>,
) {
self.connections
.lock()
.await
.retain(|(conn_user_id, conn_device_id, conn_conn_id), _| {
self.connections.lock().await.retain(
|(conn_user_id, conn_device_id, conn_conn_id), (_, notify)| {
let retain = user_id.is_none_or(is_equal_to!(conn_user_id))
&& (device_id.is_none() || device_id == conn_device_id.as_deref())
&& (conn_id.is_none() || conn_id == conn_conn_id.as_ref());

if !retain {
notify.notify_waiters();
self.db
.userdeviceconnid_conn
.del((conn_user_id, conn_device_id, conn_conn_id));
}

retain
});
},
);
}

#[implement(Service)]
Expand All @@ -126,16 +127,24 @@ pub async fn drop_connection(&self, key: &ConnectionKey) {
let mut cache = self.connections.lock().await;

self.db.userdeviceconnid_conn.del(key);
cache.remove(key);
if let Some((_, notify)) = cache.remove(key) {
notify.notify_waiters();
}
}

#[implement(Service)]
#[tracing::instrument(level = "debug", skip(self))]
pub async fn load_or_init_connection(&self, key: &ConnectionKey) -> ConnectionVal {
pub async fn load_or_init_connection(&self, key: &ConnectionKey) -> ConnectionEntry {
let mut cache = self.connections.lock().await;

match cache.entry(key.clone()) {
| Entry::Occupied(val) => val.get().clone(),
| Entry::Occupied(mut entry) => {
entry.get().1.notify_waiters();
let conn = entry.get().0.clone();
let notify = Arc::new(Notify::new());
entry.insert((conn.clone(), notify.clone()));
(conn, notify)
},
| Entry::Vacant(val) => {
let conn = self
.db
Expand All @@ -149,7 +158,9 @@ pub async fn load_or_init_connection(&self, key: &ConnectionKey) -> ConnectionVa
.map(Arc::new)
.unwrap_or_default();

val.insert(conn).clone()
let notify = Arc::new(Notify::new());
let (conn, _) = val.insert((conn, notify.clone()));
(conn.clone(), notify)
},
}
}
Expand All @@ -160,7 +171,7 @@ pub async fn load_connection(&self, key: &ConnectionKey) -> Result<ConnectionVal
let mut cache = self.connections.lock().await;

match cache.entry(key.clone()) {
| Entry::Occupied(val) => Ok(val.get().clone()),
| Entry::Occupied(val) => Ok(val.get().0.clone()),
| Entry::Vacant(val) => self
.db
.userdeviceconnid_conn
Expand All @@ -170,7 +181,10 @@ pub async fn load_connection(&self, key: &ConnectionKey) -> Result<ConnectionVal
.map(at!(0))
.map(TokioMutex::new)
.map(Arc::new)
.map(|conn| val.insert(conn).clone()),
.map(|conn| {
let notify = Arc::new(Notify::new());
val.insert((conn, notify)).0.clone()
}),
}
}

Expand All @@ -181,7 +195,7 @@ pub async fn get_loaded_connection(&self, key: &ConnectionKey) -> Result<Connect
.lock()
.await
.get(key)
.cloned()
.map(|(conn, _)| conn.clone())
.ok_or_else(|| err!(Request(NotFound("Connection not found."))))
}

Expand Down
Loading