fix broken federated room invites/joins

Signed-off-by: strawberry <strawberry@puppygock.gay>
This commit is contained in:
strawberry 2024-06-10 14:53:26 -04:00
parent f0557e3303
commit cb03654dc1
3 changed files with 97 additions and 79 deletions

View file

@ -152,8 +152,11 @@ pub(crate) async fn join_room_by_id_route(
let mut servers = services()
.rooms
.state_cache
.servers_invite_via(&body.room_id)?
.unwrap_or(
.servers_invite_via(&body.room_id)
.filter_map(Result::ok)
.collect::<Vec<_>>();
servers.extend(
services()
.rooms
.state_cache
@ -164,8 +167,7 @@ pub(crate) async fn join_room_by_id_route(
.filter_map(|event: serde_json::Value| event.get("sender").cloned())
.filter_map(|sender| sender.as_str().map(ToOwned::to_owned))
.filter_map(|sender| UserId::parse(sender).ok())
.map(|user| user.server_name().to_owned())
.collect::<Vec<_>>(),
.map(|user| user.server_name().to_owned()),
);
if let Some(server) = body.room_id.server_name() {
@ -206,8 +208,11 @@ pub(crate) async fn join_room_by_id_or_alias_route(
services()
.rooms
.state_cache
.servers_invite_via(&room_id)?
.unwrap_or(
.servers_invite_via(&room_id)
.filter_map(Result::ok),
);
servers.extend(
services()
.rooms
.state_cache
@ -218,9 +223,7 @@ pub(crate) async fn join_room_by_id_or_alias_route(
.filter_map(|event: serde_json::Value| event.get("sender").cloned())
.filter_map(|sender| sender.as_str().map(ToOwned::to_owned))
.filter_map(|sender| UserId::parse(sender).ok())
.map(|user| user.server_name().to_owned())
.collect(),
),
.map(|user| user.server_name().to_owned()),
);
if let Some(server) = room_id.server_name() {
@ -240,8 +243,11 @@ pub(crate) async fn join_room_by_id_or_alias_route(
services()
.rooms
.state_cache
.servers_invite_via(&response.room_id)?
.unwrap_or(
.servers_invite_via(&response.room_id)
.filter_map(Result::ok),
);
servers.extend(
services()
.rooms
.state_cache
@ -252,9 +258,7 @@ pub(crate) async fn join_room_by_id_or_alias_route(
.filter_map(|event: serde_json::Value| event.get("sender").cloned())
.filter_map(|sender| sender.as_str().map(ToOwned::to_owned))
.filter_map(|sender| UserId::parse(sender).ok())
.map(|user| user.server_name().to_owned())
.collect(),
),
.map(|user| user.server_name().to_owned()),
);
(servers, response.room_id)
@ -1680,11 +1684,14 @@ async fn remote_leave_room(user_id: &UserId, room_id: &RoomId) -> Result<()> {
.invite_state(user_id, room_id)?
.ok_or(Error::BadRequest(ErrorKind::BadState, "User is not invited."))?;
let servers: HashSet<OwnedServerName> = services()
let mut servers: HashSet<OwnedServerName> = services()
.rooms
.state_cache
.servers_invite_via(room_id)?
.map_or(
.servers_invite_via(room_id)
.filter_map(Result::ok)
.collect();
servers.extend(
invite_state
.iter()
.filter_map(|event| serde_json::from_str(event.json().get()).ok())
@ -1693,7 +1700,6 @@ async fn remote_leave_room(user_id: &UserId, room_id: &RoomId) -> Result<()> {
.filter_map(|sender| UserId::parse(sender).ok())
.map(|user| user.server_name().to_owned())
.collect::<HashSet<OwnedServerName>>(),
HashSet::from_iter,
);
debug!("servers in remote_leave_room: {servers:?}");

View file

@ -6,7 +6,6 @@ use ruma::{
serde::Raw,
OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId,
};
use tracing::error;
use crate::{
appservice::RegistrationInfo,
@ -94,7 +93,7 @@ pub trait Data: Send + Sync {
/// Gets the servers to either accept or decline invites via for a given
/// room.
fn servers_invite_via(&self, room_id: &RoomId) -> Result<Option<Vec<OwnedServerName>>>;
fn servers_invite_via<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a>;
/// Add the given servers the list to accept or decline invites via for a
/// given room.
@ -159,7 +158,10 @@ impl Data for KeyValueDatabase {
self.roomuserid_leftcount.remove(&roomuser_id)?;
if let Some(servers) = invite_via {
let mut prev_servers = self.servers_invite_via(room_id)?.unwrap_or(Vec::new());
let mut prev_servers = self
.servers_invite_via(room_id)
.filter_map(Result::ok)
.collect_vec();
#[allow(clippy::redundant_clone)] // this is a necessary clone?
prev_servers.append(servers.clone().as_mut());
let servers = prev_servers.iter().rev().unique().rev().collect_vec();
@ -639,30 +641,40 @@ impl Data for KeyValueDatabase {
}
#[tracing::instrument(skip(self))]
fn servers_invite_via(&self, room_id: &RoomId) -> Result<Option<Vec<OwnedServerName>>> {
fn servers_invite_via<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a> {
let key = room_id.as_bytes().to_vec();
Box::new(
self.roomid_inviteviaservers
.get(&key)?
.map(|servers| {
let state = serde_json::from_slice(&servers).map_err(|e| {
error!("Invalid state in userroomid_leftstate: {e}");
Error::bad_database("Invalid state in userroomid_leftstate.")
})?;
Ok(state)
})
.transpose()
.scan_prefix(key)
.map(|(_, servers)| {
ServerName::parse(
utils::string_from_bytes(
servers
.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| {
Error::bad_database("Server name in roomid_inviteviaservers is invalid unicode.")
})?,
)
.map_err(|_| Error::bad_database("Server name in roomid_inviteviaservers is invalid."))
}),
)
}
#[tracing::instrument(skip(self))]
fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) -> Result<()> {
let mut prev_servers = self.servers_invite_via(room_id)?.unwrap_or(Vec::new());
prev_servers.append(servers.to_owned().as_mut());
let mut prev_servers = self
.servers_invite_via(room_id)
.filter_map(Result::ok)
.collect_vec();
prev_servers.extend(servers.to_owned());
prev_servers.sort_unstable();
prev_servers.dedup();
let servers = prev_servers.iter().rev().unique().rev().collect_vec();
let servers = servers
let servers = prev_servers
.iter()
.map(|server| server.as_bytes())
.collect_vec()

View file

@ -377,7 +377,7 @@ impl Service {
pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { self.db.is_left(user_id, room_id) }
#[tracing::instrument(skip(self))]
pub fn servers_invite_via(&self, room_id: &RoomId) -> Result<Option<Vec<OwnedServerName>>> {
pub fn servers_invite_via(&self, room_id: &RoomId) -> impl Iterator<Item = Result<OwnedServerName>> + '_ {
self.db.servers_invite_via(room_id)
}