feat: unstable support for msc4125

This commit is contained in:
Matthias Ahouansou 2024-04-10 21:37:55 +01:00
parent f16bff2466
commit 270b32e3db
No known key found for this signature in database
10 changed files with 222 additions and 29 deletions

1
Cargo.lock generated
View file

@ -392,6 +392,7 @@ dependencies = [
"http", "http",
"hyper", "hyper",
"image", "image",
"itertools 0.12.1",
"jsonwebtoken", "jsonwebtoken",
"lazy_static", "lazy_static",
"lru-cache", "lru-cache",

View file

@ -112,6 +112,9 @@ tikv-jemallocator = { version = "0.5.0", features = ["unprefixed_malloc_on_suppo
lazy_static = "1.4.0" lazy_static = "1.4.0"
async-trait = "0.1.68" async-trait = "0.1.68"
# Used to make working with iterators easier, was already a transitive depdendency
itertools = "0.12"
sd-notify = { version = "0.4.1", optional = true } sd-notify = { version = "0.4.1", optional = true }
[dependencies.rocksdb] [dependencies.rocksdb]

View file

@ -52,6 +52,11 @@ pub async fn join_room_by_id_route(
let mut servers = Vec::new(); // There is no body.server_name for /roomId/join let mut servers = Vec::new(); // There is no body.server_name for /roomId/join
servers.extend( servers.extend(
services()
.rooms
.state_cache
.servers_invite_via(&body.room_id)?
.unwrap_or(
services() services()
.rooms .rooms
.state_cache .state_cache
@ -62,7 +67,9 @@ pub async fn join_room_by_id_route(
.filter_map(|event: serde_json::Value| event.get("sender").cloned()) .filter_map(|event: serde_json::Value| event.get("sender").cloned())
.filter_map(|sender| sender.as_str().map(|s| s.to_owned())) .filter_map(|sender| sender.as_str().map(|s| s.to_owned()))
.filter_map(|sender| UserId::parse(sender).ok()) .filter_map(|sender| UserId::parse(sender).ok())
.map(|user| user.server_name().to_owned()), .map(|user| user.server_name().to_owned())
.collect(),
),
); );
servers.push( servers.push(
@ -98,6 +105,11 @@ pub async fn join_room_by_id_or_alias_route(
Ok(room_id) => { Ok(room_id) => {
let mut servers = body.server_name.clone(); let mut servers = body.server_name.clone();
servers.extend( servers.extend(
services()
.rooms
.state_cache
.servers_invite_via(&room_id)?
.unwrap_or(
services() services()
.rooms .rooms
.state_cache .state_cache
@ -108,7 +120,9 @@ pub async fn join_room_by_id_or_alias_route(
.filter_map(|event: serde_json::Value| event.get("sender").cloned()) .filter_map(|event: serde_json::Value| event.get("sender").cloned())
.filter_map(|sender| sender.as_str().map(|s| s.to_owned())) .filter_map(|sender| sender.as_str().map(|s| s.to_owned()))
.filter_map(|sender| UserId::parse(sender).ok()) .filter_map(|sender| UserId::parse(sender).ok())
.map(|user| user.server_name().to_owned()), .map(|user| user.server_name().to_owned())
.collect(),
),
); );
servers.push( servers.push(
@ -1261,6 +1275,7 @@ pub(crate) async fn invite_helper<'a>(
room_version: room_version_id.clone(), room_version: room_version_id.clone(),
event: PduEvent::convert_to_outgoing_federation_event(pdu_json.clone()), event: PduEvent::convert_to_outgoing_federation_event(pdu_json.clone()),
invite_room_state, invite_room_state,
via: services().rooms.state_cache.servers_route_via(room_id).ok(),
}, },
) )
.await?; .await?;
@ -1423,6 +1438,7 @@ pub async fn leave_room(user_id: &UserId, room_id: &RoomId, reason: Option<Strin
MembershipState::Leave, MembershipState::Leave,
user_id, user_id,
last_state, last_state,
None,
true, true,
)?; )?;
} else { } else {
@ -1454,6 +1470,7 @@ pub async fn leave_room(user_id: &UserId, room_id: &RoomId, reason: Option<Strin
MembershipState::Leave, MembershipState::Leave,
user_id, user_id,
None, None,
None,
true, true,
)?; )?;
return Ok(()); return Ok(());
@ -1503,14 +1520,21 @@ async fn remote_leave_room(user_id: &UserId, room_id: &RoomId) -> Result<()> {
"User is not invited.", "User is not invited.",
))?; ))?;
let servers: HashSet<_> = invite_state let servers: HashSet<_> = services()
.rooms
.state_cache
.servers_invite_via(&room_id)?
.map(|servers| HashSet::from_iter(servers))
.unwrap_or(
invite_state
.iter() .iter()
.filter_map(|event| serde_json::from_str(event.json().get()).ok()) .filter_map(|event| serde_json::from_str(event.json().get()).ok())
.filter_map(|event: serde_json::Value| event.get("sender").cloned()) .filter_map(|event: serde_json::Value| event.get("sender").cloned())
.filter_map(|sender| sender.as_str().map(|s| s.to_owned())) .filter_map(|sender| sender.as_str().map(|s| s.to_owned()))
.filter_map(|sender| UserId::parse(sender).ok()) .filter_map(|sender| UserId::parse(sender).ok())
.map(|user| user.server_name().to_owned()) .map(|user| user.server_name().to_owned())
.collect(); .collect(),
);
for remote_server in servers { for remote_server in servers {
let make_leave_response = services() let make_leave_response = services()

View file

@ -1669,6 +1669,15 @@ pub async fn create_invite_route(
)); ));
} }
if let Some(via) = &body.via {
if via.is_empty() {
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"via field must not be empty.",
));
}
}
let mut signed_event = utils::to_canonical_object(&body.event) let mut signed_event = utils::to_canonical_object(&body.event)
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invite event is invalid."))?; .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invite event is invalid."))?;
@ -1744,6 +1753,7 @@ pub async fn create_invite_route(
MembershipState::Invite, MembershipState::Invite,
&sender, &sender,
Some(invite_state), Some(invite_state),
body.via,
true, true,
)?; )?;
} }

View file

@ -1,5 +1,6 @@
use std::{collections::HashSet, sync::Arc}; use std::{collections::HashSet, sync::Arc};
use itertools::Itertools;
use ruma::{ use ruma::{
events::{AnyStrippedStateEvent, AnySyncStateEvent}, events::{AnyStrippedStateEvent, AnySyncStateEvent},
serde::Raw, serde::Raw,
@ -21,7 +22,11 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
} }
fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let mut roomuser_id = room_id.as_bytes().to_vec(); let roomid = room_id.as_bytes().to_vec();
let mut roomid_prefix = room_id.as_bytes().to_vec();
roomid_prefix.push(0xff);
let mut roomuser_id = roomid_prefix.clone();
roomuser_id.push(0xff); roomuser_id.push(0xff);
roomuser_id.extend_from_slice(user_id.as_bytes()); roomuser_id.extend_from_slice(user_id.as_bytes());
@ -36,6 +41,16 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
self.userroomid_leftstate.remove(&userroom_id)?; self.userroomid_leftstate.remove(&userroom_id)?;
self.roomuserid_leftcount.remove(&roomuser_id)?; self.roomuserid_leftcount.remove(&roomuser_id)?;
if self.roomuserid_joined.scan_prefix(roomid_prefix).count() == 0
&& self
.roomuserid_invitecount
.scan_prefix(roomid_prefix)
.count()
== 0
{
self.roomid_inviteviaservers.remove(&roomid)?;
}
Ok(()) Ok(())
} }
@ -44,6 +59,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
user_id: &UserId, user_id: &UserId,
room_id: &RoomId, room_id: &RoomId,
last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>, last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
invite_via: Option<Vec<OwnedServerName>>,
) -> Result<()> { ) -> Result<()> {
let mut roomuser_id = room_id.as_bytes().to_vec(); let mut roomuser_id = room_id.as_bytes().to_vec();
roomuser_id.push(0xff); roomuser_id.push(0xff);
@ -67,12 +83,30 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
self.userroomid_leftstate.remove(&userroom_id)?; self.userroomid_leftstate.remove(&userroom_id)?;
self.roomuserid_leftcount.remove(&roomuser_id)?; self.roomuserid_leftcount.remove(&roomuser_id)?;
if let Some(servers) = invite_via {
let mut prev_servers = self.servers_invite_via(room_id)?.unwrap_or(Vec::new());
prev_servers.append(&mut servers);
let servers = prev_servers.iter().rev().unique().rev().collect_vec();
let servers = servers
.iter()
.map(|server| server.as_bytes())
.collect_vec()
.join(&[0xff][..]);
self.roomid_inviteviaservers
.insert(&room_id.as_bytes().to_vec(), &servers)?;
}
Ok(()) Ok(())
} }
fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let mut roomuser_id = room_id.as_bytes().to_vec(); let roomid = room_id.as_bytes().to_vec();
roomuser_id.push(0xff); let mut roomid_prefix = room_id.as_bytes().to_vec();
roomid_prefix.push(0xff);
let mut roomuser_id = roomid_prefix.clone();
roomuser_id.extend_from_slice(user_id.as_bytes()); roomuser_id.extend_from_slice(user_id.as_bytes());
let mut userroom_id = user_id.as_bytes().to_vec(); let mut userroom_id = user_id.as_bytes().to_vec();
@ -92,6 +126,16 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
self.userroomid_invitestate.remove(&userroom_id)?; self.userroomid_invitestate.remove(&userroom_id)?;
self.roomuserid_invitecount.remove(&roomuser_id)?; self.roomuserid_invitecount.remove(&roomuser_id)?;
if self.roomuserid_joined.scan_prefix(roomid_prefix).count() == 0
&& self
.roomuserid_invitecount
.scan_prefix(roomid_prefix)
.count()
== 0
{
self.roomid_inviteviaservers.remove(&roomid)?;
}
Ok(()) Ok(())
} }
@ -604,4 +648,41 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some()) Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some())
} }
#[tracing::instrument(skip(self))]
fn servers_invite_via(&self, room_id: &RoomId) -> Result<Option<Vec<OwnedServerName>>> {
let room_id = room_id.as_bytes().to_vec();
self.roomid_inviteviaservers
.get(&room_id)?
.map(|servers| {
let state = serde_json::from_slice(&servers)
.map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?;
Ok(state)
})
.transpose()
}
#[tracing::instrument(skip(self))]
fn add_servers_invite_via(
&self,
room_id: &RoomId,
servers: &Vec<OwnedServerName>,
) -> Result<()> {
let mut prev_servers = self.servers_invite_via(room_id)?.unwrap_or(Vec::new());
prev_servers.append(&mut servers);
let servers = prev_servers.iter().rev().unique().rev().collect_vec();
let servers = servers
.iter()
.map(|server| server.as_bytes())
.collect_vec()
.join(&[0xff][..]);
self.roomid_inviteviaservers
.insert(&room_id.as_bytes().to_vec(), &servers)?;
Ok(())
}
} }

View file

@ -161,6 +161,8 @@ pub struct KeyValueDatabase {
//pub pusher: pusher::PushData, //pub pusher: pusher::PushData,
pub(super) senderkey_pusher: Arc<dyn KvTree>, pub(super) senderkey_pusher: Arc<dyn KvTree>,
pub(super) roomid_inviteviaservers: Arc<dyn KvTree>,
pub(super) pdu_cache: Mutex<LruCache<OwnedEventId, Arc<PduEvent>>>, pub(super) pdu_cache: Mutex<LruCache<OwnedEventId, Arc<PduEvent>>>,
pub(super) shorteventid_cache: Mutex<LruCache<u64, Arc<EventId>>>, pub(super) shorteventid_cache: Mutex<LruCache<u64, Arc<EventId>>>,
pub(super) auth_chain_cache: Mutex<LruCache<Vec<u64>, Arc<HashSet<u64>>>>, pub(super) auth_chain_cache: Mutex<LruCache<Vec<u64>, Arc<HashSet<u64>>>>,
@ -368,6 +370,8 @@ impl KeyValueDatabase {
global: builder.open_tree("global")?, global: builder.open_tree("global")?,
server_signingkeys: builder.open_tree("server_signingkeys")?, server_signingkeys: builder.open_tree("server_signingkeys")?,
roomid_inviteviaservers: builder.open_tree("roomid_inviteviaservers")?,
pdu_cache: Mutex::new(LruCache::new( pdu_cache: Mutex::new(LruCache::new(
config config
.pdu_cache_capacity .pdu_cache_capacity

View file

@ -86,6 +86,7 @@ impl Service {
membership, membership,
&pdu.sender, &pdu.sender,
None, None,
None,
false, false,
)?; )?;
} }

View file

@ -15,6 +15,7 @@ pub trait Data: Send + Sync {
user_id: &UserId, user_id: &UserId,
room_id: &RoomId, room_id: &RoomId,
last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>, last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
invite_via: Option<Vec<OwnedServerName>>,
) -> Result<()>; ) -> Result<()>;
fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>;
@ -106,4 +107,14 @@ pub trait Data: Send + Sync {
fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool>; fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool>;
fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool>; fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool>;
/// Gets the servers to either accept or decline invites via for a given room.
fn servers_invite_via(&self, room_id: &RoomId) -> Result<Option<Vec<OwnedServerName>>>;
/// Add the given servers the list to accept or decline invites via for a given room.
fn add_servers_invite_via(
&self,
room_id: &RoomId,
servers: &Vec<OwnedServerName>,
) -> Result<()>;
} }

View file

@ -3,14 +3,19 @@ use std::{collections::HashSet, sync::Arc};
pub use data::Data; pub use data::Data;
use itertools::Itertools;
use ruma::{ use ruma::{
events::{ events::{
direct::DirectEvent, direct::DirectEvent,
ignored_user_list::IgnoredUserListEvent, ignored_user_list::IgnoredUserListEvent,
room::{create::RoomCreateEventContent, member::MembershipState}, room::{
create::RoomCreateEventContent, member::MembershipState,
power_levels::RoomPowerLevelsEventContent,
},
AnyStrippedStateEvent, AnySyncStateEvent, GlobalAccountDataEventType, AnyStrippedStateEvent, AnySyncStateEvent, GlobalAccountDataEventType,
RoomAccountDataEventType, StateEventType, RoomAccountDataEventType, StateEventType,
}, },
int,
serde::Raw, serde::Raw,
OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId,
}; };
@ -32,6 +37,7 @@ impl Service {
membership: MembershipState, membership: MembershipState,
sender: &UserId, sender: &UserId,
last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>, last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
invite_via: Option<Vec<OwnedServerName>>,
update_joined_count: bool, update_joined_count: bool,
) -> Result<()> { ) -> Result<()> {
// Keep track what remote users exist by adding them as "deactivated" users // Keep track what remote users exist by adding them as "deactivated" users
@ -176,7 +182,8 @@ impl Service {
return Ok(()); return Ok(());
} }
self.db.mark_as_invited(user_id, room_id, last_state)?; self.db
.mark_as_invited(user_id, room_id, last_state, invite_via)?;
} }
MembershipState::Leave | MembershipState::Ban => { MembershipState::Leave | MembershipState::Ban => {
self.db.mark_as_left(user_id, room_id)?; self.db.mark_as_left(user_id, room_id)?;
@ -350,4 +357,54 @@ impl Service {
pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
self.db.is_left(user_id, room_id) self.db.is_left(user_id, room_id)
} }
#[tracing::instrument(skip(self))]
pub fn servers_invite_via(&self, room_id: &RoomId) -> Result<Option<Vec<OwnedServerName>>> {
self.db.servers_invite_via(room_id)
}
/// Gets up to three servers that are likely to be in the room in the distant future.
///
/// See https://spec.matrix.org/v1.10/appendices/#routing
#[tracing::instrument(skip(self))]
pub fn servers_route_via(&self, room_id: &RoomId) -> Result<Vec<OwnedServerName>> {
let most_powerful_user_server = services()
.rooms
.state_accessor
.room_state_get(room_id, &StateEventType::RoomPowerLevels, "")?
.map(|pdu| {
serde_json::from_str(pdu.content.get()).map(
|conent: RoomPowerLevelsEventContent| {
conent
.users
.iter()
.max_by_key(|(_, power)| power)
.and_then(|x| if x.1 >= &int!(50) { Some(x) } else { None })
.map(|(user, power)| user.server_name().to_owned())
},
)
})
.transpose()
.map_err(|e| Error::bad_database("Invalid power levels event content in database"))?
.flatten();
let mut servers = services()
.rooms
.state_cache
.room_members(room_id)
.filter_map(Result::ok)
.counts_by(|user| user.server_name())
.iter()
.sorted_by_key(|(_, users)| users)
.map(|(server, _)| server.to_owned().to_owned())
.rev()
.take(3)
.collect_vec();
if let Some(server) = most_powerful_user_server {
servers.insert(0, server);
servers.truncate(3);
}
Ok(servers)
}
} }

View file

@ -427,6 +427,7 @@ impl Service {
content.membership, content.membership,
&pdu.sender, &pdu.sender,
invite_state, invite_state,
None,
true, true,
)?; )?;
} }