diff --git a/Cargo.lock b/Cargo.lock index cc614a77..e91644a8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -392,6 +392,7 @@ dependencies = [ "http", "hyper", "image", + "itertools 0.12.1", "jsonwebtoken", "lazy_static", "lru-cache", diff --git a/Cargo.toml b/Cargo.toml index 636f32e7..3a847583 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -112,6 +112,9 @@ tikv-jemallocator = { version = "0.5.0", features = ["unprefixed_malloc_on_suppo lazy_static = "1.4.0" async-trait = "0.1.68" +# Used to make working with iterators easier, was already a transitive depdendency +itertools = "0.12" + sd-notify = { version = "0.4.1", optional = true } [dependencies.rocksdb] diff --git a/src/api/client_server/membership.rs b/src/api/client_server/membership.rs index 86ac5950..8b570b9d 100644 --- a/src/api/client_server/membership.rs +++ b/src/api/client_server/membership.rs @@ -55,14 +55,21 @@ pub async fn join_room_by_id_route( services() .rooms .state_cache - .invite_state(sender_user, &body.room_id)? - .unwrap_or_default() - .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(|s| s.to_owned())) - .filter_map(|sender| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()), + .servers_invite_via(&body.room_id)? + .unwrap_or( + services() + .rooms + .state_cache + .invite_state(sender_user, &body.room_id)? + .unwrap_or_default() + .iter() + .filter_map(|event| serde_json::from_str(event.json().get()).ok()) + .filter_map(|event: serde_json::Value| event.get("sender").cloned()) + .filter_map(|sender| sender.as_str().map(|s| s.to_owned())) + .filter_map(|sender| UserId::parse(sender).ok()) + .map(|user| user.server_name().to_owned()) + .collect(), + ), ); servers.push( @@ -101,14 +108,21 @@ pub async fn join_room_by_id_or_alias_route( services() .rooms .state_cache - .invite_state(sender_user, &room_id)? - .unwrap_or_default() - .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(|s| s.to_owned())) - .filter_map(|sender| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()), + .servers_invite_via(&room_id)? + .unwrap_or( + services() + .rooms + .state_cache + .invite_state(sender_user, &room_id)? + .unwrap_or_default() + .iter() + .filter_map(|event| serde_json::from_str(event.json().get()).ok()) + .filter_map(|event: serde_json::Value| event.get("sender").cloned()) + .filter_map(|sender| sender.as_str().map(|s| s.to_owned())) + .filter_map(|sender| UserId::parse(sender).ok()) + .map(|user| user.server_name().to_owned()) + .collect(), + ), ); servers.push( @@ -1261,6 +1275,7 @@ pub(crate) async fn invite_helper<'a>( room_version: room_version_id.clone(), event: PduEvent::convert_to_outgoing_federation_event(pdu_json.clone()), invite_room_state, + via: services().rooms.state_cache.servers_route_via(room_id).ok(), }, ) .await?; @@ -1423,6 +1438,7 @@ pub async fn leave_room(user_id: &UserId, room_id: &RoomId, reason: Option Result<()> { "User is not invited.", ))?; - let servers: HashSet<_> = invite_state - .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(|s| s.to_owned())) - .filter_map(|sender| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()) - .collect(); + let servers: HashSet<_> = services() + .rooms + .state_cache + .servers_invite_via(&room_id)? + .map(|servers| HashSet::from_iter(servers)) + .unwrap_or( + invite_state + .iter() + .filter_map(|event| serde_json::from_str(event.json().get()).ok()) + .filter_map(|event: serde_json::Value| event.get("sender").cloned()) + .filter_map(|sender| sender.as_str().map(|s| s.to_owned())) + .filter_map(|sender| UserId::parse(sender).ok()) + .map(|user| user.server_name().to_owned()) + .collect(), + ); for remote_server in servers { let make_leave_response = services() diff --git a/src/api/server_server.rs b/src/api/server_server.rs index fa7f1315..ea27880d 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -1669,6 +1669,15 @@ pub async fn create_invite_route( )); } + if let Some(via) = &body.via { + if via.is_empty() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "via field must not be empty.", + )); + } + } + let mut signed_event = utils::to_canonical_object(&body.event) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invite event is invalid."))?; @@ -1744,6 +1753,7 @@ pub async fn create_invite_route( MembershipState::Invite, &sender, Some(invite_state), + body.via, true, )?; } diff --git a/src/database/key_value/rooms/state_cache.rs b/src/database/key_value/rooms/state_cache.rs index 49e3842b..8f884e64 100644 --- a/src/database/key_value/rooms/state_cache.rs +++ b/src/database/key_value/rooms/state_cache.rs @@ -1,5 +1,6 @@ use std::{collections::HashSet, sync::Arc}; +use itertools::Itertools; use ruma::{ events::{AnyStrippedStateEvent, AnySyncStateEvent}, serde::Raw, @@ -21,7 +22,11 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { } fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - let mut roomuser_id = room_id.as_bytes().to_vec(); + let roomid = room_id.as_bytes().to_vec(); + let mut roomid_prefix = room_id.as_bytes().to_vec(); + roomid_prefix.push(0xff); + + let mut roomuser_id = roomid_prefix.clone(); roomuser_id.push(0xff); roomuser_id.extend_from_slice(user_id.as_bytes()); @@ -36,6 +41,16 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { self.userroomid_leftstate.remove(&userroom_id)?; self.roomuserid_leftcount.remove(&roomuser_id)?; + if self.roomuserid_joined.scan_prefix(roomid_prefix).count() == 0 + && self + .roomuserid_invitecount + .scan_prefix(roomid_prefix) + .count() + == 0 + { + self.roomid_inviteviaservers.remove(&roomid)?; + } + Ok(()) } @@ -44,6 +59,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { user_id: &UserId, room_id: &RoomId, last_state: Option>>, + invite_via: Option>, ) -> Result<()> { let mut roomuser_id = room_id.as_bytes().to_vec(); roomuser_id.push(0xff); @@ -67,12 +83,30 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { self.userroomid_leftstate.remove(&userroom_id)?; self.roomuserid_leftcount.remove(&roomuser_id)?; + if let Some(servers) = invite_via { + let mut prev_servers = self.servers_invite_via(room_id)?.unwrap_or(Vec::new()); + prev_servers.append(&mut servers); + let servers = prev_servers.iter().rev().unique().rev().collect_vec(); + + let servers = servers + .iter() + .map(|server| server.as_bytes()) + .collect_vec() + .join(&[0xff][..]); + + self.roomid_inviteviaservers + .insert(&room_id.as_bytes().to_vec(), &servers)?; + } + Ok(()) } fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xff); + let roomid = room_id.as_bytes().to_vec(); + let mut roomid_prefix = room_id.as_bytes().to_vec(); + roomid_prefix.push(0xff); + + let mut roomuser_id = roomid_prefix.clone(); roomuser_id.extend_from_slice(user_id.as_bytes()); let mut userroom_id = user_id.as_bytes().to_vec(); @@ -92,6 +126,16 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { self.userroomid_invitestate.remove(&userroom_id)?; self.roomuserid_invitecount.remove(&roomuser_id)?; + if self.roomuserid_joined.scan_prefix(roomid_prefix).count() == 0 + && self + .roomuserid_invitecount + .scan_prefix(roomid_prefix) + .count() + == 0 + { + self.roomid_inviteviaservers.remove(&roomid)?; + } + Ok(()) } @@ -604,4 +648,41 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some()) } + + #[tracing::instrument(skip(self))] + fn servers_invite_via(&self, room_id: &RoomId) -> Result>> { + let room_id = room_id.as_bytes().to_vec(); + + self.roomid_inviteviaservers + .get(&room_id)? + .map(|servers| { + let state = serde_json::from_slice(&servers) + .map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?; + + Ok(state) + }) + .transpose() + } + + #[tracing::instrument(skip(self))] + fn add_servers_invite_via( + &self, + room_id: &RoomId, + servers: &Vec, + ) -> Result<()> { + let mut prev_servers = self.servers_invite_via(room_id)?.unwrap_or(Vec::new()); + prev_servers.append(&mut servers); + let servers = prev_servers.iter().rev().unique().rev().collect_vec(); + + let servers = servers + .iter() + .map(|server| server.as_bytes()) + .collect_vec() + .join(&[0xff][..]); + + self.roomid_inviteviaservers + .insert(&room_id.as_bytes().to_vec(), &servers)?; + + Ok(()) + } } diff --git a/src/database/mod.rs b/src/database/mod.rs index 41da857c..30b3cc1a 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -161,6 +161,8 @@ pub struct KeyValueDatabase { //pub pusher: pusher::PushData, pub(super) senderkey_pusher: Arc, + pub(super) roomid_inviteviaservers: Arc, + pub(super) pdu_cache: Mutex>>, pub(super) shorteventid_cache: Mutex>>, pub(super) auth_chain_cache: Mutex, Arc>>>, @@ -368,6 +370,8 @@ impl KeyValueDatabase { global: builder.open_tree("global")?, server_signingkeys: builder.open_tree("server_signingkeys")?, + roomid_inviteviaservers: builder.open_tree("roomid_inviteviaservers")?, + pdu_cache: Mutex::new(LruCache::new( config .pdu_cache_capacity diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index f6581bb5..351675df 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -86,6 +86,7 @@ impl Service { membership, &pdu.sender, None, + None, false, )?; } diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs index b511919a..f3c0bb74 100644 --- a/src/service/rooms/state_cache/data.rs +++ b/src/service/rooms/state_cache/data.rs @@ -15,6 +15,7 @@ pub trait Data: Send + Sync { user_id: &UserId, room_id: &RoomId, last_state: Option>>, + invite_via: Option>, ) -> Result<()>; fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; @@ -106,4 +107,14 @@ pub trait Data: Send + Sync { fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result; fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result; + + /// Gets the servers to either accept or decline invites via for a given room. + fn servers_invite_via(&self, room_id: &RoomId) -> Result>>; + + /// Add the given servers the list to accept or decline invites via for a given room. + fn add_servers_invite_via( + &self, + room_id: &RoomId, + servers: &Vec, + ) -> Result<()>; } diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index c108695d..78e2e616 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -3,14 +3,19 @@ use std::{collections::HashSet, sync::Arc}; pub use data::Data; +use itertools::Itertools; use ruma::{ events::{ direct::DirectEvent, ignored_user_list::IgnoredUserListEvent, - room::{create::RoomCreateEventContent, member::MembershipState}, + room::{ + create::RoomCreateEventContent, member::MembershipState, + power_levels::RoomPowerLevelsEventContent, + }, AnyStrippedStateEvent, AnySyncStateEvent, GlobalAccountDataEventType, RoomAccountDataEventType, StateEventType, }, + int, serde::Raw, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, }; @@ -32,6 +37,7 @@ impl Service { membership: MembershipState, sender: &UserId, last_state: Option>>, + invite_via: Option>, update_joined_count: bool, ) -> Result<()> { // Keep track what remote users exist by adding them as "deactivated" users @@ -176,7 +182,8 @@ impl Service { return Ok(()); } - self.db.mark_as_invited(user_id, room_id, last_state)?; + self.db + .mark_as_invited(user_id, room_id, last_state, invite_via)?; } MembershipState::Leave | MembershipState::Ban => { self.db.mark_as_left(user_id, room_id)?; @@ -350,4 +357,54 @@ impl Service { pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result { self.db.is_left(user_id, room_id) } + + #[tracing::instrument(skip(self))] + pub fn servers_invite_via(&self, room_id: &RoomId) -> Result>> { + self.db.servers_invite_via(room_id) + } + + /// Gets up to three servers that are likely to be in the room in the distant future. + /// + /// See https://spec.matrix.org/v1.10/appendices/#routing + #[tracing::instrument(skip(self))] + pub fn servers_route_via(&self, room_id: &RoomId) -> Result> { + let most_powerful_user_server = services() + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomPowerLevels, "")? + .map(|pdu| { + serde_json::from_str(pdu.content.get()).map( + |conent: RoomPowerLevelsEventContent| { + conent + .users + .iter() + .max_by_key(|(_, power)| power) + .and_then(|x| if x.1 >= &int!(50) { Some(x) } else { None }) + .map(|(user, power)| user.server_name().to_owned()) + }, + ) + }) + .transpose() + .map_err(|e| Error::bad_database("Invalid power levels event content in database"))? + .flatten(); + + let mut servers = services() + .rooms + .state_cache + .room_members(room_id) + .filter_map(Result::ok) + .counts_by(|user| user.server_name()) + .iter() + .sorted_by_key(|(_, users)| users) + .map(|(server, _)| server.to_owned().to_owned()) + .rev() + .take(3) + .collect_vec(); + + if let Some(server) = most_powerful_user_server { + servers.insert(0, server); + servers.truncate(3); + } + Ok(servers) + } } diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 379d97fe..002b2b3f 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -427,6 +427,7 @@ impl Service { content.membership, &pdu.sender, invite_state, + None, true, )?; }