diff --git a/src/service/rooms/spaces/mod.rs b/src/service/rooms/spaces/mod.rs index e83aaab0..6488a4d2 100644 --- a/src/service/rooms/spaces/mod.rs +++ b/src/service/rooms/spaces/mod.rs @@ -1,4 +1,7 @@ -use std::str::FromStr; +use std::{ + fmt::{Display, Formatter}, + str::FromStr, +}; use lru_cache::LruCache; use ruma::{ @@ -24,7 +27,7 @@ use ruma::{ }, serde::Raw, space::SpaceRoomJoinRule, - OwnedRoomId, RoomId, ServerName, UInt, UserId, + OwnedRoomId, OwnedServerName, RoomId, ServerName, UInt, UserId, }; use tokio::sync::Mutex; use tracing::{debug, error, warn}; @@ -58,6 +61,7 @@ pub struct Node { // o o o o first_child: Option, pub room_id: OwnedRoomId, + pub via: Vec, traversed: bool, } @@ -133,7 +137,7 @@ impl Arena { } /// Adds all the given nodes as children of the parent node - fn push(&mut self, parent: NodeId, mut children: Vec) { + fn push(&mut self, parent: NodeId, mut children: Vec<(OwnedRoomId, Vec)>) { if children.is_empty() { self.traverse(parent); } else if self.nodes.get(parent.index).is_some() { @@ -166,7 +170,7 @@ impl Arena { let mut next_id = None; - for child in children { + for (child, via) in children { // Prevent adding a child which is a parent (recursion) if !parents.iter().any(|parent| parent.1 == child) { self.nodes.push(Node { @@ -175,6 +179,7 @@ impl Arena { first_child: None, room_id: child, traversed: false, + via, }); next_id = Some(NodeId { @@ -214,6 +219,7 @@ impl Arena { first_child: None, room_id: root, traversed: zero_depth, + via: vec![], }], max_depth, first_untraversed: if zero_depth { @@ -273,8 +279,8 @@ impl FromStr for PagnationToken { } } -impl std::fmt::Display for PagnationToken { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl Display for PagnationToken { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "{}_{}_{}_{}", self.skip, self.limit, self.max_depth, self.suggested_only) } } @@ -336,16 +342,16 @@ impl Service { &self, room_id: &RoomId, server_name: &ServerName, suggested_only: bool, ) -> Result { match self - .get_summary_and_children(&room_id.to_owned(), suggested_only, Identifier::None) + .get_summary_and_children_local(&room_id.to_owned(), Identifier::None) .await? { Some(SummaryAccessibility::Accessible(room)) => { let mut children = Vec::new(); let mut inaccessible_children = Vec::new(); - for child in get_parent_children(&room.clone(), suggested_only) { + for (child, _via) in get_parent_children_via(&room, suggested_only) { match self - .get_summary_and_children(&child, suggested_only, Identifier::ServerName(server_name)) + .get_summary_and_children_local(&child, Identifier::ServerName(server_name)) .await? { Some(SummaryAccessibility::Accessible(summary)) => { @@ -371,8 +377,8 @@ impl Service { } } - async fn get_summary_and_children( - &self, current_room: &OwnedRoomId, suggested_only: bool, identifier: Identifier<'_>, + async fn get_summary_and_children_local( + &self, current_room: &OwnedRoomId, identifier: Identifier<'_>, ) -> Result> { if let Some(cached) = self .roomid_spacehierarchy_cache @@ -399,7 +405,7 @@ impl Service { Ok( if let Some(children_pdus) = get_stripped_space_child_events(current_room).await? { - let summary = self.get_room_summary(current_room, children_pdus, &identifier); + let summary = Self::get_room_summary(current_room, children_pdus, &identifier); if let Ok(summary) = summary { self.roomid_spacehierarchy_cache.lock().await.insert( current_room.clone(), @@ -410,96 +416,6 @@ impl Service { Some(SummaryAccessibility::Accessible(Box::new(summary))) } else { - None - } - // Federation requests should not request information from other - // servers - } else if let Identifier::UserId(_) = identifier { - let server = current_room - .server_name() - .expect("Room IDs should always have a server name"); - if server == services().globals.server_name() { - return Ok(None); - } - - debug!("Asking {server} for /hierarchy"); - if let Ok(response) = services() - .sending - .send_federation_request( - server, - federation::space::get_hierarchy::v1::Request { - room_id: current_room.to_owned(), - suggested_only, - }, - ) - .await - { - debug!("Got response from {server} for /hierarchy\n{response:?}"); - let summary = response.room.clone(); - - self.roomid_spacehierarchy_cache.lock().await.insert( - current_room.clone(), - Some(CachedSpaceHierarchySummary { - summary: summary.clone(), - }), - ); - - for child in response.children { - let mut guard = self.roomid_spacehierarchy_cache.lock().await; - if !guard.contains_key(current_room) { - guard.insert( - current_room.clone(), - Some(CachedSpaceHierarchySummary { - summary: { - let SpaceHierarchyChildSummary { - canonical_alias, - name, - num_joined_members, - room_id, - topic, - world_readable, - guest_can_join, - avatar_url, - join_rule, - room_type, - allowed_room_ids, - } = child; - - SpaceHierarchyParentSummary { - canonical_alias, - name, - num_joined_members, - room_id: room_id.clone(), - topic, - world_readable, - guest_can_join, - avatar_url, - join_rule, - room_type, - children_state: get_stripped_space_child_events(&room_id).await?.unwrap(), - allowed_room_ids, - } - }, - }), - ); - } - } - if is_accessable_child( - current_room, - &response.room.join_rule, - &identifier, - &response.room.allowed_room_ids, - )? { - Some(SummaryAccessibility::Accessible(Box::new(summary.clone()))) - } else { - Some(SummaryAccessibility::Inaccessible) - } - } else { - self.roomid_spacehierarchy_cache - .lock() - .await - .insert(current_room.clone(), None); - None } } else { @@ -508,9 +424,108 @@ impl Service { ) } + async fn get_summary_and_children_federation( + &self, current_room: &OwnedRoomId, suggested_only: bool, user_id: &UserId, via: &Vec, + ) -> Result> { + for server in via { + debug!("Asking {server} for /hierarchy"); + if let Ok(response) = services() + .sending + .send_federation_request( + server, + federation::space::get_hierarchy::v1::Request { + room_id: current_room.to_owned(), + suggested_only, + }, + ) + .await + { + debug!("Got response from {server} for /hierarchy\n{response:?}"); + let summary = response.room.clone(); + + self.roomid_spacehierarchy_cache.lock().await.insert( + current_room.clone(), + Some(CachedSpaceHierarchySummary { + summary: summary.clone(), + }), + ); + + for child in response.children { + let mut guard = self.roomid_spacehierarchy_cache.lock().await; + if !guard.contains_key(current_room) { + guard.insert( + current_room.clone(), + Some(CachedSpaceHierarchySummary { + summary: { + let SpaceHierarchyChildSummary { + canonical_alias, + name, + num_joined_members, + room_id, + topic, + world_readable, + guest_can_join, + avatar_url, + join_rule, + room_type, + allowed_room_ids, + } = child; + + SpaceHierarchyParentSummary { + canonical_alias, + name, + num_joined_members, + room_id: room_id.clone(), + topic, + world_readable, + guest_can_join, + avatar_url, + join_rule, + room_type, + children_state: get_stripped_space_child_events(&room_id).await?.unwrap(), + allowed_room_ids, + } + }, + }), + ); + } + } + if is_accessable_child( + current_room, + &response.room.join_rule, + &Identifier::UserId(user_id), + &response.room.allowed_room_ids, + )? { + return Ok(Some(SummaryAccessibility::Accessible(Box::new(summary.clone())))); + } + + return Ok(Some(SummaryAccessibility::Inaccessible)); + } + + self.roomid_spacehierarchy_cache + .lock() + .await + .insert(current_room.clone(), None); + } + Ok(None) + } + + async fn get_summary_and_children_client( + &self, current_room: &OwnedRoomId, suggested_only: bool, user_id: &UserId, via: &Vec, + ) -> Result> { + if let Ok(Some(response)) = self + .get_summary_and_children_local(current_room, Identifier::UserId(user_id)) + .await + { + Ok(Some(response)) + } else { + self.get_summary_and_children_federation(current_room, suggested_only, user_id, via) + .await + } + } + fn get_room_summary( - &self, current_room: &OwnedRoomId, children_state: Vec>, - identifier: &Identifier<'_>, + current_room: &OwnedRoomId, children_state: Vec>, identifier: &Identifier<'_>, ) -> Result { let room_id: &RoomId = current_room; @@ -611,7 +626,15 @@ impl Service { suggested_only: bool, ) -> Result { match self - .get_summary_and_children(&room_id.to_owned(), suggested_only, Identifier::UserId(sender_user)) + .get_summary_and_children_client( + &room_id.to_owned(), + suggested_only, + sender_user, + &match room_id.server_name() { + Some(server_name) => vec![server_name.into()], + None => vec![], + }, + ) .await? { Some(SummaryAccessibility::Accessible(summary)) => { @@ -623,23 +646,23 @@ impl Service { .first_untraversed() .expect("The node just added is not traversed"); - arena.push(root, get_parent_children(&summary.clone(), suggested_only)); - results.push(summary_to_chunk(*summary.clone())); + arena.push(root, get_parent_children_via(&summary, suggested_only)); + if left_to_skip > 0 { + left_to_skip -= 1; + } else { + results.push(summary_to_chunk(*summary.clone())); + } while let Some(current_room) = arena.first_untraversed() { if limit > results.len() { + let node = arena + .get(current_room) + .expect("We added this node, it must exist"); if let Some(SummaryAccessibility::Accessible(summary)) = self - .get_summary_and_children( - &arena - .get(current_room) - .expect("We added this node, it must exist") - .room_id, - suggested_only, - Identifier::UserId(sender_user), - ) + .get_summary_and_children_client(&node.room_id, suggested_only, sender_user, &node.via) .await? { - let children = get_parent_children(&summary.clone(), suggested_only); + let children = get_parent_children_via(&summary, suggested_only); arena.push(current_room, children); if left_to_skip > 0 { @@ -736,6 +759,7 @@ fn is_accessable_child_recurse( allowed_room_ids: &Vec, recurse_num: usize, ) -> Result { // Set limit at 10, as we cannot keep going up parents forever + // and it is very unlikely to have 10 space parents if recurse_num < 10 { match identifier { Identifier::ServerName(server_name) => { @@ -767,10 +791,9 @@ fn is_accessable_child_recurse( Identifier::None => (), } // Takes care of joinrules Ok(match join_rule { - SpaceRoomJoinRule::KnockRestricted | SpaceRoomJoinRule::Restricted => { + SpaceRoomJoinRule::Restricted => { for room in allowed_room_ids { if let Ok((join_rule, allowed_room_ids)) = get_join_rule(room) { - // Recursive, get rid of if possible if let Ok(true) = is_accessable_child_recurse( room, &join_rule, @@ -784,9 +807,8 @@ fn is_accessable_child_recurse( } false }, - SpaceRoomJoinRule::Public | SpaceRoomJoinRule::Knock => true, - // SpaceRoomJoinRule::Invite | SpaceRoomJoinRule::Private => false, - // Custom join rules, Invites, or Private + SpaceRoomJoinRule::Public | SpaceRoomJoinRule::Knock | SpaceRoomJoinRule::KnockRestricted => true, + // Custom join rules, Invite, or Private _ => false, }) } else { @@ -890,7 +912,9 @@ fn allowed_room_ids(join_rule: JoinRule) -> Vec { /// Returns the children of a SpaceHierarchyParentSummary, making use of the /// children_state field -fn get_parent_children(parent: &SpaceHierarchyParentSummary, suggested_only: bool) -> Vec { +fn get_parent_children_via( + parent: &SpaceHierarchyParentSummary, suggested_only: bool, +) -> Vec<(OwnedRoomId, Vec)> { parent .children_state .iter() @@ -899,7 +923,7 @@ fn get_parent_children(parent: &SpaceHierarchyParentSummary, suggested_only: boo if suggested_only && !ce.content.suggested { None } else { - Some(ce.state_key) + Some((ce.state_key, ce.content.via)) } }) }) @@ -910,14 +934,15 @@ fn get_parent_children(parent: &SpaceHierarchyParentSummary, suggested_only: boo mod tests { use ruma::{ api::federation::space::SpaceHierarchyParentSummaryInit, events::room::join_rules::Restricted, owned_room_id, + owned_server_name, }; use super::*; - fn first(arena: &mut Arena, room_id: &OwnedRoomId) { + fn first(arena: &mut Arena, room_id: OwnedRoomId) { let first_untrav = arena.first_untraversed().unwrap(); - assert_eq!(&arena.get(first_untrav).unwrap().room_id, room_id); + assert_eq!(arena.get(first_untrav).unwrap().room_id, room_id); } #[test] @@ -935,9 +960,9 @@ mod tests { arena.push( root, vec![ - owned_room_id!("!subspace1:example.org"), - owned_room_id!("!subspace2:example.org"), - owned_room_id!("!foo:example.org"), + (owned_room_id!("!subspace1:example.org"), vec![]), + (owned_room_id!("!subspace2:example.org"), vec![]), + (owned_room_id!("!foo:example.org"), vec![]), ], ); @@ -946,19 +971,24 @@ mod tests { arena.push( subspace1, - vec![owned_room_id!("!room1:example.org"), owned_room_id!("!room2:example.org")], + vec![ + (owned_room_id!("!room1:example.org"), vec![]), + (owned_room_id!("!room2:example.org"), vec![]), + ], ); - first(&mut arena, &owned_room_id!("!room1:example.org")); - first(&mut arena, &owned_room_id!("!room2:example.org")); + first(&mut arena, owned_room_id!("!room1:example.org")); + first(&mut arena, owned_room_id!("!room2:example.org")); arena.push( subspace2, - vec![owned_room_id!("!room3:example.org"), owned_room_id!("!room4:example.org")], + vec![ + (owned_room_id!("!room3:example.org"), vec![]), + (owned_room_id!("!room4:example.org"), vec![]), + ], ); - - first(&mut arena, &owned_room_id!("!room3:example.org")); - first(&mut arena, &owned_room_id!("!room4:example.org")); + first(&mut arena, owned_room_id!("!room3:example.org")); + first(&mut arena, owned_room_id!("!room4:example.org")); let foo_node = NodeId { index: 1, @@ -978,13 +1008,16 @@ mod tests { let root = arena.first_untraversed().unwrap(); arena.push( root, - vec![owned_room_id!("!room1:example.org"), owned_room_id!("!room2:example.org")], + vec![ + (owned_room_id!("!room1:example.org"), vec![]), + (owned_room_id!("!room2:example.org"), vec![]), + ], ); let room1 = arena.first_untraversed().unwrap(); arena.push(room1, vec![]); - first(&mut arena, &owned_room_id!("!room2:example.org")); + first(&mut arena, owned_room_id!("!room2:example.org")); assert!(arena.first_untraversed().is_none()); } @@ -996,7 +1029,7 @@ mod tests { index: 0, }; - arena.push(root, vec![owned_room_id!("!too_deep:example.org")]); + arena.push(root, vec![(owned_room_id!("!too_deep:example.org"), vec![])]); assert_eq!(arena.first_child(root), None); assert_eq!(arena.nodes.len(), 1); @@ -1010,9 +1043,9 @@ mod tests { arena.push( root, vec![ - owned_room_id!("!subspace1:example.org"), - owned_room_id!("!subspace2:example.org"), - owned_room_id!("!foo:example.org"), + (owned_room_id!("!subspace1:example.org"), vec![]), + (owned_room_id!("!subspace2:example.org"), vec![]), + (owned_room_id!("!foo:example.org"), vec![]), ], ); @@ -1020,15 +1053,15 @@ mod tests { arena.push( subspace1, vec![ - owned_room_id!("!room1:example.org"), - owned_room_id!("!room3:example.org"), - owned_room_id!("!room5:example.org"), + (owned_room_id!("!room1:example.org"), vec![]), + (owned_room_id!("!room3:example.org"), vec![]), + (owned_room_id!("!room5:example.org"), vec![]), ], ); - first(&mut arena, &owned_room_id!("!room1:example.org")); - first(&mut arena, &owned_room_id!("!room3:example.org")); - first(&mut arena, &owned_room_id!("!room5:example.org")); + first(&mut arena, owned_room_id!("!room1:example.org")); + first(&mut arena, owned_room_id!("!room3:example.org")); + first(&mut arena, owned_room_id!("!room5:example.org")); let subspace2 = arena.first_untraversed().unwrap(); @@ -1036,12 +1069,15 @@ mod tests { arena.push( subspace2, - vec![owned_room_id!("!room1:example.org"), owned_room_id!("!room2:example.org")], + vec![ + (owned_room_id!("!room1:example.org"), vec![]), + (owned_room_id!("!room2:example.org"), vec![]), + ], ); - first(&mut arena, &owned_room_id!("!room1:example.org")); - first(&mut arena, &owned_room_id!("!room2:example.org")); - first(&mut arena, &owned_room_id!("!foo:example.org")); + first(&mut arena, owned_room_id!("!room1:example.org")); + first(&mut arena, owned_room_id!("!room2:example.org")); + first(&mut arena, owned_room_id!("!foo:example.org")); assert_eq!(arena.first_untraversed(), None); } @@ -1105,18 +1141,21 @@ mod tests { .into(); assert_eq!( - get_parent_children(&summary, false), + get_parent_children_via(&summary, false), vec![ - owned_room_id!("!foo:example.org"), - owned_room_id!("!bar:example.org"), - owned_room_id!("!baz:example.org") + (owned_room_id!("!foo:example.org"), vec![owned_server_name!("example.org")]), + (owned_room_id!("!bar:example.org"), vec![owned_server_name!("example.org")]), + (owned_room_id!("!baz:example.org"), vec![owned_server_name!("example.org")]) ] ); - assert_eq!(get_parent_children(&summary, true), vec![owned_room_id!("!bar:example.org")]); + assert_eq!( + get_parent_children_via(&summary, true), + vec![(owned_room_id!("!bar:example.org"), vec![owned_server_name!("example.org")])] + ); } #[test] - fn allowed_room_ids_rom_join_rule() { + fn allowed_room_ids_from_join_rule() { let restricted_join_rule = JoinRule::Restricted(Restricted { allow: vec![ AllowRule::RoomMembership(RoomMembership { @@ -1151,7 +1190,7 @@ mod tests { fn invalid_pagnation_tokens() { fn token_is_err(token: &str) { let token: Result = PagnationToken::from_str(token); - token.unwrap_err(); + assert!(token.is_err()); } token_is_err("231_2_noabool"); @@ -1219,16 +1258,19 @@ mod tests { arena.push( root_node_id, vec![ - owned_room_id!("!subspace1:example.org"), - owned_room_id!("!room1:example.org"), - owned_room_id!("!subspace2:example.org"), + (owned_room_id!("!subspace1:example.org"), vec![]), + (owned_room_id!("!room1:example.org"), vec![]), + (owned_room_id!("!subspace2:example.org"), vec![]), ], ); let subspace1_node_id = arena.first_untraversed().unwrap(); arena.push( subspace1_node_id, - vec![owned_room_id!("!subspace2:example.org"), owned_room_id!("!room1:example.org")], + vec![ + (owned_room_id!("!subspace2:example.org"), vec![]), + (owned_room_id!("!room1:example.org"), vec![]), + ], ); let subspace2_node_id = arena.first_untraversed().unwrap(); @@ -1237,17 +1279,17 @@ mod tests { arena.push( subspace2_node_id, vec![ - owned_room_id!("!subspace1:example.org"), - owned_room_id!("!subspace2:example.org"), - owned_room_id!("!room1:example.org"), + (owned_room_id!("!subspace1:example.org"), vec![]), + (owned_room_id!("!subspace2:example.org"), vec![]), + (owned_room_id!("!room1:example.org"), vec![]), ], ); assert_eq!(arena.nodes.len(), 7); - first(&mut arena, &owned_room_id!("!room1:example.org")); - first(&mut arena, &owned_room_id!("!room1:example.org")); - first(&mut arena, &owned_room_id!("!room1:example.org")); - first(&mut arena, &owned_room_id!("!subspace2:example.org")); + first(&mut arena, owned_room_id!("!room1:example.org")); + first(&mut arena, owned_room_id!("!room1:example.org")); + first(&mut arena, owned_room_id!("!room1:example.org")); + first(&mut arena, owned_room_id!("!subspace2:example.org")); assert!(arena.first_untraversed().is_none()); } }