feat: recurse relationships

This commit is contained in:
Matthias Ahouansou 2024-04-01 10:52:36 +01:00
parent 7b19618136
commit b46000fadc
No known key found for this signature in database
4 changed files with 138 additions and 140 deletions

27
Cargo.lock generated
View file

@ -2027,7 +2027,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma" name = "ruma"
version = "0.9.4" version = "0.9.4"
source = "git+https://github.com/ruma/ruma?rev=5495b85aa311c2805302edb0a7de40399e22b397#5495b85aa311c2805302edb0a7de40399e22b397" source = "git+https://github.com/ruma/ruma?rev=c5f8137ba9741b2317313256b57e6e14b61fb419#c5f8137ba9741b2317313256b57e6e14b61fb419"
dependencies = [ dependencies = [
"assign", "assign",
"js_int", "js_int",
@ -2042,12 +2042,13 @@ dependencies = [
"ruma-server-util", "ruma-server-util",
"ruma-signatures", "ruma-signatures",
"ruma-state-res", "ruma-state-res",
"web-time",
] ]
[[package]] [[package]]
name = "ruma-appservice-api" name = "ruma-appservice-api"
version = "0.9.0" version = "0.9.0"
source = "git+https://github.com/ruma/ruma?rev=5495b85aa311c2805302edb0a7de40399e22b397#5495b85aa311c2805302edb0a7de40399e22b397" source = "git+https://github.com/ruma/ruma?rev=c5f8137ba9741b2317313256b57e6e14b61fb419#c5f8137ba9741b2317313256b57e6e14b61fb419"
dependencies = [ dependencies = [
"js_int", "js_int",
"ruma-common", "ruma-common",
@ -2059,7 +2060,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-client-api" name = "ruma-client-api"
version = "0.17.4" version = "0.17.4"
source = "git+https://github.com/ruma/ruma?rev=5495b85aa311c2805302edb0a7de40399e22b397#5495b85aa311c2805302edb0a7de40399e22b397" source = "git+https://github.com/ruma/ruma?rev=c5f8137ba9741b2317313256b57e6e14b61fb419#c5f8137ba9741b2317313256b57e6e14b61fb419"
dependencies = [ dependencies = [
"as_variant", "as_variant",
"assign", "assign",
@ -2078,7 +2079,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-common" name = "ruma-common"
version = "0.12.1" version = "0.12.1"
source = "git+https://github.com/ruma/ruma?rev=5495b85aa311c2805302edb0a7de40399e22b397#5495b85aa311c2805302edb0a7de40399e22b397" source = "git+https://github.com/ruma/ruma?rev=c5f8137ba9741b2317313256b57e6e14b61fb419#c5f8137ba9741b2317313256b57e6e14b61fb419"
dependencies = [ dependencies = [
"as_variant", "as_variant",
"base64 0.21.7", "base64 0.21.7",
@ -2108,7 +2109,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-events" name = "ruma-events"
version = "0.27.11" version = "0.27.11"
source = "git+https://github.com/ruma/ruma?rev=5495b85aa311c2805302edb0a7de40399e22b397#5495b85aa311c2805302edb0a7de40399e22b397" source = "git+https://github.com/ruma/ruma?rev=c5f8137ba9741b2317313256b57e6e14b61fb419#c5f8137ba9741b2317313256b57e6e14b61fb419"
dependencies = [ dependencies = [
"as_variant", "as_variant",
"indexmap 2.2.5", "indexmap 2.2.5",
@ -2130,7 +2131,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-federation-api" name = "ruma-federation-api"
version = "0.8.0" version = "0.8.0"
source = "git+https://github.com/ruma/ruma?rev=5495b85aa311c2805302edb0a7de40399e22b397#5495b85aa311c2805302edb0a7de40399e22b397" source = "git+https://github.com/ruma/ruma?rev=c5f8137ba9741b2317313256b57e6e14b61fb419#c5f8137ba9741b2317313256b57e6e14b61fb419"
dependencies = [ dependencies = [
"js_int", "js_int",
"ruma-common", "ruma-common",
@ -2142,7 +2143,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-identifiers-validation" name = "ruma-identifiers-validation"
version = "0.9.3" version = "0.9.3"
source = "git+https://github.com/ruma/ruma?rev=5495b85aa311c2805302edb0a7de40399e22b397#5495b85aa311c2805302edb0a7de40399e22b397" source = "git+https://github.com/ruma/ruma?rev=c5f8137ba9741b2317313256b57e6e14b61fb419#c5f8137ba9741b2317313256b57e6e14b61fb419"
dependencies = [ dependencies = [
"js_int", "js_int",
"thiserror", "thiserror",
@ -2151,7 +2152,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-identity-service-api" name = "ruma-identity-service-api"
version = "0.8.0" version = "0.8.0"
source = "git+https://github.com/ruma/ruma?rev=5495b85aa311c2805302edb0a7de40399e22b397#5495b85aa311c2805302edb0a7de40399e22b397" source = "git+https://github.com/ruma/ruma?rev=c5f8137ba9741b2317313256b57e6e14b61fb419#c5f8137ba9741b2317313256b57e6e14b61fb419"
dependencies = [ dependencies = [
"js_int", "js_int",
"ruma-common", "ruma-common",
@ -2161,7 +2162,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-macros" name = "ruma-macros"
version = "0.12.0" version = "0.12.0"
source = "git+https://github.com/ruma/ruma?rev=5495b85aa311c2805302edb0a7de40399e22b397#5495b85aa311c2805302edb0a7de40399e22b397" source = "git+https://github.com/ruma/ruma?rev=c5f8137ba9741b2317313256b57e6e14b61fb419#c5f8137ba9741b2317313256b57e6e14b61fb419"
dependencies = [ dependencies = [
"once_cell", "once_cell",
"proc-macro-crate", "proc-macro-crate",
@ -2176,7 +2177,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-push-gateway-api" name = "ruma-push-gateway-api"
version = "0.8.0" version = "0.8.0"
source = "git+https://github.com/ruma/ruma?rev=5495b85aa311c2805302edb0a7de40399e22b397#5495b85aa311c2805302edb0a7de40399e22b397" source = "git+https://github.com/ruma/ruma?rev=c5f8137ba9741b2317313256b57e6e14b61fb419#c5f8137ba9741b2317313256b57e6e14b61fb419"
dependencies = [ dependencies = [
"js_int", "js_int",
"ruma-common", "ruma-common",
@ -2188,7 +2189,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-server-util" name = "ruma-server-util"
version = "0.2.0" version = "0.2.0"
source = "git+https://github.com/ruma/ruma?rev=5495b85aa311c2805302edb0a7de40399e22b397#5495b85aa311c2805302edb0a7de40399e22b397" source = "git+https://github.com/ruma/ruma?rev=c5f8137ba9741b2317313256b57e6e14b61fb419#c5f8137ba9741b2317313256b57e6e14b61fb419"
dependencies = [ dependencies = [
"headers", "headers",
"ruma-common", "ruma-common",
@ -2199,7 +2200,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-signatures" name = "ruma-signatures"
version = "0.14.0" version = "0.14.0"
source = "git+https://github.com/ruma/ruma?rev=5495b85aa311c2805302edb0a7de40399e22b397#5495b85aa311c2805302edb0a7de40399e22b397" source = "git+https://github.com/ruma/ruma?rev=c5f8137ba9741b2317313256b57e6e14b61fb419#c5f8137ba9741b2317313256b57e6e14b61fb419"
dependencies = [ dependencies = [
"base64 0.21.7", "base64 0.21.7",
"ed25519-dalek", "ed25519-dalek",
@ -2215,7 +2216,7 @@ dependencies = [
[[package]] [[package]]
name = "ruma-state-res" name = "ruma-state-res"
version = "0.10.0" version = "0.10.0"
source = "git+https://github.com/ruma/ruma?rev=5495b85aa311c2805302edb0a7de40399e22b397#5495b85aa311c2805302edb0a7de40399e22b397" source = "git+https://github.com/ruma/ruma?rev=c5f8137ba9741b2317313256b57e6e14b61fb419#c5f8137ba9741b2317313256b57e6e14b61fb419"
dependencies = [ dependencies = [
"itertools 0.11.0", "itertools 0.11.0",
"js_int", "js_int",

View file

@ -48,7 +48,7 @@ tower-http = { version = "0.4.1", features = [
# Used for matrix spec type definitions and helpers # Used for matrix spec type definitions and helpers
#ruma = { version = "0.4.0", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } #ruma = { version = "0.4.0", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] }
ruma = { git = "https://github.com/ruma/ruma", rev = "5495b85aa311c2805302edb0a7de40399e22b397", features = [ ruma = { git = "https://github.com/ruma/ruma", rev = "c5f8137ba9741b2317313256b57e6e14b61fb419", features = [
"appservice-api-c", "appservice-api-c",
"client-api", "client-api",
"compat", "compat",

View file

@ -3,7 +3,7 @@ use ruma::api::client::relations::{
get_relating_events_with_rel_type_and_event_type, get_relating_events_with_rel_type_and_event_type,
}; };
use crate::{service::rooms::timeline::PduCount, services, Result, Ruma}; use crate::{services, Result, Ruma};
/// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}/{eventType}` /// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}/{eventType}`
pub async fn get_relating_events_with_rel_type_and_event_type_route( pub async fn get_relating_events_with_rel_type_and_event_type_route(
@ -11,27 +11,6 @@ pub async fn get_relating_events_with_rel_type_and_event_type_route(
) -> Result<get_relating_events_with_rel_type_and_event_type::v1::Response> { ) -> Result<get_relating_events_with_rel_type_and_event_type::v1::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let from = match body.from.clone() {
Some(from) => PduCount::try_from_string(&from)?,
None => match ruma::api::Direction::Backward {
// TODO: fix ruma so `body.dir` exists
ruma::api::Direction::Forward => PduCount::min(),
ruma::api::Direction::Backward => PduCount::max(),
},
};
let to = body
.to
.as_ref()
.and_then(|t| PduCount::try_from_string(t).ok());
// Use limit or else 10, with maximum 100
let limit = body
.limit
.and_then(|u| u32::try_from(u).ok())
.map_or(10_usize, |u| u as usize)
.min(100);
let res = services() let res = services()
.rooms .rooms
.pdu_metadata .pdu_metadata
@ -41,9 +20,11 @@ pub async fn get_relating_events_with_rel_type_and_event_type_route(
&body.event_id, &body.event_id,
Some(body.event_type.clone()), Some(body.event_type.clone()),
Some(body.rel_type.clone()), Some(body.rel_type.clone()),
from, body.from.clone(),
to, body.to.clone(),
limit, body.limit,
body.recurse,
&body.dir,
)?; )?;
Ok( Ok(
@ -51,6 +32,7 @@ pub async fn get_relating_events_with_rel_type_and_event_type_route(
chunk: res.chunk, chunk: res.chunk,
next_batch: res.next_batch, next_batch: res.next_batch,
prev_batch: res.prev_batch, prev_batch: res.prev_batch,
recursion_depth: res.recursion_depth,
}, },
) )
} }
@ -61,27 +43,6 @@ pub async fn get_relating_events_with_rel_type_route(
) -> Result<get_relating_events_with_rel_type::v1::Response> { ) -> Result<get_relating_events_with_rel_type::v1::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let from = match body.from.clone() {
Some(from) => PduCount::try_from_string(&from)?,
None => match ruma::api::Direction::Backward {
// TODO: fix ruma so `body.dir` exists
ruma::api::Direction::Forward => PduCount::min(),
ruma::api::Direction::Backward => PduCount::max(),
},
};
let to = body
.to
.as_ref()
.and_then(|t| PduCount::try_from_string(t).ok());
// Use limit or else 10, with maximum 100
let limit = body
.limit
.and_then(|u| u32::try_from(u).ok())
.map_or(10_usize, |u| u as usize)
.min(100);
let res = services() let res = services()
.rooms .rooms
.pdu_metadata .pdu_metadata
@ -91,15 +52,18 @@ pub async fn get_relating_events_with_rel_type_route(
&body.event_id, &body.event_id,
None, None,
Some(body.rel_type.clone()), Some(body.rel_type.clone()),
from, body.from.clone(),
to, body.to.clone(),
limit, body.limit,
body.recurse,
&body.dir,
)?; )?;
Ok(get_relating_events_with_rel_type::v1::Response { Ok(get_relating_events_with_rel_type::v1::Response {
chunk: res.chunk, chunk: res.chunk,
next_batch: res.next_batch, next_batch: res.next_batch,
prev_batch: res.prev_batch, prev_batch: res.prev_batch,
recursion_depth: res.recursion_depth,
}) })
} }
@ -109,27 +73,6 @@ pub async fn get_relating_events_route(
) -> Result<get_relating_events::v1::Response> { ) -> Result<get_relating_events::v1::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let from = match body.from.clone() {
Some(from) => PduCount::try_from_string(&from)?,
None => match ruma::api::Direction::Backward {
// TODO: fix ruma so `body.dir` exists
ruma::api::Direction::Forward => PduCount::min(),
ruma::api::Direction::Backward => PduCount::max(),
},
};
let to = body
.to
.as_ref()
.and_then(|t| PduCount::try_from_string(t).ok());
// Use limit or else 10, with maximum 100
let limit = body
.limit
.and_then(|u| u32::try_from(u).ok())
.map_or(10_usize, |u| u as usize)
.min(100);
services() services()
.rooms .rooms
.pdu_metadata .pdu_metadata
@ -139,8 +82,10 @@ pub async fn get_relating_events_route(
&body.event_id, &body.event_id,
None, None,
None, None,
from, body.from.clone(),
to, body.to.clone(),
limit, body.limit,
body.recurse,
&body.dir,
) )
} }

View file

@ -3,9 +3,9 @@ use std::sync::Arc;
pub use data::Data; pub use data::Data;
use ruma::{ use ruma::{
api::client::relations::get_relating_events, api::{client::relations::get_relating_events, Direction},
events::{relation::RelationType, TimelineEventType}, events::{relation::RelationType, TimelineEventType},
EventId, RoomId, UserId, EventId, RoomId, UInt, UserId,
}; };
use serde::Deserialize; use serde::Deserialize;
@ -48,37 +48,57 @@ impl Service {
target: &EventId, target: &EventId,
filter_event_type: Option<TimelineEventType>, filter_event_type: Option<TimelineEventType>,
filter_rel_type: Option<RelationType>, filter_rel_type: Option<RelationType>,
from: PduCount, from: Option<String>,
to: Option<PduCount>, to: Option<String>,
limit: usize, limit: Option<UInt>,
recurse: bool,
dir: &Direction,
) -> Result<get_relating_events::v1::Response> { ) -> Result<get_relating_events::v1::Response> {
let from = match from {
Some(from) => PduCount::try_from_string(&from)?,
None => match dir {
Direction::Forward => PduCount::min(),
Direction::Backward => PduCount::max(),
},
};
let to = to.as_ref().and_then(|t| PduCount::try_from_string(t).ok());
// Use limit or else 10, with maximum 100
let limit = limit
.and_then(|u| u32::try_from(u).ok())
.map_or(10_usize, |u| u as usize)
.min(100);
let next_token; let next_token;
//TODO: Fix ruma: match body.dir { // Spec (v1.10) recommends depth of at least 3
match ruma::api::Direction::Backward { let depth: u8 = if recurse { 3 } else { 1 };
ruma::api::Direction::Forward => {
let events_after: Vec<_> = services() match dir {
.rooms Direction::Forward => {
.pdu_metadata let relations_until = &services().rooms.pdu_metadata.relations_until(
.relations_until(sender_user, room_id, target, from)? // TODO: should be relations_after sender_user,
.filter(|r| { room_id,
r.as_ref().map_or(true, |(_, pdu)| { target,
filter_event_type.as_ref().map_or(true, |t| &pdu.kind == t) from,
&& if let Ok(content) = depth,
serde_json::from_str::<ExtractRelatesToEventId>( )?;
pdu.content.get(), let events_after: Vec<_> = relations_until // TODO: should be relations_after
) .iter()
{ .filter(|(_, pdu)| {
filter_rel_type filter_event_type.as_ref().map_or(true, |t| &pdu.kind == t)
.as_ref() && if let Ok(content) =
.map_or(true, |r| &content.relates_to.rel_type == r) serde_json::from_str::<ExtractRelatesToEventId>(pdu.content.get())
} else { {
false filter_rel_type
} .as_ref()
}) .map_or(true, |r| &content.relates_to.rel_type == r)
} else {
false
}
}) })
.take(limit) .take(limit)
.filter_map(|r| r.ok()) // Filter out buggy events
.filter(|(_, pdu)| { .filter(|(_, pdu)| {
services() services()
.rooms .rooms
@ -86,7 +106,7 @@ impl Service {
.user_can_see_event(sender_user, room_id, &pdu.event_id) .user_can_see_event(sender_user, room_id, &pdu.event_id)
.unwrap_or(false) .unwrap_or(false)
}) })
.take_while(|&(k, _)| Some(k) != to) // Stop at `to` .take_while(|(k, _)| Some(k) != to.as_ref()) // Stop at `to`
.collect(); .collect();
next_token = events_after.last().map(|(count, _)| count).copied(); next_token = events_after.last().map(|(count, _)| count).copied();
@ -101,31 +121,32 @@ impl Service {
chunk: events_after, chunk: events_after,
next_batch: next_token.map(|t| t.stringify()), next_batch: next_token.map(|t| t.stringify()),
prev_batch: Some(from.stringify()), prev_batch: Some(from.stringify()),
recursion_depth: if recurse { Some(depth.into()) } else { None },
}) })
} }
ruma::api::Direction::Backward => { Direction::Backward => {
let events_before: Vec<_> = services() let relations_until = &services().rooms.pdu_metadata.relations_until(
.rooms sender_user,
.pdu_metadata room_id,
.relations_until(sender_user, room_id, target, from)? target,
.filter(|r| { from,
r.as_ref().map_or(true, |(_, pdu)| { depth,
filter_event_type.as_ref().map_or(true, |t| &pdu.kind == t) )?;
&& if let Ok(content) = let events_before: Vec<_> = relations_until
serde_json::from_str::<ExtractRelatesToEventId>( .iter()
pdu.content.get(), .filter(|(_, pdu)| {
) filter_event_type.as_ref().map_or(true, |t| &pdu.kind == t)
{ && if let Ok(content) =
filter_rel_type serde_json::from_str::<ExtractRelatesToEventId>(pdu.content.get())
.as_ref() {
.map_or(true, |r| &content.relates_to.rel_type == r) filter_rel_type
} else { .as_ref()
false .map_or(true, |r| &content.relates_to.rel_type == r)
} } else {
}) false
}
}) })
.take(limit) .take(limit)
.filter_map(|r| r.ok()) // Filter out buggy events
.filter(|(_, pdu)| { .filter(|(_, pdu)| {
services() services()
.rooms .rooms
@ -133,7 +154,7 @@ impl Service {
.user_can_see_event(sender_user, room_id, &pdu.event_id) .user_can_see_event(sender_user, room_id, &pdu.event_id)
.unwrap_or(false) .unwrap_or(false)
}) })
.take_while(|&(k, _)| Some(k) != to) // Stop at `to` .take_while(|&(k, _)| Some(k) != to.as_ref()) // Stop at `to`
.collect(); .collect();
next_token = events_before.last().map(|(count, _)| count).copied(); next_token = events_before.last().map(|(count, _)| count).copied();
@ -147,6 +168,7 @@ impl Service {
chunk: events_before, chunk: events_before,
next_batch: next_token.map(|t| t.stringify()), next_batch: next_token.map(|t| t.stringify()),
prev_batch: Some(from.stringify()), prev_batch: Some(from.stringify()),
recursion_depth: if recurse { Some(depth.into()) } else { None },
}) })
} }
} }
@ -158,14 +180,44 @@ impl Service {
room_id: &'a RoomId, room_id: &'a RoomId,
target: &'a EventId, target: &'a EventId,
until: PduCount, until: PduCount,
) -> Result<impl Iterator<Item = Result<(PduCount, PduEvent)>> + 'a> { max_depth: u8,
) -> Result<Vec<(PduCount, PduEvent)>> {
let room_id = services().rooms.short.get_or_create_shortroomid(room_id)?; let room_id = services().rooms.short.get_or_create_shortroomid(room_id)?;
let target = match services().rooms.timeline.get_pdu_count(target)? { let target = match services().rooms.timeline.get_pdu_count(target)? {
Some(PduCount::Normal(c)) => c, Some(PduCount::Normal(c)) => c,
// TODO: Support backfilled relations // TODO: Support backfilled relations
_ => 0, // This will result in an empty iterator _ => 0, // This will result in an empty iterator
}; };
self.db.relations_until(user_id, room_id, target, until)
self.db
.relations_until(user_id, room_id, target, until)
.map(|mut relations| {
let mut pdus: Vec<_> = (*relations).into_iter().filter_map(Result::ok).collect();
let mut stack: Vec<_> =
pdus.clone().iter().map(|pdu| (pdu.to_owned(), 1)).collect();
while let Some(stack_pdu) = stack.pop() {
let target = match stack_pdu.0 .0 {
PduCount::Normal(c) => c,
// TODO: Support backfilled relations
PduCount::Backfilled(_) => 0, // This will result in an empty iterator
};
if let Ok(relations) = self.db.relations_until(user_id, room_id, target, until)
{
for relation in relations.flatten() {
if stack_pdu.1 < max_depth {
stack.push((relation.clone(), stack_pdu.1 + 1));
}
pdus.push(relation);
}
}
}
pdus.sort_by(|a, b| a.0.cmp(&b.0));
pdus
})
} }
#[tracing::instrument(skip(self, room_id, event_ids))] #[tracing::instrument(skip(self, room_id, event_ids))]