Compare commits

...

12 commits

Author SHA1 Message Date
strawberry
4a6805a517
fixup! feat(federation): support /make_join and /send_join for restricted rooms
check event is join membership

Co-Authored-By: Matthias Ahouansou <matthias@ahouansou.cz>
2024-07-01 20:01:37 +01:00
Matthias Ahouansou
3a8ed99246
fixup! feat(federation): support /make_join and /send_join for restricted rooms
move out checking restricted room rules into separate function
2024-07-01 17:29:32 +01:00
Matthias Ahouansou
a6da294c55
refactor: remove unecessery async 2024-07-01 17:29:32 +01:00
Matthias Ahouansou
98c7b89fca
fixup! feat(federation): support /make_join and /send_join for restricted rooms
Only sign join event if restricted join if required
2024-07-01 17:29:32 +01:00
strawberry
ddcfcb9302
fixup! feat(federation): support /make_join and /send_join for restricted rooms
check event is join membership, state_key user is the same as the sender, and state_key is from the origin

Co-Authored-By: Matthias Ahouansou <matthias@ahouansou.cz>
2024-07-01 17:29:32 +01:00
Matthias Ahouansou
7fa0025c43
fixup! feat(federation): support /make_join and /send_join for restricted rooms
collect iterator
2024-07-01 17:29:32 +01:00
Matthias Ahouansou
062f138cfe
fixup! perf(knock): add database cache
iterate over all users in migration
2024-07-01 17:29:32 +01:00
Matthias Ahouansou
960259720b
fixup! perf(knock): add database cache
rename duplicated invite to leave
2024-07-01 17:29:32 +01:00
Matthias Ahouansou
96f3b88132
fixup! feat(federation): support /make_join and /send_join for restricted rooms
remove unnecessery clone
2024-07-01 17:29:32 +01:00
Matthias Ahouansou
6565d3fa71
fixup! feat(federation): support /make_join and /send_join for restricted rooms
remove indents when getting authorized join user
2024-07-01 17:29:32 +01:00
Matthias Ahouansou
ba4fa81620
feat(federation): support /make_join and /send_join for restricted rooms 2024-07-01 17:29:32 +01:00
Matthias Ahouansou
ae0fa29b09
perf(knock): add database cache 2024-07-01 17:29:32 +01:00
7 changed files with 313 additions and 63 deletions

View file

@ -907,7 +907,6 @@ async fn join_room_by_id_helper(
.rooms .rooms
.state_accessor .state_accessor
.user_can_invite(room_id, &user, sender_user, &state_lock) .user_can_invite(room_id, &user, sender_user, &state_lock)
.await
.unwrap_or(false) .unwrap_or(false)
{ {
auth_user = Some(user); auth_user = Some(user);

View file

@ -39,7 +39,7 @@ use ruma::{
events::{ events::{
receipt::{ReceiptEvent, ReceiptEventContent, ReceiptType}, receipt::{ReceiptEvent, ReceiptEventContent, ReceiptType},
room::{ room::{
join_rules::{JoinRule, RoomJoinRulesEventContent}, join_rules::{AllowRule, JoinRule, RoomJoinRulesEventContent},
member::{MembershipState, RoomMemberEventContent}, member::{MembershipState, RoomMemberEventContent},
}, },
StateEventType, TimelineEventType, StateEventType, TimelineEventType,
@ -49,7 +49,7 @@ use ruma::{
to_device::DeviceIdOrAllDevices, to_device::DeviceIdOrAllDevices,
uint, user_id, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, uint, user_id, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch,
OwnedEventId, OwnedRoomId, OwnedServerName, OwnedServerSigningKeyId, OwnedUserId, RoomId, OwnedEventId, OwnedRoomId, OwnedServerName, OwnedServerSigningKeyId, OwnedUserId, RoomId,
ServerName, RoomVersionId, ServerName, UserId,
}; };
use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
use std::{ use std::{
@ -1517,36 +1517,46 @@ pub async fn create_join_event_template_route(
); );
let state_lock = mutex_state.lock().await; let state_lock = mutex_state.lock().await;
// TODO: Conduit does not implement restricted join rules yet, we always reject let room_version_id = services().rooms.state.get_room_version(&body.room_id)?;
let join_rules_event = services().rooms.state_accessor.room_state_get(
&body.room_id,
&StateEventType::RoomJoinRules,
"",
)?;
let join_rules_event_content: Option<RoomJoinRulesEventContent> = join_rules_event let join_authorized_via_users_server = if (services()
.as_ref() .rooms
.map(|join_rules_event| { .state_cache
serde_json::from_str(join_rules_event.content.get()).map_err(|e| { .is_left(&body.user_id, &body.room_id)
warn!("Invalid join rules event: {}", e); .unwrap_or(true)
Error::bad_database("Invalid join rules event in db.") || services()
}) .rooms
}) .state_cache
.transpose()?; .is_knocked(&body.user_id, &body.room_id)
.unwrap_or(false))
&& user_can_perform_restricted_join(&body.user_id, &body.room_id, &room_version_id)?
{
let auth_user = services()
.rooms
.state_cache
.room_members(&body.room_id)
.filter_map(Result::ok)
.filter(|user| user.server_name() == services().globals.server_name())
.find(|user| {
services()
.rooms
.state_accessor
.user_can_invite(&body.room_id, user, &body.user_id, &state_lock)
.unwrap_or(false)
});
if let Some(join_rules_event_content) = join_rules_event_content { if auth_user.is_some() {
if matches!( auth_user
join_rules_event_content.join_rule, } else {
JoinRule::Restricted { .. } | JoinRule::KnockRestricted { .. }
) {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::UnableToAuthorizeJoin, ErrorKind::UnableToGrantJoin,
"Conduit does not support restricted rooms yet.", "No user on this server is able to assist in joining.",
)); ));
} }
} } else {
None
};
let room_version_id = services().rooms.state.get_room_version(&body.room_id)?;
if !body.ver.contains(&room_version_id) { if !body.ver.contains(&room_version_id) {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::IncompatibleRoomVersion { ErrorKind::IncompatibleRoomVersion {
@ -1564,7 +1574,7 @@ pub async fn create_join_event_template_route(
membership: MembershipState::Join, membership: MembershipState::Join,
third_party_invite: None, third_party_invite: None,
reason: None, reason: None,
join_authorized_via_users_server: None, join_authorized_via_users_server,
}) })
.expect("member event is valid value"); .expect("member event is valid value");
@ -1609,35 +1619,6 @@ async fn create_join_event(
.event_handler .event_handler
.acl_check(sender_servername, room_id)?; .acl_check(sender_servername, room_id)?;
// TODO: Conduit does not implement restricted join rules yet, we always reject
let join_rules_event = services().rooms.state_accessor.room_state_get(
room_id,
&StateEventType::RoomJoinRules,
"",
)?;
let join_rules_event_content: Option<RoomJoinRulesEventContent> = join_rules_event
.as_ref()
.map(|join_rules_event| {
serde_json::from_str(join_rules_event.content.get()).map_err(|e| {
warn!("Invalid join rules event: {}", e);
Error::bad_database("Invalid join rules event in db.")
})
})
.transpose()?;
if let Some(join_rules_event_content) = join_rules_event_content {
if matches!(
join_rules_event_content.join_rule,
JoinRule::Restricted { .. } | JoinRule::KnockRestricted { .. }
) {
return Err(Error::BadRequest(
ErrorKind::UnableToAuthorizeJoin,
"Conduit does not support restricted rooms yet.",
));
}
}
// We need to return the state prior to joining, let's keep a reference to that here // We need to return the state prior to joining, let's keep a reference to that here
let shortstatehash = services() let shortstatehash = services()
.rooms .rooms
@ -1653,7 +1634,8 @@ async fn create_join_event(
// We do not add the event_id field to the pdu here because of signature and hashes checks // We do not add the event_id field to the pdu here because of signature and hashes checks
let room_version_id = services().rooms.state.get_room_version(room_id)?; let room_version_id = services().rooms.state.get_room_version(room_id)?;
let (event_id, value) = match gen_event_id_canonical_json(pdu, &room_version_id) {
let (event_id, mut value) = match gen_event_id_canonical_json(pdu, &room_version_id) {
Ok(t) => t, Ok(t) => t,
Err(_) => { Err(_) => {
// Event could not be converted to canonical json // Event could not be converted to canonical json
@ -1664,6 +1646,87 @@ async fn create_join_event(
} }
}; };
let state_key: OwnedUserId = serde_json::from_value(
value
.get("state_key")
.ok_or_else(|| Error::BadRequest(ErrorKind::BadJson, "State key is missing"))?
.clone()
.into(),
)
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "State key is not a valid user ID"))?;
let sender: OwnedUserId = serde_json::from_value(
value
.get("sender")
.ok_or_else(|| Error::BadRequest(ErrorKind::BadJson, "Sender is missing"))?
.clone()
.into(),
)
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Sender is not a valid user ID"))?;
if state_key != sender {
return Err(Error::BadRequest(
ErrorKind::BadJson,
"Sender and state key don't match",
));
}
// Security-wise, we only really need to check the event is not from us, cause otherwise it must be signed by that server,
// but we might as well check this since this event shouldn't really be sent on behalf of another server
if state_key.server_name() != sender_servername {
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"User's server and origin don't match",
));
}
let event_type: StateEventType = serde_json::from_value(
value
.get("type")
.ok_or_else(|| Error::BadRequest(ErrorKind::BadJson, "Missing event type"))?
.clone()
.into(),
)
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid event type"))?;
if event_type != StateEventType::RoomMember {
return Err(Error::BadRequest(
ErrorKind::BadJson,
"Event type is not membership",
));
}
let event_content: RoomMemberEventContent = serde_json::from_value(
value
.get("content")
.ok_or_else(|| Error::BadRequest(ErrorKind::BadJson, "Missing event content"))?
.clone()
.into(),
)
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid event content"))?;
if event_content.membership != MembershipState::Join {
return Err(Error::BadRequest(
ErrorKind::BadJson,
"Membership of sent event does not match that of the endpoint",
));
}
if event_content
.join_authorized_via_users_server
.map(|user| user.server_name() == services().globals.server_name())
.unwrap_or_default()
&& user_can_perform_restricted_join(&sender, room_id, &room_version_id).unwrap_or_default()
{
ruma::signatures::hash_and_sign_event(
services().globals.server_name().as_str(),
services().globals.keypair(),
&mut value,
&room_version_id,
)
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Failed to sign event."))?;
}
let origin: OwnedServerName = serde_json::from_value( let origin: OwnedServerName = serde_json::from_value(
serde_json::to_value(value.get("origin").ok_or(Error::BadRequest( serde_json::to_value(value.get("origin").ok_or(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
@ -1686,7 +1749,14 @@ async fn create_join_event(
let pdu_id: Vec<u8> = services() let pdu_id: Vec<u8> = services()
.rooms .rooms
.event_handler .event_handler
.handle_incoming_pdu(&origin, &event_id, room_id, value, true, &pub_key_map) .handle_incoming_pdu(
&origin,
&event_id,
room_id,
value.clone(),
true,
&pub_key_map,
)
.await? .await?
.ok_or(Error::BadRequest( .ok_or(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
@ -1724,7 +1794,11 @@ async fn create_join_event(
.filter_map(|(_, id)| services().rooms.timeline.get_pdu_json(id).ok().flatten()) .filter_map(|(_, id)| services().rooms.timeline.get_pdu_json(id).ok().flatten())
.map(PduEvent::convert_to_outgoing_federation_event) .map(PduEvent::convert_to_outgoing_federation_event)
.collect(), .collect(),
event: None, // TODO: handle restricted joins // Event field is required if the room version supports restricted join rules.
event: Some(
to_raw_value(&CanonicalJsonValue::Object(value))
.expect("To raw json should not fail since only change was adding signature"),
),
}) })
} }
@ -1771,6 +1845,79 @@ pub async fn create_join_event_v2_route(
Ok(create_join_event::v2::Response { room_state }) Ok(create_join_event::v2::Response { room_state })
} }
/// Checks whether the given user can join the given room via a restricted join.
/// This doesn't check the current user's membership. This should be done externally,
/// either by using the state cache or attempting to authorize the event.
fn user_can_perform_restricted_join(
user_id: &UserId,
room_id: &RoomId,
room_version_id: &RoomVersionId,
) -> Result<bool> {
let join_rules_event = services().rooms.state_accessor.room_state_get(
room_id,
&StateEventType::RoomJoinRules,
"",
)?;
let Some(join_rules_event_content) = join_rules_event
.as_ref()
.map(|join_rules_event| {
serde_json::from_str::<RoomJoinRulesEventContent>(join_rules_event.content.get())
.map_err(|e| {
warn!("Invalid join rules event: {}", e);
Error::bad_database("Invalid join rules event in db.")
})
})
.transpose()?
else {
return Ok(false);
};
if matches!(
room_version_id,
RoomVersionId::V1
| RoomVersionId::V2
| RoomVersionId::V3
| RoomVersionId::V4
| RoomVersionId::V5
| RoomVersionId::V6
| RoomVersionId::V7
) {
return Ok(false);
}
let (JoinRule::Restricted(r) | JoinRule::KnockRestricted(r)) =
join_rules_event_content.join_rule
else {
return Ok(false);
};
if r.allow
.iter()
.filter_map(|rule| {
if let AllowRule::RoomMembership(membership) = rule {
Some(membership)
} else {
None
}
})
.any(|m| {
services()
.rooms
.state_cache
.is_joined(user_id, &m.room_id)
.unwrap_or(false)
})
{
Ok(true)
} else {
Err(Error::BadRequest(
ErrorKind::UnableToAuthorizeJoin,
"User is not known to be in any required room.",
))
}
}
/// # `PUT /_matrix/federation/v2/invite/{roomId}/{eventId}` /// # `PUT /_matrix/federation/v2/invite/{roomId}/{eventId}`
/// ///
/// Invites a remote user to a room. /// Invites a remote user to a room.

View file

@ -66,6 +66,8 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
self.roomuserid_joined.remove(&roomuser_id)?; self.roomuserid_joined.remove(&roomuser_id)?;
self.userroomid_leftstate.remove(&userroom_id)?; self.userroomid_leftstate.remove(&userroom_id)?;
self.roomuserid_leftcount.remove(&roomuser_id)?; self.roomuserid_leftcount.remove(&roomuser_id)?;
self.userroomid_knockedstate.remove(&userroom_id)?;
self.roomuserid_knockedcount.remove(&roomuser_id)?;
Ok(()) Ok(())
} }
@ -91,6 +93,36 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
self.roomuserid_joined.remove(&roomuser_id)?; self.roomuserid_joined.remove(&roomuser_id)?;
self.userroomid_invitestate.remove(&userroom_id)?; self.userroomid_invitestate.remove(&userroom_id)?;
self.roomuserid_invitecount.remove(&roomuser_id)?; self.roomuserid_invitecount.remove(&roomuser_id)?;
self.userroomid_knockedstate.remove(&userroom_id)?;
self.roomuserid_knockedcount.remove(&roomuser_id)?;
Ok(())
}
fn mark_as_knocked(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let mut roomuser_id = room_id.as_bytes().to_vec();
roomuser_id.push(0xff);
roomuser_id.extend_from_slice(user_id.as_bytes());
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff);
userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_knockedstate.insert(
&userroom_id,
&serde_json::to_vec(&Vec::<Raw<AnySyncStateEvent>>::new())
.expect("state to bytes always works"),
)?;
self.roomuserid_knockedcount.insert(
&roomuser_id,
&services().globals.next_count()?.to_be_bytes(),
)?;
self.userroomid_invitestate.remove(&userroom_id)?;
self.roomuserid_invitecount.remove(&roomuser_id)?;
self.userroomid_joined.remove(&userroom_id)?;
self.roomuserid_joined.remove(&roomuser_id)?;
self.userroomid_leftstate.remove(&userroom_id)?;
self.roomuserid_leftcount.remove(&roomuser_id)?;
Ok(()) Ok(())
} }
@ -604,4 +636,13 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some()) Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some())
} }
#[tracing::instrument(skip(self))]
fn is_knocked(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff);
userroom_id.extend_from_slice(room_id.as_bytes());
Ok(self.userroomid_knockedstate.get(&userroom_id)?.is_some())
}
} }

View file

@ -12,7 +12,10 @@ use lru_cache::LruCache;
use ruma::{ use ruma::{
events::{ events::{
push_rules::{PushRulesEvent, PushRulesEventContent}, push_rules::{PushRulesEvent, PushRulesEventContent},
room::message::RoomMessageEventContent, room::{
member::{MembershipState, RoomMemberEventContent},
message::RoomMessageEventContent,
},
GlobalAccountDataEvent, GlobalAccountDataEventType, StateEventType, GlobalAccountDataEvent, GlobalAccountDataEventType, StateEventType,
}, },
push::Ruleset, push::Ruleset,
@ -100,6 +103,8 @@ pub struct KeyValueDatabase {
pub(super) roomuserid_invitecount: Arc<dyn KvTree>, // InviteCount = Count pub(super) roomuserid_invitecount: Arc<dyn KvTree>, // InviteCount = Count
pub(super) userroomid_leftstate: Arc<dyn KvTree>, pub(super) userroomid_leftstate: Arc<dyn KvTree>,
pub(super) roomuserid_leftcount: Arc<dyn KvTree>, pub(super) roomuserid_leftcount: Arc<dyn KvTree>,
pub(super) userroomid_knockedstate: Arc<dyn KvTree>,
pub(super) roomuserid_knockedcount: Arc<dyn KvTree>,
pub(super) alias_userid: Arc<dyn KvTree>, // User who created the alias pub(super) alias_userid: Arc<dyn KvTree>, // User who created the alias
@ -328,6 +333,8 @@ impl KeyValueDatabase {
roomuserid_invitecount: builder.open_tree("roomuserid_invitecount")?, roomuserid_invitecount: builder.open_tree("roomuserid_invitecount")?,
userroomid_leftstate: builder.open_tree("userroomid_leftstate")?, userroomid_leftstate: builder.open_tree("userroomid_leftstate")?,
roomuserid_leftcount: builder.open_tree("roomuserid_leftcount")?, roomuserid_leftcount: builder.open_tree("roomuserid_leftcount")?,
userroomid_knockedstate: builder.open_tree("userroomid_knockedstate")?,
roomuserid_knockedcount: builder.open_tree("roomuserid_knockedcount")?,
alias_userid: builder.open_tree("alias_userid")?, alias_userid: builder.open_tree("alias_userid")?,
@ -424,7 +431,7 @@ impl KeyValueDatabase {
} }
// If the database has any data, perform data migrations before starting // If the database has any data, perform data migrations before starting
let latest_database_version = 13; let latest_database_version = 14;
if services().users.count()? > 0 { if services().users.count()? > 0 {
// MIGRATIONS // MIGRATIONS
@ -941,6 +948,49 @@ impl KeyValueDatabase {
warn!("Migration: 12 -> 13 finished"); warn!("Migration: 12 -> 13 finished");
} }
if services().globals.database_version()? < 14 {
for user in services().users.iter().filter_map(Result::ok) {
for room in
services()
.rooms
.metadata
.iter_ids()
.filter_map(|room_id| match room_id {
Ok(room_id) => Some(room_id),
Err(e) => {
warn!("Invalid room id: {e}");
None
}
})
{
if services()
.rooms
.state_accessor
.room_state_get(&room, &StateEventType::RoomMember, user.as_str())?
.map(|pdu| {
serde_json::from_str(pdu.content.get())
.map_err(|_| Error::bad_database("Invalid PDU in database."))
})
.transpose()?
.map(|content: RoomMemberEventContent| content.membership)
== Some(MembershipState::Knock)
{
services().rooms.state_cache.update_membership(
&room,
&user,
MembershipState::Knock,
&user,
None,
false,
)?;
}
}
}
services().globals.bump_database_version(14)?;
warn!("Migration: 13 -> 14 finished");
}
assert_eq!( assert_eq!(
services().globals.database_version().unwrap(), services().globals.database_version().unwrap(),
latest_database_version latest_database_version

View file

@ -305,7 +305,7 @@ impl Service {
}) })
} }
pub async fn user_can_invite( pub fn user_can_invite(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
sender: &UserId, sender: &UserId,

View file

@ -18,6 +18,9 @@ pub trait Data: Send + Sync {
) -> Result<()>; ) -> Result<()>;
fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>;
/// Marks a user as knocking on a room
fn mark_as_knocked(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>;
fn update_joined_count(&self, room_id: &RoomId) -> Result<()>; fn update_joined_count(&self, room_id: &RoomId) -> Result<()>;
fn get_our_real_users(&self, room_id: &RoomId) -> Result<Arc<HashSet<OwnedUserId>>>; fn get_our_real_users(&self, room_id: &RoomId) -> Result<Arc<HashSet<OwnedUserId>>>;
@ -106,4 +109,6 @@ pub trait Data: Send + Sync {
fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool>; fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool>;
fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool>; fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool>;
fn is_knocked(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool>;
} }

View file

@ -181,6 +181,9 @@ impl Service {
MembershipState::Leave | MembershipState::Ban => { MembershipState::Leave | MembershipState::Ban => {
self.db.mark_as_left(user_id, room_id)?; self.db.mark_as_left(user_id, room_id)?;
} }
MembershipState::Knock => {
self.db.mark_as_knocked(user_id, room_id)?;
}
_ => {} _ => {}
} }
@ -350,4 +353,9 @@ impl Service {
pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
self.db.is_left(user_id, room_id) self.db.is_left(user_id, room_id)
} }
#[tracing::instrument(skip(self))]
pub fn is_knocked(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
self.db.is_knocked(user_id, room_id)
}
} }