Merge branch 'better-joining' into 'master'

improvement: use invite state as hints to what servers to ask for joining

See merge request famedly/conduit!56
This commit is contained in:
Timo Kösters 2021-04-14 13:06:14 +00:00
commit f6c4da829e
2 changed files with 48 additions and 6 deletions

View file

@ -22,7 +22,11 @@ use ruma::{
serde::{to_canonical_value, CanonicalJsonObject, Raw}, serde::{to_canonical_value, CanonicalJsonObject, Raw},
EventId, RoomId, RoomVersionId, ServerName, UserId, EventId, RoomId, RoomVersionId, ServerName, UserId,
}; };
use std::{collections::BTreeMap, convert::TryFrom, sync::RwLock}; use std::{
collections::{BTreeMap, HashSet},
convert::TryFrom,
sync::RwLock,
};
#[cfg(feature = "conduit_bin")] #[cfg(feature = "conduit_bin")]
use rocket::{get, post}; use rocket::{get, post};
@ -36,11 +40,29 @@ pub async fn join_room_by_id_route(
db: State<'_, Database>, db: State<'_, Database>,
body: Ruma<join_room_by_id::Request<'_>>, body: Ruma<join_room_by_id::Request<'_>>,
) -> ConduitResult<join_room_by_id::Response> { ) -> ConduitResult<join_room_by_id::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let mut servers = db
.rooms
.invite_state(&sender_user, &body.room_id)?
.unwrap_or_default()
.iter()
.filter_map(|event| {
serde_json::from_str::<serde_json::Value>(&event.json().to_string()).ok()
})
.filter_map(|event| event.get("sender").cloned())
.filter_map(|sender| sender.as_str().map(|s| s.to_owned()))
.filter_map(|sender| UserId::try_from(sender).ok())
.map(|user| user.server_name().to_owned())
.collect::<HashSet<_>>();
servers.insert(body.room_id.server_name().to_owned());
join_room_by_id_helper( join_room_by_id_helper(
&db, &db,
body.sender_user.as_ref(), body.sender_user.as_ref(),
&body.room_id, &body.room_id,
&[body.room_id.server_name().to_owned()], &servers,
body.third_party_signed.as_ref(), body.third_party_signed.as_ref(),
) )
.await .await
@ -55,12 +77,31 @@ pub async fn join_room_by_id_or_alias_route(
db: State<'_, Database>, db: State<'_, Database>,
body: Ruma<join_room_by_id_or_alias::Request<'_>>, body: Ruma<join_room_by_id_or_alias::Request<'_>>,
) -> ConduitResult<join_room_by_id_or_alias::Response> { ) -> ConduitResult<join_room_by_id_or_alias::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let (servers, room_id) = match RoomId::try_from(body.room_id_or_alias.clone()) { let (servers, room_id) = match RoomId::try_from(body.room_id_or_alias.clone()) {
Ok(room_id) => (vec![room_id.server_name().to_owned()], room_id), Ok(room_id) => {
let mut servers = db
.rooms
.invite_state(&sender_user, &room_id)?
.unwrap_or_default()
.iter()
.filter_map(|event| {
serde_json::from_str::<serde_json::Value>(&event.json().to_string()).ok()
})
.filter_map(|event| event.get("sender").cloned())
.filter_map(|sender| sender.as_str().map(|s| s.to_owned()))
.filter_map(|sender| UserId::try_from(sender).ok())
.map(|user| user.server_name().to_owned())
.collect::<HashSet<_>>();
servers.insert(room_id.server_name().to_owned());
(servers, room_id)
}
Err(room_alias) => { Err(room_alias) => {
let response = client_server::get_alias_helper(&db, &room_alias).await?; let response = client_server::get_alias_helper(&db, &room_alias).await?;
(response.0.servers, response.0.room_id) (response.0.servers.into_iter().collect(), response.0.room_id)
} }
}; };
@ -406,7 +447,7 @@ async fn join_room_by_id_helper(
db: &Database, db: &Database,
sender_user: Option<&UserId>, sender_user: Option<&UserId>,
room_id: &RoomId, room_id: &RoomId,
servers: &[Box<ServerName>], servers: &HashSet<Box<ServerName>>,
_third_party_signed: Option<&IncomingThirdPartySigned>, _third_party_signed: Option<&IncomingThirdPartySigned>,
) -> ConduitResult<join_room_by_id::Response> { ) -> ConduitResult<join_room_by_id::Response> {
let sender_user = sender_user.expect("user is authenticated"); let sender_user = sender_user.expect("user is authenticated");

View file

@ -1795,7 +1795,8 @@ impl Rooms {
.filter_map(|event| event.get("sender").cloned()) .filter_map(|event| event.get("sender").cloned())
.filter_map(|sender| sender.as_str().map(|s| s.to_owned())) .filter_map(|sender| sender.as_str().map(|s| s.to_owned()))
.filter_map(|sender| UserId::try_from(sender).ok()) .filter_map(|sender| UserId::try_from(sender).ok())
.map(|user| user.server_name().to_owned()); .map(|user| user.server_name().to_owned())
.collect::<HashSet<_>>();
for remote_server in servers { for remote_server in servers {
let make_leave_response = db let make_leave_response = db