From ed960f41acdaa6c7123169629ed35389d26f0415 Mon Sep 17 00:00:00 2001 From: Matthias Ahouansou Date: Wed, 3 Apr 2024 14:10:00 -0400 Subject: [PATCH] feat: recurse relationships (and fix some lints) from https://gitlab.com/famedly/conduit/-/merge_requests/613 Signed-off-by: strawberry --- src/api/client_server/relations.rs | 87 ++++---------------- src/service/rooms/pdu_metadata/mod.rs | 113 ++++++++++++++++++++------ 2 files changed, 101 insertions(+), 99 deletions(-) diff --git a/src/api/client_server/relations.rs b/src/api/client_server/relations.rs index d80374f0..f2ddfecb 100644 --- a/src/api/client_server/relations.rs +++ b/src/api/client_server/relations.rs @@ -2,7 +2,7 @@ use ruma::api::client::relations::{ get_relating_events, get_relating_events_with_rel_type, get_relating_events_with_rel_type_and_event_type, }; -use crate::{service::rooms::timeline::PduCount, services, Result, Ruma}; +use crate::{services, Result, Ruma}; /// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}/{eventType}` pub async fn get_relating_events_with_rel_type_and_event_type_route( @@ -10,26 +10,6 @@ pub async fn get_relating_events_with_rel_type_and_event_type_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let from = match body.from.clone() { - Some(from) => PduCount::try_from_string(&from)?, - None => match body.dir { - ruma::api::Direction::Forward => PduCount::min(), - ruma::api::Direction::Backward => PduCount::max(), - }, - }; - - let to = body - .to - .as_ref() - .and_then(|t| PduCount::try_from_string(t).ok()); - - // Use limit or else 10, with maximum 100 - let limit = body - .limit - .and_then(|u| u32::try_from(u).ok()) - .map_or(10_usize, |u| u as usize) - .min(100); - let res = services() .rooms .pdu_metadata @@ -39,17 +19,18 @@ pub async fn get_relating_events_with_rel_type_and_event_type_route( &body.event_id, &Some(body.event_type.clone()), &Some(body.rel_type.clone()), - from, + &body.from, + &body.to, + &body.limit, + body.recurse, body.dir, - to, - limit, )?; Ok(get_relating_events_with_rel_type_and_event_type::v1::Response { chunk: res.chunk, next_batch: res.next_batch, prev_batch: res.prev_batch, - recursion_depth: None, // TODO + recursion_depth: res.recursion_depth, }) } @@ -59,26 +40,6 @@ pub async fn get_relating_events_with_rel_type_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let from = match body.from.clone() { - Some(from) => PduCount::try_from_string(&from)?, - None => match body.dir { - ruma::api::Direction::Forward => PduCount::min(), - ruma::api::Direction::Backward => PduCount::max(), - }, - }; - - let to = body - .to - .as_ref() - .and_then(|t| PduCount::try_from_string(t).ok()); - - // Use limit or else 10, with maximum 100 - let limit = body - .limit - .and_then(|u| u32::try_from(u).ok()) - .map_or(10_usize, |u| u as usize) - .min(100); - let res = services() .rooms .pdu_metadata @@ -88,17 +49,18 @@ pub async fn get_relating_events_with_rel_type_route( &body.event_id, &None, &Some(body.rel_type.clone()), - from, + &body.from, + &body.to, + &body.limit, + body.recurse, body.dir, - to, - limit, )?; Ok(get_relating_events_with_rel_type::v1::Response { chunk: res.chunk, next_batch: res.next_batch, prev_batch: res.prev_batch, - recursion_depth: None, // TODO + recursion_depth: res.recursion_depth, }) } @@ -108,26 +70,6 @@ pub async fn get_relating_events_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let from = match body.from.clone() { - Some(from) => PduCount::try_from_string(&from)?, - None => match body.dir { - ruma::api::Direction::Forward => PduCount::min(), - ruma::api::Direction::Backward => PduCount::max(), - }, - }; - - let to = body - .to - .as_ref() - .and_then(|t| PduCount::try_from_string(t).ok()); - - // Use limit or else 10, with maximum 100 - let limit = body - .limit - .and_then(|u| u32::try_from(u).ok()) - .map_or(10_usize, |u| u as usize) - .min(100); - services() .rooms .pdu_metadata @@ -137,9 +79,10 @@ pub async fn get_relating_events_route( &body.event_id, &None, &None, - from, + &body.from, + &body.to, + &body.limit, + body.recurse, body.dir, - to, - limit, ) } diff --git a/src/service/rooms/pdu_metadata/mod.rs b/src/service/rooms/pdu_metadata/mod.rs index f44b8939..c2483475 100644 --- a/src/service/rooms/pdu_metadata/mod.rs +++ b/src/service/rooms/pdu_metadata/mod.rs @@ -5,7 +5,7 @@ pub use data::Data; use ruma::{ api::{client::relations::get_relating_events, Direction}, events::{relation::RelationType, TimelineEventType}, - EventId, RoomId, UserId, + EventId, RoomId, UInt, UserId, }; use serde::Deserialize; @@ -42,21 +42,47 @@ impl Service { #[allow(clippy::too_many_arguments)] pub fn paginate_relations_with_filter( &self, sender_user: &UserId, room_id: &RoomId, target: &EventId, filter_event_type: &Option, - filter_rel_type: &Option, from: PduCount, dir: Direction, to: Option, limit: usize, + filter_rel_type: &Option, from: &Option, to: &Option, limit: &Option, + recurse: bool, dir: Direction, ) -> Result { + let from = match from { + Some(from) => PduCount::try_from_string(from)?, + None => match dir { + Direction::Forward => PduCount::min(), + Direction::Backward => PduCount::max(), + }, + }; + + let to = to.as_ref().and_then(|t| PduCount::try_from_string(t).ok()); + + // Use limit or else 10, with maximum 100 + let limit = limit + .and_then(|u| u32::try_from(u).ok()) + .map_or(10_usize, |u| u as usize) + .min(100); + let next_token; + // Spec (v1.10) recommends depth of at least 3 + let depth: u8 = if recurse { + 3 + } else { + 1 + }; + match dir { Direction::Forward => { - let events_after: Vec<_> = services() - .rooms - .pdu_metadata - .relations_until(sender_user, room_id, target, from)? // TODO: should be relations_after - .filter(|r| { - r.as_ref().map_or(true, |(_, pdu)| { + let relations_until = + &services() + .rooms + .pdu_metadata + .relations_until(sender_user, room_id, target, from, depth)?; + let events_after: Vec<_> = relations_until // TODO: should be relations_after + .iter() + .filter(|(_, pdu)| { filter_event_type.as_ref().map_or(true, |t| &pdu.kind == t) && if let Ok(content) = - serde_json::from_str::(pdu.content.get()) + serde_json::from_str::(pdu.content.get()) { filter_rel_type .as_ref() @@ -65,9 +91,7 @@ impl Service { false } }) - }) .take(limit) - .filter_map(Result::ok) // Filter out buggy events .filter(|(_, pdu)| { services() .rooms @@ -75,7 +99,7 @@ impl Service { .user_can_see_event(sender_user, room_id, &pdu.event_id) .unwrap_or(false) }) - .take_while(|&(k, _)| Some(k) != to) // Stop at `to` + .take_while(|(k, _)| Some(k) != to.as_ref()) // Stop at `to` .collect(); next_token = events_after.last().map(|(count, _)| count).copied(); @@ -90,19 +114,25 @@ impl Service { chunk: events_after, next_batch: next_token.map(|t| t.stringify()), prev_batch: Some(from.stringify()), - recursion_depth: None, // TODO + recursion_depth: if recurse { + Some(depth.into()) + } else { + None + }, }) }, Direction::Backward => { - let events_before: Vec<_> = services() - .rooms - .pdu_metadata - .relations_until(sender_user, room_id, target, from)? - .filter(|r| { - r.as_ref().map_or(true, |(_, pdu)| { + let relations_until = + &services() + .rooms + .pdu_metadata + .relations_until(sender_user, room_id, target, from, depth)?; + let events_before: Vec<_> = relations_until + .iter() + .filter(|(_, pdu)| { filter_event_type.as_ref().map_or(true, |t| &pdu.kind == t) && if let Ok(content) = - serde_json::from_str::(pdu.content.get()) + serde_json::from_str::(pdu.content.get()) { filter_rel_type .as_ref() @@ -111,9 +141,7 @@ impl Service { false } }) - }) .take(limit) - .filter_map(Result::ok) // Filter out buggy events .filter(|(_, pdu)| { services() .rooms @@ -121,7 +149,7 @@ impl Service { .user_can_see_event(sender_user, room_id, &pdu.event_id) .unwrap_or(false) }) - .take_while(|&(k, _)| Some(k) != to) // Stop at `to` + .take_while(|&(k, _)| Some(k) != to.as_ref()) // Stop at `to` .collect(); next_token = events_before.last().map(|(count, _)| count).copied(); @@ -135,15 +163,19 @@ impl Service { chunk: events_before, next_batch: next_token.map(|t| t.stringify()), prev_batch: Some(from.stringify()), - recursion_depth: None, // TODO + recursion_depth: if recurse { + Some(depth.into()) + } else { + None + }, }) }, } } pub fn relations_until<'a>( - &'a self, user_id: &'a UserId, room_id: &'a RoomId, target: &'a EventId, until: PduCount, - ) -> Result> + 'a> { + &'a self, user_id: &'a UserId, room_id: &'a RoomId, target: &'a EventId, until: PduCount, max_depth: u8, + ) -> Result> { let room_id = services().rooms.short.get_or_create_shortroomid(room_id)?; #[allow(unknown_lints)] #[allow(clippy::manual_unwrap_or_default)] @@ -152,7 +184,34 @@ impl Service { // TODO: Support backfilled relations _ => 0, // This will result in an empty iterator }; - self.db.relations_until(user_id, room_id, target, until) + + self.db + .relations_until(user_id, room_id, target, until) + .map(|mut relations| { + let mut pdus: Vec<_> = (*relations).into_iter().filter_map(Result::ok).collect(); + let mut stack: Vec<_> = pdus.clone().iter().map(|pdu| (pdu.to_owned(), 1)).collect(); + + while let Some(stack_pdu) = stack.pop() { + let target = match stack_pdu.0 .0 { + PduCount::Normal(c) => c, + // TODO: Support backfilled relations + PduCount::Backfilled(_) => 0, // This will result in an empty iterator + }; + + if let Ok(relations) = self.db.relations_until(user_id, room_id, target, until) { + for relation in relations.flatten() { + if stack_pdu.1 < max_depth { + stack.push((relation.clone(), stack_pdu.1 + 1)); + } + + pdus.push(relation); + } + } + } + + pdus.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("u64s can always be compared")); + pdus + }) } #[tracing::instrument(skip(self, room_id, event_ids))]