Compare commits
12 commits
next
...
restricted
Author | SHA1 | Date | |
---|---|---|---|
|
4a6805a517 | ||
|
3a8ed99246 | ||
|
a6da294c55 | ||
|
98c7b89fca | ||
|
ddcfcb9302 | ||
|
7fa0025c43 | ||
|
062f138cfe | ||
|
960259720b | ||
|
96f3b88132 | ||
|
6565d3fa71 | ||
|
ba4fa81620 | ||
|
ae0fa29b09 |
7 changed files with 313 additions and 63 deletions
|
@ -907,7 +907,6 @@ async fn join_room_by_id_helper(
|
||||||
.rooms
|
.rooms
|
||||||
.state_accessor
|
.state_accessor
|
||||||
.user_can_invite(room_id, &user, sender_user, &state_lock)
|
.user_can_invite(room_id, &user, sender_user, &state_lock)
|
||||||
.await
|
|
||||||
.unwrap_or(false)
|
.unwrap_or(false)
|
||||||
{
|
{
|
||||||
auth_user = Some(user);
|
auth_user = Some(user);
|
||||||
|
|
|
@ -39,7 +39,7 @@ use ruma::{
|
||||||
events::{
|
events::{
|
||||||
receipt::{ReceiptEvent, ReceiptEventContent, ReceiptType},
|
receipt::{ReceiptEvent, ReceiptEventContent, ReceiptType},
|
||||||
room::{
|
room::{
|
||||||
join_rules::{JoinRule, RoomJoinRulesEventContent},
|
join_rules::{AllowRule, JoinRule, RoomJoinRulesEventContent},
|
||||||
member::{MembershipState, RoomMemberEventContent},
|
member::{MembershipState, RoomMemberEventContent},
|
||||||
},
|
},
|
||||||
StateEventType, TimelineEventType,
|
StateEventType, TimelineEventType,
|
||||||
|
@ -49,7 +49,7 @@ use ruma::{
|
||||||
to_device::DeviceIdOrAllDevices,
|
to_device::DeviceIdOrAllDevices,
|
||||||
uint, user_id, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch,
|
uint, user_id, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch,
|
||||||
OwnedEventId, OwnedRoomId, OwnedServerName, OwnedServerSigningKeyId, OwnedUserId, RoomId,
|
OwnedEventId, OwnedRoomId, OwnedServerName, OwnedServerSigningKeyId, OwnedUserId, RoomId,
|
||||||
ServerName,
|
RoomVersionId, ServerName, UserId,
|
||||||
};
|
};
|
||||||
use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
|
use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
|
||||||
use std::{
|
use std::{
|
||||||
|
@ -1517,36 +1517,46 @@ pub async fn create_join_event_template_route(
|
||||||
);
|
);
|
||||||
let state_lock = mutex_state.lock().await;
|
let state_lock = mutex_state.lock().await;
|
||||||
|
|
||||||
// TODO: Conduit does not implement restricted join rules yet, we always reject
|
let room_version_id = services().rooms.state.get_room_version(&body.room_id)?;
|
||||||
let join_rules_event = services().rooms.state_accessor.room_state_get(
|
|
||||||
&body.room_id,
|
|
||||||
&StateEventType::RoomJoinRules,
|
|
||||||
"",
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let join_rules_event_content: Option<RoomJoinRulesEventContent> = join_rules_event
|
let join_authorized_via_users_server = if (services()
|
||||||
.as_ref()
|
.rooms
|
||||||
.map(|join_rules_event| {
|
.state_cache
|
||||||
serde_json::from_str(join_rules_event.content.get()).map_err(|e| {
|
.is_left(&body.user_id, &body.room_id)
|
||||||
warn!("Invalid join rules event: {}", e);
|
.unwrap_or(true)
|
||||||
Error::bad_database("Invalid join rules event in db.")
|
|| services()
|
||||||
})
|
.rooms
|
||||||
})
|
.state_cache
|
||||||
.transpose()?;
|
.is_knocked(&body.user_id, &body.room_id)
|
||||||
|
.unwrap_or(false))
|
||||||
|
&& user_can_perform_restricted_join(&body.user_id, &body.room_id, &room_version_id)?
|
||||||
|
{
|
||||||
|
let auth_user = services()
|
||||||
|
.rooms
|
||||||
|
.state_cache
|
||||||
|
.room_members(&body.room_id)
|
||||||
|
.filter_map(Result::ok)
|
||||||
|
.filter(|user| user.server_name() == services().globals.server_name())
|
||||||
|
.find(|user| {
|
||||||
|
services()
|
||||||
|
.rooms
|
||||||
|
.state_accessor
|
||||||
|
.user_can_invite(&body.room_id, user, &body.user_id, &state_lock)
|
||||||
|
.unwrap_or(false)
|
||||||
|
});
|
||||||
|
|
||||||
if let Some(join_rules_event_content) = join_rules_event_content {
|
if auth_user.is_some() {
|
||||||
if matches!(
|
auth_user
|
||||||
join_rules_event_content.join_rule,
|
} else {
|
||||||
JoinRule::Restricted { .. } | JoinRule::KnockRestricted { .. }
|
|
||||||
) {
|
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(
|
||||||
ErrorKind::UnableToAuthorizeJoin,
|
ErrorKind::UnableToGrantJoin,
|
||||||
"Conduit does not support restricted rooms yet.",
|
"No user on this server is able to assist in joining.",
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
let room_version_id = services().rooms.state.get_room_version(&body.room_id)?;
|
|
||||||
if !body.ver.contains(&room_version_id) {
|
if !body.ver.contains(&room_version_id) {
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(
|
||||||
ErrorKind::IncompatibleRoomVersion {
|
ErrorKind::IncompatibleRoomVersion {
|
||||||
|
@ -1564,7 +1574,7 @@ pub async fn create_join_event_template_route(
|
||||||
membership: MembershipState::Join,
|
membership: MembershipState::Join,
|
||||||
third_party_invite: None,
|
third_party_invite: None,
|
||||||
reason: None,
|
reason: None,
|
||||||
join_authorized_via_users_server: None,
|
join_authorized_via_users_server,
|
||||||
})
|
})
|
||||||
.expect("member event is valid value");
|
.expect("member event is valid value");
|
||||||
|
|
||||||
|
@ -1609,35 +1619,6 @@ async fn create_join_event(
|
||||||
.event_handler
|
.event_handler
|
||||||
.acl_check(sender_servername, room_id)?;
|
.acl_check(sender_servername, room_id)?;
|
||||||
|
|
||||||
// TODO: Conduit does not implement restricted join rules yet, we always reject
|
|
||||||
let join_rules_event = services().rooms.state_accessor.room_state_get(
|
|
||||||
room_id,
|
|
||||||
&StateEventType::RoomJoinRules,
|
|
||||||
"",
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let join_rules_event_content: Option<RoomJoinRulesEventContent> = join_rules_event
|
|
||||||
.as_ref()
|
|
||||||
.map(|join_rules_event| {
|
|
||||||
serde_json::from_str(join_rules_event.content.get()).map_err(|e| {
|
|
||||||
warn!("Invalid join rules event: {}", e);
|
|
||||||
Error::bad_database("Invalid join rules event in db.")
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.transpose()?;
|
|
||||||
|
|
||||||
if let Some(join_rules_event_content) = join_rules_event_content {
|
|
||||||
if matches!(
|
|
||||||
join_rules_event_content.join_rule,
|
|
||||||
JoinRule::Restricted { .. } | JoinRule::KnockRestricted { .. }
|
|
||||||
) {
|
|
||||||
return Err(Error::BadRequest(
|
|
||||||
ErrorKind::UnableToAuthorizeJoin,
|
|
||||||
"Conduit does not support restricted rooms yet.",
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// We need to return the state prior to joining, let's keep a reference to that here
|
// We need to return the state prior to joining, let's keep a reference to that here
|
||||||
let shortstatehash = services()
|
let shortstatehash = services()
|
||||||
.rooms
|
.rooms
|
||||||
|
@ -1653,7 +1634,8 @@ async fn create_join_event(
|
||||||
|
|
||||||
// We do not add the event_id field to the pdu here because of signature and hashes checks
|
// We do not add the event_id field to the pdu here because of signature and hashes checks
|
||||||
let room_version_id = services().rooms.state.get_room_version(room_id)?;
|
let room_version_id = services().rooms.state.get_room_version(room_id)?;
|
||||||
let (event_id, value) = match gen_event_id_canonical_json(pdu, &room_version_id) {
|
|
||||||
|
let (event_id, mut value) = match gen_event_id_canonical_json(pdu, &room_version_id) {
|
||||||
Ok(t) => t,
|
Ok(t) => t,
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
// Event could not be converted to canonical json
|
// Event could not be converted to canonical json
|
||||||
|
@ -1664,6 +1646,87 @@ async fn create_join_event(
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let state_key: OwnedUserId = serde_json::from_value(
|
||||||
|
value
|
||||||
|
.get("state_key")
|
||||||
|
.ok_or_else(|| Error::BadRequest(ErrorKind::BadJson, "State key is missing"))?
|
||||||
|
.clone()
|
||||||
|
.into(),
|
||||||
|
)
|
||||||
|
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "State key is not a valid user ID"))?;
|
||||||
|
|
||||||
|
let sender: OwnedUserId = serde_json::from_value(
|
||||||
|
value
|
||||||
|
.get("sender")
|
||||||
|
.ok_or_else(|| Error::BadRequest(ErrorKind::BadJson, "Sender is missing"))?
|
||||||
|
.clone()
|
||||||
|
.into(),
|
||||||
|
)
|
||||||
|
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Sender is not a valid user ID"))?;
|
||||||
|
|
||||||
|
if state_key != sender {
|
||||||
|
return Err(Error::BadRequest(
|
||||||
|
ErrorKind::BadJson,
|
||||||
|
"Sender and state key don't match",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Security-wise, we only really need to check the event is not from us, cause otherwise it must be signed by that server,
|
||||||
|
// but we might as well check this since this event shouldn't really be sent on behalf of another server
|
||||||
|
if state_key.server_name() != sender_servername {
|
||||||
|
return Err(Error::BadRequest(
|
||||||
|
ErrorKind::forbidden(),
|
||||||
|
"User's server and origin don't match",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let event_type: StateEventType = serde_json::from_value(
|
||||||
|
value
|
||||||
|
.get("type")
|
||||||
|
.ok_or_else(|| Error::BadRequest(ErrorKind::BadJson, "Missing event type"))?
|
||||||
|
.clone()
|
||||||
|
.into(),
|
||||||
|
)
|
||||||
|
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid event type"))?;
|
||||||
|
|
||||||
|
if event_type != StateEventType::RoomMember {
|
||||||
|
return Err(Error::BadRequest(
|
||||||
|
ErrorKind::BadJson,
|
||||||
|
"Event type is not membership",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let event_content: RoomMemberEventContent = serde_json::from_value(
|
||||||
|
value
|
||||||
|
.get("content")
|
||||||
|
.ok_or_else(|| Error::BadRequest(ErrorKind::BadJson, "Missing event content"))?
|
||||||
|
.clone()
|
||||||
|
.into(),
|
||||||
|
)
|
||||||
|
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid event content"))?;
|
||||||
|
|
||||||
|
if event_content.membership != MembershipState::Join {
|
||||||
|
return Err(Error::BadRequest(
|
||||||
|
ErrorKind::BadJson,
|
||||||
|
"Membership of sent event does not match that of the endpoint",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if event_content
|
||||||
|
.join_authorized_via_users_server
|
||||||
|
.map(|user| user.server_name() == services().globals.server_name())
|
||||||
|
.unwrap_or_default()
|
||||||
|
&& user_can_perform_restricted_join(&sender, room_id, &room_version_id).unwrap_or_default()
|
||||||
|
{
|
||||||
|
ruma::signatures::hash_and_sign_event(
|
||||||
|
services().globals.server_name().as_str(),
|
||||||
|
services().globals.keypair(),
|
||||||
|
&mut value,
|
||||||
|
&room_version_id,
|
||||||
|
)
|
||||||
|
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Failed to sign event."))?;
|
||||||
|
}
|
||||||
|
|
||||||
let origin: OwnedServerName = serde_json::from_value(
|
let origin: OwnedServerName = serde_json::from_value(
|
||||||
serde_json::to_value(value.get("origin").ok_or(Error::BadRequest(
|
serde_json::to_value(value.get("origin").ok_or(Error::BadRequest(
|
||||||
ErrorKind::InvalidParam,
|
ErrorKind::InvalidParam,
|
||||||
|
@ -1686,7 +1749,14 @@ async fn create_join_event(
|
||||||
let pdu_id: Vec<u8> = services()
|
let pdu_id: Vec<u8> = services()
|
||||||
.rooms
|
.rooms
|
||||||
.event_handler
|
.event_handler
|
||||||
.handle_incoming_pdu(&origin, &event_id, room_id, value, true, &pub_key_map)
|
.handle_incoming_pdu(
|
||||||
|
&origin,
|
||||||
|
&event_id,
|
||||||
|
room_id,
|
||||||
|
value.clone(),
|
||||||
|
true,
|
||||||
|
&pub_key_map,
|
||||||
|
)
|
||||||
.await?
|
.await?
|
||||||
.ok_or(Error::BadRequest(
|
.ok_or(Error::BadRequest(
|
||||||
ErrorKind::InvalidParam,
|
ErrorKind::InvalidParam,
|
||||||
|
@ -1724,7 +1794,11 @@ async fn create_join_event(
|
||||||
.filter_map(|(_, id)| services().rooms.timeline.get_pdu_json(id).ok().flatten())
|
.filter_map(|(_, id)| services().rooms.timeline.get_pdu_json(id).ok().flatten())
|
||||||
.map(PduEvent::convert_to_outgoing_federation_event)
|
.map(PduEvent::convert_to_outgoing_federation_event)
|
||||||
.collect(),
|
.collect(),
|
||||||
event: None, // TODO: handle restricted joins
|
// Event field is required if the room version supports restricted join rules.
|
||||||
|
event: Some(
|
||||||
|
to_raw_value(&CanonicalJsonValue::Object(value))
|
||||||
|
.expect("To raw json should not fail since only change was adding signature"),
|
||||||
|
),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1771,6 +1845,79 @@ pub async fn create_join_event_v2_route(
|
||||||
Ok(create_join_event::v2::Response { room_state })
|
Ok(create_join_event::v2::Response { room_state })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Checks whether the given user can join the given room via a restricted join.
|
||||||
|
/// This doesn't check the current user's membership. This should be done externally,
|
||||||
|
/// either by using the state cache or attempting to authorize the event.
|
||||||
|
fn user_can_perform_restricted_join(
|
||||||
|
user_id: &UserId,
|
||||||
|
room_id: &RoomId,
|
||||||
|
room_version_id: &RoomVersionId,
|
||||||
|
) -> Result<bool> {
|
||||||
|
let join_rules_event = services().rooms.state_accessor.room_state_get(
|
||||||
|
room_id,
|
||||||
|
&StateEventType::RoomJoinRules,
|
||||||
|
"",
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let Some(join_rules_event_content) = join_rules_event
|
||||||
|
.as_ref()
|
||||||
|
.map(|join_rules_event| {
|
||||||
|
serde_json::from_str::<RoomJoinRulesEventContent>(join_rules_event.content.get())
|
||||||
|
.map_err(|e| {
|
||||||
|
warn!("Invalid join rules event: {}", e);
|
||||||
|
Error::bad_database("Invalid join rules event in db.")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.transpose()?
|
||||||
|
else {
|
||||||
|
return Ok(false);
|
||||||
|
};
|
||||||
|
|
||||||
|
if matches!(
|
||||||
|
room_version_id,
|
||||||
|
RoomVersionId::V1
|
||||||
|
| RoomVersionId::V2
|
||||||
|
| RoomVersionId::V3
|
||||||
|
| RoomVersionId::V4
|
||||||
|
| RoomVersionId::V5
|
||||||
|
| RoomVersionId::V6
|
||||||
|
| RoomVersionId::V7
|
||||||
|
) {
|
||||||
|
return Ok(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
let (JoinRule::Restricted(r) | JoinRule::KnockRestricted(r)) =
|
||||||
|
join_rules_event_content.join_rule
|
||||||
|
else {
|
||||||
|
return Ok(false);
|
||||||
|
};
|
||||||
|
|
||||||
|
if r.allow
|
||||||
|
.iter()
|
||||||
|
.filter_map(|rule| {
|
||||||
|
if let AllowRule::RoomMembership(membership) = rule {
|
||||||
|
Some(membership)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.any(|m| {
|
||||||
|
services()
|
||||||
|
.rooms
|
||||||
|
.state_cache
|
||||||
|
.is_joined(user_id, &m.room_id)
|
||||||
|
.unwrap_or(false)
|
||||||
|
})
|
||||||
|
{
|
||||||
|
Ok(true)
|
||||||
|
} else {
|
||||||
|
Err(Error::BadRequest(
|
||||||
|
ErrorKind::UnableToAuthorizeJoin,
|
||||||
|
"User is not known to be in any required room.",
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// # `PUT /_matrix/federation/v2/invite/{roomId}/{eventId}`
|
/// # `PUT /_matrix/federation/v2/invite/{roomId}/{eventId}`
|
||||||
///
|
///
|
||||||
/// Invites a remote user to a room.
|
/// Invites a remote user to a room.
|
||||||
|
|
|
@ -66,6 +66,8 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
||||||
self.roomuserid_joined.remove(&roomuser_id)?;
|
self.roomuserid_joined.remove(&roomuser_id)?;
|
||||||
self.userroomid_leftstate.remove(&userroom_id)?;
|
self.userroomid_leftstate.remove(&userroom_id)?;
|
||||||
self.roomuserid_leftcount.remove(&roomuser_id)?;
|
self.roomuserid_leftcount.remove(&roomuser_id)?;
|
||||||
|
self.userroomid_knockedstate.remove(&userroom_id)?;
|
||||||
|
self.roomuserid_knockedcount.remove(&roomuser_id)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -91,6 +93,36 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
||||||
self.roomuserid_joined.remove(&roomuser_id)?;
|
self.roomuserid_joined.remove(&roomuser_id)?;
|
||||||
self.userroomid_invitestate.remove(&userroom_id)?;
|
self.userroomid_invitestate.remove(&userroom_id)?;
|
||||||
self.roomuserid_invitecount.remove(&roomuser_id)?;
|
self.roomuserid_invitecount.remove(&roomuser_id)?;
|
||||||
|
self.userroomid_knockedstate.remove(&userroom_id)?;
|
||||||
|
self.roomuserid_knockedcount.remove(&roomuser_id)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn mark_as_knocked(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
|
||||||
|
let mut roomuser_id = room_id.as_bytes().to_vec();
|
||||||
|
roomuser_id.push(0xff);
|
||||||
|
roomuser_id.extend_from_slice(user_id.as_bytes());
|
||||||
|
|
||||||
|
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||||
|
userroom_id.push(0xff);
|
||||||
|
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||||
|
|
||||||
|
self.userroomid_knockedstate.insert(
|
||||||
|
&userroom_id,
|
||||||
|
&serde_json::to_vec(&Vec::<Raw<AnySyncStateEvent>>::new())
|
||||||
|
.expect("state to bytes always works"),
|
||||||
|
)?;
|
||||||
|
self.roomuserid_knockedcount.insert(
|
||||||
|
&roomuser_id,
|
||||||
|
&services().globals.next_count()?.to_be_bytes(),
|
||||||
|
)?;
|
||||||
|
self.userroomid_invitestate.remove(&userroom_id)?;
|
||||||
|
self.roomuserid_invitecount.remove(&roomuser_id)?;
|
||||||
|
self.userroomid_joined.remove(&userroom_id)?;
|
||||||
|
self.roomuserid_joined.remove(&roomuser_id)?;
|
||||||
|
self.userroomid_leftstate.remove(&userroom_id)?;
|
||||||
|
self.roomuserid_leftcount.remove(&roomuser_id)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -604,4 +636,13 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
||||||
|
|
||||||
Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some())
|
Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(skip(self))]
|
||||||
|
fn is_knocked(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
|
||||||
|
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||||
|
userroom_id.push(0xff);
|
||||||
|
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||||
|
|
||||||
|
Ok(self.userroomid_knockedstate.get(&userroom_id)?.is_some())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,7 +12,10 @@ use lru_cache::LruCache;
|
||||||
use ruma::{
|
use ruma::{
|
||||||
events::{
|
events::{
|
||||||
push_rules::{PushRulesEvent, PushRulesEventContent},
|
push_rules::{PushRulesEvent, PushRulesEventContent},
|
||||||
room::message::RoomMessageEventContent,
|
room::{
|
||||||
|
member::{MembershipState, RoomMemberEventContent},
|
||||||
|
message::RoomMessageEventContent,
|
||||||
|
},
|
||||||
GlobalAccountDataEvent, GlobalAccountDataEventType, StateEventType,
|
GlobalAccountDataEvent, GlobalAccountDataEventType, StateEventType,
|
||||||
},
|
},
|
||||||
push::Ruleset,
|
push::Ruleset,
|
||||||
|
@ -100,6 +103,8 @@ pub struct KeyValueDatabase {
|
||||||
pub(super) roomuserid_invitecount: Arc<dyn KvTree>, // InviteCount = Count
|
pub(super) roomuserid_invitecount: Arc<dyn KvTree>, // InviteCount = Count
|
||||||
pub(super) userroomid_leftstate: Arc<dyn KvTree>,
|
pub(super) userroomid_leftstate: Arc<dyn KvTree>,
|
||||||
pub(super) roomuserid_leftcount: Arc<dyn KvTree>,
|
pub(super) roomuserid_leftcount: Arc<dyn KvTree>,
|
||||||
|
pub(super) userroomid_knockedstate: Arc<dyn KvTree>,
|
||||||
|
pub(super) roomuserid_knockedcount: Arc<dyn KvTree>,
|
||||||
|
|
||||||
pub(super) alias_userid: Arc<dyn KvTree>, // User who created the alias
|
pub(super) alias_userid: Arc<dyn KvTree>, // User who created the alias
|
||||||
|
|
||||||
|
@ -328,6 +333,8 @@ impl KeyValueDatabase {
|
||||||
roomuserid_invitecount: builder.open_tree("roomuserid_invitecount")?,
|
roomuserid_invitecount: builder.open_tree("roomuserid_invitecount")?,
|
||||||
userroomid_leftstate: builder.open_tree("userroomid_leftstate")?,
|
userroomid_leftstate: builder.open_tree("userroomid_leftstate")?,
|
||||||
roomuserid_leftcount: builder.open_tree("roomuserid_leftcount")?,
|
roomuserid_leftcount: builder.open_tree("roomuserid_leftcount")?,
|
||||||
|
userroomid_knockedstate: builder.open_tree("userroomid_knockedstate")?,
|
||||||
|
roomuserid_knockedcount: builder.open_tree("roomuserid_knockedcount")?,
|
||||||
|
|
||||||
alias_userid: builder.open_tree("alias_userid")?,
|
alias_userid: builder.open_tree("alias_userid")?,
|
||||||
|
|
||||||
|
@ -424,7 +431,7 @@ impl KeyValueDatabase {
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the database has any data, perform data migrations before starting
|
// If the database has any data, perform data migrations before starting
|
||||||
let latest_database_version = 13;
|
let latest_database_version = 14;
|
||||||
|
|
||||||
if services().users.count()? > 0 {
|
if services().users.count()? > 0 {
|
||||||
// MIGRATIONS
|
// MIGRATIONS
|
||||||
|
@ -941,6 +948,49 @@ impl KeyValueDatabase {
|
||||||
warn!("Migration: 12 -> 13 finished");
|
warn!("Migration: 12 -> 13 finished");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if services().globals.database_version()? < 14 {
|
||||||
|
for user in services().users.iter().filter_map(Result::ok) {
|
||||||
|
for room in
|
||||||
|
services()
|
||||||
|
.rooms
|
||||||
|
.metadata
|
||||||
|
.iter_ids()
|
||||||
|
.filter_map(|room_id| match room_id {
|
||||||
|
Ok(room_id) => Some(room_id),
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Invalid room id: {e}");
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
{
|
||||||
|
if services()
|
||||||
|
.rooms
|
||||||
|
.state_accessor
|
||||||
|
.room_state_get(&room, &StateEventType::RoomMember, user.as_str())?
|
||||||
|
.map(|pdu| {
|
||||||
|
serde_json::from_str(pdu.content.get())
|
||||||
|
.map_err(|_| Error::bad_database("Invalid PDU in database."))
|
||||||
|
})
|
||||||
|
.transpose()?
|
||||||
|
.map(|content: RoomMemberEventContent| content.membership)
|
||||||
|
== Some(MembershipState::Knock)
|
||||||
|
{
|
||||||
|
services().rooms.state_cache.update_membership(
|
||||||
|
&room,
|
||||||
|
&user,
|
||||||
|
MembershipState::Knock,
|
||||||
|
&user,
|
||||||
|
None,
|
||||||
|
false,
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
services().globals.bump_database_version(14)?;
|
||||||
|
|
||||||
|
warn!("Migration: 13 -> 14 finished");
|
||||||
|
}
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
services().globals.database_version().unwrap(),
|
services().globals.database_version().unwrap(),
|
||||||
latest_database_version
|
latest_database_version
|
||||||
|
|
|
@ -305,7 +305,7 @@ impl Service {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn user_can_invite(
|
pub fn user_can_invite(
|
||||||
&self,
|
&self,
|
||||||
room_id: &RoomId,
|
room_id: &RoomId,
|
||||||
sender: &UserId,
|
sender: &UserId,
|
||||||
|
|
|
@ -18,6 +18,9 @@ pub trait Data: Send + Sync {
|
||||||
) -> Result<()>;
|
) -> Result<()>;
|
||||||
fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>;
|
fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>;
|
||||||
|
|
||||||
|
/// Marks a user as knocking on a room
|
||||||
|
fn mark_as_knocked(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>;
|
||||||
|
|
||||||
fn update_joined_count(&self, room_id: &RoomId) -> Result<()>;
|
fn update_joined_count(&self, room_id: &RoomId) -> Result<()>;
|
||||||
|
|
||||||
fn get_our_real_users(&self, room_id: &RoomId) -> Result<Arc<HashSet<OwnedUserId>>>;
|
fn get_our_real_users(&self, room_id: &RoomId) -> Result<Arc<HashSet<OwnedUserId>>>;
|
||||||
|
@ -106,4 +109,6 @@ pub trait Data: Send + Sync {
|
||||||
fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool>;
|
fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool>;
|
||||||
|
|
||||||
fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool>;
|
fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool>;
|
||||||
|
|
||||||
|
fn is_knocked(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool>;
|
||||||
}
|
}
|
||||||
|
|
|
@ -181,6 +181,9 @@ impl Service {
|
||||||
MembershipState::Leave | MembershipState::Ban => {
|
MembershipState::Leave | MembershipState::Ban => {
|
||||||
self.db.mark_as_left(user_id, room_id)?;
|
self.db.mark_as_left(user_id, room_id)?;
|
||||||
}
|
}
|
||||||
|
MembershipState::Knock => {
|
||||||
|
self.db.mark_as_knocked(user_id, room_id)?;
|
||||||
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -350,4 +353,9 @@ impl Service {
|
||||||
pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
|
pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
|
||||||
self.db.is_left(user_id, room_id)
|
self.db.is_left(user_id, room_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(skip(self))]
|
||||||
|
pub fn is_knocked(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
|
||||||
|
self.db.is_knocked(user_id, room_id)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue