fix broken federated room invites/joins
Signed-off-by: strawberry <strawberry@puppygock.gay>
This commit is contained in:
parent
f0557e3303
commit
cb03654dc1
3 changed files with 97 additions and 79 deletions
|
@ -152,8 +152,11 @@ pub(crate) async fn join_room_by_id_route(
|
|||
let mut servers = services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.servers_invite_via(&body.room_id)?
|
||||
.unwrap_or(
|
||||
.servers_invite_via(&body.room_id)
|
||||
.filter_map(Result::ok)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
servers.extend(
|
||||
services()
|
||||
.rooms
|
||||
.state_cache
|
||||
|
@ -164,8 +167,7 @@ pub(crate) async fn join_room_by_id_route(
|
|||
.filter_map(|event: serde_json::Value| event.get("sender").cloned())
|
||||
.filter_map(|sender| sender.as_str().map(ToOwned::to_owned))
|
||||
.filter_map(|sender| UserId::parse(sender).ok())
|
||||
.map(|user| user.server_name().to_owned())
|
||||
.collect::<Vec<_>>(),
|
||||
.map(|user| user.server_name().to_owned()),
|
||||
);
|
||||
|
||||
if let Some(server) = body.room_id.server_name() {
|
||||
|
@ -206,8 +208,11 @@ pub(crate) async fn join_room_by_id_or_alias_route(
|
|||
services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.servers_invite_via(&room_id)?
|
||||
.unwrap_or(
|
||||
.servers_invite_via(&room_id)
|
||||
.filter_map(Result::ok),
|
||||
);
|
||||
|
||||
servers.extend(
|
||||
services()
|
||||
.rooms
|
||||
.state_cache
|
||||
|
@ -218,9 +223,7 @@ pub(crate) async fn join_room_by_id_or_alias_route(
|
|||
.filter_map(|event: serde_json::Value| event.get("sender").cloned())
|
||||
.filter_map(|sender| sender.as_str().map(ToOwned::to_owned))
|
||||
.filter_map(|sender| UserId::parse(sender).ok())
|
||||
.map(|user| user.server_name().to_owned())
|
||||
.collect(),
|
||||
),
|
||||
.map(|user| user.server_name().to_owned()),
|
||||
);
|
||||
|
||||
if let Some(server) = room_id.server_name() {
|
||||
|
@ -240,8 +243,11 @@ pub(crate) async fn join_room_by_id_or_alias_route(
|
|||
services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.servers_invite_via(&response.room_id)?
|
||||
.unwrap_or(
|
||||
.servers_invite_via(&response.room_id)
|
||||
.filter_map(Result::ok),
|
||||
);
|
||||
|
||||
servers.extend(
|
||||
services()
|
||||
.rooms
|
||||
.state_cache
|
||||
|
@ -252,9 +258,7 @@ pub(crate) async fn join_room_by_id_or_alias_route(
|
|||
.filter_map(|event: serde_json::Value| event.get("sender").cloned())
|
||||
.filter_map(|sender| sender.as_str().map(ToOwned::to_owned))
|
||||
.filter_map(|sender| UserId::parse(sender).ok())
|
||||
.map(|user| user.server_name().to_owned())
|
||||
.collect(),
|
||||
),
|
||||
.map(|user| user.server_name().to_owned()),
|
||||
);
|
||||
|
||||
(servers, response.room_id)
|
||||
|
@ -1680,11 +1684,14 @@ async fn remote_leave_room(user_id: &UserId, room_id: &RoomId) -> Result<()> {
|
|||
.invite_state(user_id, room_id)?
|
||||
.ok_or(Error::BadRequest(ErrorKind::BadState, "User is not invited."))?;
|
||||
|
||||
let servers: HashSet<OwnedServerName> = services()
|
||||
let mut servers: HashSet<OwnedServerName> = services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.servers_invite_via(room_id)?
|
||||
.map_or(
|
||||
.servers_invite_via(room_id)
|
||||
.filter_map(Result::ok)
|
||||
.collect();
|
||||
|
||||
servers.extend(
|
||||
invite_state
|
||||
.iter()
|
||||
.filter_map(|event| serde_json::from_str(event.json().get()).ok())
|
||||
|
@ -1693,7 +1700,6 @@ async fn remote_leave_room(user_id: &UserId, room_id: &RoomId) -> Result<()> {
|
|||
.filter_map(|sender| UserId::parse(sender).ok())
|
||||
.map(|user| user.server_name().to_owned())
|
||||
.collect::<HashSet<OwnedServerName>>(),
|
||||
HashSet::from_iter,
|
||||
);
|
||||
|
||||
debug!("servers in remote_leave_room: {servers:?}");
|
||||
|
|
|
@ -6,7 +6,6 @@ use ruma::{
|
|||
serde::Raw,
|
||||
OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId,
|
||||
};
|
||||
use tracing::error;
|
||||
|
||||
use crate::{
|
||||
appservice::RegistrationInfo,
|
||||
|
@ -94,7 +93,7 @@ pub trait Data: Send + Sync {
|
|||
|
||||
/// Gets the servers to either accept or decline invites via for a given
|
||||
/// room.
|
||||
fn servers_invite_via(&self, room_id: &RoomId) -> Result<Option<Vec<OwnedServerName>>>;
|
||||
fn servers_invite_via<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a>;
|
||||
|
||||
/// Add the given servers the list to accept or decline invites via for a
|
||||
/// given room.
|
||||
|
@ -159,7 +158,10 @@ impl Data for KeyValueDatabase {
|
|||
self.roomuserid_leftcount.remove(&roomuser_id)?;
|
||||
|
||||
if let Some(servers) = invite_via {
|
||||
let mut prev_servers = self.servers_invite_via(room_id)?.unwrap_or(Vec::new());
|
||||
let mut prev_servers = self
|
||||
.servers_invite_via(room_id)
|
||||
.filter_map(Result::ok)
|
||||
.collect_vec();
|
||||
#[allow(clippy::redundant_clone)] // this is a necessary clone?
|
||||
prev_servers.append(servers.clone().as_mut());
|
||||
let servers = prev_servers.iter().rev().unique().rev().collect_vec();
|
||||
|
@ -639,30 +641,40 @@ impl Data for KeyValueDatabase {
|
|||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn servers_invite_via(&self, room_id: &RoomId) -> Result<Option<Vec<OwnedServerName>>> {
|
||||
fn servers_invite_via<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a> {
|
||||
let key = room_id.as_bytes().to_vec();
|
||||
|
||||
Box::new(
|
||||
self.roomid_inviteviaservers
|
||||
.get(&key)?
|
||||
.map(|servers| {
|
||||
let state = serde_json::from_slice(&servers).map_err(|e| {
|
||||
error!("Invalid state in userroomid_leftstate: {e}");
|
||||
Error::bad_database("Invalid state in userroomid_leftstate.")
|
||||
})?;
|
||||
|
||||
Ok(state)
|
||||
})
|
||||
.transpose()
|
||||
.scan_prefix(key)
|
||||
.map(|(_, servers)| {
|
||||
ServerName::parse(
|
||||
utils::string_from_bytes(
|
||||
servers
|
||||
.rsplit(|&b| b == 0xFF)
|
||||
.next()
|
||||
.expect("rsplit always returns an element"),
|
||||
)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Server name in roomid_inviteviaservers is invalid unicode.")
|
||||
})?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Server name in roomid_inviteviaservers is invalid."))
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) -> Result<()> {
|
||||
let mut prev_servers = self.servers_invite_via(room_id)?.unwrap_or(Vec::new());
|
||||
prev_servers.append(servers.to_owned().as_mut());
|
||||
let mut prev_servers = self
|
||||
.servers_invite_via(room_id)
|
||||
.filter_map(Result::ok)
|
||||
.collect_vec();
|
||||
prev_servers.extend(servers.to_owned());
|
||||
prev_servers.sort_unstable();
|
||||
prev_servers.dedup();
|
||||
|
||||
let servers = prev_servers.iter().rev().unique().rev().collect_vec();
|
||||
|
||||
let servers = servers
|
||||
let servers = prev_servers
|
||||
.iter()
|
||||
.map(|server| server.as_bytes())
|
||||
.collect_vec()
|
||||
|
|
|
@ -377,7 +377,7 @@ impl Service {
|
|||
pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { self.db.is_left(user_id, room_id) }
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn servers_invite_via(&self, room_id: &RoomId) -> Result<Option<Vec<OwnedServerName>>> {
|
||||
pub fn servers_invite_via(&self, room_id: &RoomId) -> impl Iterator<Item = Result<OwnedServerName>> + '_ {
|
||||
self.db.servers_invite_via(room_id)
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue