From 28ac3790c281bd164d124dfc3983a11b18ed4ab6 Mon Sep 17 00:00:00 2001 From: strawberry Date: Tue, 2 Jul 2024 16:42:07 -0400 Subject: [PATCH] sync upstream spaces/hierarchy federation MR also had to fix a million clippy lints fix(spaces): deal with hierarchy recursion fix(spaces): properly handle max_depth refactor(spaces): token scheme to prevent clients from modifying max_depth and suggested_only perf(spaces): use tokens to skip to room to start populating results at feat(spaces): request hierarchy from servers in via field of child event Co-authored-by: Matthias Ahouansou Signed-off-by: strawberry --- src/api/client/space.rs | 18 +- src/service/rooms/spaces/mod.rs | 1014 ++++++----------------- src/service/rooms/spaces/tests.rs | 145 ++++ src/service/rooms/state_accessor/mod.rs | 38 +- src/service/rooms/state_cache/mod.rs | 2 + 5 files changed, 433 insertions(+), 784 deletions(-) create mode 100644 src/service/rooms/spaces/tests.rs diff --git a/src/api/client/space.rs b/src/api/client/space.rs index 61431c01..e00171c3 100644 --- a/src/api/client/space.rs +++ b/src/api/client/space.rs @@ -2,10 +2,10 @@ use std::str::FromStr; use ruma::{ api::client::{error::ErrorKind, space::get_hierarchy}, - uint, UInt, + UInt, }; -use crate::{service::rooms::spaces::PagnationToken, services, Error, Result, Ruma}; +use crate::{service::rooms::spaces::PaginationToken, services, Error, Result, Ruma}; /// # `GET /_matrix/client/v1/rooms/{room_id}/hierarchy` /// @@ -14,12 +14,10 @@ use crate::{service::rooms::spaces::PagnationToken, services, Error, Result, Rum pub(crate) async fn get_hierarchy_route(body: Ruma) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let limit: usize = body + let limit = body .limit - .unwrap_or_else(|| uint!(10)) - .try_into() - .unwrap_or(10) - .min(100); + .unwrap_or_else(|| UInt::from(10_u32)) + .min(UInt::from(100_u32)); let max_depth = body .max_depth @@ -29,7 +27,7 @@ pub(crate) async fn get_hierarchy_route(body: Ruma) let key = body .from .as_ref() - .and_then(|s| PagnationToken::from_str(s).ok()); + .and_then(|s| PaginationToken::from_str(s).ok()); // Should prevent unexpeded behaviour in (bad) clients if let Some(ref token) = key { @@ -47,8 +45,8 @@ pub(crate) async fn get_hierarchy_route(body: Ruma) .get_client_hierarchy( sender_user, &body.room_id, - limit, - key.map_or(0, |token| token.skip.try_into().unwrap_or(0)), + limit.try_into().unwrap_or(10), + key.map_or(vec![], |token| token.short_room_ids), max_depth.try_into().unwrap_or(3), body.suggested_only, ) diff --git a/src/service/rooms/spaces/mod.rs b/src/service/rooms/spaces/mod.rs index bc7703a3..3dfbcf3c 100644 --- a/src/service/rooms/spaces/mod.rs +++ b/src/service/rooms/spaces/mod.rs @@ -1,10 +1,13 @@ +mod tests; + use std::{ + collections::VecDeque, fmt::{Display, Formatter}, str::FromStr, sync::Arc, }; -use conduit::{debug_info, Error, Result, Server}; +use conduit::{debug_info, Server}; use database::Database; use lru_cache::LruCache; use ruma::{ @@ -18,8 +21,10 @@ use ruma::{ events::{ room::{ avatar::RoomAvatarEventContent, + canonical_alias::RoomCanonicalAliasEventContent, create::RoomCreateEventContent, - join_rules::{AllowRule, JoinRule, RoomJoinRulesEventContent, RoomMembership}, + join_rules::{JoinRule, RoomJoinRulesEventContent}, + topic::RoomTopicEventContent, }, space::child::{HierarchySpaceChildEvent, SpaceChildEventContent}, StateEventType, @@ -31,225 +36,43 @@ use ruma::{ use tokio::sync::Mutex; use tracing::{debug, error, warn}; -use crate::{server_is_ours, services}; +use crate::{services, Error, Result}; pub struct CachedSpaceHierarchySummary { summary: SpaceHierarchyParentSummary, } -enum SummaryAccessibility { +pub enum SummaryAccessibility { Accessible(Box), Inaccessible, } -struct Arena { - nodes: Vec, - max_depth: usize, - first_untraversed: Option, -} - -struct Node { - parent: Option, - // Next meaning: - // --> - // o o o o - next_sibling: Option, - // First meaning: - // | - // v - // o o o o - first_child: Option, - room_id: OwnedRoomId, - via: Vec, - traversed: bool, -} - -#[derive(Clone, Copy, PartialEq, Debug, PartialOrd)] -struct NodeId { - index: usize, -} - -impl Arena { - /// Checks if a given node is traversed - fn traversed(&self, id: NodeId) -> Option { Some(self.get(id)?.traversed) } - - /// Gets the previous sibling of a given node - fn next_sibling(&self, id: NodeId) -> Option { self.get(id)?.next_sibling } - - /// Gets the parent of a given node - fn parent(&self, id: NodeId) -> Option { self.get(id)?.parent } - - /// Gets the last child of a given node - fn first_child(&self, id: NodeId) -> Option { self.get(id)?.first_child } - - /// Sets traversed to true for a given node - fn traverse(&mut self, id: NodeId) { self.nodes[id.index].traversed = true; } - - /// Gets the node of a given id - fn get(&self, id: NodeId) -> Option<&Node> { self.nodes.get(id.index) } - - /// Gets a mutable reference of a node of a given id - fn get_mut(&mut self, id: NodeId) -> Option<&mut Node> { self.nodes.get_mut(id.index) } - - /// Returns the first untraversed node, marking it as traversed in the - /// process - fn first_untraversed(&mut self) -> Option { - if self.nodes.is_empty() { - None - } else if let Some(untraversed) = self.first_untraversed { - let mut current = untraversed; - - self.traverse(untraversed); - - // Possible paths: - // 1) Next child exists, and hence is not traversed - // 2) Next child does not exist, so go to the parent, then repeat - // 3) If both the parent and child do not exist, then we have just traversed the - // whole space tree. - // - // You should only ever encounter a traversed node when going up through parents - while self.traversed(current) == Some(true) { - if let Some(next) = self.next_sibling(current) { - current = next; - } else if let Some(parent) = self.parent(current) { - current = parent; - } else { - break; - } - } - - // Traverses down the children until it reaches one without children - while let Some(child) = self.first_child(current) { - current = child; - } - - if self.traversed(current)? { - self.first_untraversed = None; - } else { - self.first_untraversed = Some(current); - } - - Some(untraversed) - } else { - None - } - } - - /// Adds all the given nodes as children of the parent node - fn push(&mut self, parent: NodeId, mut children: Vec<(OwnedRoomId, Vec)>) { - if children.is_empty() { - self.traverse(parent); - } else if self.nodes.get(parent.index).is_some() { - let mut parents = vec![( - parent, - self.get(parent) - .expect("It is some, as above") - .room_id - // Cloning cause otherwise when iterating over the parents, below, there would - // be a mutable and immutable reference to self.nodes - .clone(), - )]; - - while let Some(parent) = self.parent(parents.last().expect("Has at least one value, as above").0) { - parents.push(( - parent, - self.get(parent) - .expect("It is some, as above") - .room_id - .clone(), - )); - } - - // If at max_depth, don't add new rooms - if self.max_depth < parents.len() { - return; - } - - children.reverse(); - - let mut next_id = None; - - for (child, via) in children { - // Prevent adding a child which is a parent (recursion) - if !parents.iter().any(|parent| parent.1 == child) { - self.nodes.push(Node { - parent: Some(parent), - next_sibling: next_id, - first_child: None, - room_id: child, - traversed: false, - via, - }); - - next_id = Some(NodeId { - index: self.nodes.len() - 1, - }); - } - } - - if self.first_untraversed.is_none() - || parent - >= self - .first_untraversed - .expect("Should have already continued if none") - { - self.first_untraversed = next_id; - } - - self.traverse(parent); - - // This is done as if we use an if-let above, we cannot reference self.nodes - // above as then we would have multiple mutable references - let node = self - .get_mut(parent) - .expect("Must be some, as inside this block"); - - node.first_child = next_id; - } - } - - fn new(root: OwnedRoomId, max_depth: usize) -> Self { - let zero_depth = max_depth == 0; - - Self { - nodes: vec![Node { - parent: None, - next_sibling: None, - first_child: None, - room_id: root, - traversed: zero_depth, - via: vec![], - }], - max_depth, - first_untraversed: if zero_depth { - None - } else { - Some(NodeId { - index: 0, - }) - }, - } - } -} - -// Note: perhaps use some better form of token rather than just room count +// TODO: perhaps use some better form of token rather than just room count #[derive(Debug, Eq, PartialEq)] -pub struct PagnationToken { - pub skip: UInt, +pub struct PaginationToken { + /// Path down the hierarchy of the room to start the response at, + /// excluding the root space. + pub short_room_ids: Vec, pub limit: UInt, pub max_depth: UInt, pub suggested_only: bool, } -impl FromStr for PagnationToken { +impl FromStr for PaginationToken { type Err = Error; fn from_str(value: &str) -> Result { let mut values = value.split('_'); let mut pag_tok = || { + let mut rooms = vec![]; + + for room in values.next()?.split(',') { + rooms.push(u64::from_str(room).ok()?); + } + Some(Self { - skip: UInt::from_str(values.next()?).ok()?, + short_room_ids: rooms, limit: UInt::from_str(values.next()?).ok()?, max_depth: UInt::from_str(values.next()?).ok()?, suggested_only: { @@ -278,9 +101,20 @@ impl FromStr for PagnationToken { } } -impl Display for PagnationToken { +impl Display for PaginationToken { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{}_{}_{}_{}", self.skip, self.limit, self.max_depth, self.suggested_only) + write!( + f, + "{}_{}_{}_{}", + self.short_room_ids + .iter() + .map(ToString::to_string) + .collect::>() + .join(","), + self.limit, + self.max_depth, + self.suggested_only + ) } } @@ -290,7 +124,6 @@ impl Display for PagnationToken { enum Identifier<'a> { UserId(&'a UserId), ServerName(&'a ServerName), - None, } pub struct Service { @@ -343,15 +176,15 @@ impl Service { }) } - ///Gets the response for the space hierarchy over federation request + /// Gets the response for the space hierarchy over federation request /// - ///Panics if the room does not exist, so a check if the room exists should + /// Errors if the room does not exist, so a check if the room exists should /// be done pub async fn get_federation_hierarchy( &self, room_id: &RoomId, server_name: &ServerName, suggested_only: bool, ) -> Result { match self - .get_summary_and_children_local(&room_id.to_owned(), Identifier::None) + .get_summary_and_children_local(&room_id.to_owned(), Identifier::ServerName(server_name)) .await? { Some(SummaryAccessibility::Accessible(room)) => { @@ -386,6 +219,7 @@ impl Service { } } + /// Gets the summary of a space using solely local information async fn get_summary_and_children_local( &self, current_room: &OwnedRoomId, identifier: Identifier<'_>, ) -> Result> { @@ -402,7 +236,7 @@ impl Service { &cached.summary.join_rule, &identifier, &cached.summary.allowed_room_ids, - )? { + ) { Some(SummaryAccessibility::Accessible(Box::new(cached.summary.clone()))) } else { Some(SummaryAccessibility::Inaccessible) @@ -433,11 +267,11 @@ impl Service { ) } + /// Gets the summary of a space using solely federation + #[tracing::instrument(skip(self))] async fn get_summary_and_children_federation( &self, current_room: &OwnedRoomId, suggested_only: bool, user_id: &UserId, via: &[OwnedServerName], ) -> Result> { - debug_info!("servers via for federation hierarchy: {via:?}"); - for server in via { debug_info!("Asking {server} for /hierarchy"); if let Ok(response) = services() @@ -506,21 +340,24 @@ impl Service { &response.room.join_rule, &Identifier::UserId(user_id), &response.room.allowed_room_ids, - )? { + ) { return Ok(Some(SummaryAccessibility::Accessible(Box::new(summary.clone())))); } return Ok(Some(SummaryAccessibility::Inaccessible)); } - - self.roomid_spacehierarchy_cache - .lock() - .await - .insert(current_room.clone(), None); } + + self.roomid_spacehierarchy_cache + .lock() + .await + .insert(current_room.clone(), None); + Ok(None) } + /// Gets the summary of a space using either local or remote (federation) + /// sources async fn get_summary_and_children_client( &self, current_room: &OwnedRoomId, suggested_only: bool, user_id: &UserId, via: &[OwnedServerName], ) -> Result> { @@ -555,9 +392,12 @@ impl Service { .transpose()? .unwrap_or(JoinRule::Invite); - let allowed_room_ids = allowed_room_ids(join_rule.clone()); + let allowed_room_ids = services() + .rooms + .state_accessor + .allowed_room_ids(join_rule.clone()); - if !is_accessable_child(current_room, &join_rule.clone().into(), identifier, &allowed_room_ids)? { + if !is_accessable_child(current_room, &join_rule.clone().into(), identifier, &allowed_room_ids) { debug!("User is not allowed to see room {room_id}"); // This error will be caught later return Err(Error::BadRequest(ErrorKind::forbidden(), "User is not allowed to see the room")); @@ -569,7 +409,12 @@ impl Service { canonical_alias: services() .rooms .state_accessor - .get_canonical_alias(room_id)?, + .room_state_get(room_id, &StateEventType::RoomCanonicalAlias, "")? + .map_or(Ok(None), |s| { + serde_json::from_str(s.content.get()) + .map(|c: RoomCanonicalAliasEventContent| c.alias) + .map_err(|_| Error::bad_database("Invalid canonical alias event in database.")) + })?, name: services().rooms.state_accessor.get_name(room_id)?, num_joined_members: services() .rooms @@ -585,8 +430,15 @@ impl Service { topic: services() .rooms .state_accessor - .get_room_topic(room_id) - .unwrap_or(None), + .room_state_get(room_id, &StateEventType::RoomTopic, "")? + .map_or(Ok(None), |s| { + serde_json::from_str(s.content.get()) + .map(|c: RoomTopicEventContent| Some(c.topic)) + .map_err(|_| { + error!("Invalid room topic event in database for room {}", room_id); + Error::bad_database("Invalid room topic event in database.") + }) + })?, world_readable: services().rooms.state_accessor.is_world_readable(room_id)?, guest_can_join: services().rooms.state_accessor.guest_can_join(room_id)?, avatar_url: services() @@ -619,116 +471,134 @@ impl Service { }) } - // TODO: make this a lot less messy pub async fn get_client_hierarchy( - &self, sender_user: &UserId, room_id: &RoomId, limit: usize, skip: usize, max_depth: usize, + &self, sender_user: &UserId, room_id: &RoomId, limit: usize, short_room_ids: Vec, max_depth: usize, suggested_only: bool, ) -> Result { - // try to find more servers to fetch hierachy from if the only - // choice is the room ID's server name (usually dead) - // - // all spaces are normal rooms, so they should always have at least - // 1 admin in it which has a far higher chance of their server still - // being alive - let power_levels: ruma::events::room::power_levels::RoomPowerLevelsEventContent = services() - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomPowerLevels, "")? - .map(|ev| { - serde_json::from_str(ev.content.get()) - .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) - }) - .transpose()? - .unwrap_or_default(); + let mut parents = VecDeque::new(); - // add server names of the list of admins in the room for backfill server - let mut via = power_levels - .users - .iter() - .filter(|(_, level)| **level > power_levels.users_default) - .map(|(user_id, _)| user_id.server_name()) - .filter(|server| !server_is_ours(server)) - .map(ToOwned::to_owned) - .collect::>(); + // Don't start populating the results if we have to start at a specific room. + let mut populate_results = short_room_ids.is_empty(); - if let Some(server_name) = room_id.server_name() { - via.push(server_name.to_owned()); - } + let mut stack = vec![vec![( + room_id.to_owned(), + match room_id.server_name() { + Some(server_name) => vec![server_name.into()], + None => vec![], + }, + )]]; - debug_info!("servers via for hierarchy: {via:?}"); + let mut results = Vec::new(); - match self - .get_summary_and_children_client(&room_id.to_owned(), suggested_only, sender_user, &via) - .await? - { - Some(SummaryAccessibility::Accessible(summary)) => { - let mut left_to_skip = skip; - let mut arena = Arena::new(summary.room_id.clone(), max_depth); + while let Some((current_room, via)) = { next_room_to_traverse(&mut stack, &mut parents) } { + if limit > results.len() { + match ( + self.get_summary_and_children_client(¤t_room, suggested_only, sender_user, &via) + .await?, + current_room == room_id, + ) { + (Some(SummaryAccessibility::Accessible(summary)), _) => { + let mut children: Vec<(OwnedRoomId, Vec)> = + get_parent_children_via(&summary, suggested_only) + .into_iter() + .filter(|(room, _)| parents.iter().all(|parent| parent != room)) + .rev() + .collect(); - let mut results = Vec::new(); - let root = arena - .first_untraversed() - .expect("The node just added is not traversed"); + if populate_results { + results.push(summary_to_chunk(*summary.clone())); + } else { + children = children + .into_iter() + .rev() + .skip_while(|(room, _)| { + if let Ok(short) = services().rooms.short.get_shortroomid(room) + { + short.as_ref() != short_room_ids.get(parents.len()) + } else { + false + } + }) + .collect::>() + // skip_while doesn't implement DoubleEndedIterator, which is needed for rev + .into_iter() + .rev() + .collect(); - arena.push(root, get_parent_children_via(&summary, suggested_only)); - if left_to_skip > 0 { - left_to_skip -= 1; - } else { - results.push(summary_to_chunk(*summary.clone())); - } + if children.is_empty() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Room IDs in token were not found.", + )); + } - while let Some(current_room) = arena.first_untraversed() { - if limit > results.len() { - let node = arena - .get(current_room) - .expect("We added this node, it must exist"); - if let Some(SummaryAccessibility::Accessible(summary)) = self - .get_summary_and_children_client(&node.room_id, suggested_only, sender_user, &node.via) - .await? - { - let children = get_parent_children_via(&summary, suggested_only); - arena.push(current_room, children); - - if left_to_skip > 0 { - left_to_skip -= 1; - } else { - results.push(summary_to_chunk(*summary.clone())); + // We have reached the room after where we last left off + if parents.len() + 1 == short_room_ids.len() { + populate_results = true; } } - } else { - break; - } + + if !children.is_empty() && parents.len() < max_depth { + parents.push_back(current_room.clone()); + stack.push(children); + } + // Root room in the space hierarchy, we return an error + // if this one fails. + }, + (Some(SummaryAccessibility::Inaccessible), true) => { + return Err(Error::BadRequest(ErrorKind::forbidden(), "The requested room is inaccessible")); + }, + (None, true) => { + return Err(Error::BadRequest(ErrorKind::forbidden(), "The requested room was not found")); + }, + // Just ignore other unavailable rooms + (None | Some(SummaryAccessibility::Inaccessible), false) => (), + } + } else { + break; + } + } + + Ok(client::space::get_hierarchy::v1::Response { + next_batch: if let Some((room, _)) = next_room_to_traverse(&mut stack, &mut parents) { + parents.pop_front(); + parents.push_back(room); + + let mut short_room_ids = vec![]; + + for room in parents { + short_room_ids.push(services().rooms.short.get_or_create_shortroomid(&room)?); } - Ok(client::space::get_hierarchy::v1::Response { - next_batch: if results.len() < limit { - None - } else { - let skip = UInt::new((skip + limit) as u64); - - skip.map(|skip| { - PagnationToken { - skip, - limit: UInt::new(max_depth as u64) - .expect("When sent in request it must have been valid UInt"), - max_depth: UInt::new(max_depth as u64) - .expect("When sent in request it must have been valid UInt"), - suggested_only, - } - .to_string() - }) - }, - rooms: results, - }) + Some( + PaginationToken { + short_room_ids, + limit: UInt::new(max_depth as u64).expect("When sent in request it must have been valid UInt"), + max_depth: UInt::new(max_depth as u64) + .expect("When sent in request it must have been valid UInt"), + suggested_only, + } + .to_string(), + ) + } else { + None }, - Some(SummaryAccessibility::Inaccessible) => { - Err(Error::BadRequest(ErrorKind::forbidden(), "The requested room is inaccessible")) - }, - None => Err(Error::BadRequest(ErrorKind::forbidden(), "The requested room was not found")), - } + rooms: results, + }) } } +fn next_room_to_traverse( + stack: &mut Vec)>>, parents: &mut VecDeque, +) -> Option<(OwnedRoomId, Vec)> { + while stack.last().map_or(false, Vec::is_empty) { + stack.pop(); + parents.pop_back(); + } + + stack.last_mut().and_then(Vec::pop) +} + /// Simply returns the stripped m.space.child events of a room async fn get_stripped_space_child_events( room_id: &RoomId, @@ -774,96 +644,74 @@ async fn get_stripped_space_child_events( fn is_accessable_child( current_room: &OwnedRoomId, join_rule: &SpaceRoomJoinRule, identifier: &Identifier<'_>, allowed_room_ids: &Vec, -) -> Result { - is_accessable_child_recurse(current_room, join_rule, identifier, allowed_room_ids, 0) -} +) -> bool { + // Note: unwrap_or_default for bool means false + match identifier { + Identifier::ServerName(server_name) => { + let room_id: &RoomId = current_room; -fn is_accessable_child_recurse( - current_room: &OwnedRoomId, join_rule: &SpaceRoomJoinRule, identifier: &Identifier<'_>, - allowed_room_ids: &Vec, recurse_num: usize, -) -> Result { - // Set limit at 10, as we cannot keep going up parents forever - // and it is very unlikely to have 10 space parents - if recurse_num < 10 { - match identifier { - Identifier::ServerName(server_name) => { - let room_id: &RoomId = current_room; - - // Checks if ACLs allow for the server to participate - if services() - .rooms - .event_handler - .acl_check(server_name, room_id) - .is_err() - { - return Ok(false); - } - }, - Identifier::UserId(user_id) => { - if services() + // Checks if ACLs allow for the server to participate + if services() + .rooms + .event_handler + .acl_check(server_name, room_id) + .is_err() + { + return false; + } + }, + Identifier::UserId(user_id) => { + if services() + .rooms + .state_cache + .is_joined(user_id, current_room) + .unwrap_or_default() + || services() .rooms .state_cache - .is_joined(user_id, current_room)? - || services() - .rooms - .state_cache - .is_invited(user_id, current_room)? - { - return Ok(true); - } - }, - Identifier::None => (), - } // Takes care of joinrules - Ok(match join_rule { - SpaceRoomJoinRule::Restricted => { - for room in allowed_room_ids { - if let Ok((join_rule, allowed_room_ids)) = get_join_rule(room) { - if matches!( - is_accessable_child_recurse( - room, - &join_rule, - identifier, - &allowed_room_ids, - recurse_num + 1, - ), - Ok(true) - ) { - return Ok(true); + .is_invited(user_id, current_room) + .unwrap_or_default() + { + return true; + } + }, + } // Takes care of join rules + match join_rule { + SpaceRoomJoinRule::Restricted => { + for room in allowed_room_ids { + match identifier { + Identifier::UserId(user) => { + if services() + .rooms + .state_cache + .is_joined(user, room) + .unwrap_or_default() + { + return true; } - } + }, + Identifier::ServerName(server) => { + if services() + .rooms + .state_cache + .server_in_room(server, room) + .unwrap_or_default() + { + return true; + } + }, } - false - }, - SpaceRoomJoinRule::Public | SpaceRoomJoinRule::Knock | SpaceRoomJoinRule::KnockRestricted => true, - // Custom join rules, Invite, or Private - _ => false, - }) - } else { - // If you need to go up 10 parents, we just assume it is inaccessable - Ok(false) + } + false + }, + SpaceRoomJoinRule::Public | SpaceRoomJoinRule::Knock | SpaceRoomJoinRule::KnockRestricted => true, + // Invite only, Private, or Custom join rule + _ => false, } } -/// Returns the join rule for a given room -fn get_join_rule(current_room: &RoomId) -> Result<(SpaceRoomJoinRule, Vec), Error> { - Ok(services() - .rooms - .state_accessor - .room_state_get(current_room, &StateEventType::RoomJoinRules, "")? - .map(|s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomJoinRulesEventContent| (c.join_rule.clone().into(), allowed_room_ids(c.join_rule))) - .map_err(|e| { - error!("Invalid room join rule event in database: {}", e); - Error::BadDatabase("Invalid room join rule event in database.") - }) - }) - .transpose()? - .unwrap_or((SpaceRoomJoinRule::Invite, vec![]))) -} - -// Here because cannot implement `From` across ruma-federation-api and -// ruma-client-api types +/// Here because cannot implement `From` across ruma-federation-api and +/// ruma-client-api types fn summary_to_chunk(summary: SpaceHierarchyParentSummary) -> SpaceHierarchyRoomsChunk { let SpaceHierarchyParentSummary { canonical_alias, @@ -895,22 +743,6 @@ fn summary_to_chunk(summary: SpaceHierarchyParentSummary) -> SpaceHierarchyRooms } } -/// Returns an empty vec if not a restricted room -fn allowed_room_ids(join_rule: JoinRule) -> Vec { - let mut room_ids = vec![]; - if let JoinRule::Restricted(r) | JoinRule::KnockRestricted(r) = join_rule { - for rule in r.allow { - if let AllowRule::RoomMembership(RoomMembership { - room_id: membership, - }) = rule - { - room_ids.push(membership.clone()); - } - } - } - room_ids -} - /// Returns the children of a SpaceHierarchyParentSummary, making use of the /// children_state field fn get_parent_children_via( @@ -930,367 +762,3 @@ fn get_parent_children_via( }) .collect() } - -#[cfg(test)] -mod tests { - use ruma::{ - api::federation::space::SpaceHierarchyParentSummaryInit, events::room::join_rules::Restricted, owned_room_id, - owned_server_name, - }; - - use super::*; - - fn first(arena: &mut Arena, room_id: &OwnedRoomId) { - let first_untrav = arena.first_untraversed().unwrap(); - - assert_eq!(arena.get(first_untrav).unwrap().room_id, *room_id); - } - - #[test] - fn zero_depth() { - let mut arena = Arena::new(owned_room_id!("!foo:example.org"), 0); - - assert_eq!(arena.first_untraversed(), None); - } - - #[test] - fn two_depth() { - let mut arena = Arena::new(owned_room_id!("!root:example.org"), 2); - - let root = arena.first_untraversed().unwrap(); - arena.push( - root, - vec![ - (owned_room_id!("!subspace1:example.org"), vec![]), - (owned_room_id!("!subspace2:example.org"), vec![]), - (owned_room_id!("!foo:example.org"), vec![]), - ], - ); - - let subspace1 = arena.first_untraversed().unwrap(); - let subspace2 = arena.first_untraversed().unwrap(); - - arena.push( - subspace1, - vec![ - (owned_room_id!("!room1:example.org"), vec![]), - (owned_room_id!("!room2:example.org"), vec![]), - ], - ); - - first(&mut arena, &owned_room_id!("!room1:example.org")); - first(&mut arena, &owned_room_id!("!room2:example.org")); - - arena.push( - subspace2, - vec![ - (owned_room_id!("!room3:example.org"), vec![]), - (owned_room_id!("!room4:example.org"), vec![]), - ], - ); - first(&mut arena, &owned_room_id!("!room3:example.org")); - first(&mut arena, &owned_room_id!("!room4:example.org")); - - let foo_node = NodeId { - index: 1, - }; - - assert_eq!(arena.first_untraversed(), Some(foo_node)); - assert_eq!( - arena.get(foo_node).map(|node| node.room_id.clone()), - Some(owned_room_id!("!foo:example.org")) - ); - } - - #[test] - fn empty_push() { - let mut arena = Arena::new(owned_room_id!("!root:example.org"), 5); - - let root = arena.first_untraversed().unwrap(); - arena.push( - root, - vec![ - (owned_room_id!("!room1:example.org"), vec![]), - (owned_room_id!("!room2:example.org"), vec![]), - ], - ); - - let room1 = arena.first_untraversed().unwrap(); - arena.push(room1, vec![]); - - first(&mut arena, &owned_room_id!("!room2:example.org")); - assert!(arena.first_untraversed().is_none()); - } - - #[test] - fn beyond_max_depth() { - let mut arena = Arena::new(owned_room_id!("!root:example.org"), 0); - - let root = NodeId { - index: 0, - }; - - arena.push(root, vec![(owned_room_id!("!too_deep:example.org"), vec![])]); - - assert_eq!(arena.first_child(root), None); - assert_eq!(arena.nodes.len(), 1); - } - - #[test] - fn order_check() { - let mut arena = Arena::new(owned_room_id!("!root:example.org"), 3); - - let root = arena.first_untraversed().unwrap(); - arena.push( - root, - vec![ - (owned_room_id!("!subspace1:example.org"), vec![]), - (owned_room_id!("!subspace2:example.org"), vec![]), - (owned_room_id!("!foo:example.org"), vec![]), - ], - ); - - let subspace1 = arena.first_untraversed().unwrap(); - arena.push( - subspace1, - vec![ - (owned_room_id!("!room1:example.org"), vec![]), - (owned_room_id!("!room3:example.org"), vec![]), - (owned_room_id!("!room5:example.org"), vec![]), - ], - ); - - first(&mut arena, &owned_room_id!("!room1:example.org")); - first(&mut arena, &owned_room_id!("!room3:example.org")); - first(&mut arena, &owned_room_id!("!room5:example.org")); - - let subspace2 = arena.first_untraversed().unwrap(); - - assert_eq!(arena.get(subspace2).unwrap().room_id, owned_room_id!("!subspace2:example.org")); - - arena.push( - subspace2, - vec![ - (owned_room_id!("!room1:example.org"), vec![]), - (owned_room_id!("!room2:example.org"), vec![]), - ], - ); - - first(&mut arena, &owned_room_id!("!room1:example.org")); - first(&mut arena, &owned_room_id!("!room2:example.org")); - first(&mut arena, &owned_room_id!("!foo:example.org")); - - assert_eq!(arena.first_untraversed(), None); - } - - #[test] - fn get_summary_children() { - let summary: SpaceHierarchyParentSummary = SpaceHierarchyParentSummaryInit { - num_joined_members: UInt::from(1_u32), - room_id: owned_room_id!("!root:example.org"), - world_readable: true, - guest_can_join: true, - join_rule: SpaceRoomJoinRule::Public, - children_state: vec![ - serde_json::from_str( - r#"{ - "content": { - "via": [ - "example.org" - ], - "suggested": false - }, - "origin_server_ts": 1629413349153, - "sender": "@alice:example.org", - "state_key": "!foo:example.org", - "type": "m.space.child" - }"#, - ) - .unwrap(), - serde_json::from_str( - r#"{ - "content": { - "via": [ - "example.org" - ], - "suggested": true - }, - "origin_server_ts": 1629413349157, - "sender": "@alice:example.org", - "state_key": "!bar:example.org", - "type": "m.space.child" - }"#, - ) - .unwrap(), - serde_json::from_str( - r#"{ - "content": { - "via": [ - "example.org" - ] - }, - "origin_server_ts": 1629413349160, - "sender": "@alice:example.org", - "state_key": "!baz:example.org", - "type": "m.space.child" - }"#, - ) - .unwrap(), - ], - allowed_room_ids: vec![], - } - .into(); - - assert_eq!( - get_parent_children_via(&summary, false), - vec![ - (owned_room_id!("!foo:example.org"), vec![owned_server_name!("example.org")]), - (owned_room_id!("!bar:example.org"), vec![owned_server_name!("example.org")]), - (owned_room_id!("!baz:example.org"), vec![owned_server_name!("example.org")]) - ] - ); - assert_eq!( - get_parent_children_via(&summary, true), - vec![(owned_room_id!("!bar:example.org"), vec![owned_server_name!("example.org")])] - ); - } - - #[test] - fn allowed_room_ids_from_join_rule() { - let restricted_join_rule = JoinRule::Restricted(Restricted { - allow: vec![ - AllowRule::RoomMembership(RoomMembership { - room_id: owned_room_id!("!foo:example.org"), - }), - AllowRule::RoomMembership(RoomMembership { - room_id: owned_room_id!("!bar:example.org"), - }), - AllowRule::RoomMembership(RoomMembership { - room_id: owned_room_id!("!baz:example.org"), - }), - ], - }); - - let invite_join_rule = JoinRule::Invite; - - assert_eq!( - allowed_room_ids(restricted_join_rule), - vec![ - owned_room_id!("!foo:example.org"), - owned_room_id!("!bar:example.org"), - owned_room_id!("!baz:example.org") - ] - ); - - let empty_vec: Vec = vec![]; - - assert_eq!(allowed_room_ids(invite_join_rule), empty_vec); - } - - #[test] - fn invalid_pagnation_tokens() { - fn token_is_err(token: &str) { - let token: Result = PagnationToken::from_str(token); - token.unwrap_err(); - } - - token_is_err("231_2_noabool"); - token_is_err(""); - token_is_err("111_3_"); - token_is_err("foo_not_int"); - token_is_err("11_4_true_"); - token_is_err("___"); - token_is_err("__false"); - } - - #[test] - fn valid_pagnation_tokens() { - assert_eq!( - PagnationToken { - skip: UInt::from(40_u32), - limit: UInt::from(20_u32), - max_depth: UInt::from(1_u32), - suggested_only: true - }, - PagnationToken::from_str("40_20_1_true").unwrap() - ); - - assert_eq!( - PagnationToken { - skip: UInt::from(27645_u32), - limit: UInt::from(97_u32), - max_depth: UInt::from(10539_u32), - suggested_only: false - }, - PagnationToken::from_str("27645_97_10539_false").unwrap() - ); - } - - #[test] - fn pagnation_token_to_string() { - assert_eq!( - PagnationToken { - skip: UInt::from(27645_u32), - limit: UInt::from(97_u32), - max_depth: UInt::from(9420_u32), - suggested_only: false - } - .to_string(), - "27645_97_9420_false" - ); - - assert_eq!( - PagnationToken { - skip: UInt::from(12_u32), - limit: UInt::from(3_u32), - max_depth: UInt::from(1_u32), - suggested_only: true - } - .to_string(), - "12_3_1_true" - ); - } - - #[test] - fn forbid_recursion() { - let mut arena = Arena::new(owned_room_id!("!root:example.org"), 5); - let root_node_id = arena.first_untraversed().unwrap(); - - arena.push( - root_node_id, - vec![ - (owned_room_id!("!subspace1:example.org"), vec![]), - (owned_room_id!("!room1:example.org"), vec![]), - (owned_room_id!("!subspace2:example.org"), vec![]), - ], - ); - - let subspace1_node_id = arena.first_untraversed().unwrap(); - arena.push( - subspace1_node_id, - vec![ - (owned_room_id!("!subspace2:example.org"), vec![]), - (owned_room_id!("!room1:example.org"), vec![]), - ], - ); - - let subspace2_node_id = arena.first_untraversed().unwrap(); - // Here, both subspaces should be ignored and not added, as they are both - // parents of subspace2 - arena.push( - subspace2_node_id, - vec![ - (owned_room_id!("!subspace1:example.org"), vec![]), - (owned_room_id!("!subspace2:example.org"), vec![]), - (owned_room_id!("!room1:example.org"), vec![]), - ], - ); - - assert_eq!(arena.nodes.len(), 7); - first(&mut arena, &owned_room_id!("!room1:example.org")); - first(&mut arena, &owned_room_id!("!room1:example.org")); - first(&mut arena, &owned_room_id!("!room1:example.org")); - first(&mut arena, &owned_room_id!("!subspace2:example.org")); - assert!(arena.first_untraversed().is_none()); - } -} diff --git a/src/service/rooms/spaces/tests.rs b/src/service/rooms/spaces/tests.rs new file mode 100644 index 00000000..71640035 --- /dev/null +++ b/src/service/rooms/spaces/tests.rs @@ -0,0 +1,145 @@ +#![cfg(test)] + +use std::str::FromStr; + +use ruma::{ + api::federation::space::{SpaceHierarchyParentSummary, SpaceHierarchyParentSummaryInit}, + owned_room_id, owned_server_name, + space::SpaceRoomJoinRule, + UInt, +}; + +use crate::rooms::spaces::{get_parent_children_via, PaginationToken}; + +#[test] +fn get_summary_children() { + let summary: SpaceHierarchyParentSummary = SpaceHierarchyParentSummaryInit { + num_joined_members: UInt::from(1_u32), + room_id: owned_room_id!("!root:example.org"), + world_readable: true, + guest_can_join: true, + join_rule: SpaceRoomJoinRule::Public, + children_state: vec![ + serde_json::from_str( + r#"{ + "content": { + "via": [ + "example.org" + ], + "suggested": false + }, + "origin_server_ts": 1629413349153, + "sender": "@alice:example.org", + "state_key": "!foo:example.org", + "type": "m.space.child" + }"#, + ) + .unwrap(), + serde_json::from_str( + r#"{ + "content": { + "via": [ + "example.org" + ], + "suggested": true + }, + "origin_server_ts": 1629413349157, + "sender": "@alice:example.org", + "state_key": "!bar:example.org", + "type": "m.space.child" + }"#, + ) + .unwrap(), + serde_json::from_str( + r#"{ + "content": { + "via": [ + "example.org" + ] + }, + "origin_server_ts": 1629413349160, + "sender": "@alice:example.org", + "state_key": "!baz:example.org", + "type": "m.space.child" + }"#, + ) + .unwrap(), + ], + allowed_room_ids: vec![], + } + .into(); + + assert_eq!( + get_parent_children_via(&summary, false), + vec![ + (owned_room_id!("!foo:example.org"), vec![owned_server_name!("example.org")]), + (owned_room_id!("!bar:example.org"), vec![owned_server_name!("example.org")]), + (owned_room_id!("!baz:example.org"), vec![owned_server_name!("example.org")]) + ] + ); + assert_eq!( + get_parent_children_via(&summary, true), + vec![(owned_room_id!("!bar:example.org"), vec![owned_server_name!("example.org")])] + ); +} + +#[test] +fn invalid_pagination_tokens() { + fn token_is_err(token: &str) { PaginationToken::from_str(token).unwrap_err(); } + + token_is_err("231_2_noabool"); + token_is_err(""); + token_is_err("111_3_"); + token_is_err("foo_not_int"); + token_is_err("11_4_true_"); + token_is_err("___"); + token_is_err("__false"); +} + +#[test] +fn valid_pagination_tokens() { + assert_eq!( + PaginationToken { + short_room_ids: vec![5383, 42934, 283, 423], + limit: UInt::from(20_u32), + max_depth: UInt::from(1_u32), + suggested_only: true + }, + PaginationToken::from_str("5383,42934,283,423_20_1_true").unwrap() + ); + + assert_eq!( + PaginationToken { + short_room_ids: vec![740], + limit: UInt::from(97_u32), + max_depth: UInt::from(10539_u32), + suggested_only: false + }, + PaginationToken::from_str("740_97_10539_false").unwrap() + ); +} + +#[test] +fn pagination_token_to_string() { + assert_eq!( + PaginationToken { + short_room_ids: vec![740], + limit: UInt::from(97_u32), + max_depth: UInt::from(10539_u32), + suggested_only: false + } + .to_string(), + "740_97_10539_false" + ); + + assert_eq!( + PaginationToken { + short_room_ids: vec![9, 34], + limit: UInt::from(3_u32), + max_depth: UInt::from(1_u32), + suggested_only: true + } + .to_string(), + "9,34_3_1_true" + ); +} diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index ce884fe5..65c43905 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -16,6 +16,7 @@ use ruma::{ canonical_alias::RoomCanonicalAliasEventContent, guest_access::{GuestAccess, RoomGuestAccessEventContent}, history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, + join_rules::{AllowRule, JoinRule, RoomJoinRulesEventContent, RoomMembership}, member::{MembershipState, RoomMemberEventContent}, name::RoomNameEventContent, power_levels::{RoomPowerLevels, RoomPowerLevelsEventContent}, @@ -23,7 +24,8 @@ use ruma::{ }, StateEventType, }, - EventId, OwnedRoomAliasId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, + space::SpaceRoomJoinRule, + EventId, OwnedRoomAliasId, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, }; use serde_json::value::to_raw_value; @@ -415,4 +417,38 @@ impl Service { }, ) } + + /// Returns the join rule for a given room + pub fn get_join_rule(&self, current_room: &RoomId) -> Result<(SpaceRoomJoinRule, Vec), Error> { + Ok(self + .room_state_get(current_room, &StateEventType::RoomJoinRules, "")? + .map(|s| { + serde_json::from_str(s.content.get()) + .map(|c: RoomJoinRulesEventContent| { + (c.join_rule.clone().into(), self.allowed_room_ids(c.join_rule)) + }) + .map_err(|e| { + error!("Invalid room join rule event in database: {e}"); + Error::BadDatabase("Invalid room join rule event in database.") + }) + }) + .transpose()? + .unwrap_or((SpaceRoomJoinRule::Invite, vec![]))) + } + + /// Returns an empty vec if not a restricted room + pub fn allowed_room_ids(&self, join_rule: JoinRule) -> Vec { + let mut room_ids = vec![]; + if let JoinRule::Restricted(r) | JoinRule::KnockRestricted(r) = join_rule { + for rule in r.allow { + if let AllowRule::RoomMembership(RoomMembership { + room_id: membership, + }) = rule + { + room_ids.push(membership.clone()); + } + } + } + room_ids + } } diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index eeecc9c7..6572fcba 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -293,6 +293,7 @@ impl Service { self.db.room_members(room_id) } + /// Returns the number of users which are currently in a room #[tracing::instrument(skip(self))] pub fn room_joined_count(&self, room_id: &RoomId) -> Result> { self.db.room_joined_count(room_id) } @@ -310,6 +311,7 @@ impl Service { self.db.active_local_users_in_room(room_id) } + /// Returns the number of users which are currently invited to a room #[tracing::instrument(skip(self))] pub fn room_invited_count(&self, room_id: &RoomId) -> Result> { self.db.room_invited_count(room_id) }