use chain_width 60

Signed-off-by: strawberry <strawberry@puppygock.gay>
This commit is contained in:
strawberry 2024-03-25 17:05:11 -04:00 committed by June
parent 9d6b070f35
commit 868976a149
98 changed files with 4836 additions and 1767 deletions

View file

@ -25,3 +25,4 @@ newline_style = "Unix"
use_field_init_shorthand = true use_field_init_shorthand = true
use_small_heuristics = "Off" use_small_heuristics = "Off"
use_try_shorthand = true use_try_shorthand = true
chain_width = 60

View file

@ -36,7 +36,11 @@ where
"?" "?"
}; };
parts.path_and_query = Some((old_path_and_query + symbol + "access_token=" + hs_token).parse().unwrap()); parts.path_and_query = Some(
(old_path_and_query + symbol + "access_token=" + hs_token)
.parse()
.unwrap(),
);
*http_request.uri_mut() = parts.try_into().expect("our manipulation is always valid"); *http_request.uri_mut() = parts.try_into().expect("our manipulation is always valid");
let mut reqwest_request = let mut reqwest_request =
@ -64,7 +68,13 @@ where
} }
} }
let mut response = match services().globals.client.appservice.execute(reqwest_request).await { let mut response = match services()
.globals
.client
.appservice
.execute(reqwest_request)
.await
{
Ok(r) => r, Ok(r) => r,
Err(e) => { Err(e) => {
warn!( warn!(
@ -77,10 +87,14 @@ where
// reqwest::Response -> http::Response conversion // reqwest::Response -> http::Response conversion
let status = response.status(); let status = response.status();
let mut http_response_builder = http::Response::builder().status(status).version(response.version()); let mut http_response_builder = http::Response::builder()
.status(status)
.version(response.version());
mem::swap( mem::swap(
response.headers_mut(), response.headers_mut(),
http_response_builder.headers_mut().expect("http::response::Builder is usable"), http_response_builder
.headers_mut()
.expect("http::response::Builder is usable"),
); );
let body = response.bytes().await.unwrap_or_else(|e| { let body = response.bytes().await.unwrap_or_else(|e| {
@ -99,7 +113,9 @@ where
} }
let response = T::IncomingResponse::try_from_http_response( let response = T::IncomingResponse::try_from_http_response(
http_response_builder.body(body).expect("reqwest body is valid http body"), http_response_builder
.body(body)
.expect("reqwest body is valid http body"),
); );
Some(response.map_err(|_| { Some(response.map_err(|_| {

View file

@ -47,7 +47,11 @@ pub async fn get_register_available_route(
return Err(Error::BadRequest(ErrorKind::UserInUse, "Desired user ID is already taken.")); return Err(Error::BadRequest(ErrorKind::UserInUse, "Desired user ID is already taken."));
} }
if services().globals.forbidden_usernames().is_match(user_id.localpart()) { if services()
.globals
.forbidden_usernames()
.is_match(user_id.localpart())
{
return Err(Error::BadRequest(ErrorKind::Unknown, "Username is forbidden.")); return Err(Error::BadRequest(ErrorKind::Unknown, "Username is forbidden."));
} }
@ -129,7 +133,11 @@ pub async fn register_route(body: Ruma<register::v3::Request>) -> Result<registe
return Err(Error::BadRequest(ErrorKind::UserInUse, "Desired user ID is already taken.")); return Err(Error::BadRequest(ErrorKind::UserInUse, "Desired user ID is already taken."));
} }
if services().globals.forbidden_usernames().is_match(proposed_user_id.localpart()) { if services()
.globals
.forbidden_usernames()
.is_match(proposed_user_id.localpart())
{
return Err(Error::BadRequest(ErrorKind::Unknown, "Username is forbidden.")); return Err(Error::BadRequest(ErrorKind::Unknown, "Username is forbidden."));
} }
@ -220,7 +228,10 @@ pub async fn register_route(body: Ruma<register::v3::Request>) -> Result<registe
displayname.push_str(&(" ".to_owned() + services().globals.new_user_displayname_suffix())); displayname.push_str(&(" ".to_owned() + services().globals.new_user_displayname_suffix()));
} }
services().users.set_displayname(&user_id, Some(displayname.clone())).await?; services()
.users
.set_displayname(&user_id, Some(displayname.clone()))
.await?;
// Initial account data // Initial account data
services().account_data.update( services().account_data.update(
@ -258,20 +269,26 @@ pub async fn register_route(body: Ruma<register::v3::Request>) -> Result<registe
let token = utils::random_string(TOKEN_LENGTH); let token = utils::random_string(TOKEN_LENGTH);
// Create device for this account // Create device for this account
services().users.create_device(&user_id, &device_id, &token, body.initial_device_display_name.clone())?; services()
.users
.create_device(&user_id, &device_id, &token, body.initial_device_display_name.clone())?;
info!("New user \"{}\" registered on this server.", user_id); info!("New user \"{}\" registered on this server.", user_id);
// log in conduit admin channel if a non-guest user registered // log in conduit admin channel if a non-guest user registered
if !body.from_appservice && !is_guest { if !body.from_appservice && !is_guest {
services().admin.send_message(RoomMessageEventContent::notice_plain(format!( services()
.admin
.send_message(RoomMessageEventContent::notice_plain(format!(
"New user \"{user_id}\" registered on this server." "New user \"{user_id}\" registered on this server."
))); )));
} }
// log in conduit admin channel if a guest registered // log in conduit admin channel if a guest registered
if !body.from_appservice && is_guest { if !body.from_appservice && is_guest {
services().admin.send_message(RoomMessageEventContent::notice_plain(format!( services()
.admin
.send_message(RoomMessageEventContent::notice_plain(format!(
"Guest user \"{user_id}\" with device display name `{:?}` registered on this server.", "Guest user \"{user_id}\" with device display name `{:?}` registered on this server.",
body.initial_device_display_name body.initial_device_display_name
))); )));
@ -281,8 +298,16 @@ pub async fn register_route(body: Ruma<register::v3::Request>) -> Result<registe
// users Note: the server user, @conduit:servername, is generated first // users Note: the server user, @conduit:servername, is generated first
if !is_guest { if !is_guest {
if let Some(admin_room) = service::admin::Service::get_admin_room()? { if let Some(admin_room) = service::admin::Service::get_admin_room()? {
if services().rooms.state_cache.room_joined_count(&admin_room)? == Some(1) { if services()
services().admin.make_user_admin(&user_id, displayname).await?; .rooms
.state_cache
.room_joined_count(&admin_room)?
== Some(1)
{
services()
.admin
.make_user_admin(&user_id, displayname)
.await?;
warn!("Granting {} admin privileges as the first user", user_id); warn!("Granting {} admin privileges as the first user", user_id);
} }
@ -291,7 +316,11 @@ pub async fn register_route(body: Ruma<register::v3::Request>) -> Result<registe
if !services().globals.config.auto_join_rooms.is_empty() { if !services().globals.config.auto_join_rooms.is_empty() {
for room in &services().globals.config.auto_join_rooms { for room in &services().globals.config.auto_join_rooms {
if !services().rooms.state_cache.server_in_room(services().globals.server_name(), room)? { if !services()
.rooms
.state_cache
.server_in_room(services().globals.server_name(), room)?
{
warn!("Skipping room {room} to automatically join as we have never joined before."); warn!("Skipping room {room} to automatically join as we have never joined before.");
continue; continue;
} }
@ -359,30 +388,43 @@ pub async fn change_password_route(body: Ruma<change_password::v3::Request>) ->
}; };
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = services().uiaa.try_auth(sender_user, sender_device, auth, &uiaainfo)?; let (worked, uiaainfo) = services()
.uiaa
.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
if !worked { if !worked {
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
} }
// Success! // Success!
} else if let Some(json) = body.json_body { } else if let Some(json) = body.json_body {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?; services()
.uiaa
.create(sender_user, sender_device, &uiaainfo, &json)?;
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
} else { } else {
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
} }
services().users.set_password(sender_user, Some(&body.new_password))?; services()
.users
.set_password(sender_user, Some(&body.new_password))?;
if body.logout_devices { if body.logout_devices {
// Logout all devices except the current one // Logout all devices except the current one
for id in services().users.all_device_ids(sender_user).filter_map(Result::ok).filter(|id| id != sender_device) { for id in services()
.users
.all_device_ids(sender_user)
.filter_map(Result::ok)
.filter(|id| id != sender_device)
{
services().users.remove_device(sender_user, &id)?; services().users.remove_device(sender_user, &id)?;
} }
} }
info!("User {} changed their password.", sender_user); info!("User {} changed their password.", sender_user);
services().admin.send_message(RoomMessageEventContent::notice_plain(format!( services()
.admin
.send_message(RoomMessageEventContent::notice_plain(format!(
"User {sender_user} changed their password." "User {sender_user} changed their password."
))); )));
@ -431,14 +473,18 @@ pub async fn deactivate_route(body: Ruma<deactivate::v3::Request>) -> Result<dea
}; };
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = services().uiaa.try_auth(sender_user, sender_device, auth, &uiaainfo)?; let (worked, uiaainfo) = services()
.uiaa
.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
if !worked { if !worked {
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
} }
// Success! // Success!
} else if let Some(json) = body.json_body { } else if let Some(json) = body.json_body {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?; services()
.uiaa
.create(sender_user, sender_device, &uiaainfo, &json)?;
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
} else { } else {
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
@ -451,7 +497,9 @@ pub async fn deactivate_route(body: Ruma<deactivate::v3::Request>) -> Result<dea
services().users.deactivate_account(sender_user)?; services().users.deactivate_account(sender_user)?;
info!("User {} deactivated their account.", sender_user); info!("User {} deactivated their account.", sender_user);
services().admin.send_message(RoomMessageEventContent::notice_plain(format!( services()
.admin
.send_message(RoomMessageEventContent::notice_plain(format!(
"User {sender_user} deactivated their account." "User {sender_user} deactivated their account."
))); )));

View file

@ -21,15 +21,29 @@ pub async fn create_alias_route(body: Ruma<create_alias::v3::Request>) -> Result
return Err(Error::BadRequest(ErrorKind::InvalidParam, "Alias is from another server.")); return Err(Error::BadRequest(ErrorKind::InvalidParam, "Alias is from another server."));
} }
if services().globals.forbidden_alias_names().is_match(body.room_alias.alias()) { if services()
.globals
.forbidden_alias_names()
.is_match(body.room_alias.alias())
{
return Err(Error::BadRequest(ErrorKind::Unknown, "Room alias is forbidden.")); return Err(Error::BadRequest(ErrorKind::Unknown, "Room alias is forbidden."));
} }
if services().rooms.alias.resolve_local_alias(&body.room_alias)?.is_some() { if services()
.rooms
.alias
.resolve_local_alias(&body.room_alias)?
.is_some()
{
return Err(Error::Conflict("Alias already exists.")); return Err(Error::Conflict("Alias already exists."));
} }
if services().rooms.alias.set_alias(&body.room_alias, &body.room_id).is_err() { if services()
.rooms
.alias
.set_alias(&body.room_alias, &body.room_id)
.is_err()
{
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"Invalid room alias. Alias must be in the form of '#localpart:server_name'", "Invalid room alias. Alias must be in the form of '#localpart:server_name'",
@ -50,11 +64,21 @@ pub async fn delete_alias_route(body: Ruma<delete_alias::v3::Request>) -> Result
return Err(Error::BadRequest(ErrorKind::InvalidParam, "Alias is from another server.")); return Err(Error::BadRequest(ErrorKind::InvalidParam, "Alias is from another server."));
} }
if services().rooms.alias.resolve_local_alias(&body.room_alias)?.is_none() { if services()
.rooms
.alias
.resolve_local_alias(&body.room_alias)?
.is_none()
{
return Err(Error::BadRequest(ErrorKind::NotFound, "Alias does not exist.")); return Err(Error::BadRequest(ErrorKind::NotFound, "Alias does not exist."));
} }
if services().rooms.alias.remove_alias(&body.room_alias).is_err() { if services()
.rooms
.alias
.remove_alias(&body.room_alias)
.is_err()
{
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"Invalid room alias. Alias must be in the form of '#localpart:server_name'", "Invalid room alias. Alias must be in the form of '#localpart:server_name'",
@ -90,13 +114,20 @@ pub(crate) async fn get_alias_helper(room_alias: OwnedRoomAliasId) -> Result<get
let mut servers = response.servers; let mut servers = response.servers;
// find active servers in room state cache to suggest // find active servers in room state cache to suggest
for extra_servers in services().rooms.state_cache.room_servers(&room_id).filter_map(Result::ok) { for extra_servers in services()
.rooms
.state_cache
.room_servers(&room_id)
.filter_map(Result::ok)
{
servers.push(extra_servers); servers.push(extra_servers);
} }
// insert our server as the very first choice if in list // insert our server as the very first choice if in list
if let Some(server_index) = if let Some(server_index) = servers
servers.clone().into_iter().position(|server| server == services().globals.server_name()) .clone()
.into_iter()
.position(|server| server == services().globals.server_name())
{ {
servers.remove(server_index); servers.remove(server_index);
servers.insert(0, services().globals.server_name().to_owned()); servers.insert(0, services().globals.server_name().to_owned());
@ -151,13 +182,20 @@ pub(crate) async fn get_alias_helper(room_alias: OwnedRoomAliasId) -> Result<get
let mut servers: Vec<OwnedServerName> = Vec::new(); let mut servers: Vec<OwnedServerName> = Vec::new();
// find active servers in room state cache to suggest // find active servers in room state cache to suggest
for extra_servers in services().rooms.state_cache.room_servers(&room_id).filter_map(Result::ok) { for extra_servers in services()
.rooms
.state_cache
.room_servers(&room_id)
.filter_map(Result::ok)
{
servers.push(extra_servers); servers.push(extra_servers);
} }
// insert our server as the very first choice if in list // insert our server as the very first choice if in list
if let Some(server_index) = if let Some(server_index) = servers
servers.clone().into_iter().position(|server| server == services().globals.server_name()) .clone()
.into_iter()
.position(|server| server == services().globals.server_name())
{ {
servers.remove(server_index); servers.remove(server_index);
servers.insert(0, services().globals.server_name().to_owned()); servers.insert(0, services().globals.server_name().to_owned());

View file

@ -17,7 +17,9 @@ pub async fn create_backup_version_route(
body: Ruma<create_backup_version::v3::Request>, body: Ruma<create_backup_version::v3::Request>,
) -> Result<create_backup_version::v3::Response> { ) -> Result<create_backup_version::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let version = services().key_backups.create_backup(sender_user, &body.algorithm)?; let version = services()
.key_backups
.create_backup(sender_user, &body.algorithm)?;
Ok(create_backup_version::v3::Response { Ok(create_backup_version::v3::Response {
version, version,
@ -32,7 +34,9 @@ pub async fn update_backup_version_route(
body: Ruma<update_backup_version::v3::Request>, body: Ruma<update_backup_version::v3::Request>,
) -> Result<update_backup_version::v3::Response> { ) -> Result<update_backup_version::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services().key_backups.update_backup(sender_user, &body.version, &body.algorithm)?; services()
.key_backups
.update_backup(sender_user, &body.version, &body.algorithm)?;
Ok(update_backup_version::v3::Response {}) Ok(update_backup_version::v3::Response {})
} }
@ -70,8 +74,13 @@ pub async fn get_backup_info_route(body: Ruma<get_backup_info::v3::Request>) ->
Ok(get_backup_info::v3::Response { Ok(get_backup_info::v3::Response {
algorithm, algorithm,
count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(), count: (services()
etag: services().key_backups.get_etag(sender_user, &body.version)?, .key_backups
.count_keys(sender_user, &body.version)? as u32)
.into(),
etag: services()
.key_backups
.get_etag(sender_user, &body.version)?,
version: body.version.clone(), version: body.version.clone(),
}) })
} }
@ -87,7 +96,9 @@ pub async fn delete_backup_version_route(
) -> Result<delete_backup_version::v3::Response> { ) -> Result<delete_backup_version::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services().key_backups.delete_backup(sender_user, &body.version)?; services()
.key_backups
.delete_backup(sender_user, &body.version)?;
Ok(delete_backup_version::v3::Response {}) Ok(delete_backup_version::v3::Response {})
} }
@ -103,7 +114,12 @@ pub async fn delete_backup_version_route(
pub async fn add_backup_keys_route(body: Ruma<add_backup_keys::v3::Request>) -> Result<add_backup_keys::v3::Response> { pub async fn add_backup_keys_route(body: Ruma<add_backup_keys::v3::Request>) -> Result<add_backup_keys::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if Some(&body.version) != services().key_backups.get_latest_backup_version(sender_user)?.as_ref() { if Some(&body.version)
!= services()
.key_backups
.get_latest_backup_version(sender_user)?
.as_ref()
{
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"You may only manipulate the most recently created version of the backup.", "You may only manipulate the most recently created version of the backup.",
@ -112,13 +128,20 @@ pub async fn add_backup_keys_route(body: Ruma<add_backup_keys::v3::Request>) ->
for (room_id, room) in &body.rooms { for (room_id, room) in &body.rooms {
for (session_id, key_data) in &room.sessions { for (session_id, key_data) in &room.sessions {
services().key_backups.add_key(sender_user, &body.version, room_id, session_id, key_data)?; services()
.key_backups
.add_key(sender_user, &body.version, room_id, session_id, key_data)?;
} }
} }
Ok(add_backup_keys::v3::Response { Ok(add_backup_keys::v3::Response {
count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(), count: (services()
etag: services().key_backups.get_etag(sender_user, &body.version)?, .key_backups
.count_keys(sender_user, &body.version)? as u32)
.into(),
etag: services()
.key_backups
.get_etag(sender_user, &body.version)?,
}) })
} }
@ -135,7 +158,12 @@ pub async fn add_backup_keys_for_room_route(
) -> Result<add_backup_keys_for_room::v3::Response> { ) -> Result<add_backup_keys_for_room::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if Some(&body.version) != services().key_backups.get_latest_backup_version(sender_user)?.as_ref() { if Some(&body.version)
!= services()
.key_backups
.get_latest_backup_version(sender_user)?
.as_ref()
{
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"You may only manipulate the most recently created version of the backup.", "You may only manipulate the most recently created version of the backup.",
@ -143,12 +171,19 @@ pub async fn add_backup_keys_for_room_route(
} }
for (session_id, key_data) in &body.sessions { for (session_id, key_data) in &body.sessions {
services().key_backups.add_key(sender_user, &body.version, &body.room_id, session_id, key_data)?; services()
.key_backups
.add_key(sender_user, &body.version, &body.room_id, session_id, key_data)?;
} }
Ok(add_backup_keys_for_room::v3::Response { Ok(add_backup_keys_for_room::v3::Response {
count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(), count: (services()
etag: services().key_backups.get_etag(sender_user, &body.version)?, .key_backups
.count_keys(sender_user, &body.version)? as u32)
.into(),
etag: services()
.key_backups
.get_etag(sender_user, &body.version)?,
}) })
} }
@ -165,18 +200,30 @@ pub async fn add_backup_keys_for_session_route(
) -> Result<add_backup_keys_for_session::v3::Response> { ) -> Result<add_backup_keys_for_session::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if Some(&body.version) != services().key_backups.get_latest_backup_version(sender_user)?.as_ref() { if Some(&body.version)
!= services()
.key_backups
.get_latest_backup_version(sender_user)?
.as_ref()
{
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"You may only manipulate the most recently created version of the backup.", "You may only manipulate the most recently created version of the backup.",
)); ));
} }
services().key_backups.add_key(sender_user, &body.version, &body.room_id, &body.session_id, &body.session_data)?; services()
.key_backups
.add_key(sender_user, &body.version, &body.room_id, &body.session_id, &body.session_data)?;
Ok(add_backup_keys_for_session::v3::Response { Ok(add_backup_keys_for_session::v3::Response {
count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(), count: (services()
etag: services().key_backups.get_etag(sender_user, &body.version)?, .key_backups
.count_keys(sender_user, &body.version)? as u32)
.into(),
etag: services()
.key_backups
.get_etag(sender_user, &body.version)?,
}) })
} }
@ -201,7 +248,9 @@ pub async fn get_backup_keys_for_room_route(
) -> Result<get_backup_keys_for_room::v3::Response> { ) -> Result<get_backup_keys_for_room::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sessions = services().key_backups.get_room(sender_user, &body.version, &body.room_id)?; let sessions = services()
.key_backups
.get_room(sender_user, &body.version, &body.room_id)?;
Ok(get_backup_keys_for_room::v3::Response { Ok(get_backup_keys_for_room::v3::Response {
sessions, sessions,
@ -216,10 +265,13 @@ pub async fn get_backup_keys_for_session_route(
) -> Result<get_backup_keys_for_session::v3::Response> { ) -> Result<get_backup_keys_for_session::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let key_data = let key_data = services()
services().key_backups.get_session(sender_user, &body.version, &body.room_id, &body.session_id)?.ok_or( .key_backups
Error::BadRequest(ErrorKind::NotFound, "Backup key not found for this user's session."), .get_session(sender_user, &body.version, &body.room_id, &body.session_id)?
)?; .ok_or(Error::BadRequest(
ErrorKind::NotFound,
"Backup key not found for this user's session.",
))?;
Ok(get_backup_keys_for_session::v3::Response { Ok(get_backup_keys_for_session::v3::Response {
key_data, key_data,
@ -234,11 +286,18 @@ pub async fn delete_backup_keys_route(
) -> Result<delete_backup_keys::v3::Response> { ) -> Result<delete_backup_keys::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services().key_backups.delete_all_keys(sender_user, &body.version)?; services()
.key_backups
.delete_all_keys(sender_user, &body.version)?;
Ok(delete_backup_keys::v3::Response { Ok(delete_backup_keys::v3::Response {
count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(), count: (services()
etag: services().key_backups.get_etag(sender_user, &body.version)?, .key_backups
.count_keys(sender_user, &body.version)? as u32)
.into(),
etag: services()
.key_backups
.get_etag(sender_user, &body.version)?,
}) })
} }
@ -250,11 +309,18 @@ pub async fn delete_backup_keys_for_room_route(
) -> Result<delete_backup_keys_for_room::v3::Response> { ) -> Result<delete_backup_keys_for_room::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services().key_backups.delete_room_keys(sender_user, &body.version, &body.room_id)?; services()
.key_backups
.delete_room_keys(sender_user, &body.version, &body.room_id)?;
Ok(delete_backup_keys_for_room::v3::Response { Ok(delete_backup_keys_for_room::v3::Response {
count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(), count: (services()
etag: services().key_backups.get_etag(sender_user, &body.version)?, .key_backups
.count_keys(sender_user, &body.version)? as u32)
.into(),
etag: services()
.key_backups
.get_etag(sender_user, &body.version)?,
}) })
} }
@ -266,10 +332,17 @@ pub async fn delete_backup_keys_for_session_route(
) -> Result<delete_backup_keys_for_session::v3::Response> { ) -> Result<delete_backup_keys_for_session::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services().key_backups.delete_room_key(sender_user, &body.version, &body.room_id, &body.session_id)?; services()
.key_backups
.delete_room_key(sender_user, &body.version, &body.room_id, &body.session_id)?;
Ok(delete_backup_keys_for_session::v3::Response { Ok(delete_backup_keys_for_session::v3::Response {
count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(), count: (services()
etag: services().key_backups.get_etag(sender_user, &body.version)?, .key_backups
.count_keys(sender_user, &body.version)? as u32)
.into(),
etag: services()
.key_backups
.get_etag(sender_user, &body.version)?,
}) })
} }

View file

@ -42,7 +42,11 @@ pub async fn get_context_route(body: Ruma<get_context::v3::Request>) -> Result<g
let room_id = base_event.room_id.clone(); let room_id = base_event.room_id.clone();
if !services().rooms.state_accessor.user_can_see_event(sender_user, &room_id, &body.event_id)? { if !services()
.rooms
.state_accessor
.user_can_see_event(sender_user, &room_id, &body.event_id)?
{
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
"You don't have permission to view this event.", "You don't have permission to view this event.",
@ -91,9 +95,14 @@ pub async fn get_context_route(body: Ruma<get_context::v3::Request>) -> Result<g
} }
} }
let start_token = events_before.last().map_or_else(|| base_token.stringify(), |(count, _)| count.stringify()); let start_token = events_before
.last()
.map_or_else(|| base_token.stringify(), |(count, _)| count.stringify());
let events_before: Vec<_> = events_before.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect(); let events_before: Vec<_> = events_before
.into_iter()
.map(|(_, pdu)| pdu.to_room_event())
.collect();
let events_after: Vec<_> = services() let events_after: Vec<_> = services()
.rooms .rooms
@ -122,25 +131,41 @@ pub async fn get_context_route(body: Ruma<get_context::v3::Request>) -> Result<g
} }
} }
let shortstatehash = match services() let shortstatehash = match services().rooms.state_accessor.pdu_shortstatehash(
.rooms events_after
.state_accessor .last()
.pdu_shortstatehash(events_after.last().map_or(&*body.event_id, |(_, e)| &*e.event_id))? .map_or(&*body.event_id, |(_, e)| &*e.event_id),
{ )? {
Some(s) => s, Some(s) => s,
None => services().rooms.state.get_room_shortstatehash(&room_id)?.expect("All rooms have state"), None => services()
.rooms
.state
.get_room_shortstatehash(&room_id)?
.expect("All rooms have state"),
}; };
let state_ids = services().rooms.state_accessor.state_full_ids(shortstatehash).await?; let state_ids = services()
.rooms
.state_accessor
.state_full_ids(shortstatehash)
.await?;
let end_token = events_after.last().map_or_else(|| base_token.stringify(), |(count, _)| count.stringify()); let end_token = events_after
.last()
.map_or_else(|| base_token.stringify(), |(count, _)| count.stringify());
let events_after: Vec<_> = events_after.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect(); let events_after: Vec<_> = events_after
.into_iter()
.map(|(_, pdu)| pdu.to_room_event())
.collect();
let mut state = Vec::new(); let mut state = Vec::new();
for (shortstatekey, id) in state_ids { for (shortstatekey, id) in state_ids {
let (event_type, state_key) = services().rooms.short.get_statekey_from_short(shortstatekey)?; let (event_type, state_key) = services()
.rooms
.short
.get_statekey_from_short(shortstatekey)?;
if event_type != StateEventType::RoomMember { if event_type != StateEventType::RoomMember {
let Some(pdu) = services().rooms.timeline.get_pdu(&id)? else { let Some(pdu) = services().rooms.timeline.get_pdu(&id)? else {

View file

@ -53,7 +53,9 @@ pub async fn update_device_route(body: Ruma<update_device::v3::Request>) -> Resu
device.display_name.clone_from(&body.display_name); device.display_name.clone_from(&body.display_name);
services().users.update_device_metadata(sender_user, &body.device_id, &device)?; services()
.users
.update_device_metadata(sender_user, &body.device_id, &device)?;
Ok(update_device::v3::Response {}) Ok(update_device::v3::Response {})
} }
@ -84,20 +86,26 @@ pub async fn delete_device_route(body: Ruma<delete_device::v3::Request>) -> Resu
}; };
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = services().uiaa.try_auth(sender_user, sender_device, auth, &uiaainfo)?; let (worked, uiaainfo) = services()
.uiaa
.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
if !worked { if !worked {
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
} }
// Success! // Success!
} else if let Some(json) = body.json_body { } else if let Some(json) = body.json_body {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?; services()
.uiaa
.create(sender_user, sender_device, &uiaainfo, &json)?;
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
} else { } else {
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
} }
services().users.remove_device(sender_user, &body.device_id)?; services()
.users
.remove_device(sender_user, &body.device_id)?;
Ok(delete_device::v3::Response {}) Ok(delete_device::v3::Response {})
} }
@ -130,14 +138,18 @@ pub async fn delete_devices_route(body: Ruma<delete_devices::v3::Request>) -> Re
}; };
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = services().uiaa.try_auth(sender_user, sender_device, auth, &uiaainfo)?; let (worked, uiaainfo) = services()
.uiaa
.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
if !worked { if !worked {
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
} }
// Success! // Success!
} else if let Some(json) = body.json_body { } else if let Some(json) = body.json_body {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?; services()
.uiaa
.create(sender_user, sender_device, &uiaainfo, &json)?;
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
} else { } else {
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));

View file

@ -34,7 +34,11 @@ use crate::{services, Error, Result, Ruma};
pub async fn get_public_rooms_filtered_route( pub async fn get_public_rooms_filtered_route(
body: Ruma<get_public_rooms_filtered::v3::Request>, body: Ruma<get_public_rooms_filtered::v3::Request>,
) -> Result<get_public_rooms_filtered::v3::Response> { ) -> Result<get_public_rooms_filtered::v3::Response> {
if !services().globals.config.allow_public_room_directory_without_auth { if !services()
.globals
.config
.allow_public_room_directory_without_auth
{
let _sender_user = body.sender_user.as_ref().expect("user is authenticated"); let _sender_user = body.sender_user.as_ref().expect("user is authenticated");
} }
@ -56,7 +60,11 @@ pub async fn get_public_rooms_filtered_route(
pub async fn get_public_rooms_route( pub async fn get_public_rooms_route(
body: Ruma<get_public_rooms::v3::Request>, body: Ruma<get_public_rooms::v3::Request>,
) -> Result<get_public_rooms::v3::Response> { ) -> Result<get_public_rooms::v3::Response> {
if !services().globals.config.allow_public_room_directory_without_auth { if !services()
.globals
.config
.allow_public_room_directory_without_auth
{
let _sender_user = body.sender_user.as_ref().expect("user is authenticated"); let _sender_user = body.sender_user.as_ref().expect("user is authenticated");
} }
@ -325,7 +333,11 @@ pub(crate) async fn get_public_rooms_filtered_helper(
let total_room_count_estimate = (all_rooms.len() as u32).into(); let total_room_count_estimate = (all_rooms.len() as u32).into();
let chunk: Vec<_> = all_rooms.into_iter().skip(num_since as usize).take(limit as usize).collect(); let chunk: Vec<_> = all_rooms
.into_iter()
.skip(num_since as usize)
.take(limit as usize)
.collect();
let prev_batch = if num_since == 0 { let prev_batch = if num_since == 0 {
None None

View file

@ -34,19 +34,29 @@ pub async fn upload_keys_route(body: Ruma<upload_keys::v3::Request>) -> Result<u
let sender_device = body.sender_device.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated");
for (key_key, key_value) in &body.one_time_keys { for (key_key, key_value) in &body.one_time_keys {
services().users.add_one_time_key(sender_user, sender_device, key_key, key_value)?; services()
.users
.add_one_time_key(sender_user, sender_device, key_key, key_value)?;
} }
if let Some(device_keys) = &body.device_keys { if let Some(device_keys) = &body.device_keys {
// TODO: merge this and the existing event? // TODO: merge this and the existing event?
// This check is needed to assure that signatures are kept // This check is needed to assure that signatures are kept
if services().users.get_device_keys(sender_user, sender_device)?.is_none() { if services()
services().users.add_device_keys(sender_user, sender_device, device_keys)?; .users
.get_device_keys(sender_user, sender_device)?
.is_none()
{
services()
.users
.add_device_keys(sender_user, sender_device, device_keys)?;
} }
} }
Ok(upload_keys::v3::Response { Ok(upload_keys::v3::Response {
one_time_key_counts: services().users.count_one_time_keys(sender_user, sender_device)?, one_time_key_counts: services()
.users
.count_one_time_keys(sender_user, sender_device)?,
}) })
} }
@ -104,14 +114,18 @@ pub async fn upload_signing_keys_route(
}; };
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = services().uiaa.try_auth(sender_user, sender_device, auth, &uiaainfo)?; let (worked, uiaainfo) = services()
.uiaa
.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
if !worked { if !worked {
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
} }
// Success! // Success!
} else if let Some(json) = body.json_body { } else if let Some(json) = body.json_body {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?; services()
.uiaa
.create(sender_user, sender_device, &uiaainfo, &json)?;
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
} else { } else {
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
@ -161,7 +175,9 @@ pub async fn upload_signatures_route(
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid signature value."))? .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid signature value."))?
.to_owned(), .to_owned(),
); );
services().users.sign_key(user_id, key_id, signature, sender_user)?; services()
.users
.sign_key(user_id, key_id, signature, sender_user)?;
} }
} }
} }
@ -187,20 +203,37 @@ pub async fn get_key_changes_route(body: Ruma<get_key_changes::v3::Request>) ->
.users .users
.keys_changed( .keys_changed(
sender_user.as_str(), sender_user.as_str(),
body.from.parse().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?, body.from
Some(body.to.parse().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?), .parse()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?,
Some(
body.to
.parse()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?,
),
) )
.filter_map(Result::ok), .filter_map(Result::ok),
); );
for room_id in services().rooms.state_cache.rooms_joined(sender_user).filter_map(Result::ok) { for room_id in services()
.rooms
.state_cache
.rooms_joined(sender_user)
.filter_map(Result::ok)
{
device_list_updates.extend( device_list_updates.extend(
services() services()
.users .users
.keys_changed( .keys_changed(
room_id.as_ref(), room_id.as_ref(),
body.from.parse().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?, body.from
Some(body.to.parse().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?), .parse()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?,
Some(
body.to
.parse()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?,
),
) )
.filter_map(Result::ok), .filter_map(Result::ok),
); );
@ -226,7 +259,10 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
let user_id: &UserId = user_id; let user_id: &UserId = user_id;
if user_id.server_name() != services().globals.server_name() { if user_id.server_name() != services().globals.server_name() {
get_over_federation.entry(user_id.server_name()).or_insert_with(Vec::new).push((user_id, device_ids)); get_over_federation
.entry(user_id.server_name())
.or_insert_with(Vec::new)
.push((user_id, device_ids));
continue; continue;
} }
@ -251,9 +287,13 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
for device_id in device_ids { for device_id in device_ids {
let mut container = BTreeMap::new(); let mut container = BTreeMap::new();
if let Some(mut keys) = services().users.get_device_keys(user_id, device_id)? { if let Some(mut keys) = services().users.get_device_keys(user_id, device_id)? {
let metadata = services().users.get_device_metadata(user_id, device_id)?.ok_or( let metadata = services()
Error::BadRequest(ErrorKind::InvalidParam, "Tried to get keys for nonexistent device."), .users
)?; .get_device_metadata(user_id, device_id)?
.ok_or(Error::BadRequest(
ErrorKind::InvalidParam,
"Tried to get keys for nonexistent device.",
))?;
add_unsigned_device_display_name(&mut keys, metadata, include_display_names) add_unsigned_device_display_name(&mut keys, metadata, include_display_names)
.map_err(|_| Error::bad_database("invalid device keys in database"))?; .map_err(|_| Error::bad_database("invalid device keys in database"))?;
@ -263,11 +303,16 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
} }
} }
if let Some(master_key) = services().users.get_master_key(sender_user, user_id, &allowed_signatures)? { if let Some(master_key) = services()
.users
.get_master_key(sender_user, user_id, &allowed_signatures)?
{
master_keys.insert(user_id.to_owned(), master_key); master_keys.insert(user_id.to_owned(), master_key);
} }
if let Some(self_signing_key) = if let Some(self_signing_key) =
services().users.get_self_signing_key(sender_user, user_id, &allowed_signatures)? services()
.users
.get_self_signing_key(sender_user, user_id, &allowed_signatures)?
{ {
self_signing_keys.insert(user_id.to_owned(), self_signing_key); self_signing_keys.insert(user_id.to_owned(), self_signing_key);
} }
@ -281,7 +326,13 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
let mut failures = BTreeMap::new(); let mut failures = BTreeMap::new();
let back_off = |id| async { let back_off = |id| async {
match services().globals.bad_query_ratelimiter.write().await.entry(id) { match services()
.globals
.bad_query_ratelimiter
.write()
.await
.entry(id)
{
hash_map::Entry::Vacant(e) => { hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1)); e.insert((Instant::now(), 1));
}, },
@ -292,7 +343,13 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
let mut futures: FuturesUnordered<_> = get_over_federation let mut futures: FuturesUnordered<_> = get_over_federation
.into_iter() .into_iter()
.map(|(server, vec)| async move { .map(|(server, vec)| async move {
if let Some((time, tries)) = services().globals.bad_query_ratelimiter.read().await.get(server) { if let Some((time, tries)) = services()
.globals
.bad_query_ratelimiter
.read()
.await
.get(server)
{
// Exponential backoff // Exponential backoff
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
@ -336,7 +393,9 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
let (master_key_id, mut master_key) = services().users.parse_master_key(&user, &masterkey)?; let (master_key_id, mut master_key) = services().users.parse_master_key(&user, &masterkey)?;
if let Some(our_master_key) = if let Some(our_master_key) =
services().users.get_key(&master_key_id, sender_user, &user, &allowed_signatures)? services()
.users
.get_key(&master_key_id, sender_user, &user, &allowed_signatures)?
{ {
let (_, our_master_key) = services().users.parse_master_key(&user, &our_master_key)?; let (_, our_master_key) = services().users.parse_master_key(&user, &our_master_key)?;
master_key.signatures.extend(our_master_key.signatures); master_key.signatures.extend(our_master_key.signatures);
@ -404,12 +463,18 @@ pub(crate) async fn claim_keys_helper(
for (user_id, map) in one_time_keys_input { for (user_id, map) in one_time_keys_input {
if user_id.server_name() != services().globals.server_name() { if user_id.server_name() != services().globals.server_name() {
get_over_federation.entry(user_id.server_name()).or_insert_with(Vec::new).push((user_id, map)); get_over_federation
.entry(user_id.server_name())
.or_insert_with(Vec::new)
.push((user_id, map));
} }
let mut container = BTreeMap::new(); let mut container = BTreeMap::new();
for (device_id, key_algorithm) in map { for (device_id, key_algorithm) in map {
if let Some(one_time_keys) = services().users.take_one_time_key(user_id, device_id, key_algorithm)? { if let Some(one_time_keys) = services()
.users
.take_one_time_key(user_id, device_id, key_algorithm)?
{
let mut c = BTreeMap::new(); let mut c = BTreeMap::new();
c.insert(one_time_keys.0, one_time_keys.1); c.insert(one_time_keys.0, one_time_keys.1);
container.insert(device_id.clone(), c); container.insert(device_id.clone(), c);

View file

@ -152,7 +152,10 @@ pub async fn create_content_route(body: Ruma<create_content::v3::Request>) -> Re
.create( .create(
Some(sender_user.clone()), Some(sender_user.clone()),
mxc.clone(), mxc.clone(),
body.filename.as_ref().map(|filename| "inline; filename=".to_owned() + filename).as_deref(), body.filename
.as_ref()
.map(|filename| "inline; filename=".to_owned() + filename)
.as_deref(),
body.content_type.as_deref(), body.content_type.as_deref(),
&body.file, &body.file,
) )
@ -192,7 +195,10 @@ pub async fn create_content_v1_route(
.create( .create(
Some(sender_user.clone()), Some(sender_user.clone()),
mxc.clone(), mxc.clone(),
body.filename.as_ref().map(|filename| "inline; filename=".to_owned() + filename).as_deref(), body.filename
.as_ref()
.map(|filename| "inline; filename=".to_owned() + filename)
.as_deref(),
body.content_type.as_deref(), body.content_type.as_deref(),
&body.file, &body.file,
) )
@ -213,7 +219,11 @@ pub async fn get_remote_content(
) -> Result<get_content::v3::Response, Error> { ) -> Result<get_content::v3::Response, Error> {
// we'll lie to the client and say the blocked server's media was not found and // we'll lie to the client and say the blocked server's media was not found and
// log. the client has no way of telling anyways so this is a security bonus. // log. the client has no way of telling anyways so this is a security bonus.
if services().globals.prevent_media_downloads_from().contains(&server_name.to_owned()) { if services()
.globals
.prevent_media_downloads_from()
.contains(&server_name.to_owned())
{
info!( info!(
"Received request for remote media `{}` but server is in our media server blocklist. Returning 404.", "Received request for remote media `{}` but server is in our media server blocklist. Returning 404.",
mxc mxc
@ -451,8 +461,12 @@ pub async fn get_content_thumbnail_route(
.media .media
.get_thumbnail( .get_thumbnail(
mxc.clone(), mxc.clone(),
body.width.try_into().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?, body.width
body.height.try_into().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Height is invalid."))?, .try_into()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?,
body.height
.try_into()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Height is invalid."))?,
) )
.await? .await?
{ {
@ -464,7 +478,11 @@ pub async fn get_content_thumbnail_route(
} else if &*body.server_name != services().globals.server_name() && body.allow_remote { } else if &*body.server_name != services().globals.server_name() && body.allow_remote {
// we'll lie to the client and say the blocked server's media was not found and // we'll lie to the client and say the blocked server's media was not found and
// log. the client has no way of telling anyways so this is a security bonus. // log. the client has no way of telling anyways so this is a security bonus.
if services().globals.prevent_media_downloads_from().contains(&body.server_name.clone()) { if services()
.globals
.prevent_media_downloads_from()
.contains(&body.server_name.clone())
{
info!( info!(
"Received request for remote media `{}` but server is in our media server blocklist. Returning 404.", "Received request for remote media `{}` but server is in our media server blocklist. Returning 404.",
mxc mxc
@ -533,8 +551,12 @@ pub async fn get_content_thumbnail_v1_route(
.media .media
.get_thumbnail( .get_thumbnail(
mxc.clone(), mxc.clone(),
body.width.try_into().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?, body.width
body.height.try_into().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Height is invalid."))?, .try_into()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?,
body.height
.try_into()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Height is invalid."))?,
) )
.await? .await?
{ {
@ -547,7 +569,11 @@ pub async fn get_content_thumbnail_v1_route(
} else if &*body.server_name != services().globals.server_name() && body.allow_remote { } else if &*body.server_name != services().globals.server_name() && body.allow_remote {
// we'll lie to the client and say the blocked server's media was not found and // we'll lie to the client and say the blocked server's media was not found and
// log. the client has no way of telling anyways so this is a security bonus. // log. the client has no way of telling anyways so this is a security bonus.
if services().globals.prevent_media_downloads_from().contains(&body.server_name.clone()) { if services()
.globals
.prevent_media_downloads_from()
.contains(&body.server_name.clone())
{
info!( info!(
"Received request for remote media `{}` but server is in our media server blocklist. Returning 404.", "Received request for remote media `{}` but server is in our media server blocklist. Returning 404.",
mxc mxc
@ -599,7 +625,10 @@ async fn download_image(client: &reqwest::Client, url: &str) -> Result<UrlPrevie
utils::random_string(MXC_LENGTH) utils::random_string(MXC_LENGTH)
); );
services().media.create(None, mxc.clone(), None, None, &image).await?; services()
.media
.create(None, mxc.clone(), None, None, &image)
.await?;
let (width, height) = match ImgReader::new(Cursor::new(&image)).with_guessed_format() { let (width, height) = match ImgReader::new(Cursor::new(&image)).with_guessed_format() {
Err(_) => (None, None), Err(_) => (None, None),
@ -749,14 +778,21 @@ async fn request_url_preview(url: &str) -> Result<UrlPreviewData> {
} }
} }
if !response.remote_addr().map_or(false, |a| url_request_allowed(&a.ip())) { if !response
.remote_addr()
.map_or(false, |a| url_request_allowed(&a.ip()))
{
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
"Requesting from this address is forbidden", "Requesting from this address is forbidden",
)); ));
} }
let content_type = match response.headers().get(reqwest::header::CONTENT_TYPE).and_then(|x| x.to_str().ok()) { let content_type = match response
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|x| x.to_str().ok())
{
Some(ct) => ct, Some(ct) => ct,
None => return Err(Error::BadRequest(ErrorKind::Unknown, "Unknown Content-Type")), None => return Err(Error::BadRequest(ErrorKind::Unknown, "Unknown Content-Type")),
}; };
@ -777,7 +813,15 @@ async fn get_url_preview(url: &str) -> Result<UrlPreviewData> {
} }
// ensure that only one request is made per URL // ensure that only one request is made per URL
let mutex_request = Arc::clone(services().media.url_preview_mutex.write().await.entry(url.to_owned()).or_default()); let mutex_request = Arc::clone(
services()
.media
.url_preview_mutex
.write()
.await
.entry(url.to_owned())
.or_default(),
);
let _request_lock = mutex_request.lock().await; let _request_lock = mutex_request.lock().await;
match services().media.get_url_preview(url).await { match services().media.get_url_preview(url).await {
@ -795,7 +839,10 @@ fn url_preview_allowed(url_str: &str) -> bool {
}, },
}; };
if ["http", "https"].iter().all(|&scheme| scheme != url.scheme().to_lowercase()) { if ["http", "https"]
.iter()
.all(|&scheme| scheme != url.scheme().to_lowercase())
{
debug!("Ignoring non-HTTP/HTTPS URL to preview: {}", url); debug!("Ignoring non-HTTP/HTTPS URL to preview: {}", url);
return false; return false;
} }
@ -826,12 +873,18 @@ fn url_preview_allowed(url_str: &str) -> bool {
return true; return true;
} }
if allowlist_domain_contains.iter().any(|domain_s| domain_s.contains(&host.clone())) { if allowlist_domain_contains
.iter()
.any(|domain_s| domain_s.contains(&host.clone()))
{
debug!("Host {} is allowed by url_preview_domain_contains_allowlist (check 2/3)", &host); debug!("Host {} is allowed by url_preview_domain_contains_allowlist (check 2/3)", &host);
return true; return true;
} }
if allowlist_url_contains.iter().any(|url_s| url.to_string().contains(&url_s.to_string())) { if allowlist_url_contains
.iter()
.any(|url_s| url.to_string().contains(&url_s.to_string()))
{
debug!("URL {} is allowed by url_preview_url_contains_allowlist (check 3/3)", &host); debug!("URL {} is allowed by url_preview_url_contains_allowlist (check 3/3)", &host);
return true; return true;
} }
@ -850,7 +903,10 @@ fn url_preview_allowed(url_str: &str) -> bool {
return true; return true;
} }
if allowlist_domain_contains.iter().any(|domain_s| domain_s.contains(&root_domain.to_owned())) { if allowlist_domain_contains
.iter()
.any(|domain_s| domain_s.contains(&root_domain.to_owned()))
{
debug!( debug!(
"Root domain {} is allowed by url_preview_domain_contains_allowlist (check 2/3)", "Root domain {} is allowed by url_preview_domain_contains_allowlist (check 2/3)",
&root_domain &root_domain

View file

@ -242,8 +242,15 @@ pub async fn kick_user_route(body: Ruma<kick_user::v3::Request>) -> Result<kick_
event.membership = MembershipState::Leave; event.membership = MembershipState::Leave;
event.reason.clone_from(&body.reason); event.reason.clone_from(&body.reason);
let mutex_state = let mutex_state = Arc::clone(
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(body.room_id.clone()).or_default()); services()
.globals
.roomid_mutex_state
.write()
.await
.entry(body.room_id.clone())
.or_default(),
);
let state_lock = mutex_state.lock().await; let state_lock = mutex_state.lock().await;
services() services()
@ -293,8 +300,14 @@ pub async fn ban_user_route(body: Ruma<ban_user::v3::Request>) -> Result<ban_use
serde_json::from_str(event.content.get()) serde_json::from_str(event.content.get())
.map(|event: RoomMemberEventContent| RoomMemberEventContent { .map(|event: RoomMemberEventContent| RoomMemberEventContent {
membership: MembershipState::Ban, membership: MembershipState::Ban,
displayname: services().users.displayname(&body.user_id).unwrap_or_default(), displayname: services()
avatar_url: services().users.avatar_url(&body.user_id).unwrap_or_default(), .users
.displayname(&body.user_id)
.unwrap_or_default(),
avatar_url: services()
.users
.avatar_url(&body.user_id)
.unwrap_or_default(),
blurhash: services().users.blurhash(&body.user_id).unwrap_or_default(), blurhash: services().users.blurhash(&body.user_id).unwrap_or_default(),
reason: body.reason.clone(), reason: body.reason.clone(),
..event ..event
@ -303,8 +316,15 @@ pub async fn ban_user_route(body: Ruma<ban_user::v3::Request>) -> Result<ban_use
}, },
)?; )?;
let mutex_state = let mutex_state = Arc::clone(
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(body.room_id.clone()).or_default()); services()
.globals
.roomid_mutex_state
.write()
.await
.entry(body.room_id.clone())
.or_default(),
);
let state_lock = mutex_state.lock().await; let state_lock = mutex_state.lock().await;
services() services()
@ -349,8 +369,15 @@ pub async fn unban_user_route(body: Ruma<unban_user::v3::Request>) -> Result<unb
event.membership = MembershipState::Leave; event.membership = MembershipState::Leave;
event.reason.clone_from(&body.reason); event.reason.clone_from(&body.reason);
let mutex_state = let mutex_state = Arc::clone(
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(body.room_id.clone()).or_default()); services()
.globals
.roomid_mutex_state
.write()
.await
.entry(body.room_id.clone())
.or_default(),
);
let state_lock = mutex_state.lock().await; let state_lock = mutex_state.lock().await;
services() services()
@ -387,7 +414,10 @@ pub async fn unban_user_route(body: Ruma<unban_user::v3::Request>) -> Result<unb
pub async fn forget_room_route(body: Ruma<forget_room::v3::Request>) -> Result<forget_room::v3::Response> { pub async fn forget_room_route(body: Ruma<forget_room::v3::Request>) -> Result<forget_room::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services().rooms.state_cache.forget(&body.room_id, sender_user)?; services()
.rooms
.state_cache
.forget(&body.room_id, sender_user)?;
Ok(forget_room::v3::Response::new()) Ok(forget_room::v3::Response::new())
} }
@ -399,7 +429,12 @@ pub async fn joined_rooms_route(body: Ruma<joined_rooms::v3::Request>) -> Result
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
Ok(joined_rooms::v3::Response { Ok(joined_rooms::v3::Response {
joined_rooms: services().rooms.state_cache.rooms_joined(sender_user).filter_map(Result::ok).collect(), joined_rooms: services()
.rooms
.state_cache
.rooms_joined(sender_user)
.filter_map(Result::ok)
.collect(),
}) })
} }
@ -414,7 +449,11 @@ pub async fn get_member_events_route(
) -> Result<get_member_events::v3::Response> { ) -> Result<get_member_events::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if !services().rooms.state_accessor.user_can_see_state_events(sender_user, &body.room_id)? { if !services()
.rooms
.state_accessor
.user_can_see_state_events(sender_user, &body.room_id)?
{
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
"You don't have permission to view this room.", "You don't have permission to view this room.",
@ -443,7 +482,11 @@ pub async fn get_member_events_route(
pub async fn joined_members_route(body: Ruma<joined_members::v3::Request>) -> Result<joined_members::v3::Response> { pub async fn joined_members_route(body: Ruma<joined_members::v3::Request>) -> Result<joined_members::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if !services().rooms.state_accessor.user_can_see_state_events(sender_user, &body.room_id)? { if !services()
.rooms
.state_accessor
.user_can_see_state_events(sender_user, &body.room_id)?
{
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
"You don't have permission to view this room.", "You don't have permission to view this room.",
@ -451,7 +494,12 @@ pub async fn joined_members_route(body: Ruma<joined_members::v3::Request>) -> Re
} }
let mut joined = BTreeMap::new(); let mut joined = BTreeMap::new();
for user_id in services().rooms.state_cache.room_members(&body.room_id).filter_map(Result::ok) { for user_id in services()
.rooms
.state_cache
.room_members(&body.room_id)
.filter_map(Result::ok)
{
let display_name = services().users.displayname(&user_id)?; let display_name = services().users.displayname(&user_id)?;
let avatar_url = services().users.avatar_url(&user_id)?; let avatar_url = services().users.avatar_url(&user_id)?;
@ -475,12 +523,23 @@ pub(crate) async fn join_room_by_id_helper(
) -> Result<join_room_by_id::v3::Response> { ) -> Result<join_room_by_id::v3::Response> {
let sender_user = sender_user.expect("user is authenticated"); let sender_user = sender_user.expect("user is authenticated");
let mutex_state = let mutex_state = Arc::clone(
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.to_owned()).or_default()); services()
.globals
.roomid_mutex_state
.write()
.await
.entry(room_id.to_owned())
.or_default(),
);
let state_lock = mutex_state.lock().await; let state_lock = mutex_state.lock().await;
// Ask a remote server if we are not participating in this room // Ask a remote server if we are not participating in this room
if !services().rooms.state_cache.server_in_room(services().globals.server_name(), room_id)? { if !services()
.rooms
.state_cache
.server_in_room(services().globals.server_name(), room_id)?
{
info!("Joining {room_id} over federation."); info!("Joining {room_id} over federation.");
let (make_join_response, remote_server) = make_join_request(sender_user, room_id, servers).await?; let (make_join_response, remote_server) = make_join_request(sender_user, room_id, servers).await?;
@ -488,7 +547,14 @@ pub(crate) async fn join_room_by_id_helper(
info!("make_join finished"); info!("make_join finished");
let room_version_id = match make_join_response.room_version { let room_version_id = match make_join_response.room_version {
Some(room_version) if services().globals.supported_room_versions().contains(&room_version) => room_version, Some(room_version)
if services()
.globals
.supported_room_versions()
.contains(&room_version) =>
{
room_version
},
_ => return Err(Error::BadServerResponse("Room version is not supported")), _ => return Err(Error::BadServerResponse("Room version is not supported")),
}; };
@ -497,7 +563,11 @@ pub(crate) async fn join_room_by_id_helper(
let join_authorized_via_users_server = join_event_stub let join_authorized_via_users_server = join_event_stub
.get("content") .get("content")
.map(|s| s.as_object()?.get("join_authorised_via_users_server")?.as_str()) .map(|s| {
s.as_object()?
.get("join_authorised_via_users_server")?
.as_str()
})
.and_then(|s| OwnedUserId::try_from(s.unwrap_or_default()).ok()); .and_then(|s| OwnedUserId::try_from(s.unwrap_or_default()).ok());
// TODO: Is origin needed? // TODO: Is origin needed?
@ -508,7 +578,9 @@ pub(crate) async fn join_room_by_id_helper(
join_event_stub.insert( join_event_stub.insert(
"origin_server_ts".to_owned(), "origin_server_ts".to_owned(),
CanonicalJsonValue::Integer( CanonicalJsonValue::Integer(
utils::millis_since_unix_epoch().try_into().expect("Timestamp is valid js_int value"), utils::millis_since_unix_epoch()
.try_into()
.expect("Timestamp is valid js_int value"),
), ),
); );
join_event_stub.insert( join_event_stub.insert(
@ -691,10 +763,15 @@ pub(crate) async fn join_room_by_id_helper(
Error::BadServerResponse("Invalid PDU in send_join response.") Error::BadServerResponse("Invalid PDU in send_join response.")
})?; })?;
services().rooms.outlier.add_pdu_outlier(&event_id, &value)?; services()
.rooms
.outlier
.add_pdu_outlier(&event_id, &value)?;
if let Some(state_key) = &pdu.state_key { if let Some(state_key) = &pdu.state_key {
let shortstatekey = let shortstatekey = services()
services().rooms.short.get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key)?; .rooms
.short
.get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key)?;
state.insert(shortstatekey, pdu.event_id.clone()); state.insert(shortstatekey, pdu.event_id.clone());
} }
} }
@ -711,7 +788,10 @@ pub(crate) async fn join_room_by_id_helper(
Err(_) => continue, Err(_) => continue,
}; };
services().rooms.outlier.add_pdu_outlier(&event_id, &value)?; services()
.rooms
.outlier
.add_pdu_outlier(&event_id, &value)?;
} }
info!("Running send_join auth check"); info!("Running send_join auth check");
@ -725,8 +805,13 @@ pub(crate) async fn join_room_by_id_helper(
.rooms .rooms
.timeline .timeline
.get_pdu( .get_pdu(
state state.get(
.get(&services().rooms.short.get_or_create_shortstatekey(&k.to_string().into(), s).ok()?)?, &services()
.rooms
.short
.get_or_create_shortstatekey(&k.to_string().into(), s)
.ok()?,
)?,
) )
.ok()? .ok()?
}, },
@ -746,12 +831,21 @@ pub(crate) async fn join_room_by_id_helper(
Arc::new( Arc::new(
state state
.into_iter() .into_iter()
.map(|(k, id)| services().rooms.state_compressor.compress_state_event(k, &id)) .map(|(k, id)| {
services()
.rooms
.state_compressor
.compress_state_event(k, &id)
})
.collect::<Result<_>>()?, .collect::<Result<_>>()?,
), ),
)?; )?;
services().rooms.state.force_state(room_id, statehash_before_join, new, removed, &state_lock).await?; services()
.rooms
.state
.force_state(room_id, statehash_before_join, new, removed, &state_lock)
.await?;
info!("Updating joined counts for new room"); info!("Updating joined counts for new room");
services().rooms.state_cache.update_joined_count(room_id)?; services().rooms.state_cache.update_joined_count(room_id)?;
@ -776,14 +870,23 @@ pub(crate) async fn join_room_by_id_helper(
info!("Setting final room state for new room"); info!("Setting final room state for new room");
// We set the room state after inserting the pdu, so that we never have a moment // We set the room state after inserting the pdu, so that we never have a moment
// in time where events in the current room state do not exist // in time where events in the current room state do not exist
services().rooms.state.set_room_state(room_id, statehash_after_join, &state_lock)?; services()
.rooms
.state
.set_room_state(room_id, statehash_after_join, &state_lock)?;
} else { } else {
info!("We can join locally"); info!("We can join locally");
let join_rules_event = let join_rules_event =
services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomJoinRules, "")?; services()
.rooms
.state_accessor
.room_state_get(room_id, &StateEventType::RoomJoinRules, "")?;
let power_levels_event = let power_levels_event =
services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomPowerLevels, "")?; services()
.rooms
.state_accessor
.room_state_get(room_id, &StateEventType::RoomPowerLevels, "")?;
let join_rules_event_content: Option<RoomJoinRulesEventContent> = join_rules_event let join_rules_event_content: Option<RoomJoinRulesEventContent> = join_rules_event
.as_ref() .as_ref()
@ -821,7 +924,12 @@ pub(crate) async fn join_room_by_id_helper(
let authorized_user = restriction_rooms let authorized_user = restriction_rooms
.iter() .iter()
.find_map(|restriction_room_id| { .find_map(|restriction_room_id| {
if !services().rooms.state_cache.is_joined(sender_user, restriction_room_id).ok()? { if !services()
.rooms
.state_cache
.is_joined(sender_user, restriction_room_id)
.ok()?
{
return None; return None;
} }
let authorized_user = power_levels_event_content let authorized_user = power_levels_event_content
@ -888,7 +996,10 @@ pub(crate) async fn join_room_by_id_helper(
}; };
if !restriction_rooms.is_empty() if !restriction_rooms.is_empty()
&& servers.iter().filter(|s| *s != services().globals.server_name()).count() > 0 && servers
.iter()
.filter(|s| *s != services().globals.server_name())
.count() > 0
{ {
info!( info!(
"We couldn't do the join locally, maybe federation can help to satisfy the restricted join \ "We couldn't do the join locally, maybe federation can help to satisfy the restricted join \
@ -897,7 +1008,12 @@ pub(crate) async fn join_room_by_id_helper(
let (make_join_response, remote_server) = make_join_request(sender_user, room_id, servers).await?; let (make_join_response, remote_server) = make_join_request(sender_user, room_id, servers).await?;
let room_version_id = match make_join_response.room_version { let room_version_id = match make_join_response.room_version {
Some(room_version_id) if services().globals.supported_room_versions().contains(&room_version_id) => { Some(room_version_id)
if services()
.globals
.supported_room_versions()
.contains(&room_version_id) =>
{
room_version_id room_version_id
}, },
_ => return Err(Error::BadServerResponse("Room version is not supported")), _ => return Err(Error::BadServerResponse("Room version is not supported")),
@ -906,7 +1022,11 @@ pub(crate) async fn join_room_by_id_helper(
.map_err(|_| Error::BadServerResponse("Invalid make_join event json received from server."))?; .map_err(|_| Error::BadServerResponse("Invalid make_join event json received from server."))?;
let join_authorized_via_users_server = join_event_stub let join_authorized_via_users_server = join_event_stub
.get("content") .get("content")
.map(|s| s.as_object()?.get("join_authorised_via_users_server")?.as_str()) .map(|s| {
s.as_object()?
.get("join_authorised_via_users_server")?
.as_str()
})
.and_then(|s| OwnedUserId::try_from(s.unwrap_or_default()).ok()); .and_then(|s| OwnedUserId::try_from(s.unwrap_or_default()).ok());
// TODO: Is origin needed? // TODO: Is origin needed?
join_event_stub.insert( join_event_stub.insert(
@ -916,7 +1036,9 @@ pub(crate) async fn join_room_by_id_helper(
join_event_stub.insert( join_event_stub.insert(
"origin_server_ts".to_owned(), "origin_server_ts".to_owned(),
CanonicalJsonValue::Integer( CanonicalJsonValue::Integer(
utils::millis_since_unix_epoch().try_into().expect("Timestamp is valid js_int value"), utils::millis_since_unix_epoch()
.try_into()
.expect("Timestamp is valid js_int value"),
), ),
); );
join_event_stub.insert( join_event_stub.insert(
@ -1002,7 +1124,11 @@ pub(crate) async fn join_room_by_id_helper(
drop(state_lock); drop(state_lock);
let pub_key_map = RwLock::new(BTreeMap::new()); let pub_key_map = RwLock::new(BTreeMap::new());
services().rooms.event_handler.fetch_required_signing_keys([&signed_value], &pub_key_map).await?; services()
.rooms
.event_handler
.fetch_required_signing_keys([&signed_value], &pub_key_map)
.await?;
services() services()
.rooms .rooms
.event_handler .event_handler
@ -1065,7 +1191,13 @@ async fn validate_and_add_event_id(
.expect("ruma's reference hashes are valid event ids"); .expect("ruma's reference hashes are valid event ids");
let back_off = |id| async { let back_off = |id| async {
match services().globals.bad_event_ratelimiter.write().await.entry(id) { match services()
.globals
.bad_event_ratelimiter
.write()
.await
.entry(id)
{
Entry::Vacant(e) => { Entry::Vacant(e) => {
e.insert((Instant::now(), 1)); e.insert((Instant::now(), 1));
}, },
@ -1073,7 +1205,13 @@ async fn validate_and_add_event_id(
} }
}; };
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().await.get(&event_id) { if let Some((time, tries)) = services()
.globals
.bad_event_ratelimiter
.read()
.await
.get(&event_id)
{
// Exponential backoff // Exponential backoff
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
@ -1110,8 +1248,15 @@ pub(crate) async fn invite_helper(
if user_id.server_name() != services().globals.server_name() { if user_id.server_name() != services().globals.server_name() {
let (pdu, pdu_json, invite_room_state) = { let (pdu, pdu_json, invite_room_state) = {
let mutex_state = let mutex_state = Arc::clone(
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.to_owned()).or_default()); services()
.globals
.roomid_mutex_state
.write()
.await
.entry(room_id.to_owned())
.or_default(),
);
let state_lock = mutex_state.lock().await; let state_lock = mutex_state.lock().await;
let content = to_raw_value(&RoomMemberEventContent { let content = to_raw_value(&RoomMemberEventContent {
@ -1196,7 +1341,11 @@ pub(crate) async fn invite_helper(
) )
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid."))?; .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid."))?;
services().rooms.event_handler.fetch_required_signing_keys([&value], &pub_key_map).await?; services()
.rooms
.event_handler
.fetch_required_signing_keys([&value], &pub_key_map)
.await?;
let pdu_id: Vec<u8> = services() let pdu_id: Vec<u8> = services()
.rooms .rooms
@ -1221,15 +1370,26 @@ pub(crate) async fn invite_helper(
return Ok(()); return Ok(());
} }
if !services().rooms.state_cache.is_joined(sender_user, room_id)? { if !services()
.rooms
.state_cache
.is_joined(sender_user, room_id)?
{
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
"You don't have permission to view this room.", "You don't have permission to view this room.",
)); ));
} }
let mutex_state = let mutex_state = Arc::clone(
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.to_owned()).or_default()); services()
.globals
.roomid_mutex_state
.write()
.await
.entry(room_id.to_owned())
.or_default(),
);
let state_lock = mutex_state.lock().await; let state_lock = mutex_state.lock().await;
services() services()
@ -1270,7 +1430,13 @@ pub async fn leave_all_rooms(user_id: &UserId) -> Result<()> {
.rooms .rooms
.state_cache .state_cache
.rooms_joined(user_id) .rooms_joined(user_id)
.chain(services().rooms.state_cache.rooms_invited(user_id).map(|t| t.map(|(r, _)| r))) .chain(
services()
.rooms
.state_cache
.rooms_invited(user_id)
.map(|t| t.map(|(r, _)| r)),
)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
for room_id in all_rooms { for room_id in all_rooms {
@ -1313,12 +1479,22 @@ pub async fn leave_room(user_id: &UserId, room_id: &RoomId, reason: Option<Strin
) )
.await?; .await?;
} else { } else {
let mutex_state = let mutex_state = Arc::clone(
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.to_owned()).or_default()); services()
.globals
.roomid_mutex_state
.write()
.await
.entry(room_id.to_owned())
.or_default(),
);
let state_lock = mutex_state.lock().await; let state_lock = mutex_state.lock().await;
let member_event = let member_event =
services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())?; services()
.rooms
.state_accessor
.room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())?;
// Fix for broken rooms // Fix for broken rooms
let member_event = match member_event { let member_event = match member_event {
@ -1411,7 +1587,14 @@ async fn remote_leave_room(user_id: &UserId, room_id: &RoomId) -> Result<()> {
let (make_leave_response, remote_server) = make_leave_response_and_server?; let (make_leave_response, remote_server) = make_leave_response_and_server?;
let room_version_id = match make_leave_response.room_version { let room_version_id = match make_leave_response.room_version {
Some(version) if services().globals.supported_room_versions().contains(&version) => version, Some(version)
if services()
.globals
.supported_room_versions()
.contains(&version) =>
{
version
},
_ => return Err(Error::BadServerResponse("Room version is not supported")), _ => return Err(Error::BadServerResponse("Room version is not supported")),
}; };
@ -1426,7 +1609,9 @@ async fn remote_leave_room(user_id: &UserId, room_id: &RoomId) -> Result<()> {
leave_event_stub.insert( leave_event_stub.insert(
"origin_server_ts".to_owned(), "origin_server_ts".to_owned(),
CanonicalJsonValue::Integer( CanonicalJsonValue::Integer(
utils::millis_since_unix_epoch().try_into().expect("Timestamp is valid js_int value"), utils::millis_since_unix_epoch()
.try_into()
.expect("Timestamp is valid js_int value"),
), ),
); );

View file

@ -32,8 +32,15 @@ pub async fn send_message_event_route(
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_deref(); let sender_device = body.sender_device.as_deref();
let mutex_state = let mutex_state = Arc::clone(
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(body.room_id.clone()).or_default()); services()
.globals
.roomid_mutex_state
.write()
.await
.entry(body.room_id.clone())
.or_default(),
);
let state_lock = mutex_state.lock().await; let state_lock = mutex_state.lock().await;
// Forbid m.room.encrypted if encryption is disabled // Forbid m.room.encrypted if encryption is disabled
@ -90,7 +97,10 @@ pub async fn send_message_event_route(
}; };
// Check if this is a new transaction id // Check if this is a new transaction id
if let Some(response) = services().transaction_ids.existing_txnid(sender_user, sender_device, &body.txn_id)? { if let Some(response) = services()
.transaction_ids
.existing_txnid(sender_user, sender_device, &body.txn_id)?
{
// The client might have sent a txnid of the /sendToDevice endpoint // The client might have sent a txnid of the /sendToDevice endpoint
// This txnid has no response associated with it // This txnid has no response associated with it
if response.is_empty() { if response.is_empty() {
@ -130,7 +140,9 @@ pub async fn send_message_event_route(
) )
.await?; .await?;
services().transaction_ids.add_txnid(sender_user, sender_device, &body.txn_id, event_id.as_bytes())?; services()
.transaction_ids
.add_txnid(sender_user, sender_device, &body.txn_id, event_id.as_bytes())?;
drop(state_lock); drop(state_lock);
@ -158,9 +170,16 @@ pub async fn get_message_events_route(
}, },
}; };
let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok()); let to = body
.to
.as_ref()
.and_then(|t| PduCount::try_from_string(t).ok());
services().rooms.lazy_loading.lazy_load_confirm_delivery(sender_user, sender_device, &body.room_id, from).await?; services()
.rooms
.lazy_loading
.lazy_load_confirm_delivery(sender_user, sender_device, &body.room_id, from)
.await?;
let limit = u64::from(body.limit).min(100) as usize; let limit = u64::from(body.limit).min(100) as usize;
@ -208,14 +227,21 @@ pub async fn get_message_events_route(
next_token = events_after.last().map(|(count, _)| count).copied(); next_token = events_after.last().map(|(count, _)| count).copied();
let events_after: Vec<_> = events_after.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect(); let events_after: Vec<_> = events_after
.into_iter()
.map(|(_, pdu)| pdu.to_room_event())
.collect();
resp.start = from.stringify(); resp.start = from.stringify();
resp.end = next_token.map(|count| count.stringify()); resp.end = next_token.map(|count| count.stringify());
resp.chunk = events_after; resp.chunk = events_after;
}, },
ruma::api::Direction::Backward => { ruma::api::Direction::Backward => {
services().rooms.timeline.backfill_if_required(&body.room_id, from).await?; services()
.rooms
.timeline
.backfill_if_required(&body.room_id, from)
.await?;
let events_before: Vec<_> = services() let events_before: Vec<_> = services()
.rooms .rooms
.timeline .timeline
@ -252,7 +278,10 @@ pub async fn get_message_events_route(
next_token = events_before.last().map(|(count, _)| count).copied(); next_token = events_before.last().map(|(count, _)| count).copied();
let events_before: Vec<_> = events_before.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect(); let events_before: Vec<_> = events_before
.into_iter()
.map(|(_, pdu)| pdu.to_room_event())
.collect();
resp.start = from.stringify(); resp.start = from.stringify();
resp.end = next_token.map(|count| count.stringify()); resp.end = next_token.map(|count| count.stringify());

View file

@ -46,10 +46,19 @@ pub async fn get_presence_route(body: Ruma<get_presence::v3::Request>) -> Result
let mut presence_event = None; let mut presence_event = None;
for room_id in services().rooms.user.get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])? { for room_id in services()
.rooms
.user
.get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])?
{
let room_id = room_id?; let room_id = room_id?;
if let Some(presence) = services().rooms.edus.presence.get_presence(&room_id, sender_user)? { if let Some(presence) = services()
.rooms
.edus
.presence
.get_presence(&room_id, sender_user)?
{
presence_event = Some(presence); presence_event = Some(presence);
break; break;
} }
@ -60,7 +69,10 @@ pub async fn get_presence_route(body: Ruma<get_presence::v3::Request>) -> Result
// TODO: Should ruma just use the presenceeventcontent type here? // TODO: Should ruma just use the presenceeventcontent type here?
status_msg: presence.content.status_msg, status_msg: presence.content.status_msg,
currently_active: presence.content.currently_active, currently_active: presence.content.currently_active,
last_active_ago: presence.content.last_active_ago.map(|millis| Duration::from_millis(millis.into())), last_active_ago: presence
.content
.last_active_ago
.map(|millis| Duration::from_millis(millis.into())),
presence: presence.content.presence, presence: presence.content.presence,
}) })
} else { } else {

View file

@ -25,7 +25,10 @@ pub async fn set_displayname_route(
) -> Result<set_display_name::v3::Response> { ) -> Result<set_display_name::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services().users.set_displayname(sender_user, body.displayname.clone()).await?; services()
.users
.set_displayname(sender_user, body.displayname.clone())
.await?;
// Send a new membership event and presence update into all joined rooms // Send a new membership event and presence update into all joined rooms
let all_rooms_joined: Vec<_> = services() let all_rooms_joined: Vec<_> = services()
@ -64,16 +67,31 @@ pub async fn set_displayname_route(
.collect(); .collect();
for (pdu_builder, room_id) in all_rooms_joined { for (pdu_builder, room_id) in all_rooms_joined {
let mutex_state = let mutex_state = Arc::clone(
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.clone()).or_default()); services()
.globals
.roomid_mutex_state
.write()
.await
.entry(room_id.clone())
.or_default(),
);
let state_lock = mutex_state.lock().await; let state_lock = mutex_state.lock().await;
_ = services().rooms.timeline.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock).await; _ = services()
.rooms
.timeline
.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock)
.await;
} }
if services().globals.allow_local_presence() { if services().globals.allow_local_presence() {
// Presence update // Presence update
services().rooms.edus.presence.ping_presence(sender_user, PresenceState::Online)?; services()
.rooms
.edus
.presence
.ping_presence(sender_user, PresenceState::Online)?;
} }
Ok(set_display_name::v3::Response {}) Ok(set_display_name::v3::Response {})
@ -105,9 +123,18 @@ pub async fn get_displayname_route(
services().users.create(&body.user_id, None)?; services().users.create(&body.user_id, None)?;
} }
services().users.set_displayname(&body.user_id, response.displayname.clone()).await?; services()
services().users.set_avatar_url(&body.user_id, response.avatar_url.clone()).await?; .users
services().users.set_blurhash(&body.user_id, response.blurhash.clone()).await?; .set_displayname(&body.user_id, response.displayname.clone())
.await?;
services()
.users
.set_avatar_url(&body.user_id, response.avatar_url.clone())
.await?;
services()
.users
.set_blurhash(&body.user_id, response.blurhash.clone())
.await?;
return Ok(get_display_name::v3::Response { return Ok(get_display_name::v3::Response {
displayname: response.displayname, displayname: response.displayname,
@ -134,9 +161,15 @@ pub async fn get_displayname_route(
pub async fn set_avatar_url_route(body: Ruma<set_avatar_url::v3::Request>) -> Result<set_avatar_url::v3::Response> { pub async fn set_avatar_url_route(body: Ruma<set_avatar_url::v3::Request>) -> Result<set_avatar_url::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services().users.set_avatar_url(sender_user, body.avatar_url.clone()).await?; services()
.users
.set_avatar_url(sender_user, body.avatar_url.clone())
.await?;
services().users.set_blurhash(sender_user, body.blurhash.clone()).await?; services()
.users
.set_blurhash(sender_user, body.blurhash.clone())
.await?;
// Send a new membership event and presence update into all joined rooms // Send a new membership event and presence update into all joined rooms
let all_joined_rooms: Vec<_> = services() let all_joined_rooms: Vec<_> = services()
@ -175,16 +208,31 @@ pub async fn set_avatar_url_route(body: Ruma<set_avatar_url::v3::Request>) -> Re
.collect(); .collect();
for (pdu_builder, room_id) in all_joined_rooms { for (pdu_builder, room_id) in all_joined_rooms {
let mutex_state = let mutex_state = Arc::clone(
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.clone()).or_default()); services()
.globals
.roomid_mutex_state
.write()
.await
.entry(room_id.clone())
.or_default(),
);
let state_lock = mutex_state.lock().await; let state_lock = mutex_state.lock().await;
_ = services().rooms.timeline.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock).await; _ = services()
.rooms
.timeline
.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock)
.await;
} }
if services().globals.allow_local_presence() { if services().globals.allow_local_presence() {
// Presence update // Presence update
services().rooms.edus.presence.ping_presence(sender_user, PresenceState::Online)?; services()
.rooms
.edus
.presence
.ping_presence(sender_user, PresenceState::Online)?;
} }
Ok(set_avatar_url::v3::Response {}) Ok(set_avatar_url::v3::Response {})
@ -214,9 +262,18 @@ pub async fn get_avatar_url_route(body: Ruma<get_avatar_url::v3::Request>) -> Re
services().users.create(&body.user_id, None)?; services().users.create(&body.user_id, None)?;
} }
services().users.set_displayname(&body.user_id, response.displayname.clone()).await?; services()
services().users.set_avatar_url(&body.user_id, response.avatar_url.clone()).await?; .users
services().users.set_blurhash(&body.user_id, response.blurhash.clone()).await?; .set_displayname(&body.user_id, response.displayname.clone())
.await?;
services()
.users
.set_avatar_url(&body.user_id, response.avatar_url.clone())
.await?;
services()
.users
.set_blurhash(&body.user_id, response.blurhash.clone())
.await?;
return Ok(get_avatar_url::v3::Response { return Ok(get_avatar_url::v3::Response {
avatar_url: response.avatar_url, avatar_url: response.avatar_url,
@ -261,9 +318,18 @@ pub async fn get_profile_route(body: Ruma<get_profile::v3::Request>) -> Result<g
services().users.create(&body.user_id, None)?; services().users.create(&body.user_id, None)?;
} }
services().users.set_displayname(&body.user_id, response.displayname.clone()).await?; services()
services().users.set_avatar_url(&body.user_id, response.avatar_url.clone()).await?; .users
services().users.set_blurhash(&body.user_id, response.blurhash.clone()).await?; .set_displayname(&body.user_id, response.displayname.clone())
.await?;
services()
.users
.set_avatar_url(&body.user_id, response.avatar_url.clone())
.await?;
services()
.users
.set_blurhash(&body.user_id, response.blurhash.clone())
.await?;
return Ok(get_profile::v3::Response { return Ok(get_profile::v3::Response {
displayname: response.displayname, displayname: response.displayname,

View file

@ -49,7 +49,10 @@ pub async fn get_pushrule_route(body: Ruma<get_pushrule::v3::Request>) -> Result
.map_err(|_| Error::bad_database("Invalid account data event in db."))? .map_err(|_| Error::bad_database("Invalid account data event in db."))?
.content; .content;
let rule = account_data.global.get(body.kind.clone(), &body.rule_id).map(Into::into); let rule = account_data
.global
.get(body.kind.clone(), &body.rule_id)
.map(Into::into);
if let Some(rule) = rule { if let Some(rule) = rule {
Ok(get_pushrule::v3::Response { Ok(get_pushrule::v3::Response {
@ -83,7 +86,10 @@ pub async fn set_pushrule_route(body: Ruma<set_pushrule::v3::Request>) -> Result
.map_err(|_| Error::bad_database("Invalid account data event in db."))?; .map_err(|_| Error::bad_database("Invalid account data event in db."))?;
if let Err(error) = if let Err(error) =
account_data.content.global.insert(body.rule.clone(), body.after.as_deref(), body.before.as_deref()) account_data
.content
.global
.insert(body.rule.clone(), body.after.as_deref(), body.before.as_deref())
{ {
let err = match error { let err = match error {
InsertPushRuleError::ServerDefaultRuleId => Error::BadRequest( InsertPushRuleError::ServerDefaultRuleId => Error::BadRequest(
@ -178,7 +184,12 @@ pub async fn set_pushrule_actions_route(
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get()) let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))?; .map_err(|_| Error::bad_database("Invalid account data event in db."))?;
if account_data.content.global.set_actions(body.kind.clone(), &body.rule_id, body.actions.clone()).is_err() { if account_data
.content
.global
.set_actions(body.kind.clone(), &body.rule_id, body.actions.clone())
.is_err()
{
return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found.")); return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found."));
} }
@ -249,7 +260,12 @@ pub async fn set_pushrule_enabled_route(
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get()) let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))?; .map_err(|_| Error::bad_database("Invalid account data event in db."))?;
if account_data.content.global.set_enabled(body.kind.clone(), &body.rule_id, body.enabled).is_err() { if account_data
.content
.global
.set_enabled(body.kind.clone(), &body.rule_id, body.enabled)
.is_err()
{
return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found.")); return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found."));
} }
@ -284,7 +300,11 @@ pub async fn delete_pushrule_route(body: Ruma<delete_pushrule::v3::Request>) ->
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get()) let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))?; .map_err(|_| Error::bad_database("Invalid account data event in db."))?;
if let Err(error) = account_data.content.global.remove(body.kind.clone(), &body.rule_id) { if let Err(error) = account_data
.content
.global
.remove(body.kind.clone(), &body.rule_id)
{
let err = match error { let err = match error {
RemovePushRuleError::ServerDefault => { RemovePushRuleError::ServerDefault => {
Error::BadRequest(ErrorKind::InvalidParam, "Cannot delete a server-default pushrule.") Error::BadRequest(ErrorKind::InvalidParam, "Cannot delete a server-default pushrule.")
@ -325,7 +345,9 @@ pub async fn get_pushers_route(body: Ruma<get_pushers::v3::Request>) -> Result<g
pub async fn set_pushers_route(body: Ruma<set_pusher::v3::Request>) -> Result<set_pusher::v3::Response> { pub async fn set_pushers_route(body: Ruma<set_pusher::v3::Request>) -> Result<set_pusher::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services().pusher.set_pusher(sender_user, body.action.clone())?; services()
.pusher
.set_pusher(sender_user, body.action.clone())?;
Ok(set_pusher::v3::Response::default()) Ok(set_pusher::v3::Response::default())
} }

View file

@ -36,7 +36,10 @@ pub async fn set_read_marker_route(body: Ruma<set_read_marker::v3::Request>) ->
} }
if body.private_read_receipt.is_some() || body.read_receipt.is_some() { if body.private_read_receipt.is_some() || body.read_receipt.is_some() {
services().rooms.user.reset_notification_counts(sender_user, &body.room_id)?; services()
.rooms
.user
.reset_notification_counts(sender_user, &body.room_id)?;
} }
if let Some(event) = &body.private_read_receipt { if let Some(event) = &body.private_read_receipt {
@ -54,7 +57,11 @@ pub async fn set_read_marker_route(body: Ruma<set_read_marker::v3::Request>) ->
}, },
PduCount::Normal(c) => c, PduCount::Normal(c) => c,
}; };
services().rooms.edus.read_receipt.private_read_set(&body.room_id, sender_user, count)?; services()
.rooms
.edus
.read_receipt
.private_read_set(&body.room_id, sender_user, count)?;
} }
if let Some(event) = &body.read_receipt { if let Some(event) = &body.read_receipt {
@ -98,7 +105,10 @@ pub async fn create_receipt_route(body: Ruma<create_receipt::v3::Request>) -> Re
&body.receipt_type, &body.receipt_type,
create_receipt::v3::ReceiptType::Read | create_receipt::v3::ReceiptType::ReadPrivate create_receipt::v3::ReceiptType::Read | create_receipt::v3::ReceiptType::ReadPrivate
) { ) {
services().rooms.user.reset_notification_counts(sender_user, &body.room_id)?; services()
.rooms
.user
.reset_notification_counts(sender_user, &body.room_id)?;
} }
match body.receipt_type { match body.receipt_type {
@ -156,7 +166,11 @@ pub async fn create_receipt_route(body: Ruma<create_receipt::v3::Request>) -> Re
}, },
PduCount::Normal(c) => c, PduCount::Normal(c) => c,
}; };
services().rooms.edus.read_receipt.private_read_set(&body.room_id, sender_user, count)?; services()
.rooms
.edus
.read_receipt
.private_read_set(&body.room_id, sender_user, count)?;
}, },
_ => return Err(Error::bad_database("Unsupported receipt type")), _ => return Err(Error::bad_database("Unsupported receipt type")),
} }

View file

@ -17,8 +17,15 @@ pub async fn redact_event_route(body: Ruma<redact_event::v3::Request>) -> Result
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let body = body.body; let body = body.body;
let mutex_state = let mutex_state = Arc::clone(
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(body.room_id.clone()).or_default()); services()
.globals
.roomid_mutex_state
.write()
.await
.entry(body.room_id.clone())
.or_default(),
);
let state_lock = mutex_state.lock().await; let state_lock = mutex_state.lock().await;
let event_id = services() let event_id = services()

View file

@ -19,12 +19,22 @@ pub async fn get_relating_events_with_rel_type_and_event_type_route(
}, },
}; };
let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok()); let to = body
.to
.as_ref()
.and_then(|t| PduCount::try_from_string(t).ok());
// Use limit or else 10, with maximum 100 // Use limit or else 10, with maximum 100
let limit = body.limit.and_then(|u| u32::try_from(u).ok()).map_or(10_usize, |u| u as usize).min(100); let limit = body
.limit
.and_then(|u| u32::try_from(u).ok())
.map_or(10_usize, |u| u as usize)
.min(100);
let res = services().rooms.pdu_metadata.paginate_relations_with_filter( let res = services()
.rooms
.pdu_metadata
.paginate_relations_with_filter(
sender_user, sender_user,
&body.room_id, &body.room_id,
&body.event_id, &body.event_id,
@ -57,12 +67,22 @@ pub async fn get_relating_events_with_rel_type_route(
}, },
}; };
let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok()); let to = body
.to
.as_ref()
.and_then(|t| PduCount::try_from_string(t).ok());
// Use limit or else 10, with maximum 100 // Use limit or else 10, with maximum 100
let limit = body.limit.and_then(|u| u32::try_from(u).ok()).map_or(10_usize, |u| u as usize).min(100); let limit = body
.limit
.and_then(|u| u32::try_from(u).ok())
.map_or(10_usize, |u| u as usize)
.min(100);
let res = services().rooms.pdu_metadata.paginate_relations_with_filter( let res = services()
.rooms
.pdu_metadata
.paginate_relations_with_filter(
sender_user, sender_user,
&body.room_id, &body.room_id,
&body.event_id, &body.event_id,
@ -95,19 +115,20 @@ pub async fn get_relating_events_route(
}, },
}; };
let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok()); let to = body
.to
.as_ref()
.and_then(|t| PduCount::try_from_string(t).ok());
// Use limit or else 10, with maximum 100 // Use limit or else 10, with maximum 100
let limit = body.limit.and_then(|u| u32::try_from(u).ok()).map_or(10_usize, |u| u as usize).min(100); let limit = body
.limit
.and_then(|u| u32::try_from(u).ok())
.map_or(10_usize, |u| u as usize)
.min(100);
services().rooms.pdu_metadata.paginate_relations_with_filter( services()
sender_user, .rooms
&body.room_id, .pdu_metadata
&body.event_id, .paginate_relations_with_filter(sender_user, &body.room_id, &body.event_id, &None, &None, from, to, limit)
&None,
&None,
from,
to,
limit,
)
} }

View file

@ -71,7 +71,9 @@ pub async fn report_event_route(body: Ruma<report_content::v3::Request>) -> Resu
// send admin room message that we received the report with an @room ping for // send admin room message that we received the report with an @room ping for
// urgency // urgency
services().admin.send_message(message::RoomMessageEventContent::text_html( services()
.admin
.send_message(message::RoomMessageEventContent::text_html(
format!( format!(
"@room Report received from: {}\n\nEvent ID: {}\nRoom ID: {}\nSent By: {}\n\nReport Score: {}\nReport \ "@room Report received from: {}\n\nEvent ID: {}\nRoom ID: {}\nSent By: {}\n\nReport Score: {}\nReport \
Reason: {}", Reason: {}",

View file

@ -80,7 +80,11 @@ pub async fn create_room_route(body: Ruma<create_room::v3::Request>) -> Result<c
} }
// apply forbidden room alias checks to custom room IDs too // apply forbidden room alias checks to custom room IDs too
if services().globals.forbidden_alias_names().is_match(&custom_room_id_s) { if services()
.globals
.forbidden_alias_names()
.is_match(&custom_room_id_s)
{
return Err(Error::BadRequest(ErrorKind::Unknown, "Custom room ID is forbidden.")); return Err(Error::BadRequest(ErrorKind::Unknown, "Custom room ID is forbidden."));
} }
@ -110,17 +114,27 @@ pub async fn create_room_route(body: Ruma<create_room::v3::Request>) -> Result<c
services().rooms.short.get_or_create_shortroomid(&room_id)?; services().rooms.short.get_or_create_shortroomid(&room_id)?;
let mutex_state = let mutex_state = Arc::clone(
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.clone()).or_default()); services()
.globals
.roomid_mutex_state
.write()
.await
.entry(room_id.clone())
.or_default(),
);
let state_lock = mutex_state.lock().await; let state_lock = mutex_state.lock().await;
let alias: Option<OwnedRoomAliasId> = body.room_alias_name.as_ref().map_or(Ok(None), |localpart| { let alias: Option<OwnedRoomAliasId> = body
.room_alias_name
.as_ref()
.map_or(Ok(None), |localpart| {
// Basic checks on the room alias validity // Basic checks on the room alias validity
if localpart.contains(':') { if localpart.contains(':') {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"Room alias contained `:` which is not allowed. Please note that this expects a localpart, not the \ "Room alias contained `:` which is not allowed. Please note that this expects a localpart, not \
full room alias.", the full room alias.",
)); ));
} else if localpart.contains(char::is_whitespace) { } else if localpart.contains(char::is_whitespace) {
return Err(Error::BadRequest( return Err(Error::BadRequest(
@ -144,7 +158,11 @@ pub async fn create_room_route(body: Ruma<create_room::v3::Request>) -> Result<c
} }
// check if room alias is forbidden // check if room alias is forbidden
if services().globals.forbidden_alias_names().is_match(localpart) { if services()
.globals
.forbidden_alias_names()
.is_match(localpart)
{
return Err(Error::BadRequest(ErrorKind::Unknown, "Room alias name is forbidden.")); return Err(Error::BadRequest(ErrorKind::Unknown, "Room alias name is forbidden."));
} }
@ -154,7 +172,12 @@ pub async fn create_room_route(body: Ruma<create_room::v3::Request>) -> Result<c
Error::BadRequest(ErrorKind::InvalidParam, "Invalid room alias specified.") Error::BadRequest(ErrorKind::InvalidParam, "Invalid room alias specified.")
})?; })?;
if services().rooms.alias.resolve_local_alias(&alias)?.is_some() { if services()
.rooms
.alias
.resolve_local_alias(&alias)?
.is_some()
{
Err(Error::BadRequest(ErrorKind::RoomInUse, "Room alias already exists.")) Err(Error::BadRequest(ErrorKind::RoomInUse, "Room alias already exists."))
} else { } else {
Ok(Some(alias)) Ok(Some(alias))
@ -163,7 +186,11 @@ pub async fn create_room_route(body: Ruma<create_room::v3::Request>) -> Result<c
let room_version = match body.room_version.clone() { let room_version = match body.room_version.clone() {
Some(room_version) => { Some(room_version) => {
if services().globals.supported_room_versions().contains(&room_version) { if services()
.globals
.supported_room_versions()
.contains(&room_version)
{
room_version room_version
} else { } else {
return Err(Error::BadRequest( return Err(Error::BadRequest(
@ -177,7 +204,9 @@ pub async fn create_room_route(body: Ruma<create_room::v3::Request>) -> Result<c
let content = match &body.creation_content { let content = match &body.creation_content {
Some(content) => { Some(content) => {
let mut content = content.deserialize_as::<CanonicalJsonObject>().map_err(|e| { let mut content = content
.deserialize_as::<CanonicalJsonObject>()
.map_err(|e| {
error!("Failed to deserialise content as canonical JSON: {}", e); error!("Failed to deserialise content as canonical JSON: {}", e);
Error::bad_database("Failed to deserialise content as canonical JSON.") Error::bad_database("Failed to deserialise content as canonical JSON.")
})?; })?;
@ -256,8 +285,11 @@ pub async fn create_room_route(body: Ruma<create_room::v3::Request>) -> Result<c
}; };
// Validate creation content // Validate creation content
let de_result = let de_result = serde_json::from_str::<CanonicalJsonObject>(
serde_json::from_str::<CanonicalJsonObject>(to_raw_value(&content).expect("Invalid creation content").get()); to_raw_value(&content)
.expect("Invalid creation content")
.get(),
);
if de_result.is_err() { if de_result.is_err() {
return Err(Error::BadRequest(ErrorKind::BadJson, "Invalid creation content")); return Err(Error::BadRequest(ErrorKind::BadJson, "Invalid creation content"));
@ -463,7 +495,11 @@ pub async fn create_room_route(body: Ruma<create_room::v3::Request>) -> Result<c
continue; continue;
} }
services().rooms.timeline.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock).await?; services()
.rooms
.timeline
.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock)
.await?;
} }
// 7. Events implied by name and topic // 7. Events implied by name and topic
@ -538,12 +574,20 @@ pub async fn create_room_route(body: Ruma<create_room::v3::Request>) -> Result<c
pub async fn get_room_event_route(body: Ruma<get_room_event::v3::Request>) -> Result<get_room_event::v3::Response> { pub async fn get_room_event_route(body: Ruma<get_room_event::v3::Request>) -> Result<get_room_event::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let event = services().rooms.timeline.get_pdu(&body.event_id)?.ok_or_else(|| { let event = services()
.rooms
.timeline
.get_pdu(&body.event_id)?
.ok_or_else(|| {
warn!("Event not found, event ID: {:?}", &body.event_id); warn!("Event not found, event ID: {:?}", &body.event_id);
Error::BadRequest(ErrorKind::NotFound, "Event not found.") Error::BadRequest(ErrorKind::NotFound, "Event not found.")
})?; })?;
if !services().rooms.state_accessor.user_can_see_event(sender_user, &event.room_id, &body.event_id)? { if !services()
.rooms
.state_accessor
.user_can_see_event(sender_user, &event.room_id, &body.event_id)?
{
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
"You don't have permission to view this event.", "You don't have permission to view this event.",
@ -567,7 +611,11 @@ pub async fn get_room_event_route(body: Ruma<get_room_event::v3::Request>) -> Re
pub async fn get_room_aliases_route(body: Ruma<aliases::v3::Request>) -> Result<aliases::v3::Response> { pub async fn get_room_aliases_route(body: Ruma<aliases::v3::Request>) -> Result<aliases::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if !services().rooms.state_accessor.user_can_see_state_events(sender_user, &body.room_id)? { if !services()
.rooms
.state_accessor
.user_can_see_state_events(sender_user, &body.room_id)?
{
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
"You don't have permission to view this room.", "You don't have permission to view this room.",
@ -575,7 +623,12 @@ pub async fn get_room_aliases_route(body: Ruma<aliases::v3::Request>) -> Result<
} }
Ok(aliases::v3::Response { Ok(aliases::v3::Response {
aliases: services().rooms.alias.local_aliases_for_room(&body.room_id).filter_map(Result::ok).collect(), aliases: services()
.rooms
.alias
.local_aliases_for_room(&body.room_id)
.filter_map(Result::ok)
.collect(),
}) })
} }
@ -592,7 +645,11 @@ pub async fn get_room_aliases_route(body: Ruma<aliases::v3::Request>) -> Result<
pub async fn upgrade_room_route(body: Ruma<upgrade_room::v3::Request>) -> Result<upgrade_room::v3::Response> { pub async fn upgrade_room_route(body: Ruma<upgrade_room::v3::Request>) -> Result<upgrade_room::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if !services().globals.supported_room_versions().contains(&body.new_version) { if !services()
.globals
.supported_room_versions()
.contains(&body.new_version)
{
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::UnsupportedRoomVersion, ErrorKind::UnsupportedRoomVersion,
"This server does not support that room version.", "This server does not support that room version.",
@ -601,10 +658,20 @@ pub async fn upgrade_room_route(body: Ruma<upgrade_room::v3::Request>) -> Result
// Create a replacement room // Create a replacement room
let replacement_room = RoomId::new(services().globals.server_name()); let replacement_room = RoomId::new(services().globals.server_name());
services().rooms.short.get_or_create_shortroomid(&replacement_room)?; services()
.rooms
.short
.get_or_create_shortroomid(&replacement_room)?;
let mutex_state = let mutex_state = Arc::clone(
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(body.room_id.clone()).or_default()); services()
.globals
.roomid_mutex_state
.write()
.await
.entry(body.room_id.clone())
.or_default(),
);
let state_lock = mutex_state.lock().await; let state_lock = mutex_state.lock().await;
// Send a m.room.tombstone event to the old room to indicate that it is not // Send a m.room.tombstone event to the old room to indicate that it is not
@ -633,8 +700,15 @@ pub async fn upgrade_room_route(body: Ruma<upgrade_room::v3::Request>) -> Result
// Change lock to replacement room // Change lock to replacement room
drop(state_lock); drop(state_lock);
let mutex_state = let mutex_state = Arc::clone(
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(replacement_room.clone()).or_default()); services()
.globals
.roomid_mutex_state
.write()
.await
.entry(replacement_room.clone())
.or_default(),
);
let state_lock = mutex_state.lock().await; let state_lock = mutex_state.lock().await;
// Get the old room creation event // Get the old room creation event
@ -704,7 +778,9 @@ pub async fn upgrade_room_route(body: Ruma<upgrade_room::v3::Request>) -> Result
// Validate creation event content // Validate creation event content
let de_result = serde_json::from_str::<CanonicalJsonObject>( let de_result = serde_json::from_str::<CanonicalJsonObject>(
to_raw_value(&create_event_content).expect("Error forming creation event").get(), to_raw_value(&create_event_content)
.expect("Error forming creation event")
.get(),
); );
if de_result.is_err() { if de_result.is_err() {
@ -771,7 +847,11 @@ pub async fn upgrade_room_route(body: Ruma<upgrade_room::v3::Request>) -> Result
// Replicate transferable state events to the new room // Replicate transferable state events to the new room
for event_type in transferable_state_events { for event_type in transferable_state_events {
let event_content = match services().rooms.state_accessor.room_state_get(&body.room_id, &event_type, "")? { let event_content = match services()
.rooms
.state_accessor
.room_state_get(&body.room_id, &event_type, "")?
{
Some(v) => v.content.clone(), Some(v) => v.content.clone(),
None => continue, // Skipping missing events. None => continue, // Skipping missing events.
}; };
@ -795,8 +875,16 @@ pub async fn upgrade_room_route(body: Ruma<upgrade_room::v3::Request>) -> Result
} }
// Moves any local aliases to the new room // Moves any local aliases to the new room
for alias in services().rooms.alias.local_aliases_for_room(&body.room_id).filter_map(Result::ok) { for alias in services()
services().rooms.alias.set_alias(&alias, &replacement_room)?; .rooms
.alias
.local_aliases_for_room(&body.room_id)
.filter_map(Result::ok)
{
services()
.rooms
.alias
.set_alias(&alias, &replacement_room)?;
} }
// Get the old room power levels // Get the old room power levels

View file

@ -29,10 +29,14 @@ pub async fn search_events_route(body: Ruma<search_events::v3::Request>) -> Resu
let filter = &search_criteria.filter; let filter = &search_criteria.filter;
let include_state = &search_criteria.include_state; let include_state = &search_criteria.include_state;
let room_ids = filter let room_ids = filter.rooms.clone().unwrap_or_else(|| {
services()
.rooms .rooms
.clone() .state_cache
.unwrap_or_else(|| services().rooms.state_cache.rooms_joined(sender_user).filter_map(Result::ok).collect()); .rooms_joined(sender_user)
.filter_map(Result::ok)
.collect()
});
// Use limit or else 10, with maximum 100 // Use limit or else 10, with maximum 100
let limit = filter.limit.map_or(10, u64::from).min(100) as usize; let limit = filter.limit.map_or(10, u64::from).min(100) as usize;
@ -41,7 +45,11 @@ pub async fn search_events_route(body: Ruma<search_events::v3::Request>) -> Resu
if include_state.is_some_and(|include_state| include_state) { if include_state.is_some_and(|include_state| include_state) {
for room_id in &room_ids { for room_id in &room_ids {
if !services().rooms.state_cache.is_joined(sender_user, room_id)? { if !services()
.rooms
.state_cache
.is_joined(sender_user, room_id)?
{
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
"You don't have permission to view this room.", "You don't have permission to view this room.",
@ -49,7 +57,11 @@ pub async fn search_events_route(body: Ruma<search_events::v3::Request>) -> Resu
} }
// check if sender_user can see state events // check if sender_user can see state events
if services().rooms.state_accessor.user_can_see_state_events(sender_user, room_id)? { if services()
.rooms
.state_accessor
.user_can_see_state_events(sender_user, room_id)?
{
let room_state = services() let room_state = services()
.rooms .rooms
.state_accessor .state_accessor
@ -74,14 +86,22 @@ pub async fn search_events_route(body: Ruma<search_events::v3::Request>) -> Resu
let mut searches = Vec::new(); let mut searches = Vec::new();
for room_id in &room_ids { for room_id in &room_ids {
if !services().rooms.state_cache.is_joined(sender_user, room_id)? { if !services()
.rooms
.state_cache
.is_joined(sender_user, room_id)?
{
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
"You don't have permission to view this room.", "You don't have permission to view this room.",
)); ));
} }
if let Some(search) = services().rooms.search.search_pdus(room_id, &search_criteria.search_term)? { if let Some(search) = services()
.rooms
.search
.search_pdus(room_id, &search_criteria.search_term)?
{
searches.push(search.0.peekable()); searches.push(search.0.peekable());
} }
} }

View file

@ -101,7 +101,11 @@ pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Re
return Err(Error::BadServerResponse("could not hash")); return Err(Error::BadServerResponse("could not hash"));
}; };
let hash_matches = services().globals.argon.verify_password(password.as_bytes(), &parsed_hash).is_ok(); let hash_matches = services()
.globals
.argon
.verify_password(password.as_bytes(), &parsed_hash)
.is_ok();
if !hash_matches { if !hash_matches {
return Err(Error::BadRequest(ErrorKind::Forbidden, "Wrong username or password.")); return Err(Error::BadRequest(ErrorKind::Forbidden, "Wrong username or password."));
@ -174,25 +178,37 @@ pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Re
}; };
// Generate new device id if the user didn't specify one // Generate new device id if the user didn't specify one
let device_id = body.device_id.clone().unwrap_or_else(|| utils::random_string(DEVICE_ID_LENGTH).into()); let device_id = body
.device_id
.clone()
.unwrap_or_else(|| utils::random_string(DEVICE_ID_LENGTH).into());
// Generate a new token for the device // Generate a new token for the device
let token = utils::random_string(TOKEN_LENGTH); let token = utils::random_string(TOKEN_LENGTH);
// Determine if device_id was provided and exists in the db for this user // Determine if device_id was provided and exists in the db for this user
let device_exists = body.device_id.as_ref().map_or(false, |device_id| { let device_exists = body.device_id.as_ref().map_or(false, |device_id| {
services().users.all_device_ids(&user_id).any(|x| x.as_ref().map_or(false, |v| v == device_id)) services()
.users
.all_device_ids(&user_id)
.any(|x| x.as_ref().map_or(false, |v| v == device_id))
}); });
if device_exists { if device_exists {
services().users.set_token(&user_id, &device_id, &token)?; services().users.set_token(&user_id, &device_id, &token)?;
} else { } else {
services().users.create_device(&user_id, &device_id, &token, body.initial_device_display_name.clone())?; services()
.users
.create_device(&user_id, &device_id, &token, body.initial_device_display_name.clone())?;
} }
// send client well-known if specified so the client knows to reconfigure itself // send client well-known if specified so the client knows to reconfigure itself
let client_discovery_info = DiscoveryInfo::new(HomeserverInfo::new( let client_discovery_info = DiscoveryInfo::new(HomeserverInfo::new(
services().globals.well_known_client().to_owned().unwrap_or_default(), services()
.globals
.well_known_client()
.to_owned()
.unwrap_or_default(),
)); ));
info!("{} logged in", user_id); info!("{} logged in", user_id);

View file

@ -14,11 +14,20 @@ use crate::{service::rooms::spaces::PagnationToken, services, Error, Result, Rum
pub async fn get_hierarchy_route(body: Ruma<get_hierarchy::v1::Request>) -> Result<get_hierarchy::v1::Response> { pub async fn get_hierarchy_route(body: Ruma<get_hierarchy::v1::Request>) -> Result<get_hierarchy::v1::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let limit = body.limit.unwrap_or_else(|| UInt::from(10_u32)).min(UInt::from(100_u32)); let limit = body
.limit
.unwrap_or_else(|| UInt::from(10_u32))
.min(UInt::from(100_u32));
let max_depth = body.max_depth.unwrap_or_else(|| UInt::from(3_u32)).min(UInt::from(10_u32)); let max_depth = body
.max_depth
.unwrap_or_else(|| UInt::from(3_u32))
.min(UInt::from(10_u32));
let key = body.from.as_ref().and_then(|s| PagnationToken::from_str(s).ok()); let key = body
.from
.as_ref()
.and_then(|s| PagnationToken::from_str(s).ok());
// Should prevent unexpeded behaviour in (bad) clients // Should prevent unexpeded behaviour in (bad) clients
if let Some(ref token) = key { if let Some(ref token) = key {

View file

@ -86,7 +86,11 @@ pub async fn get_state_events_route(
) -> Result<get_state_events::v3::Response> { ) -> Result<get_state_events::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if !services().rooms.state_accessor.user_can_see_state_events(sender_user, &body.room_id)? { if !services()
.rooms
.state_accessor
.user_can_see_state_events(sender_user, &body.room_id)?
{
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
"You don't have permission to view the room state.", "You don't have permission to view the room state.",
@ -118,21 +122,30 @@ pub async fn get_state_events_for_key_route(
) -> Result<get_state_events_for_key::v3::Response> { ) -> Result<get_state_events_for_key::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if !services().rooms.state_accessor.user_can_see_state_events(sender_user, &body.room_id)? { if !services()
.rooms
.state_accessor
.user_can_see_state_events(sender_user, &body.room_id)?
{
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
"You don't have permission to view the room state.", "You don't have permission to view the room state.",
)); ));
} }
let event = let event = services()
services().rooms.state_accessor.room_state_get(&body.room_id, &body.event_type, &body.state_key)?.ok_or_else( .rooms
|| { .state_accessor
.room_state_get(&body.room_id, &body.event_type, &body.state_key)?
.ok_or_else(|| {
warn!("State event {:?} not found in room {:?}", &body.event_type, &body.room_id); warn!("State event {:?} not found in room {:?}", &body.event_type, &body.room_id);
Error::BadRequest(ErrorKind::NotFound, "State event not found.") Error::BadRequest(ErrorKind::NotFound, "State event not found.")
}, })?;
)?; if body
if body.format.as_ref().is_some_and(|f| f.to_lowercase().eq("event")) { .format
.as_ref()
.is_some_and(|f| f.to_lowercase().eq("event"))
{
Ok(get_state_events_for_key::v3::Response { Ok(get_state_events_for_key::v3::Response {
content: None, content: None,
event: serde_json::from_str(event.to_state_event().json().get()).map_err(|e| { event: serde_json::from_str(event.to_state_event().json().get()).map_err(|e| {
@ -164,20 +177,31 @@ pub async fn get_state_events_for_empty_key_route(
) -> Result<RumaResponse<get_state_events_for_key::v3::Response>> { ) -> Result<RumaResponse<get_state_events_for_key::v3::Response>> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if !services().rooms.state_accessor.user_can_see_state_events(sender_user, &body.room_id)? { if !services()
.rooms
.state_accessor
.user_can_see_state_events(sender_user, &body.room_id)?
{
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
"You don't have permission to view the room state.", "You don't have permission to view the room state.",
)); ));
} }
let event = let event = services()
services().rooms.state_accessor.room_state_get(&body.room_id, &body.event_type, "")?.ok_or_else(|| { .rooms
.state_accessor
.room_state_get(&body.room_id, &body.event_type, "")?
.ok_or_else(|| {
warn!("State event {:?} not found in room {:?}", &body.event_type, &body.room_id); warn!("State event {:?} not found in room {:?}", &body.event_type, &body.room_id);
Error::BadRequest(ErrorKind::NotFound, "State event not found.") Error::BadRequest(ErrorKind::NotFound, "State event not found.")
})?; })?;
if body.format.as_ref().is_some_and(|f| f.to_lowercase().eq("event")) { if body
.format
.as_ref()
.is_some_and(|f| f.to_lowercase().eq("event"))
{
Ok(get_state_events_for_key::v3::Response { Ok(get_state_events_for_key::v3::Response {
content: None, content: None,
event: serde_json::from_str(event.to_state_event().json().get()).map_err(|e| { event: serde_json::from_str(event.to_state_event().json().get()).map_err(|e| {
@ -229,8 +253,15 @@ async fn send_state_event_for_key_helper(
} }
} }
let mutex_state = let mutex_state = Arc::clone(
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.to_owned()).or_default()); services()
.globals
.roomid_mutex_state
.write()
.await
.entry(room_id.to_owned())
.or_default(),
);
let state_lock = mutex_state.lock().await; let state_lock = mutex_state.lock().await;
let event_id = services() let event_id = services()

View file

@ -81,8 +81,13 @@ pub async fn sync_events_route(
let sender_device = body.sender_device.expect("user is authenticated"); let sender_device = body.sender_device.expect("user is authenticated");
let body = body.body; let body = body.body;
let mut rx = let mut rx = match services()
match services().globals.sync_receivers.write().await.entry((sender_user.clone(), sender_device.clone())) { .globals
.sync_receivers
.write()
.await
.entry((sender_user.clone(), sender_device.clone()))
{
Entry::Vacant(v) => { Entry::Vacant(v) => {
let (tx, rx) = tokio::sync::watch::channel(None); let (tx, rx) = tokio::sync::watch::channel(None);
@ -116,7 +121,11 @@ pub async fn sync_events_route(
} }
} }
let result = match rx.borrow().as_ref().expect("When sync channel changes it's always set to some") { let result = match rx
.borrow()
.as_ref()
.expect("When sync channel changes it's always set to some")
{
Ok(response) => Ok(response.clone()), Ok(response) => Ok(response.clone()),
Err(error) => Err(error.to_response()), Err(error) => Err(error.to_response()),
}; };
@ -134,7 +143,13 @@ async fn sync_helper_wrapper(
if let Ok((_, caching_allowed)) = r { if let Ok((_, caching_allowed)) = r {
if !caching_allowed { if !caching_allowed {
match services().globals.sync_receivers.write().await.entry((sender_user, sender_device)) { match services()
.globals
.sync_receivers
.write()
.await
.entry((sender_user, sender_device))
{
Entry::Occupied(o) => { Entry::Occupied(o) => {
// Only remove if the device didn't start a different /sync already // Only remove if the device didn't start a different /sync already
if o.get().0 == since { if o.get().0 == since {
@ -157,7 +172,11 @@ async fn sync_helper(
) -> Result<(sync_events::v3::Response, bool), Error> { ) -> Result<(sync_events::v3::Response, bool), Error> {
// Presence update // Presence update
if services().globals.allow_local_presence() { if services().globals.allow_local_presence() {
services().rooms.edus.presence.ping_presence(&sender_user, body.set_presence)?; services()
.rooms
.edus
.presence
.ping_presence(&sender_user, body.set_presence)?;
} }
// Setup watchers, so if there's no response, we can wait for them // Setup watchers, so if there's no response, we can wait for them
@ -171,7 +190,10 @@ async fn sync_helper(
let filter = match body.filter { let filter = match body.filter {
None => FilterDefinition::default(), None => FilterDefinition::default(),
Some(Filter::FilterDefinition(filter)) => filter, Some(Filter::FilterDefinition(filter)) => filter,
Some(Filter::FilterId(filter_id)) => services().users.get_filter(&sender_user, &filter_id)?.unwrap_or_default(), Some(Filter::FilterId(filter_id)) => services()
.users
.get_filter(&sender_user, &filter_id)?
.unwrap_or_default(),
}; };
let (lazy_load_enabled, lazy_load_send_redundant) = match filter.room.state.lazy_load_options { let (lazy_load_enabled, lazy_load_send_redundant) = match filter.room.state.lazy_load_options {
@ -184,7 +206,11 @@ async fn sync_helper(
let full_state = body.full_state; let full_state = body.full_state;
let mut joined_rooms = BTreeMap::new(); let mut joined_rooms = BTreeMap::new();
let since = body.since.as_ref().and_then(|string| string.parse().ok()).unwrap_or(0); let since = body
.since
.as_ref()
.and_then(|string| string.parse().ok())
.unwrap_or(0);
let sincecount = PduCount::Normal(since); let sincecount = PduCount::Normal(since);
let mut presence_updates = HashMap::new(); let mut presence_updates = HashMap::new();
@ -193,9 +219,18 @@ async fn sync_helper(
let mut device_list_left = HashSet::new(); let mut device_list_left = HashSet::new();
// Look for device list updates of this account // Look for device list updates of this account
device_list_updates.extend(services().users.keys_changed(sender_user.as_ref(), since, None).filter_map(Result::ok)); device_list_updates.extend(
services()
.users
.keys_changed(sender_user.as_ref(), since, None)
.filter_map(Result::ok),
);
let all_joined_rooms = services().rooms.state_cache.rooms_joined(&sender_user).collect::<Vec<_>>(); let all_joined_rooms = services()
.rooms
.state_cache
.rooms_joined(&sender_user)
.collect::<Vec<_>>();
// Coalesce database writes for the remainder of this scope. // Coalesce database writes for the remainder of this scope.
let _cork = services().globals.db.cork_and_flush()?; let _cork = services().globals.db.cork_and_flush()?;
@ -229,7 +264,11 @@ async fn sync_helper(
} }
let mut left_rooms = BTreeMap::new(); let mut left_rooms = BTreeMap::new();
let all_left_rooms: Vec<_> = services().rooms.state_cache.rooms_left(&sender_user).collect(); let all_left_rooms: Vec<_> = services()
.rooms
.state_cache
.rooms_left(&sender_user)
.collect();
for result in all_left_rooms { for result in all_left_rooms {
let (room_id, _) = result?; let (room_id, _) = result?;
@ -237,13 +276,23 @@ async fn sync_helper(
{ {
// Get and drop the lock to wait for remaining operations to finish // Get and drop the lock to wait for remaining operations to finish
let mutex_insert = let mutex_insert = Arc::clone(
Arc::clone(services().globals.roomid_mutex_insert.write().await.entry(room_id.clone()).or_default()); services()
.globals
.roomid_mutex_insert
.write()
.await
.entry(room_id.clone())
.or_default(),
);
let insert_lock = mutex_insert.lock().await; let insert_lock = mutex_insert.lock().await;
drop(insert_lock); drop(insert_lock);
}; };
let left_count = services().rooms.state_cache.get_left_count(&room_id, &sender_user)?; let left_count = services()
.rooms
.state_cache
.get_left_count(&room_id, &sender_user)?;
// Left before last sync // Left before last sync
if Some(since) >= left_count { if Some(since) >= left_count {
@ -255,7 +304,10 @@ async fn sync_helper(
continue; continue;
} }
let since_shortstatehash = services().rooms.user.get_token_shortstatehash(&room_id, since)?; let since_shortstatehash = services()
.rooms
.user
.get_token_shortstatehash(&room_id, since)?;
let since_state_ids = match since_shortstatehash { let since_state_ids = match since_shortstatehash {
Some(s) => services().rooms.state_accessor.state_full_ids(s).await?, Some(s) => services().rooms.state_accessor.state_full_ids(s).await?,
@ -274,7 +326,11 @@ async fn sync_helper(
}, },
}; };
let left_shortstatehash = match services().rooms.state_accessor.pdu_shortstatehash(&left_event_id)? { let left_shortstatehash = match services()
.rooms
.state_accessor
.pdu_shortstatehash(&left_event_id)?
{
Some(s) => s, Some(s) => s,
None => { None => {
error!("Leave event has no state"); error!("Leave event has no state");
@ -282,10 +338,16 @@ async fn sync_helper(
}, },
}; };
let mut left_state_ids = services().rooms.state_accessor.state_full_ids(left_shortstatehash).await?; let mut left_state_ids = services()
.rooms
.state_accessor
.state_full_ids(left_shortstatehash)
.await?;
let leave_shortstatekey = let leave_shortstatekey = services()
services().rooms.short.get_or_create_shortstatekey(&StateEventType::RoomMember, sender_user.as_str())?; .rooms
.short
.get_or_create_shortstatekey(&StateEventType::RoomMember, sender_user.as_str())?;
left_state_ids.insert(leave_shortstatekey, left_event_id); left_state_ids.insert(leave_shortstatekey, left_event_id);
@ -334,19 +396,33 @@ async fn sync_helper(
} }
let mut invited_rooms = BTreeMap::new(); let mut invited_rooms = BTreeMap::new();
let all_invited_rooms: Vec<_> = services().rooms.state_cache.rooms_invited(&sender_user).collect(); let all_invited_rooms: Vec<_> = services()
.rooms
.state_cache
.rooms_invited(&sender_user)
.collect();
for result in all_invited_rooms { for result in all_invited_rooms {
let (room_id, invite_state_events) = result?; let (room_id, invite_state_events) = result?;
{ {
// Get and drop the lock to wait for remaining operations to finish // Get and drop the lock to wait for remaining operations to finish
let mutex_insert = let mutex_insert = Arc::clone(
Arc::clone(services().globals.roomid_mutex_insert.write().await.entry(room_id.clone()).or_default()); services()
.globals
.roomid_mutex_insert
.write()
.await
.entry(room_id.clone())
.or_default(),
);
let insert_lock = mutex_insert.lock().await; let insert_lock = mutex_insert.lock().await;
drop(insert_lock); drop(insert_lock);
}; };
let invite_count = services().rooms.state_cache.get_invite_count(&room_id, &sender_user)?; let invite_count = services()
.rooms
.state_cache
.get_invite_count(&room_id, &sender_user)?;
// Invited before last sync // Invited before last sync
if Some(since) >= invite_count { if Some(since) >= invite_count {
@ -388,7 +464,9 @@ async fn sync_helper(
} }
// Remove all to-device events the device received *last time* // Remove all to-device events the device received *last time*
services().users.remove_to_device_events(&sender_user, &sender_device, since)?; services()
.users
.remove_to_device_events(&sender_user, &sender_device, since)?;
let response = sync_events::v3::Response { let response = sync_events::v3::Response {
next_batch: next_batch_string, next_batch: next_batch_string,
@ -420,9 +498,13 @@ async fn sync_helper(
changed: device_list_updates.into_iter().collect(), changed: device_list_updates.into_iter().collect(),
left: device_list_left.into_iter().collect(), left: device_list_left.into_iter().collect(),
}, },
device_one_time_keys_count: services().users.count_one_time_keys(&sender_user, &sender_device)?, device_one_time_keys_count: services()
.users
.count_one_time_keys(&sender_user, &sender_device)?,
to_device: ToDevice { to_device: ToDevice {
events: services().users.get_to_device_events(&sender_user, &sender_device)?, events: services()
.users
.get_to_device_events(&sender_user, &sender_device)?,
}, },
// Fallback keys are not yet supported // Fallback keys are not yet supported
device_unused_fallback_key_types: None, device_unused_fallback_key_types: None,
@ -453,7 +535,12 @@ async fn process_room_presence_updates(
presence_updates: &mut HashMap<OwnedUserId, PresenceEvent>, room_id: &RoomId, since: u64, presence_updates: &mut HashMap<OwnedUserId, PresenceEvent>, room_id: &RoomId, since: u64,
) -> Result<()> { ) -> Result<()> {
// Take presence updates from this room // Take presence updates from this room
for (user_id, _, presence_event) in services().rooms.edus.presence.presence_since(room_id, since) { for (user_id, _, presence_event) in services()
.rooms
.edus
.presence
.presence_since(room_id, since)
{
match presence_updates.entry(user_id) { match presence_updates.entry(user_id) {
Entry::Vacant(slot) => { Entry::Vacant(slot) => {
slot.insert(presence_event); slot.insert(presence_event);
@ -465,11 +552,19 @@ async fn process_room_presence_updates(
// Update existing presence event with more info // Update existing presence event with more info
curr_content.presence = new_content.presence; curr_content.presence = new_content.presence;
curr_content.status_msg = new_content.status_msg.or_else(|| curr_content.status_msg.take()); curr_content.status_msg = new_content
.status_msg
.or_else(|| curr_content.status_msg.take());
curr_content.last_active_ago = new_content.last_active_ago.or(curr_content.last_active_ago); curr_content.last_active_ago = new_content.last_active_ago.or(curr_content.last_active_ago);
curr_content.displayname = new_content.displayname.or_else(|| curr_content.displayname.take()); curr_content.displayname = new_content
curr_content.avatar_url = new_content.avatar_url.or_else(|| curr_content.avatar_url.take()); .displayname
curr_content.currently_active = new_content.currently_active.or(curr_content.currently_active); .or_else(|| curr_content.displayname.take());
curr_content.avatar_url = new_content
.avatar_url
.or_else(|| curr_content.avatar_url.take());
curr_content.currently_active = new_content
.currently_active
.or(curr_content.currently_active);
}, },
} }
} }
@ -486,23 +581,38 @@ async fn load_joined_room(
{ {
// Get and drop the lock to wait for remaining operations to finish // Get and drop the lock to wait for remaining operations to finish
// This will make sure the we have all events until next_batch // This will make sure the we have all events until next_batch
let mutex_insert = let mutex_insert = Arc::clone(
Arc::clone(services().globals.roomid_mutex_insert.write().await.entry(room_id.to_owned()).or_default()); services()
.globals
.roomid_mutex_insert
.write()
.await
.entry(room_id.to_owned())
.or_default(),
);
let insert_lock = mutex_insert.lock().await; let insert_lock = mutex_insert.lock().await;
drop(insert_lock); drop(insert_lock);
}; };
let (timeline_pdus, limited) = load_timeline(sender_user, room_id, sincecount, 10)?; let (timeline_pdus, limited) = load_timeline(sender_user, room_id, sincecount, 10)?;
let send_notification_counts = let send_notification_counts = !timeline_pdus.is_empty()
!timeline_pdus.is_empty() || services().rooms.user.last_notification_read(sender_user, room_id)? > since; || services()
.rooms
.user
.last_notification_read(sender_user, room_id)?
> since;
let mut timeline_users = HashSet::new(); let mut timeline_users = HashSet::new();
for (_, event) in &timeline_pdus { for (_, event) in &timeline_pdus {
timeline_users.insert(event.sender.as_str().to_owned()); timeline_users.insert(event.sender.as_str().to_owned());
} }
services().rooms.lazy_loading.lazy_load_confirm_delivery(sender_user, sender_device, room_id, sincecount).await?; services()
.rooms
.lazy_loading
.lazy_load_confirm_delivery(sender_user, sender_device, room_id, sincecount)
.await?;
// Database queries: // Database queries:
@ -513,19 +623,28 @@ async fn load_joined_room(
return Err(Error::BadDatabase("Room has no state")); return Err(Error::BadDatabase("Room has no state"));
}; };
let since_shortstatehash = services().rooms.user.get_token_shortstatehash(room_id, since)?; let since_shortstatehash = services()
.rooms
.user
.get_token_shortstatehash(room_id, since)?;
let (heroes, joined_member_count, invited_member_count, joined_since_last_sync, state_events) = if timeline_pdus let (heroes, joined_member_count, invited_member_count, joined_since_last_sync, state_events) =
.is_empty() if timeline_pdus.is_empty() && since_shortstatehash == Some(current_shortstatehash) {
&& since_shortstatehash == Some(current_shortstatehash)
{
// No state changes // No state changes
(Vec::new(), None, None, false, Vec::new()) (Vec::new(), None, None, false, Vec::new())
} else { } else {
// Calculates joined_member_count, invited_member_count and heroes // Calculates joined_member_count, invited_member_count and heroes
let calculate_counts = || { let calculate_counts = || {
let joined_member_count = services().rooms.state_cache.room_joined_count(room_id)?.unwrap_or(0); let joined_member_count = services()
let invited_member_count = services().rooms.state_cache.room_invited_count(room_id)?.unwrap_or(0); .rooms
.state_cache
.room_joined_count(room_id)?
.unwrap_or(0);
let invited_member_count = services()
.rooms
.state_cache
.room_invited_count(room_id)?
.unwrap_or(0);
// Recalculate heroes (first 5 members) // Recalculate heroes (first 5 members)
let mut heroes = Vec::new(); let mut heroes = Vec::new();
@ -600,14 +719,21 @@ async fn load_joined_room(
let (joined_member_count, invited_member_count, heroes) = calculate_counts()?; let (joined_member_count, invited_member_count, heroes) = calculate_counts()?;
let current_state_ids = services().rooms.state_accessor.state_full_ids(current_shortstatehash).await?; let current_state_ids = services()
.rooms
.state_accessor
.state_full_ids(current_shortstatehash)
.await?;
let mut state_events = Vec::new(); let mut state_events = Vec::new();
let mut lazy_loaded = HashSet::new(); let mut lazy_loaded = HashSet::new();
let mut i = 0; let mut i = 0;
for (shortstatekey, id) in current_state_ids { for (shortstatekey, id) in current_state_ids {
let (event_type, state_key) = services().rooms.short.get_statekey_from_short(shortstatekey)?; let (event_type, state_key) = services()
.rooms
.short
.get_statekey_from_short(shortstatekey)?;
if event_type != StateEventType::RoomMember { if event_type != StateEventType::RoomMember {
let pdu = match services().rooms.timeline.get_pdu(&id)? { let pdu = match services().rooms.timeline.get_pdu(&id)? {
@ -651,7 +777,10 @@ async fn load_joined_room(
} }
// Reset lazy loading because this is an initial sync // Reset lazy loading because this is an initial sync
services().rooms.lazy_loading.lazy_load_reset(sender_user, sender_device, room_id)?; services()
.rooms
.lazy_loading
.lazy_load_reset(sender_user, sender_device, room_id)?;
// The state_events above should contain all timeline_users, let's mark them as // The state_events above should contain all timeline_users, let's mark them as
// lazy loaded. // lazy loaded.
@ -670,8 +799,16 @@ async fn load_joined_room(
let mut lazy_loaded = HashSet::new(); let mut lazy_loaded = HashSet::new();
if since_shortstatehash != current_shortstatehash { if since_shortstatehash != current_shortstatehash {
let current_state_ids = services().rooms.state_accessor.state_full_ids(current_shortstatehash).await?; let current_state_ids = services()
let since_state_ids = services().rooms.state_accessor.state_full_ids(since_shortstatehash).await?; .rooms
.state_accessor
.state_full_ids(current_shortstatehash)
.await?;
let since_state_ids = services()
.rooms
.state_accessor
.state_full_ids(since_shortstatehash)
.await?;
for (key, id) in current_state_ids { for (key, id) in current_state_ids {
if full_state || since_state_ids.get(&key) != Some(&id) { if full_state || since_state_ids.get(&key) != Some(&id) {
@ -684,7 +821,12 @@ async fn load_joined_room(
}; };
if pdu.kind == TimelineEventType::RoomMember { if pdu.kind == TimelineEventType::RoomMember {
match UserId::parse(pdu.state_key.as_ref().expect("State event has state key").clone()) { match UserId::parse(
pdu.state_key
.as_ref()
.expect("State event has state key")
.clone(),
) {
Ok(state_key_userid) => { Ok(state_key_userid) => {
lazy_loaded.insert(state_key_userid); lazy_loaded.insert(state_key_userid);
}, },
@ -733,13 +875,18 @@ async fn load_joined_room(
.state_get(current_shortstatehash, &StateEventType::RoomEncryption, "")? .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "")?
.is_some(); .is_some();
let since_encryption = let since_encryption = services().rooms.state_accessor.state_get(
services().rooms.state_accessor.state_get(since_shortstatehash, &StateEventType::RoomEncryption, "")?; since_shortstatehash,
&StateEventType::RoomEncryption,
"",
)?;
// Calculations: // Calculations:
let new_encrypted_room = encrypted_room && since_encryption.is_none(); let new_encrypted_room = encrypted_room && since_encryption.is_none();
let send_member_count = state_events.iter().any(|event| event.kind == TimelineEventType::RoomMember); let send_member_count = state_events
.iter()
.any(|event| event.kind == TimelineEventType::RoomMember);
if encrypted_room { if encrypted_room {
for state_event in &state_events { for state_event in &state_events {
@ -755,7 +902,8 @@ async fn load_joined_room(
continue; continue;
} }
let new_membership = serde_json::from_str::<RoomMemberEventContent>(state_event.content.get()) let new_membership =
serde_json::from_str::<RoomMemberEventContent>(state_event.content.get())
.map_err(|_| Error::bad_database("Invalid PDU in database."))? .map_err(|_| Error::bad_database("Invalid PDU in database."))?
.membership; .membership;
@ -813,7 +961,12 @@ async fn load_joined_room(
}; };
// Look for device list updates in this room // Look for device list updates in this room
device_list_updates.extend(services().users.keys_changed(room_id.as_ref(), since, None).filter_map(Result::ok)); device_list_updates.extend(
services()
.users
.keys_changed(room_id.as_ref(), since, None)
.filter_map(Result::ok),
);
let notification_count = if send_notification_counts { let notification_count = if send_notification_counts {
Some( Some(
@ -841,7 +994,9 @@ async fn load_joined_room(
None None
}; };
let prev_batch = timeline_pdus.first().map_or(Ok::<_, Error>(None), |(pdu_count, _)| { let prev_batch = timeline_pdus
.first()
.map_or(Ok::<_, Error>(None), |(pdu_count, _)| {
Ok(Some(match pdu_count { Ok(Some(match pdu_count {
PduCount::Backfilled(_) => { PduCount::Backfilled(_) => {
error!("timeline in backfill state?!"); error!("timeline in backfill state?!");
@ -851,7 +1006,10 @@ async fn load_joined_room(
})) }))
})?; })?;
let room_events: Vec<_> = timeline_pdus.iter().map(|(_, pdu)| pdu.to_sync_room_event()).collect(); let room_events: Vec<_> = timeline_pdus
.iter()
.map(|(_, pdu)| pdu.to_sync_room_event())
.collect();
let mut edus: Vec<_> = services() let mut edus: Vec<_> = services()
.rooms .rooms
@ -862,7 +1020,13 @@ async fn load_joined_room(
.map(|(_, _, v)| v) .map(|(_, _, v)| v)
.collect(); .collect();
if services().rooms.edus.typing.last_typing_update(room_id).await? > since { if services()
.rooms
.edus
.typing
.last_typing_update(room_id)
.await? > since
{
edus.push( edus.push(
serde_json::from_str( serde_json::from_str(
&serde_json::to_string(&services().rooms.edus.typing.typings_all(room_id).await?) &serde_json::to_string(&services().rooms.edus.typing.typings_all(room_id).await?)
@ -874,7 +1038,10 @@ async fn load_joined_room(
// Save the state after this sync so we can send the correct state diff next // Save the state after this sync so we can send the correct state diff next
// sync // sync
services().rooms.user.associate_token_shortstatehash(room_id, next_batch, current_shortstatehash)?; services()
.rooms
.user
.associate_token_shortstatehash(room_id, next_batch, current_shortstatehash)?;
Ok(JoinedRoom { Ok(JoinedRoom {
account_data: RoomAccountData { account_data: RoomAccountData {
@ -904,7 +1071,10 @@ async fn load_joined_room(
events: room_events, events: room_events,
}, },
state: State { state: State {
events: state_events.iter().map(|pdu| pdu.to_sync_state_event()).collect(), events: state_events
.iter()
.map(|pdu| pdu.to_sync_state_event())
.collect(),
}, },
ephemeral: Ephemeral { ephemeral: Ephemeral {
events: edus, events: edus,
@ -918,7 +1088,12 @@ fn load_timeline(
) -> Result<(Vec<(PduCount, PduEvent)>, bool), Error> { ) -> Result<(Vec<(PduCount, PduEvent)>, bool), Error> {
let timeline_pdus; let timeline_pdus;
let limited; let limited;
if services().rooms.timeline.last_timeline_count(sender_user, room_id)? > roomsincecount { if services()
.rooms
.timeline
.last_timeline_count(sender_user, room_id)?
> roomsincecount
{
let mut non_timeline_pdus = services() let mut non_timeline_pdus = services()
.rooms .rooms
.timeline .timeline
@ -933,8 +1108,13 @@ fn load_timeline(
.take_while(|(pducount, _)| pducount > &roomsincecount); .take_while(|(pducount, _)| pducount > &roomsincecount);
// Take the last events for the timeline // Take the last events for the timeline
timeline_pdus = timeline_pdus = non_timeline_pdus
non_timeline_pdus.by_ref().take(limit as usize).collect::<Vec<_>>().into_iter().rev().collect::<Vec<_>>(); .by_ref()
.take(limit as usize)
.collect::<Vec<_>>()
.into_iter()
.rev()
.collect::<Vec<_>>();
// They /sync response doesn't always return all messages, so we say the output // They /sync response doesn't always return all messages, so we say the output
// is limited unless there are events in non_timeline_pdus // is limited unless there are events in non_timeline_pdus
@ -980,7 +1160,11 @@ pub async fn sync_events_v4_route(
let next_batch = services().globals.next_count()?; let next_batch = services().globals.next_count()?;
let globalsince = body.pos.as_ref().and_then(|string| string.parse().ok()).unwrap_or(0); let globalsince = body
.pos
.as_ref()
.and_then(|string| string.parse().ok())
.unwrap_or(0);
if globalsince == 0 { if globalsince == 0 {
if let Some(conn_id) = &body.conn_id { if let Some(conn_id) = &body.conn_id {
@ -994,13 +1178,21 @@ pub async fn sync_events_v4_route(
// Get sticky parameters from cache // Get sticky parameters from cache
let known_rooms = let known_rooms =
services().users.update_sync_request_with_cache(sender_user.clone(), sender_device.clone(), &mut body); services()
.users
.update_sync_request_with_cache(sender_user.clone(), sender_device.clone(), &mut body);
let all_joined_rooms = let all_joined_rooms = services()
services().rooms.state_cache.rooms_joined(&sender_user).filter_map(Result::ok).collect::<Vec<_>>(); .rooms
.state_cache
.rooms_joined(&sender_user)
.filter_map(Result::ok)
.collect::<Vec<_>>();
if body.extensions.to_device.enabled.unwrap_or(false) { if body.extensions.to_device.enabled.unwrap_or(false) {
services().users.remove_to_device_events(&sender_user, &sender_device, globalsince)?; services()
.users
.remove_to_device_events(&sender_user, &sender_device, globalsince)?;
} }
let mut left_encrypted_users = HashSet::new(); // Users that have left any encrypted rooms the sender was in let mut left_encrypted_users = HashSet::new(); // Users that have left any encrypted rooms the sender was in
@ -1009,8 +1201,12 @@ pub async fn sync_events_v4_route(
if body.extensions.e2ee.enabled.unwrap_or(false) { if body.extensions.e2ee.enabled.unwrap_or(false) {
// Look for device list updates of this account // Look for device list updates of this account
device_list_changes device_list_changes.extend(
.extend(services().users.keys_changed(sender_user.as_ref(), globalsince, None).filter_map(Result::ok)); services()
.users
.keys_changed(sender_user.as_ref(), globalsince, None)
.filter_map(Result::ok),
);
for room_id in &all_joined_rooms { for room_id in &all_joined_rooms {
let current_shortstatehash = if let Some(s) = services().rooms.state.get_room_shortstatehash(room_id)? { let current_shortstatehash = if let Some(s) = services().rooms.state.get_room_shortstatehash(room_id)? {
@ -1020,7 +1216,10 @@ pub async fn sync_events_v4_route(
continue; continue;
}; };
let since_shortstatehash = services().rooms.user.get_token_shortstatehash(room_id, globalsince)?; let since_shortstatehash = services()
.rooms
.user
.get_token_shortstatehash(room_id, globalsince)?;
let since_sender_member: Option<RoomMemberEventContent> = since_shortstatehash let since_sender_member: Option<RoomMemberEventContent> = since_shortstatehash
.and_then(|shortstatehash| { .and_then(|shortstatehash| {
@ -1060,9 +1259,16 @@ pub async fn sync_events_v4_route(
let new_encrypted_room = encrypted_room && since_encryption.is_none(); let new_encrypted_room = encrypted_room && since_encryption.is_none();
if encrypted_room { if encrypted_room {
let current_state_ids = let current_state_ids = services()
services().rooms.state_accessor.state_full_ids(current_shortstatehash).await?; .rooms
let since_state_ids = services().rooms.state_accessor.state_full_ids(since_shortstatehash).await?; .state_accessor
.state_full_ids(current_shortstatehash)
.await?;
let since_state_ids = services()
.rooms
.state_accessor
.state_full_ids(since_shortstatehash)
.await?;
for (key, id) in current_state_ids { for (key, id) in current_state_ids {
if since_state_ids.get(&key) != Some(&id) { if since_state_ids.get(&key) != Some(&id) {
@ -1126,8 +1332,12 @@ pub async fn sync_events_v4_route(
} }
} }
// Look for device list updates in this room // Look for device list updates in this room
device_list_changes device_list_changes.extend(
.extend(services().users.keys_changed(room_id.as_ref(), globalsince, None).filter_map(Result::ok)); services()
.users
.keys_changed(room_id.as_ref(), globalsince, None)
.filter_map(Result::ok),
);
} }
for user_id in left_encrypted_users { for user_id in left_encrypted_users {
let dont_share_encrypted_room = services() let dont_share_encrypted_room = services()
@ -1171,19 +1381,33 @@ pub async fn sync_events_v4_route(
.ranges .ranges
.into_iter() .into_iter()
.map(|mut r| { .map(|mut r| {
r.0 = r.0.clamp(uint!(0), UInt::from(all_joined_rooms.len() as u32 - 1)); r.0 =
r.1 = r.1.clamp(r.0, UInt::from(all_joined_rooms.len() as u32 - 1)); r.0.clamp(uint!(0), UInt::from(all_joined_rooms.len() as u32 - 1));
r.1 =
r.1.clamp(r.0, UInt::from(all_joined_rooms.len() as u32 - 1));
let room_ids = all_joined_rooms[(u64::from(r.0) as usize)..=(u64::from(r.1) as usize)].to_vec(); let room_ids = all_joined_rooms[(u64::from(r.0) as usize)..=(u64::from(r.1) as usize)].to_vec();
new_known_rooms.extend(room_ids.iter().cloned()); new_known_rooms.extend(room_ids.iter().cloned());
for room_id in &room_ids { for room_id in &room_ids {
let todo_room = todo_rooms.entry(room_id.clone()).or_insert((BTreeSet::new(), 0, u64::MAX)); let todo_room = todo_rooms
let limit = list.room_details.timeline_limit.map_or(10, u64::from).min(100); .entry(room_id.clone())
todo_room.0.extend(list.room_details.required_state.iter().cloned()); .or_insert((BTreeSet::new(), 0, u64::MAX));
let limit = list
.room_details
.timeline_limit
.map_or(10, u64::from)
.min(100);
todo_room
.0
.extend(list.room_details.required_state.iter().cloned());
todo_room.1 = todo_room.1.max(limit); todo_room.1 = todo_room.1.max(limit);
// 0 means unknown because it got out of date // 0 means unknown because it got out of date
todo_room.2 = todo_room todo_room.2 = todo_room.2.min(
.2 known_rooms
.min(known_rooms.get(&list_id).and_then(|k| k.get(room_id)).copied().unwrap_or(0)); .get(&list_id)
.and_then(|k| k.get(room_id))
.copied()
.unwrap_or(0),
);
} }
sync_events::v4::SyncOp { sync_events::v4::SyncOp {
op: SlidingOp::Sync, op: SlidingOp::Sync,
@ -1215,13 +1439,20 @@ pub async fn sync_events_v4_route(
if !services().rooms.metadata.exists(room_id)? { if !services().rooms.metadata.exists(room_id)? {
continue; continue;
} }
let todo_room = todo_rooms.entry(room_id.clone()).or_insert((BTreeSet::new(), 0, u64::MAX)); let todo_room = todo_rooms
.entry(room_id.clone())
.or_insert((BTreeSet::new(), 0, u64::MAX));
let limit = room.timeline_limit.map_or(10, u64::from).min(100); let limit = room.timeline_limit.map_or(10, u64::from).min(100);
todo_room.0.extend(room.required_state.iter().cloned()); todo_room.0.extend(room.required_state.iter().cloned());
todo_room.1 = todo_room.1.max(limit); todo_room.1 = todo_room.1.max(limit);
// 0 means unknown because it got out of date // 0 means unknown because it got out of date
todo_room.2 = todo_room.2 = todo_room.2.min(
todo_room.2.min(known_rooms.get("subscriptions").and_then(|k| k.get(room_id)).copied().unwrap_or(0)); known_rooms
.get("subscriptions")
.and_then(|k| k.get(room_id))
.copied()
.unwrap_or(0),
);
known_subscription_rooms.insert(room_id.clone()); known_subscription_rooms.insert(room_id.clone());
} }
@ -1279,11 +1510,19 @@ pub async fn sync_events_v4_route(
} }
}); });
let room_events: Vec<_> = timeline_pdus.iter().map(|(_, pdu)| pdu.to_sync_room_event()).collect(); let room_events: Vec<_> = timeline_pdus
.iter()
.map(|(_, pdu)| pdu.to_sync_room_event())
.collect();
let required_state = required_state_request let required_state = required_state_request
.iter() .iter()
.map(|state| services().rooms.state_accessor.room_state_get(room_id, &state.0, &state.1)) .map(|state| {
services()
.rooms
.state_accessor
.room_state_get(room_id, &state.0, &state.1)
})
.filter_map(Result::ok) .filter_map(Result::ok)
.flatten() .flatten()
.map(|state| state.to_sync_state_event()) .map(|state| state.to_sync_state_event())
@ -1298,9 +1537,15 @@ pub async fn sync_events_v4_route(
.filter(|member| member != &sender_user) .filter(|member| member != &sender_user)
.map(|member| { .map(|member| {
Ok::<_, Error>( Ok::<_, Error>(
services().rooms.state_accessor.get_member(room_id, &member)?.map(|memberevent| { services()
.rooms
.state_accessor
.get_member(room_id, &member)?
.map(|memberevent| {
( (
memberevent.displayname.unwrap_or_else(|| member.to_string()), memberevent
.displayname
.unwrap_or_else(|| member.to_string()),
memberevent.avatar_url, memberevent.avatar_url,
) )
}), }),
@ -1313,7 +1558,13 @@ pub async fn sync_events_v4_route(
let name = match heroes.len().cmp(&(1_usize)) { let name = match heroes.len().cmp(&(1_usize)) {
Ordering::Greater => { Ordering::Greater => {
let last = heroes[0].0.clone(); let last = heroes[0].0.clone();
Some(heroes[1..].iter().map(|h| h.0.clone()).collect::<Vec<_>>().join(", ") + " and " + &last) Some(
heroes[1..]
.iter()
.map(|h| h.0.clone())
.collect::<Vec<_>>()
.join(", ") + " and " + &last,
)
}, },
Ordering::Equal => Some(heroes[0].0.clone()), Ordering::Equal => Some(heroes[0].0.clone()),
Ordering::Less => None, Ordering::Less => None,
@ -1364,10 +1615,20 @@ pub async fn sync_events_v4_route(
prev_batch, prev_batch,
limited, limited,
joined_count: Some( joined_count: Some(
(services().rooms.state_cache.room_joined_count(room_id)?.unwrap_or(0) as u32).into(), (services()
.rooms
.state_cache
.room_joined_count(room_id)?
.unwrap_or(0) as u32)
.into(),
), ),
invited_count: Some( invited_count: Some(
(services().rooms.state_cache.room_invited_count(room_id)?.unwrap_or(0) as u32).into(), (services()
.rooms
.state_cache
.room_invited_count(room_id)?
.unwrap_or(0) as u32)
.into(),
), ),
num_live: None, // Count events in timeline greater than global sync counter num_live: None, // Count events in timeline greater than global sync counter
timestamp: None, timestamp: None,
@ -1375,7 +1636,10 @@ pub async fn sync_events_v4_route(
); );
} }
if rooms.iter().all(|(_, r)| r.timeline.is_empty() && r.required_state.is_empty()) { if rooms
.iter()
.all(|(_, r)| r.timeline.is_empty() && r.required_state.is_empty())
{
// Hang a few seconds so requests are not spammed // Hang a few seconds so requests are not spammed
// Stop hanging if new info arrives // Stop hanging if new info arrives
let mut duration = body.timeout.unwrap_or(Duration::from_secs(30)); let mut duration = body.timeout.unwrap_or(Duration::from_secs(30));
@ -1394,7 +1658,9 @@ pub async fn sync_events_v4_route(
extensions: sync_events::v4::Extensions { extensions: sync_events::v4::Extensions {
to_device: if body.extensions.to_device.enabled.unwrap_or(false) { to_device: if body.extensions.to_device.enabled.unwrap_or(false) {
Some(sync_events::v4::ToDevice { Some(sync_events::v4::ToDevice {
events: services().users.get_to_device_events(&sender_user, &sender_device)?, events: services()
.users
.get_to_device_events(&sender_user, &sender_device)?,
next_batch: next_batch.to_string(), next_batch: next_batch.to_string(),
}) })
} else { } else {
@ -1405,7 +1671,9 @@ pub async fn sync_events_v4_route(
changed: device_list_changes.into_iter().collect(), changed: device_list_changes.into_iter().collect(),
left: device_list_left.into_iter().collect(), left: device_list_left.into_iter().collect(),
}, },
device_one_time_keys_count: services().users.count_one_time_keys(&sender_user, &sender_device)?, device_one_time_keys_count: services()
.users
.count_one_time_keys(&sender_user, &sender_device)?,
// Fallback keys are not yet supported // Fallback keys are not yet supported
device_unused_fallback_key_types: None, device_unused_fallback_key_types: None,
}, },

View file

@ -18,7 +18,9 @@ use crate::{services, Error, Result, Ruma};
pub async fn update_tag_route(body: Ruma<create_tag::v3::Request>) -> Result<create_tag::v3::Response> { pub async fn update_tag_route(body: Ruma<create_tag::v3::Request>) -> Result<create_tag::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let event = services().account_data.get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?; let event = services()
.account_data
.get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?;
let mut tags_event = event.map_or_else( let mut tags_event = event.map_or_else(
|| { || {
@ -31,7 +33,10 @@ pub async fn update_tag_route(body: Ruma<create_tag::v3::Request>) -> Result<cre
|e| serde_json::from_str(e.get()).map_err(|_| Error::bad_database("Invalid account data event in db.")), |e| serde_json::from_str(e.get()).map_err(|_| Error::bad_database("Invalid account data event in db.")),
)?; )?;
tags_event.content.tags.insert(body.tag.clone().into(), body.tag_info.clone()); tags_event
.content
.tags
.insert(body.tag.clone().into(), body.tag_info.clone());
services().account_data.update( services().account_data.update(
Some(&body.room_id), Some(&body.room_id),
@ -51,7 +56,9 @@ pub async fn update_tag_route(body: Ruma<create_tag::v3::Request>) -> Result<cre
pub async fn delete_tag_route(body: Ruma<delete_tag::v3::Request>) -> Result<delete_tag::v3::Response> { pub async fn delete_tag_route(body: Ruma<delete_tag::v3::Request>) -> Result<delete_tag::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let event = services().account_data.get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?; let event = services()
.account_data
.get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?;
let mut tags_event = event.map_or_else( let mut tags_event = event.map_or_else(
|| { || {
@ -84,7 +91,9 @@ pub async fn delete_tag_route(body: Ruma<delete_tag::v3::Request>) -> Result<del
pub async fn get_tags_route(body: Ruma<get_tags::v3::Request>) -> Result<get_tags::v3::Response> { pub async fn get_tags_route(body: Ruma<get_tags::v3::Request>) -> Result<get_tags::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let event = services().account_data.get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?; let event = services()
.account_data
.get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?;
let tags_event = event.map_or_else( let tags_event = event.map_or_else(
|| { || {

View file

@ -7,10 +7,15 @@ pub async fn get_threads_route(body: Ruma<get_threads::v1::Request>) -> Result<g
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
// Use limit or else 10, with maximum 100 // Use limit or else 10, with maximum 100
let limit = body.limit.and_then(|l| l.try_into().ok()).unwrap_or(10).min(100); let limit = body
.limit
.and_then(|l| l.try_into().ok())
.unwrap_or(10)
.min(100);
let from = if let Some(from) = &body.from { let from = if let Some(from) = &body.from {
from.parse().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, ""))? from.parse()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, ""))?
} else { } else {
u64::MAX u64::MAX
}; };
@ -33,7 +38,10 @@ pub async fn get_threads_route(body: Ruma<get_threads::v1::Request>) -> Result<g
let next_batch = threads.last().map(|(count, _)| count.to_string()); let next_batch = threads.last().map(|(count, _)| count.to_string());
Ok(get_threads::v1::Response { Ok(get_threads::v1::Response {
chunk: threads.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect(), chunk: threads
.into_iter()
.map(|(_, pdu)| pdu.to_room_event())
.collect(),
next_batch, next_batch,
}) })
} }

View file

@ -20,7 +20,11 @@ pub async fn send_event_to_device_route(
let sender_device = body.sender_device.as_deref(); let sender_device = body.sender_device.as_deref();
// Check if this is a new transaction id // Check if this is a new transaction id
if services().transaction_ids.existing_txnid(sender_user, sender_device, &body.txn_id)?.is_some() { if services()
.transaction_ids
.existing_txnid(sender_user, sender_device, &body.txn_id)?
.is_some()
{
return Ok(send_event_to_device::v3::Response {}); return Ok(send_event_to_device::v3::Response {});
} }
@ -79,7 +83,9 @@ pub async fn send_event_to_device_route(
} }
// Save transaction id with empty data // Save transaction id with empty data
services().transaction_ids.add_txnid(sender_user, sender_device, &body.txn_id, &[])?; services()
.transaction_ids
.add_txnid(sender_user, sender_device, &body.txn_id, &[])?;
Ok(send_event_to_device::v3::Response {}) Ok(send_event_to_device::v3::Response {})
} }

View file

@ -12,7 +12,11 @@ pub async fn create_typing_event_route(
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if !services().rooms.state_cache.is_joined(sender_user, &body.room_id)? { if !services()
.rooms
.state_cache
.is_joined(sender_user, &body.room_id)?
{
return Err(Error::BadRequest(ErrorKind::Forbidden, "You are not in this room.")); return Err(Error::BadRequest(ErrorKind::Forbidden, "You are not in this room."));
} }
@ -28,7 +32,12 @@ pub async fn create_typing_event_route(
) )
.await?; .await?;
} else { } else {
services().rooms.edus.typing.typing_remove(sender_user, &body.room_id).await?; services()
.rooms
.edus
.typing
.typing_remove(sender_user, &body.room_id)
.await?;
} }
Ok(create_typing_event::v3::Response {}) Ok(create_typing_event::v3::Response {})

View file

@ -29,12 +29,19 @@ pub async fn search_users_route(body: Ruma<search_users::v3::Request>) -> Result
avatar_url: services().users.avatar_url(&user_id).ok()?, avatar_url: services().users.avatar_url(&user_id).ok()?,
}; };
let user_id_matches = user.user_id.to_string().to_lowercase().contains(&body.search_term.to_lowercase()); let user_id_matches = user
.user_id
.to_string()
.to_lowercase()
.contains(&body.search_term.to_lowercase());
let user_displayname_matches = user let user_displayname_matches = user
.display_name .display_name
.as_ref() .as_ref()
.filter(|name| name.to_lowercase().contains(&body.search_term.to_lowercase())) .filter(|name| {
name.to_lowercase()
.contains(&body.search_term.to_lowercase())
})
.is_some(); .is_some();
if !user_id_matches && !user_displayname_matches { if !user_id_matches && !user_displayname_matches {
@ -44,24 +51,34 @@ pub async fn search_users_route(body: Ruma<search_users::v3::Request>) -> Result
// It's a matching user, but is the sender allowed to see them? // It's a matching user, but is the sender allowed to see them?
let mut user_visible = false; let mut user_visible = false;
let user_is_in_public_rooms = let user_is_in_public_rooms = services()
services().rooms.state_cache.rooms_joined(&user_id).filter_map(Result::ok).any(|room| { .rooms
services().rooms.state_accessor.room_state_get(&room, &StateEventType::RoomJoinRules, "").map_or( .state_cache
false, .rooms_joined(&user_id)
|event| { .filter_map(Result::ok)
.any(|room| {
services()
.rooms
.state_accessor
.room_state_get(&room, &StateEventType::RoomJoinRules, "")
.map_or(false, |event| {
event.map_or(false, |event| { event.map_or(false, |event| {
serde_json::from_str(event.content.get()) serde_json::from_str(event.content.get())
.map_or(false, |r: RoomJoinRulesEventContent| r.join_rule == JoinRule::Public) .map_or(false, |r: RoomJoinRulesEventContent| r.join_rule == JoinRule::Public)
}) })
}, })
)
}); });
if user_is_in_public_rooms { if user_is_in_public_rooms {
user_visible = true; user_visible = true;
} else { } else {
let user_is_in_shared_rooms = let user_is_in_shared_rooms = services()
services().rooms.user.get_shared_rooms(vec![sender_user.clone(), user_id]).ok()?.next().is_some(); .rooms
.user
.get_shared_rooms(vec![sender_user.clone(), user_id])
.ok()?
.next()
.is_some();
if user_is_in_shared_rooms { if user_is_in_shared_rooms {
user_visible = true; user_visible = true;

View file

@ -44,14 +44,16 @@ where
let (mut parts, mut body) = match req.with_limited_body() { let (mut parts, mut body) = match req.with_limited_body() {
Ok(limited_req) => { Ok(limited_req) => {
let (parts, body) = limited_req.into_parts(); let (parts, body) = limited_req.into_parts();
let body = let body = to_bytes(body)
to_bytes(body).await.map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?; .await
.map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?;
(parts, body) (parts, body)
}, },
Err(original_req) => { Err(original_req) => {
let (parts, body) = original_req.into_parts(); let (parts, body) = original_req.into_parts();
let body = let body = to_bytes(body)
to_bytes(body).await.map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?; .await
.map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?;
(parts, body) (parts, body)
}, },
}; };
@ -180,8 +182,10 @@ where
return Err(Error::bad_config("Federation is disabled.")); return Err(Error::bad_config("Federation is disabled."));
} }
let TypedHeader(Authorization(x_matrix)) = let TypedHeader(Authorization(x_matrix)) = parts
parts.extract::<TypedHeader<Authorization<XMatrix>>>().await.map_err(|e| { .extract::<TypedHeader<Authorization<XMatrix>>>()
.await
.map_err(|e| {
warn!("Missing or invalid Authorization header: {}", e); warn!("Missing or invalid Authorization header: {}", e);
let msg = match e.reason() { let msg = match e.reason() {
@ -265,7 +269,11 @@ where
AuthScheme::None => match parts.uri.path() { AuthScheme::None => match parts.uri.path() {
// allow_public_room_directory_without_auth // allow_public_room_directory_without_auth
"/_matrix/client/v3/publicRooms" | "/_matrix/client/r0/publicRooms" => { "/_matrix/client/v3/publicRooms" | "/_matrix/client/r0/publicRooms" => {
if !services().globals.config.allow_public_room_directory_without_auth { if !services()
.globals
.config
.allow_public_room_directory_without_auth
{
let token = match token { let token = match token {
Some(token) => token, Some(token) => token,
_ => return Err(Error::BadRequest(ErrorKind::MissingToken, "Missing access token.")), _ => return Err(Error::BadRequest(ErrorKind::MissingToken, "Missing access token.")),
@ -362,7 +370,9 @@ impl Credentials for XMatrix {
"HeaderValue to decode should start with \"X-Matrix ..\", received = {value:?}", "HeaderValue to decode should start with \"X-Matrix ..\", received = {value:?}",
); );
let parameters = str::from_utf8(&value.as_bytes()["X-Matrix ".len()..]).ok()?.trim_start(); let parameters = str::from_utf8(&value.as_bytes()["X-Matrix ".len()..])
.ok()?
.trim_start();
let mut origin = None; let mut origin = None;
let mut destination = None; let mut destination = None;
@ -374,7 +384,10 @@ impl Credentials for XMatrix {
// It's not at all clear why some fields are quoted and others not in the spec, // It's not at all clear why some fields are quoted and others not in the spec,
// let's simply accept either form for every field. // let's simply accept either form for every field.
let value = value.strip_prefix('"').and_then(|rest| rest.strip_suffix('"')).unwrap_or(value); let value = value
.strip_prefix('"')
.and_then(|rest| rest.strip_suffix('"'))
.unwrap_or(value);
// FIXME: Catch multiple fields of the same name // FIXME: Catch multiple fields of the same name
match name { match name {

View file

@ -159,7 +159,13 @@ where
let mut write_destination_to_cache = false; let mut write_destination_to_cache = false;
let cached_result = services().globals.actual_destinations().read().await.get(destination).cloned(); let cached_result = services()
.globals
.actual_destinations()
.read()
.await
.get(destination)
.cloned();
let (actual_destination, host) = if let Some(result) = cached_result { let (actual_destination, host) = if let Some(result) = cached_result {
result result
@ -196,7 +202,12 @@ where
request_map.insert("method".to_owned(), T::METADATA.method.to_string().into()); request_map.insert("method".to_owned(), T::METADATA.method.to_string().into());
request_map.insert( request_map.insert(
"uri".to_owned(), "uri".to_owned(),
http_request.uri().path_and_query().expect("all requests have a path").to_string().into(), http_request
.uri()
.path_and_query()
.expect("all requests have a path")
.to_string()
.into(),
); );
request_map.insert("origin".to_owned(), services().globals.server_name().as_str().into()); request_map.insert("origin".to_owned(), services().globals.server_name().as_str().into());
request_map.insert("destination".to_owned(), destination.as_str().into()); request_map.insert("destination".to_owned(), destination.as_str().into());
@ -217,7 +228,12 @@ where
.as_object() .as_object()
.unwrap() .unwrap()
.values() .values()
.map(|v| v.as_object().unwrap().iter().map(|(k, v)| (k, v.as_str().unwrap()))); .map(|v| {
v.as_object()
.unwrap()
.iter()
.map(|(k, v)| (k, v.as_str().unwrap()))
});
for signature_server in signatures { for signature_server in signatures {
for s in signature_server { for s in signature_server {
@ -257,7 +273,12 @@ where
} }
debug!("Sending request to {destination} at {url}"); debug!("Sending request to {destination} at {url}");
let response = services().globals.client.federation.execute(reqwest_request).await; let response = services()
.globals
.client
.federation
.execute(reqwest_request)
.await;
debug!("Received response from {destination} at {url}"); debug!("Received response from {destination} at {url}");
match response { match response {
@ -283,10 +304,14 @@ where
} }
let status = response.status(); let status = response.status();
let mut http_response_builder = http::Response::builder().status(status).version(response.version()); let mut http_response_builder = http::Response::builder()
.status(status)
.version(response.version());
mem::swap( mem::swap(
response.headers_mut(), response.headers_mut(),
http_response_builder.headers_mut().expect("http::response::Builder is usable"), http_response_builder
.headers_mut()
.expect("http::response::Builder is usable"),
); );
debug!("Getting response bytes from {destination}"); debug!("Getting response bytes from {destination}");
@ -301,11 +326,16 @@ where
"Response not successful\n{} {}: {}", "Response not successful\n{} {}: {}",
url, url,
status, status,
String::from_utf8_lossy(&body).lines().collect::<Vec<_>>().join(" ") String::from_utf8_lossy(&body)
.lines()
.collect::<Vec<_>>()
.join(" ")
); );
} }
let http_response = http_response_builder.body(body).expect("reqwest body is valid http body"); let http_response = http_response_builder
.body(body)
.expect("reqwest body is valid http body");
if status.is_success() { if status.is_success() {
debug!("Parsing response bytes from {destination}"); debug!("Parsing response bytes from {destination}");
@ -490,7 +520,12 @@ async fn find_actual_destination(destination: &'_ ServerName) -> (FedDest, FedDe
} }
async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u16) { async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u16) {
match services().globals.dns_resolver().lookup_ip(hostname.to_owned()).await { match services()
.globals
.dns_resolver()
.lookup_ip(hostname.to_owned())
.await
{
Ok(override_ip) => { Ok(override_ip) => {
debug!("Caching result of {:?} overriding {:?}", hostname, overname); debug!("Caching result of {:?} overriding {:?}", hostname, overname);
@ -521,7 +556,11 @@ async fn query_srv_record(hostname: &'_ str) -> Option<FedDest> {
async fn lookup_srv(hostname: &str) -> Result<SrvLookup, ResolveError> { async fn lookup_srv(hostname: &str) -> Result<SrvLookup, ResolveError> {
debug!("querying SRV for {:?}", hostname); debug!("querying SRV for {:?}", hostname);
let hostname = hostname.trim_end_matches('.'); let hostname = hostname.trim_end_matches('.');
services().globals.dns_resolver().srv_lookup(hostname.to_owned()).await services()
.globals
.dns_resolver()
.srv_lookup(hostname.to_owned())
.await
} }
let first_hostname = format!("_matrix-fed._tcp.{hostname}."); let first_hostname = format!("_matrix-fed._tcp.{hostname}.");
@ -539,7 +578,14 @@ async fn query_srv_record(hostname: &'_ str) -> Option<FedDest> {
} }
async fn request_well_known(destination: &str) -> Option<String> { async fn request_well_known(destination: &str) -> Option<String> {
if !services().globals.resolver.overrides.read().unwrap().contains_key(destination) { if !services()
.globals
.resolver
.overrides
.read()
.unwrap()
.contains_key(destination)
{
query_and_cache_override(destination, destination, 8448).await; query_and_cache_override(destination, destination, 8448).await;
} }
@ -663,7 +709,10 @@ pub async fn get_server_keys_deprecated_route() -> impl IntoResponse { get_serve
pub async fn get_public_rooms_filtered_route( pub async fn get_public_rooms_filtered_route(
body: Ruma<get_public_rooms_filtered::v1::Request>, body: Ruma<get_public_rooms_filtered::v1::Request>,
) -> Result<get_public_rooms_filtered::v1::Response> { ) -> Result<get_public_rooms_filtered::v1::Response> {
if !services().globals.allow_public_room_directory_over_federation() { if !services()
.globals
.allow_public_room_directory_over_federation()
{
return Err(Error::bad_config("Room directory is not public.")); return Err(Error::bad_config("Room directory is not public."));
} }
@ -690,7 +739,10 @@ pub async fn get_public_rooms_filtered_route(
pub async fn get_public_rooms_route( pub async fn get_public_rooms_route(
body: Ruma<get_public_rooms::v1::Request>, body: Ruma<get_public_rooms::v1::Request>,
) -> Result<get_public_rooms::v1::Response> { ) -> Result<get_public_rooms::v1::Response> {
if !services().globals.allow_public_room_directory_over_federation() { if !services()
.globals
.allow_public_room_directory_over_federation()
{
return Err(Error::bad_config("Room directory is not public.")); return Err(Error::bad_config("Room directory is not public."));
} }
@ -743,7 +795,10 @@ pub fn parse_incoming_pdu(pdu: &RawJsonValue) -> Result<(OwnedEventId, Canonical
pub async fn send_transaction_message_route( pub async fn send_transaction_message_route(
body: Ruma<send_transaction_message::v1::Request>, body: Ruma<send_transaction_message::v1::Request>,
) -> Result<send_transaction_message::v1::Response> { ) -> Result<send_transaction_message::v1::Response> {
let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); let sender_servername = body
.sender_servername
.as_ref()
.expect("server is authenticated");
let mut resolved_map = BTreeMap::new(); let mut resolved_map = BTreeMap::new();
@ -799,8 +854,15 @@ pub async fn send_transaction_message_route(
}); });
for (event_id, value, room_id) in parsed_pdus { for (event_id, value, room_id) in parsed_pdus {
let mutex = let mutex = Arc::clone(
Arc::clone(services().globals.roomid_mutex_federation.write().await.entry(room_id.clone()).or_default()); services()
.globals
.roomid_mutex_federation
.write()
.await
.entry(room_id.clone())
.or_default(),
);
let mutex_lock = mutex.lock().await; let mutex_lock = mutex.lock().await;
let start_time = Instant::now(); let start_time = Instant::now();
resolved_map.insert( resolved_map.insert(
@ -831,7 +893,11 @@ pub async fn send_transaction_message_route(
} }
} }
for edu in body.edus.iter().filter_map(|edu| serde_json::from_str::<Edu>(edu.json().get()).ok()) { for edu in body
.edus
.iter()
.filter_map(|edu| serde_json::from_str::<Edu>(edu.json().get()).ok())
{
match edu { match edu {
Edu::Presence(presence) => { Edu::Presence(presence) => {
if !services().globals.allow_incoming_presence() { if !services().globals.allow_incoming_presence() {
@ -862,7 +928,13 @@ pub async fn send_transaction_message_route(
.event_ids .event_ids
.iter() .iter()
.filter_map(|id| { .filter_map(|id| {
services().rooms.timeline.get_pdu_count(id).ok().flatten().map(|r| (id, r)) services()
.rooms
.timeline
.get_pdu_count(id)
.ok()
.flatten()
.map(|r| (id, r))
}) })
.max_by_key(|(_, count)| *count) .max_by_key(|(_, count)| *count)
{ {
@ -879,7 +951,11 @@ pub async fn send_transaction_message_route(
content: ReceiptEventContent(receipt_content), content: ReceiptEventContent(receipt_content),
room_id: room_id.clone(), room_id: room_id.clone(),
}; };
services().rooms.edus.read_receipt.readreceipt_update(&user_id, &room_id, event)?; services()
.rooms
.edus
.read_receipt
.readreceipt_update(&user_id, &room_id, event)?;
} else { } else {
// TODO fetch missing events // TODO fetch missing events
debug!("No known event ids in read receipt: {:?}", user_updates); debug!("No known event ids in read receipt: {:?}", user_updates);
@ -888,7 +964,11 @@ pub async fn send_transaction_message_route(
} }
}, },
Edu::Typing(typing) => { Edu::Typing(typing) => {
if services().rooms.state_cache.is_joined(&typing.user_id, &typing.room_id)? { if services()
.rooms
.state_cache
.is_joined(&typing.user_id, &typing.room_id)?
{
if typing.typing { if typing.typing {
services() services()
.rooms .rooms
@ -897,7 +977,12 @@ pub async fn send_transaction_message_route(
.typing_add(&typing.user_id, &typing.room_id, 3000 + utils::millis_since_unix_epoch()) .typing_add(&typing.user_id, &typing.room_id, 3000 + utils::millis_since_unix_epoch())
.await?; .await?;
} else { } else {
services().rooms.edus.typing.typing_remove(&typing.user_id, &typing.room_id).await?; services()
.rooms
.edus
.typing
.typing_remove(&typing.user_id, &typing.room_id)
.await?;
} }
} }
}, },
@ -914,7 +999,11 @@ pub async fn send_transaction_message_route(
messages, messages,
}) => { }) => {
// Check if this is a new transaction id // Check if this is a new transaction id
if services().transaction_ids.existing_txnid(&sender, None, &message_id)?.is_some() { if services()
.transaction_ids
.existing_txnid(&sender, None, &message_id)?
.is_some()
{
continue; continue;
} }
@ -952,7 +1041,9 @@ pub async fn send_transaction_message_route(
} }
// Save transaction id with empty data // Save transaction id with empty data
services().transaction_ids.add_txnid(&sender, None, &message_id, &[])?; services()
.transaction_ids
.add_txnid(&sender, None, &message_id, &[])?;
}, },
Edu::SigningKeyUpdate(SigningKeyUpdateContent { Edu::SigningKeyUpdate(SigningKeyUpdateContent {
user_id, user_id,
@ -963,7 +1054,9 @@ pub async fn send_transaction_message_route(
continue; continue;
} }
if let Some(master_key) = master_key { if let Some(master_key) = master_key {
services().users.add_cross_signing_keys(&user_id, &master_key, &self_signing_key, &None, true)?; services()
.users
.add_cross_signing_keys(&user_id, &master_key, &self_signing_key, &None, true)?;
} }
}, },
Edu::_Custom(_) => {}, Edu::_Custom(_) => {},
@ -971,7 +1064,10 @@ pub async fn send_transaction_message_route(
} }
Ok(send_transaction_message::v1::Response { Ok(send_transaction_message::v1::Response {
pdus: resolved_map.into_iter().map(|(e, r)| (e, r.map_err(|e| e.sanitized_error()))).collect(), pdus: resolved_map
.into_iter()
.map(|(e, r)| (e, r.map_err(|e| e.sanitized_error())))
.collect(),
}) })
} }
@ -982,9 +1078,16 @@ pub async fn send_transaction_message_route(
/// - Only works if a user of this server is currently invited or joined the /// - Only works if a user of this server is currently invited or joined the
/// room /// room
pub async fn get_event_route(body: Ruma<get_event::v1::Request>) -> Result<get_event::v1::Response> { pub async fn get_event_route(body: Ruma<get_event::v1::Request>) -> Result<get_event::v1::Response> {
let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); let sender_servername = body
.sender_servername
.as_ref()
.expect("server is authenticated");
let event = services().rooms.timeline.get_pdu_json(&body.event_id)?.ok_or_else(|| { let event = services()
.rooms
.timeline
.get_pdu_json(&body.event_id)?
.ok_or_else(|| {
warn!("Event not found, event ID: {:?}", &body.event_id); warn!("Event not found, event ID: {:?}", &body.event_id);
Error::BadRequest(ErrorKind::NotFound, "Event not found.") Error::BadRequest(ErrorKind::NotFound, "Event not found.")
})?; })?;
@ -997,11 +1100,19 @@ pub async fn get_event_route(body: Ruma<get_event::v1::Request>) -> Result<get_e
let room_id = <&RoomId>::try_from(room_id_str) let room_id = <&RoomId>::try_from(room_id_str)
.map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?;
if !services().rooms.state_cache.server_in_room(sender_servername, room_id)? { if !services()
.rooms
.state_cache
.server_in_room(sender_servername, room_id)?
{
return Err(Error::BadRequest(ErrorKind::Forbidden, "Server is not in room")); return Err(Error::BadRequest(ErrorKind::Forbidden, "Server is not in room"));
} }
if !services().rooms.state_accessor.server_can_see_event(sender_servername, room_id, &body.event_id)? { if !services()
.rooms
.state_accessor
.server_can_see_event(sender_servername, room_id, &body.event_id)?
{
return Err(Error::BadRequest(ErrorKind::Forbidden, "Server is not allowed to see event.")); return Err(Error::BadRequest(ErrorKind::Forbidden, "Server is not allowed to see event."));
} }
@ -1017,15 +1128,25 @@ pub async fn get_event_route(body: Ruma<get_event::v1::Request>) -> Result<get_e
/// Retrieves events from before the sender joined the room, if the room's /// Retrieves events from before the sender joined the room, if the room's
/// history visibility allows. /// history visibility allows.
pub async fn get_backfill_route(body: Ruma<get_backfill::v1::Request>) -> Result<get_backfill::v1::Response> { pub async fn get_backfill_route(body: Ruma<get_backfill::v1::Request>) -> Result<get_backfill::v1::Response> {
let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); let sender_servername = body
.sender_servername
.as_ref()
.expect("server is authenticated");
debug!("Got backfill request from: {}", sender_servername); debug!("Got backfill request from: {}", sender_servername);
if !services().rooms.state_cache.server_in_room(sender_servername, &body.room_id)? { if !services()
.rooms
.state_cache
.server_in_room(sender_servername, &body.room_id)?
{
return Err(Error::BadRequest(ErrorKind::Forbidden, "Server is not in room.")); return Err(Error::BadRequest(ErrorKind::Forbidden, "Server is not in room."));
} }
services().rooms.event_handler.acl_check(sender_servername, &body.room_id)?; services()
.rooms
.event_handler
.acl_check(sender_servername, &body.room_id)?;
let until = body let until = body
.v .v
@ -1047,7 +1168,10 @@ pub async fn get_backfill_route(body: Ruma<get_backfill::v1::Request>) -> Result
.filter_map(Result::ok) .filter_map(Result::ok)
.filter(|(_, e)| { .filter(|(_, e)| {
matches!( matches!(
services().rooms.state_accessor.server_can_see_event(sender_servername, &e.room_id, &e.event_id,), services()
.rooms
.state_accessor
.server_can_see_event(sender_servername, &e.room_id, &e.event_id,),
Ok(true), Ok(true),
) )
}) })
@ -1069,13 +1193,23 @@ pub async fn get_backfill_route(body: Ruma<get_backfill::v1::Request>) -> Result
pub async fn get_missing_events_route( pub async fn get_missing_events_route(
body: Ruma<get_missing_events::v1::Request>, body: Ruma<get_missing_events::v1::Request>,
) -> Result<get_missing_events::v1::Response> { ) -> Result<get_missing_events::v1::Response> {
let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); let sender_servername = body
.sender_servername
.as_ref()
.expect("server is authenticated");
if !services().rooms.state_cache.server_in_room(sender_servername, &body.room_id)? { if !services()
.rooms
.state_cache
.server_in_room(sender_servername, &body.room_id)?
{
return Err(Error::BadRequest(ErrorKind::Forbidden, "Server is not in room")); return Err(Error::BadRequest(ErrorKind::Forbidden, "Server is not in room"));
} }
services().rooms.event_handler.acl_check(sender_servername, &body.room_id)?; services()
.rooms
.event_handler
.acl_check(sender_servername, &body.room_id)?;
let mut queued_events = body.latest_events.clone(); let mut queued_events = body.latest_events.clone();
let mut events = Vec::new(); let mut events = Vec::new();
@ -1142,15 +1276,29 @@ pub async fn get_missing_events_route(
pub async fn get_event_authorization_route( pub async fn get_event_authorization_route(
body: Ruma<get_event_authorization::v1::Request>, body: Ruma<get_event_authorization::v1::Request>,
) -> Result<get_event_authorization::v1::Response> { ) -> Result<get_event_authorization::v1::Response> {
let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); let sender_servername = body
.sender_servername
.as_ref()
.expect("server is authenticated");
if !services().rooms.state_cache.server_in_room(sender_servername, &body.room_id)? { if !services()
.rooms
.state_cache
.server_in_room(sender_servername, &body.room_id)?
{
return Err(Error::BadRequest(ErrorKind::Forbidden, "Server is not in room.")); return Err(Error::BadRequest(ErrorKind::Forbidden, "Server is not in room."));
} }
services().rooms.event_handler.acl_check(sender_servername, &body.room_id)?; services()
.rooms
.event_handler
.acl_check(sender_servername, &body.room_id)?;
let event = services().rooms.timeline.get_pdu_json(&body.event_id)?.ok_or_else(|| { let event = services()
.rooms
.timeline
.get_pdu_json(&body.event_id)?
.ok_or_else(|| {
warn!("Event not found, event ID: {:?}", &body.event_id); warn!("Event not found, event ID: {:?}", &body.event_id);
Error::BadRequest(ErrorKind::NotFound, "Event not found.") Error::BadRequest(ErrorKind::NotFound, "Event not found.")
})?; })?;
@ -1163,7 +1311,11 @@ pub async fn get_event_authorization_route(
let room_id = <&RoomId>::try_from(room_id_str) let room_id = <&RoomId>::try_from(room_id_str)
.map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?;
let auth_chain_ids = services().rooms.auth_chain.get_auth_chain(room_id, vec![Arc::from(&*body.event_id)]).await?; let auth_chain_ids = services()
.rooms
.auth_chain
.get_auth_chain(room_id, vec![Arc::from(&*body.event_id)])
.await?;
Ok(get_event_authorization::v1::Response { Ok(get_event_authorization::v1::Response {
auth_chain: auth_chain_ids auth_chain: auth_chain_ids
@ -1177,13 +1329,23 @@ pub async fn get_event_authorization_route(
/// ///
/// Retrieves the current state of the room. /// Retrieves the current state of the room.
pub async fn get_room_state_route(body: Ruma<get_room_state::v1::Request>) -> Result<get_room_state::v1::Response> { pub async fn get_room_state_route(body: Ruma<get_room_state::v1::Request>) -> Result<get_room_state::v1::Response> {
let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); let sender_servername = body
.sender_servername
.as_ref()
.expect("server is authenticated");
if !services().rooms.state_cache.server_in_room(sender_servername, &body.room_id)? { if !services()
.rooms
.state_cache
.server_in_room(sender_servername, &body.room_id)?
{
return Err(Error::BadRequest(ErrorKind::Forbidden, "Server is not in room.")); return Err(Error::BadRequest(ErrorKind::Forbidden, "Server is not in room."));
} }
services().rooms.event_handler.acl_check(sender_servername, &body.room_id)?; services()
.rooms
.event_handler
.acl_check(sender_servername, &body.room_id)?;
let shortstatehash = services() let shortstatehash = services()
.rooms .rooms
@ -1199,13 +1361,21 @@ pub async fn get_room_state_route(body: Ruma<get_room_state::v1::Request>) -> Re
.into_values() .into_values()
.map(|id| { .map(|id| {
PduEvent::convert_to_outgoing_federation_event( PduEvent::convert_to_outgoing_federation_event(
services().rooms.timeline.get_pdu_json(&id).unwrap().unwrap(), services()
.rooms
.timeline
.get_pdu_json(&id)
.unwrap()
.unwrap(),
) )
}) })
.collect(); .collect();
let auth_chain_ids = let auth_chain_ids = services()
services().rooms.auth_chain.get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)]).await?; .rooms
.auth_chain
.get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)])
.await?;
Ok(get_room_state::v1::Response { Ok(get_room_state::v1::Response {
auth_chain: auth_chain_ids auth_chain: auth_chain_ids
@ -1227,13 +1397,23 @@ pub async fn get_room_state_route(body: Ruma<get_room_state::v1::Request>) -> Re
pub async fn get_room_state_ids_route( pub async fn get_room_state_ids_route(
body: Ruma<get_room_state_ids::v1::Request>, body: Ruma<get_room_state_ids::v1::Request>,
) -> Result<get_room_state_ids::v1::Response> { ) -> Result<get_room_state_ids::v1::Response> {
let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); let sender_servername = body
.sender_servername
.as_ref()
.expect("server is authenticated");
if !services().rooms.state_cache.server_in_room(sender_servername, &body.room_id)? { if !services()
.rooms
.state_cache
.server_in_room(sender_servername, &body.room_id)?
{
return Err(Error::BadRequest(ErrorKind::Forbidden, "Server is not in room.")); return Err(Error::BadRequest(ErrorKind::Forbidden, "Server is not in room."));
} }
services().rooms.event_handler.acl_check(sender_servername, &body.room_id)?; services()
.rooms
.event_handler
.acl_check(sender_servername, &body.room_id)?;
let shortstatehash = services() let shortstatehash = services()
.rooms .rooms
@ -1250,8 +1430,11 @@ pub async fn get_room_state_ids_route(
.map(|id| (*id).to_owned()) .map(|id| (*id).to_owned())
.collect(); .collect();
let auth_chain_ids = let auth_chain_ids = services()
services().rooms.auth_chain.get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)]).await?; .rooms
.auth_chain
.get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)])
.await?;
Ok(get_room_state_ids::v1::Response { Ok(get_room_state_ids::v1::Response {
auth_chain_ids: auth_chain_ids.map(|id| (*id).to_owned()).collect(), auth_chain_ids: auth_chain_ids.map(|id| (*id).to_owned()).collect(),
@ -1269,17 +1452,33 @@ pub async fn create_join_event_template_route(
return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server."));
} }
let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); let sender_servername = body
.sender_servername
.as_ref()
.expect("server is authenticated");
services().rooms.event_handler.acl_check(sender_servername, &body.room_id)?; services()
.rooms
.event_handler
.acl_check(sender_servername, &body.room_id)?;
let mutex_state = let mutex_state = Arc::clone(
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(body.room_id.clone()).or_default()); services()
.globals
.roomid_mutex_state
.write()
.await
.entry(body.room_id.clone())
.or_default(),
);
let state_lock = mutex_state.lock().await; let state_lock = mutex_state.lock().await;
// TODO: Conduit does not implement restricted join rules yet, we always reject // TODO: Conduit does not implement restricted join rules yet, we always reject
let join_rules_event = let join_rules_event =
services().rooms.state_accessor.room_state_get(&body.room_id, &StateEventType::RoomJoinRules, "")?; services()
.rooms
.state_accessor
.room_state_get(&body.room_id, &StateEventType::RoomJoinRules, "")?;
let join_rules_event_content: Option<RoomJoinRulesEventContent> = join_rules_event let join_rules_event_content: Option<RoomJoinRulesEventContent> = join_rules_event
.as_ref() .as_ref()
@ -1376,11 +1575,17 @@ async fn create_join_event(
return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server."));
} }
services().rooms.event_handler.acl_check(sender_servername, room_id)?; services()
.rooms
.event_handler
.acl_check(sender_servername, room_id)?;
// TODO: Conduit does not implement restricted join rules yet, we always reject // TODO: Conduit does not implement restricted join rules yet, we always reject
let join_rules_event = let join_rules_event =
services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomJoinRules, "")?; services()
.rooms
.state_accessor
.room_state_get(room_id, &StateEventType::RoomJoinRules, "")?;
let join_rules_event_content: Option<RoomJoinRulesEventContent> = join_rules_event let join_rules_event_content: Option<RoomJoinRulesEventContent> = join_rules_event
.as_ref() .as_ref()
@ -1431,16 +1636,29 @@ async fn create_join_event(
let origin: OwnedServerName = serde_json::from_value( let origin: OwnedServerName = serde_json::from_value(
serde_json::to_value( serde_json::to_value(
value.get("origin").ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Event needs an origin field."))?, value
.get("origin")
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Event needs an origin field."))?,
) )
.expect("CanonicalJson is valid json value"), .expect("CanonicalJson is valid json value"),
) )
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid."))?; .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid."))?;
services().rooms.event_handler.fetch_required_signing_keys([&value], &pub_key_map).await?; services()
.rooms
.event_handler
.fetch_required_signing_keys([&value], &pub_key_map)
.await?;
let mutex = let mutex = Arc::clone(
Arc::clone(services().globals.roomid_mutex_federation.write().await.entry(room_id.to_owned()).or_default()); services()
.globals
.roomid_mutex_federation
.write()
.await
.entry(room_id.to_owned())
.or_default(),
);
let mutex_lock = mutex.lock().await; let mutex_lock = mutex.lock().await;
let pdu_id: Vec<u8> = services() let pdu_id: Vec<u8> = services()
.rooms .rooms
@ -1453,9 +1671,16 @@ async fn create_join_event(
))?; ))?;
drop(mutex_lock); drop(mutex_lock);
let state_ids = services().rooms.state_accessor.state_full_ids(shortstatehash).await?; let state_ids = services()
let auth_chain_ids = .rooms
services().rooms.auth_chain.get_auth_chain(room_id, state_ids.values().cloned().collect()).await?; .state_accessor
.state_full_ids(shortstatehash)
.await?;
let auth_chain_ids = services()
.rooms
.auth_chain
.get_auth_chain(room_id, state_ids.values().cloned().collect())
.await?;
let servers = services() let servers = services()
.rooms .rooms
@ -1486,7 +1711,10 @@ async fn create_join_event(
pub async fn create_join_event_v1_route( pub async fn create_join_event_v1_route(
body: Ruma<create_join_event::v1::Request>, body: Ruma<create_join_event::v1::Request>,
) -> Result<create_join_event::v1::Response> { ) -> Result<create_join_event::v1::Response> {
let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); let sender_servername = body
.sender_servername
.as_ref()
.expect("server is authenticated");
let room_state = create_join_event(sender_servername, &body.room_id, &body.pdu).await?; let room_state = create_join_event(sender_servername, &body.room_id, &body.pdu).await?;
@ -1501,7 +1729,10 @@ pub async fn create_join_event_v1_route(
pub async fn create_join_event_v2_route( pub async fn create_join_event_v2_route(
body: Ruma<create_join_event::v2::Request>, body: Ruma<create_join_event::v2::Request>,
) -> Result<create_join_event::v2::Response> { ) -> Result<create_join_event::v2::Response> {
let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); let sender_servername = body
.sender_servername
.as_ref()
.expect("server is authenticated");
let create_join_event::v1::RoomState { let create_join_event::v1::RoomState {
auth_chain, auth_chain,
@ -1525,11 +1756,21 @@ pub async fn create_join_event_v2_route(
/// ///
/// Invites a remote user to a room. /// Invites a remote user to a room.
pub async fn create_invite_route(body: Ruma<create_invite::v2::Request>) -> Result<create_invite::v2::Response> { pub async fn create_invite_route(body: Ruma<create_invite::v2::Request>) -> Result<create_invite::v2::Response> {
let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); let sender_servername = body
.sender_servername
.as_ref()
.expect("server is authenticated");
services().rooms.event_handler.acl_check(sender_servername, &body.room_id)?; services()
.rooms
.event_handler
.acl_check(sender_servername, &body.room_id)?;
if !services().globals.supported_room_versions().contains(&body.room_version) { if !services()
.globals
.supported_room_versions()
.contains(&body.room_version)
{
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::IncompatibleRoomVersion { ErrorKind::IncompatibleRoomVersion {
room_version: body.room_version.clone(), room_version: body.room_version.clone(),
@ -1619,7 +1860,11 @@ pub async fn create_invite_route(body: Ruma<create_invite::v2::Request>) -> Resu
// If we are active in the room, the remote server will notify us about the join // If we are active in the room, the remote server will notify us about the join
// via /send // via /send
if !services().rooms.state_cache.server_in_room(services().globals.server_name(), &body.room_id)? { if !services()
.rooms
.state_cache
.server_in_room(services().globals.server_name(), &body.room_id)?
{
services() services()
.rooms .rooms
.state_cache .state_cache
@ -1650,7 +1895,10 @@ pub async fn get_devices_route(body: Ruma<get_devices::v1::Request>) -> Result<g
)); ));
} }
let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); let sender_servername = body
.sender_servername
.as_ref()
.expect("server is authenticated");
Ok(get_devices::v1::Response { Ok(get_devices::v1::Response {
user_id: body.user_id.clone(), user_id: body.user_id.clone(),
@ -1672,13 +1920,18 @@ pub async fn get_devices_route(body: Ruma<get_devices::v1::Request>) -> Result<g
Some(device_id_string) Some(device_id_string)
}; };
Some(UserDevice { Some(UserDevice {
keys: services().users.get_device_keys(&body.user_id, &metadata.device_id).ok()??, keys: services()
.users
.get_device_keys(&body.user_id, &metadata.device_id)
.ok()??,
device_id: metadata.device_id, device_id: metadata.device_id,
device_display_name, device_display_name,
}) })
}) })
.collect(), .collect(),
master_key: services().users.get_master_key(None, &body.user_id, &|u| u.server_name() == sender_servername)?, master_key: services()
.users
.get_master_key(None, &body.user_id, &|u| u.server_name() == sender_servername)?,
self_signing_key: services() self_signing_key: services()
.users .users
.get_self_signing_key(None, &body.user_id, &|u| u.server_name() == sender_servername)?, .get_self_signing_key(None, &body.user_id, &|u| u.server_name() == sender_servername)?,
@ -1748,7 +2001,11 @@ pub async fn get_profile_information_route(
/// ///
/// Gets devices and identity keys for the given users. /// Gets devices and identity keys for the given users.
pub async fn get_keys_route(body: Ruma<get_keys::v1::Request>) -> Result<get_keys::v1::Response> { pub async fn get_keys_route(body: Ruma<get_keys::v1::Request>) -> Result<get_keys::v1::Response> {
if body.device_keys.iter().any(|(u, _)| u.server_name() != services().globals.server_name()) { if body
.device_keys
.iter()
.any(|(u, _)| u.server_name() != services().globals.server_name())
{
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"User does not belong to this server.", "User does not belong to this server.",
@ -1774,7 +2031,11 @@ pub async fn get_keys_route(body: Ruma<get_keys::v1::Request>) -> Result<get_key
/// ///
/// Claims one-time keys. /// Claims one-time keys.
pub async fn claim_keys_route(body: Ruma<claim_keys::v1::Request>) -> Result<claim_keys::v1::Response> { pub async fn claim_keys_route(body: Ruma<claim_keys::v1::Request>) -> Result<claim_keys::v1::Response> {
if body.one_time_keys.iter().any(|(u, _)| u.server_name() != services().globals.server_name()) { if body
.one_time_keys
.iter()
.any(|(u, _)| u.server_name() != services().globals.server_name())
{
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"Tried to access user from other server.", "Tried to access user from other server.",
@ -1809,10 +2070,17 @@ pub async fn well_known_server_route() -> Result<impl IntoResponse> {
/// Gets the space tree in a depth-first manner to locate child rooms of a given /// Gets the space tree in a depth-first manner to locate child rooms of a given
/// space. /// space.
pub async fn get_hierarchy_route(body: Ruma<get_hierarchy::v1::Request>) -> Result<get_hierarchy::v1::Response> { pub async fn get_hierarchy_route(body: Ruma<get_hierarchy::v1::Request>) -> Result<get_hierarchy::v1::Response> {
let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); let sender_servername = body
.sender_servername
.as_ref()
.expect("server is authenticated");
if services().rooms.metadata.exists(&body.room_id)? { if services().rooms.metadata.exists(&body.room_id)? {
services().rooms.spaces.get_federation_hierarchy(&body.room_id, sender_servername, body.suggested_only).await services()
.rooms
.spaces
.get_federation_hierarchy(&body.room_id, sender_servername, body.suggested_only)
.await
} else { } else {
Err(Error::BadRequest(ErrorKind::NotFound, "Room does not exist.")) Err(Error::BadRequest(ErrorKind::NotFound, "Room does not exist."))
} }

View file

@ -251,7 +251,11 @@ impl Config {
pub fn warn_deprecated(&self) { pub fn warn_deprecated(&self) {
debug!("Checking for deprecated config keys"); debug!("Checking for deprecated config keys");
let mut was_deprecated = false; let mut was_deprecated = false;
for key in self.catchall.keys().filter(|key| DEPRECATED_KEYS.iter().any(|s| s == key)) { for key in self
.catchall
.keys()
.filter(|key| DEPRECATED_KEYS.iter().any(|s| s == key))
{
warn!("Config parameter \"{}\" is deprecated, ignoring.", key); warn!("Config parameter \"{}\" is deprecated, ignoring.", key);
was_deprecated = true; was_deprecated = true;
} }
@ -268,8 +272,10 @@ impl Config {
/// if there are any. /// if there are any.
pub fn warn_unknown_key(&self) { pub fn warn_unknown_key(&self) {
debug!("Checking for unknown config keys"); debug!("Checking for unknown config keys");
for key in for key in self
self.catchall.keys().filter(|key| "config".to_owned().ne(key.to_owned()) /* "config" is expected */) .catchall
.keys()
.filter(|key| "config".to_owned().ne(key.to_owned()) /* "config" is expected */)
{ {
warn!("Config parameter \"{}\" is unknown to conduwuit, ignoring.", key); warn!("Config parameter \"{}\" is unknown to conduwuit, ignoring.", key);
} }

View file

@ -155,7 +155,8 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
let db = rust_rocksdb::DBWithThreadMode::<rust_rocksdb::MultiThreaded>::open_cf_descriptors( let db = rust_rocksdb::DBWithThreadMode::<rust_rocksdb::MultiThreaded>::open_cf_descriptors(
&db_opts, &db_opts,
&config.database_path, &config.database_path,
cfs.iter().map(|name| rust_rocksdb::ColumnFamilyDescriptor::new(name, db_opts.clone())), cfs.iter()
.map(|name| rust_rocksdb::ColumnFamilyDescriptor::new(name, db_opts.clone())),
)?; )?;
Ok(Arc::new(Engine { Ok(Arc::new(Engine {
@ -200,13 +201,15 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
fn corked(&self) -> bool { self.corks.load(std::sync::atomic::Ordering::Relaxed) > 0 } fn corked(&self) -> bool { self.corks.load(std::sync::atomic::Ordering::Relaxed) > 0 }
fn cork(&self) -> Result<()> { fn cork(&self) -> Result<()> {
self.corks.fetch_add(1, std::sync::atomic::Ordering::Relaxed); self.corks
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Ok(()) Ok(())
} }
fn uncork(&self) -> Result<()> { fn uncork(&self) -> Result<()> {
self.corks.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); self.corks
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
Ok(()) Ok(())
} }
@ -293,7 +296,9 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
format_args!( format_args!(
"#{} {}: {} bytes, {} files\n", "#{} {}: {} bytes, {} files\n",
info.backup_id, info.backup_id,
DateTime::<Utc>::from_timestamp(info.timestamp, 0).unwrap_or_default().to_rfc2822(), DateTime::<Utc>::from_timestamp(info.timestamp, 0)
.unwrap_or_default()
.to_rfc2822(),
info.size, info.size,
info.num_files, info.num_files,
), ),
@ -358,7 +363,9 @@ impl KvTree for RocksDbEngineTree<'_> {
let writeoptions = rust_rocksdb::WriteOptions::default(); let writeoptions = rust_rocksdb::WriteOptions::default();
let lock = self.write_lock.read().unwrap(); let lock = self.write_lock.read().unwrap();
self.db.rocks.put_cf_opt(&self.cf(), key, value, &writeoptions)?; self.db
.rocks
.put_cf_opt(&self.cf(), key, value, &writeoptions)?;
drop(lock); drop(lock);
@ -465,7 +472,9 @@ impl KvTree for RocksDbEngineTree<'_> {
let old = self.db.rocks.get_cf_opt(&self.cf(), key, &readoptions)?; let old = self.db.rocks.get_cf_opt(&self.cf(), key, &readoptions)?;
let new = utils::increment(old.as_deref()); let new = utils::increment(old.as_deref());
self.db.rocks.put_cf_opt(&self.cf(), key, &new, &writeoptions)?; self.db
.rocks
.put_cf_opt(&self.cf(), key, &new, &writeoptions)?;
drop(lock); drop(lock);

View file

@ -68,15 +68,18 @@ impl Engine {
fn write_lock(&self) -> MutexGuard<'_, Connection> { self.writer.lock() } fn write_lock(&self) -> MutexGuard<'_, Connection> { self.writer.lock() }
fn read_lock(&self) -> &Connection { fn read_lock(&self) -> &Connection {
self.read_conn_tls.get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap()) self.read_conn_tls
.get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap())
} }
fn read_lock_iterator(&self) -> &Connection { fn read_lock_iterator(&self) -> &Connection {
self.read_iterator_conn_tls.get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap()) self.read_iterator_conn_tls
.get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap())
} }
pub fn flush_wal(self: &Arc<Self>) -> Result<()> { pub fn flush_wal(self: &Arc<Self>) -> Result<()> {
self.write_lock().pragma_update(Some(Main), "wal_checkpoint", "RESTART")?; self.write_lock()
.pragma_update(Some(Main), "wal_checkpoint", "RESTART")?;
Ok(()) Ok(())
} }
} }
@ -153,7 +156,9 @@ impl SqliteTable {
pub fn iter_with_guard<'a>(&'a self, guard: &'a Connection) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> { pub fn iter_with_guard<'a>(&'a self, guard: &'a Connection) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
let statement = Box::leak(Box::new( let statement = Box::leak(Box::new(
guard.prepare(&format!("SELECT key, value FROM {} ORDER BY key ASC", &self.name)).unwrap(), guard
.prepare(&format!("SELECT key, value FROM {} ORDER BY key ASC", &self.name))
.unwrap(),
)); ));
let statement_ref = NonAliasingBox(statement); let statement_ref = NonAliasingBox(statement);
@ -161,7 +166,10 @@ impl SqliteTable {
//let name = self.name.clone(); //let name = self.name.clone();
let iterator = Box::new( let iterator = Box::new(
statement.query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))).unwrap().map(Result::unwrap), statement
.query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1))))
.unwrap()
.map(Result::unwrap),
); );
Box::new(PreparedStatementIterator { Box::new(PreparedStatementIterator {
@ -293,7 +301,10 @@ impl KvTree for SqliteTable {
} }
fn scan_prefix<'a>(&'a self, prefix: Vec<u8>) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> { fn scan_prefix<'a>(&'a self, prefix: Vec<u8>) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
Box::new(self.iter_from(&prefix, false).take_while(move |(key, _)| key.starts_with(&prefix))) Box::new(
self.iter_from(&prefix, false)
.take_while(move |(key, _)| key.starts_with(&prefix)),
)
} }
fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> { fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
@ -302,7 +313,9 @@ impl KvTree for SqliteTable {
fn clear(&self) -> Result<()> { fn clear(&self) -> Result<()> {
debug!("clear: running"); debug!("clear: running");
self.engine.write_lock().execute(format!("DELETE FROM {}", self.name).as_str(), [])?; self.engine
.write_lock()
.execute(format!("DELETE FROM {}", self.name).as_str(), [])?;
debug!("clear: ran"); debug!("clear: ran");
Ok(()) Ok(())
} }

View file

@ -18,7 +18,11 @@ impl service::account_data::Data for KeyValueDatabase {
&self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType,
data: &serde_json::Value, data: &serde_json::Value,
) -> Result<()> { ) -> Result<()> {
let mut prefix = room_id.map(ToString::to_string).unwrap_or_default().as_bytes().to_vec(); let mut prefix = room_id
.map(ToString::to_string)
.unwrap_or_default()
.as_bytes()
.to_vec();
prefix.push(0xFF); prefix.push(0xFF);
prefix.extend_from_slice(user_id.as_bytes()); prefix.extend_from_slice(user_id.as_bytes());
prefix.push(0xFF); prefix.push(0xFF);
@ -45,7 +49,8 @@ impl service::account_data::Data for KeyValueDatabase {
let prev = self.roomusertype_roomuserdataid.get(&key)?; let prev = self.roomusertype_roomuserdataid.get(&key)?;
self.roomusertype_roomuserdataid.insert(&key, &roomuserdataid)?; self.roomusertype_roomuserdataid
.insert(&key, &roomuserdataid)?;
// Remove old entry // Remove old entry
if let Some(prev) = prev { if let Some(prev) = prev {
@ -60,7 +65,11 @@ impl service::account_data::Data for KeyValueDatabase {
fn get( fn get(
&self, room_id: Option<&RoomId>, user_id: &UserId, kind: RoomAccountDataEventType, &self, room_id: Option<&RoomId>, user_id: &UserId, kind: RoomAccountDataEventType,
) -> Result<Option<Box<serde_json::value::RawValue>>> { ) -> Result<Option<Box<serde_json::value::RawValue>>> {
let mut key = room_id.map(ToString::to_string).unwrap_or_default().as_bytes().to_vec(); let mut key = room_id
.map(ToString::to_string)
.unwrap_or_default()
.as_bytes()
.to_vec();
key.push(0xFF); key.push(0xFF);
key.extend_from_slice(user_id.as_bytes()); key.extend_from_slice(user_id.as_bytes());
key.push(0xFF); key.push(0xFF);
@ -68,7 +77,11 @@ impl service::account_data::Data for KeyValueDatabase {
self.roomusertype_roomuserdataid self.roomusertype_roomuserdataid
.get(&key)? .get(&key)?
.and_then(|roomuserdataid| self.roomuserdataid_accountdata.get(&roomuserdataid).transpose()) .and_then(|roomuserdataid| {
self.roomuserdataid_accountdata
.get(&roomuserdataid)
.transpose()
})
.transpose()? .transpose()?
.map(|data| serde_json::from_slice(&data).map_err(|_| Error::bad_database("could not deserialize"))) .map(|data| serde_json::from_slice(&data).map_err(|_| Error::bad_database("could not deserialize")))
.transpose() .transpose()
@ -81,7 +94,11 @@ impl service::account_data::Data for KeyValueDatabase {
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>> { ) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>> {
let mut userdata = HashMap::new(); let mut userdata = HashMap::new();
let mut prefix = room_id.map(ToString::to_string).unwrap_or_default().as_bytes().to_vec(); let mut prefix = room_id
.map(ToString::to_string)
.unwrap_or_default()
.as_bytes()
.to_vec();
prefix.push(0xFF); prefix.push(0xFF);
prefix.extend_from_slice(user_id.as_bytes()); prefix.extend_from_slice(user_id.as_bytes());
prefix.push(0xFF); prefix.push(0xFF);

View file

@ -6,7 +6,8 @@ impl service::appservice::Data for KeyValueDatabase {
/// Registers an appservice and returns the ID to the caller /// Registers an appservice and returns the ID to the caller
fn register_appservice(&self, yaml: Registration) -> Result<String> { fn register_appservice(&self, yaml: Registration) -> Result<String> {
let id = yaml.id.as_str(); let id = yaml.id.as_str();
self.id_appserviceregistrations.insert(id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes())?; self.id_appserviceregistrations
.insert(id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes())?;
Ok(id.to_owned()) Ok(id.to_owned())
} }
@ -17,7 +18,8 @@ impl service::appservice::Data for KeyValueDatabase {
/// ///
/// * `service_name` - the name you send to register the service previously /// * `service_name` - the name you send to register the service previously
fn unregister_appservice(&self, service_name: &str) -> Result<()> { fn unregister_appservice(&self, service_name: &str) -> Result<()> {
self.id_appserviceregistrations.remove(service_name.as_bytes())?; self.id_appserviceregistrations
.remove(service_name.as_bytes())?;
Ok(()) Ok(())
} }
@ -44,7 +46,8 @@ impl service::appservice::Data for KeyValueDatabase {
.map(move |id| { .map(move |id| {
Ok(( Ok((
id.clone(), id.clone(),
self.get_registration(&id)?.expect("iter_ids only returns appservices that exist"), self.get_registration(&id)?
.expect("iter_ids only returns appservices that exist"),
)) ))
}) })
.collect() .collect()

View file

@ -31,14 +31,17 @@ impl service::globals::Data for KeyValueDatabase {
} }
fn last_check_for_updates_id(&self) -> Result<u64> { fn last_check_for_updates_id(&self) -> Result<u64> {
self.global.get(LAST_CHECK_FOR_UPDATES_COUNT)?.map_or(Ok(0_u64), |bytes| { self.global
.get(LAST_CHECK_FOR_UPDATES_COUNT)?
.map_or(Ok(0_u64), |bytes| {
utils::u64_from_bytes(&bytes) utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("last check for updates count has invalid bytes.")) .map_err(|_| Error::bad_database("last check for updates count has invalid bytes."))
}) })
} }
fn update_check_for_updates_id(&self, id: u64) -> Result<()> { fn update_check_for_updates_id(&self, id: u64) -> Result<()> {
self.global.insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?; self.global
.insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?;
Ok(()) Ok(())
} }
@ -62,11 +65,19 @@ impl service::globals::Data for KeyValueDatabase {
futures.push(self.userroomid_joined.watch_prefix(&userid_prefix)); futures.push(self.userroomid_joined.watch_prefix(&userid_prefix));
futures.push(self.userroomid_invitestate.watch_prefix(&userid_prefix)); futures.push(self.userroomid_invitestate.watch_prefix(&userid_prefix));
futures.push(self.userroomid_leftstate.watch_prefix(&userid_prefix)); futures.push(self.userroomid_leftstate.watch_prefix(&userid_prefix));
futures.push(self.userroomid_notificationcount.watch_prefix(&userid_prefix)); futures.push(
self.userroomid_notificationcount
.watch_prefix(&userid_prefix),
);
futures.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix)); futures.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix));
// Events for rooms we are in // Events for rooms we are in
for room_id in services().rooms.state_cache.rooms_joined(user_id).filter_map(Result::ok) { for room_id in services()
.rooms
.state_cache
.rooms_joined(user_id)
.filter_map(Result::ok)
{
let short_roomid = services() let short_roomid = services()
.rooms .rooms
.short .short
@ -98,13 +109,19 @@ impl service::globals::Data for KeyValueDatabase {
let mut roomuser_prefix = roomid_prefix.clone(); let mut roomuser_prefix = roomid_prefix.clone();
roomuser_prefix.extend_from_slice(&userid_prefix); roomuser_prefix.extend_from_slice(&userid_prefix);
futures.push(self.roomusertype_roomuserdataid.watch_prefix(&roomuser_prefix)); futures.push(
self.roomusertype_roomuserdataid
.watch_prefix(&roomuser_prefix),
);
} }
let mut globaluserdata_prefix = vec![0xFF]; let mut globaluserdata_prefix = vec![0xFF];
globaluserdata_prefix.extend_from_slice(&userid_prefix); globaluserdata_prefix.extend_from_slice(&userid_prefix);
futures.push(self.roomusertype_roomuserdataid.watch_prefix(&globaluserdata_prefix)); futures.push(
self.roomusertype_roomuserdataid
.watch_prefix(&globaluserdata_prefix),
);
// More key changes (used when user is not joined to any rooms) // More key changes (used when user is not joined to any rooms)
futures.push(self.keychangeid_userid.watch_prefix(&userid_prefix)); futures.push(self.keychangeid_userid.watch_prefix(&userid_prefix));
@ -207,7 +224,9 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n"
utils::string_from_bytes( utils::string_from_bytes(
// 1. version // 1. version
parts.next().expect("splitn always returns at least one element"), parts
.next()
.expect("splitn always returns at least one element"),
) )
.map_err(|_| Error::bad_database("Invalid version bytes in keypair.")) .map_err(|_| Error::bad_database("Invalid version bytes in keypair."))
.and_then(|version| { .and_then(|version| {
@ -231,7 +250,9 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n"
// Not atomic, but this is not critical // Not atomic, but this is not critical
let signingkeys = self.server_signingkeys.get(origin.as_bytes())?; let signingkeys = self.server_signingkeys.get(origin.as_bytes())?;
let mut keys = signingkeys.and_then(|keys| serde_json::from_slice(&keys).ok()).unwrap_or_else(|| { let mut keys = signingkeys
.and_then(|keys| serde_json::from_slice(&keys).ok())
.unwrap_or_else(|| {
// Just insert "now", it doesn't matter // Just insert "now", it doesn't matter
ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now()) ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now())
}); });
@ -251,7 +272,11 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n"
)?; )?;
let mut tree = keys.verify_keys; let mut tree = keys.verify_keys;
tree.extend(keys.old_verify_keys.into_iter().map(|old| (old.0, VerifyKey::new(old.1.key)))); tree.extend(
keys.old_verify_keys
.into_iter()
.map(|old| (old.0, VerifyKey::new(old.1.key))),
);
Ok(tree) Ok(tree)
} }
@ -265,7 +290,11 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n"
.and_then(|bytes| serde_json::from_slice(&bytes).ok()) .and_then(|bytes| serde_json::from_slice(&bytes).ok())
.map_or_else(BTreeMap::new, |keys: ServerSigningKeys| { .map_or_else(BTreeMap::new, |keys: ServerSigningKeys| {
let mut tree = keys.verify_keys; let mut tree = keys.verify_keys;
tree.extend(keys.old_verify_keys.into_iter().map(|old| (old.0, VerifyKey::new(old.1.key)))); tree.extend(
keys.old_verify_keys
.into_iter()
.map(|old| (old.0, VerifyKey::new(old.1.key))),
);
tree tree
}); });

View file

@ -23,7 +23,8 @@ impl service::key_backups::Data for KeyValueDatabase {
&key, &key,
&serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"), &serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"),
)?; )?;
self.backupid_etag.insert(&key, &services().globals.next_count()?.to_be_bytes())?; self.backupid_etag
.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
Ok(version) Ok(version)
} }
@ -53,8 +54,10 @@ impl service::key_backups::Data for KeyValueDatabase {
return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup.")); return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup."));
} }
self.backupid_algorithm.insert(&key, backup_metadata.json().get().as_bytes())?; self.backupid_algorithm
self.backupid_etag.insert(&key, &services().globals.next_count()?.to_be_bytes())?; .insert(&key, backup_metadata.json().get().as_bytes())?;
self.backupid_etag
.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
Ok(version.to_owned()) Ok(version.to_owned())
} }
@ -69,7 +72,11 @@ impl service::key_backups::Data for KeyValueDatabase {
.take_while(move |(k, _)| k.starts_with(&prefix)) .take_while(move |(k, _)| k.starts_with(&prefix))
.next() .next()
.map(|(key, _)| { .map(|(key, _)| {
utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element")) utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("backupid_algorithm key is invalid.")) .map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))
}) })
.transpose() .transpose()
@ -87,7 +94,9 @@ impl service::key_backups::Data for KeyValueDatabase {
.next() .next()
.map(|(key, value)| { .map(|(key, value)| {
let version = utils::string_from_bytes( let version = utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element"), key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
) )
.map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))?; .map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))?;
@ -105,7 +114,9 @@ impl service::key_backups::Data for KeyValueDatabase {
key.push(0xFF); key.push(0xFF);
key.extend_from_slice(version.as_bytes()); key.extend_from_slice(version.as_bytes());
self.backupid_algorithm.get(&key)?.map_or(Ok(None), |bytes| { self.backupid_algorithm
.get(&key)?
.map_or(Ok(None), |bytes| {
serde_json::from_slice(&bytes) serde_json::from_slice(&bytes)
.map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid.")) .map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))
}) })
@ -122,14 +133,16 @@ impl service::key_backups::Data for KeyValueDatabase {
return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup.")); return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup."));
} }
self.backupid_etag.insert(&key, &services().globals.next_count()?.to_be_bytes())?; self.backupid_etag
.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
key.push(0xFF); key.push(0xFF);
key.extend_from_slice(room_id.as_bytes()); key.extend_from_slice(room_id.as_bytes());
key.push(0xFF); key.push(0xFF);
key.extend_from_slice(session_id.as_bytes()); key.extend_from_slice(session_id.as_bytes());
self.backupkeyid_backup.insert(&key, key_data.json().get().as_bytes())?; self.backupkeyid_backup
.insert(&key, key_data.json().get().as_bytes())?;
Ok(()) Ok(())
} }
@ -148,7 +161,10 @@ impl service::key_backups::Data for KeyValueDatabase {
key.extend_from_slice(version.as_bytes()); key.extend_from_slice(version.as_bytes());
Ok(utils::u64_from_bytes( Ok(utils::u64_from_bytes(
&self.backupid_etag.get(&key)?.ok_or_else(|| Error::bad_database("Backup has no etag."))?, &self
.backupid_etag
.get(&key)?
.ok_or_else(|| Error::bad_database("Backup has no etag."))?,
) )
.map_err(|_| Error::bad_database("etag in backupid_etag invalid."))? .map_err(|_| Error::bad_database("etag in backupid_etag invalid."))?
.to_string()) .to_string())
@ -162,17 +178,24 @@ impl service::key_backups::Data for KeyValueDatabase {
let mut rooms = BTreeMap::<OwnedRoomId, RoomKeyBackup>::new(); let mut rooms = BTreeMap::<OwnedRoomId, RoomKeyBackup>::new();
for result in self.backupkeyid_backup.scan_prefix(prefix).map(|(key, value)| { for result in self
.backupkeyid_backup
.scan_prefix(prefix)
.map(|(key, value)| {
let mut parts = key.rsplit(|&b| b == 0xFF); let mut parts = key.rsplit(|&b| b == 0xFF);
let session_id = utils::string_from_bytes( let session_id = utils::string_from_bytes(
parts.next().ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?, parts
.next()
.ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?,
) )
.map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?; .map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?;
let room_id = RoomId::parse( let room_id = RoomId::parse(
utils::string_from_bytes( utils::string_from_bytes(
parts.next().ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?, parts
.next()
.ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?,
) )
.map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?, .map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?,
) )
@ -213,7 +236,9 @@ impl service::key_backups::Data for KeyValueDatabase {
let mut parts = key.rsplit(|&b| b == 0xFF); let mut parts = key.rsplit(|&b| b == 0xFF);
let session_id = utils::string_from_bytes( let session_id = utils::string_from_bytes(
parts.next().ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?, parts
.next()
.ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?,
) )
.map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?; .map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?;

View file

@ -18,9 +18,19 @@ impl service::media::Data for KeyValueDatabase {
key.extend_from_slice(&width.to_be_bytes()); key.extend_from_slice(&width.to_be_bytes());
key.extend_from_slice(&height.to_be_bytes()); key.extend_from_slice(&height.to_be_bytes());
key.push(0xFF); key.push(0xFF);
key.extend_from_slice(content_disposition.as_ref().map(|f| f.as_bytes()).unwrap_or_default()); key.extend_from_slice(
content_disposition
.as_ref()
.map(|f| f.as_bytes())
.unwrap_or_default(),
);
key.push(0xFF); key.push(0xFF);
key.extend_from_slice(content_type.as_ref().map(|c| c.as_bytes()).unwrap_or_default()); key.extend_from_slice(
content_type
.as_ref()
.map(|c| c.as_bytes())
.unwrap_or_default(),
);
self.mediaid_file.insert(&key, &[])?; self.mediaid_file.insert(&key, &[])?;
@ -107,8 +117,9 @@ impl service::media::Data for KeyValueDatabase {
}) })
.transpose()?; .transpose()?;
let content_disposition_bytes = let content_disposition_bytes = parts
parts.next().ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?; .next()
.ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?;
let content_disposition = if content_disposition_bytes.is_empty() { let content_disposition = if content_disposition_bytes.is_empty() {
None None
@ -139,11 +150,26 @@ impl service::media::Data for KeyValueDatabase {
let mut value = Vec::<u8>::new(); let mut value = Vec::<u8>::new();
value.extend_from_slice(&timestamp.as_secs().to_be_bytes()); value.extend_from_slice(&timestamp.as_secs().to_be_bytes());
value.push(0xFF); value.push(0xFF);
value.extend_from_slice(data.title.as_ref().map(String::as_bytes).unwrap_or_default()); value.extend_from_slice(
data.title
.as_ref()
.map(String::as_bytes)
.unwrap_or_default(),
);
value.push(0xFF); value.push(0xFF);
value.extend_from_slice(data.description.as_ref().map(String::as_bytes).unwrap_or_default()); value.extend_from_slice(
data.description
.as_ref()
.map(String::as_bytes)
.unwrap_or_default(),
);
value.push(0xFF); value.push(0xFF);
value.extend_from_slice(data.image.as_ref().map(String::as_bytes).unwrap_or_default()); value.extend_from_slice(
data.image
.as_ref()
.map(String::as_bytes)
.unwrap_or_default(),
);
value.push(0xFF); value.push(0xFF);
value.extend_from_slice(&data.image_size.unwrap_or(0).to_be_bytes()); value.extend_from_slice(&data.image_size.unwrap_or(0).to_be_bytes());
value.push(0xFF); value.push(0xFF);
@ -166,27 +192,45 @@ impl service::media::Data for KeyValueDatabase {
x => x, x => x,
};*/ };*/
let title = match values.next().and_then(|b| String::from_utf8(b.to_vec()).ok()) { let title = match values
.next()
.and_then(|b| String::from_utf8(b.to_vec()).ok())
{
Some(s) if s.is_empty() => None, Some(s) if s.is_empty() => None,
x => x, x => x,
}; };
let description = match values.next().and_then(|b| String::from_utf8(b.to_vec()).ok()) { let description = match values
.next()
.and_then(|b| String::from_utf8(b.to_vec()).ok())
{
Some(s) if s.is_empty() => None, Some(s) if s.is_empty() => None,
x => x, x => x,
}; };
let image = match values.next().and_then(|b| String::from_utf8(b.to_vec()).ok()) { let image = match values
.next()
.and_then(|b| String::from_utf8(b.to_vec()).ok())
{
Some(s) if s.is_empty() => None, Some(s) if s.is_empty() => None,
x => x, x => x,
}; };
let image_size = match values.next().map(|b| usize::from_be_bytes(b.try_into().unwrap_or_default())) { let image_size = match values
.next()
.map(|b| usize::from_be_bytes(b.try_into().unwrap_or_default()))
{
Some(0) => None, Some(0) => None,
x => x, x => x,
}; };
let image_width = match values.next().map(|b| u32::from_be_bytes(b.try_into().unwrap_or_default())) { let image_width = match values
.next()
.map(|b| u32::from_be_bytes(b.try_into().unwrap_or_default()))
{
Some(0) => None, Some(0) => None,
x => x, x => x,
}; };
let image_height = match values.next().map(|b| u32::from_be_bytes(b.try_into().unwrap_or_default())) { let image_height = match values
.next()
.map(|b| u32::from_be_bytes(b.try_into().unwrap_or_default()))
{
Some(0) => None, Some(0) => None,
x => x, x => x,
}; };

View file

@ -53,7 +53,9 @@ impl service::pusher::Data for KeyValueDatabase {
Box::new(self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| { Box::new(self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| {
let mut parts = k.splitn(2, |&b| b == 0xFF); let mut parts = k.splitn(2, |&b| b == 0xFF);
let _senderkey = parts.next(); let _senderkey = parts.next();
let push_key = parts.next().ok_or_else(|| Error::bad_database("Invalid senderkey_pusher in db"))?; let push_key = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid senderkey_pusher in db"))?;
let push_key_string = utils::string_from_bytes(push_key) let push_key_string = utils::string_from_bytes(push_key)
.map_err(|_| Error::bad_database("Invalid pusher bytes in senderkey_pusher"))?; .map_err(|_| Error::bad_database("Invalid pusher bytes in senderkey_pusher"))?;

View file

@ -4,7 +4,8 @@ use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}
impl service::rooms::alias::Data for KeyValueDatabase { impl service::rooms::alias::Data for KeyValueDatabase {
fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> { fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> {
self.alias_roomid.insert(alias.alias().as_bytes(), room_id.as_bytes())?; self.alias_roomid
.insert(alias.alias().as_bytes(), room_id.as_bytes())?;
let mut aliasid = room_id.as_bytes().to_vec(); let mut aliasid = room_id.as_bytes().to_vec();
aliasid.push(0xFF); aliasid.push(0xFF);
aliasid.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); aliasid.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
@ -55,7 +56,10 @@ impl service::rooms::alias::Data for KeyValueDatabase {
} }
fn all_local_aliases<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a> { fn all_local_aliases<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a> {
Box::new(self.alias_roomid.iter().map(|(room_alias_bytes, room_id_bytes)| { Box::new(
self.alias_roomid
.iter()
.map(|(room_alias_bytes, room_id_bytes)| {
let room_alias_localpart = utils::string_from_bytes(&room_alias_bytes) let room_alias_localpart = utils::string_from_bytes(&room_alias_bytes)
.map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))?; .map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))?;
@ -65,6 +69,7 @@ impl service::rooms::alias::Data for KeyValueDatabase {
.map_err(|_| Error::bad_database("Invalid room_id in aliasid_alias."))?; .map_err(|_| Error::bad_database("Invalid room_id in aliasid_alias."))?;
Ok((room_id, room_alias_localpart)) Ok((room_id, room_alias_localpart))
})) }),
)
} }
} }

View file

@ -12,7 +12,10 @@ impl service::rooms::auth_chain::Data for KeyValueDatabase {
// We only save auth chains for single events in the db // We only save auth chains for single events in the db
if key.len() == 1 { if key.len() == 1 {
// Check DB cache // Check DB cache
let chain = self.shorteventid_authchain.get(&key[0].to_be_bytes())?.map(|chain| { let chain = self
.shorteventid_authchain
.get(&key[0].to_be_bytes())?
.map(|chain| {
chain chain
.chunks_exact(size_of::<u64>()) .chunks_exact(size_of::<u64>())
.map(|chunk| utils::u64_from_bytes(chunk).expect("byte length is correct")) .map(|chunk| utils::u64_from_bytes(chunk).expect("byte length is correct"))
@ -23,7 +26,10 @@ impl service::rooms::auth_chain::Data for KeyValueDatabase {
let chain = Arc::new(chain); let chain = Arc::new(chain);
// Cache in RAM // Cache in RAM
self.auth_chain_cache.lock().unwrap().insert(vec![key[0]], Arc::clone(&chain)); self.auth_chain_cache
.lock()
.unwrap()
.insert(vec![key[0]], Arc::clone(&chain));
return Ok(Some(chain)); return Ok(Some(chain));
} }
@ -37,12 +43,18 @@ impl service::rooms::auth_chain::Data for KeyValueDatabase {
if key.len() == 1 { if key.len() == 1 {
self.shorteventid_authchain.insert( self.shorteventid_authchain.insert(
&key[0].to_be_bytes(), &key[0].to_be_bytes(),
&auth_chain.iter().flat_map(|s| s.to_be_bytes().to_vec()).collect::<Vec<u8>>(), &auth_chain
.iter()
.flat_map(|s| s.to_be_bytes().to_vec())
.collect::<Vec<u8>>(),
)?; )?;
} }
// Cache in RAM // Cache in RAM
self.auth_chain_cache.lock().unwrap().insert(key, auth_chain); self.auth_chain_cache
.lock()
.unwrap()
.insert(key, auth_chain);
Ok(()) Ok(())
} }

View file

@ -65,7 +65,8 @@ impl service::rooms::edus::presence::Data for KeyValueDatabase {
None => Presence::new(new_state.clone(), new_state == PresenceState::Online, now, count, None), None => Presence::new(new_state.clone(), new_state == PresenceState::Online, now, count, None),
}; };
self.roomuserid_presence.insert(&key, &new_presence.to_json_bytes()?)?; self.roomuserid_presence
.insert(&key, &new_presence.to_json_bytes()?)?;
} }
let timeout = match new_state { let timeout = match new_state {
@ -73,7 +74,9 @@ impl service::rooms::edus::presence::Data for KeyValueDatabase {
_ => services().globals.config.presence_offline_timeout_s, _ => services().globals.config.presence_offline_timeout_s,
}; };
self.presence_timer_sender.send((user_id.to_owned(), Duration::from_secs(timeout))).map_err(|e| { self.presence_timer_sender
.send((user_id.to_owned(), Duration::from_secs(timeout)))
.map_err(|e| {
error!("Failed to add presence timer: {}", e); error!("Failed to add presence timer: {}", e);
Error::bad_database("Failed to add presence timer") Error::bad_database("Failed to add presence timer")
}) })
@ -104,12 +107,15 @@ impl service::rooms::edus::presence::Data for KeyValueDatabase {
_ => services().globals.config.presence_offline_timeout_s, _ => services().globals.config.presence_offline_timeout_s,
}; };
self.presence_timer_sender.send((user_id.to_owned(), Duration::from_secs(timeout))).map_err(|e| { self.presence_timer_sender
.send((user_id.to_owned(), Duration::from_secs(timeout)))
.map_err(|e| {
error!("Failed to add presence timer: {}", e); error!("Failed to add presence timer: {}", e);
Error::bad_database("Failed to add presence timer") Error::bad_database("Failed to add presence timer")
})?; })?;
self.roomuserid_presence.insert(&key, &presence.to_json_bytes()?)?; self.roomuserid_presence
.insert(&key, &presence.to_json_bytes()?)?;
Ok(()) Ok(())
} }

View file

@ -18,7 +18,10 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
.iter_from(&last_possible_key, true) .iter_from(&last_possible_key, true)
.take_while(|(key, _)| key.starts_with(&prefix)) .take_while(|(key, _)| key.starts_with(&prefix))
.find(|(key, _)| { .find(|(key, _)| {
key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element") == user_id.as_bytes() key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element")
== user_id.as_bytes()
}) { }) {
// This is the old room_latest // This is the old room_latest
self.readreceiptid_readreceipt.remove(&old)?; self.readreceiptid_readreceipt.remove(&old)?;
@ -78,9 +81,11 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
key.push(0xFF); key.push(0xFF);
key.extend_from_slice(user_id.as_bytes()); key.extend_from_slice(user_id.as_bytes());
self.roomuserid_privateread.insert(&key, &count.to_be_bytes())?; self.roomuserid_privateread
.insert(&key, &count.to_be_bytes())?;
self.roomuserid_lastprivatereadupdate.insert(&key, &services().globals.next_count()?.to_be_bytes()) self.roomuserid_lastprivatereadupdate
.insert(&key, &services().globals.next_count()?.to_be_bytes())
} }
fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> { fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
@ -88,7 +93,9 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
key.push(0xFF); key.push(0xFF);
key.extend_from_slice(user_id.as_bytes()); key.extend_from_slice(user_id.as_bytes());
self.roomuserid_privateread.get(&key)?.map_or(Ok(None), |v| { self.roomuserid_privateread
.get(&key)?
.map_or(Ok(None), |v| {
Ok(Some( Ok(Some(
utils::u64_from_bytes(&v).map_err(|_| Error::bad_database("Invalid private read marker bytes"))?, utils::u64_from_bytes(&v).map_err(|_| Error::bad_database("Invalid private read marker bytes"))?,
)) ))

View file

@ -11,7 +11,12 @@ impl service::rooms::metadata::Data for KeyValueDatabase {
}; };
// Look for PDUs in that room. // Look for PDUs in that room.
Ok(self.pduid_pdu.iter_from(&prefix, false).next().filter(|(k, _)| k.starts_with(&prefix)).is_some()) Ok(self
.pduid_pdu
.iter_from(&prefix, false)
.next()
.filter(|(k, _)| k.starts_with(&prefix))
.is_some())
} }
fn iter_ids<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> { fn iter_ids<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {

View file

@ -4,13 +4,17 @@ use crate::{database::KeyValueDatabase, service, Error, PduEvent, Result};
impl service::rooms::outlier::Data for KeyValueDatabase { impl service::rooms::outlier::Data for KeyValueDatabase {
fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> { fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.eventid_outlierpdu.get(event_id.as_bytes())?.map_or(Ok(None), |pdu| { self.eventid_outlierpdu
.get(event_id.as_bytes())?
.map_or(Ok(None), |pdu| {
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
}) })
} }
fn get_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> { fn get_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
self.eventid_outlierpdu.get(event_id.as_bytes())?.map_or(Ok(None), |pdu| { self.eventid_outlierpdu
.get(event_id.as_bytes())?
.map_or(Ok(None), |pdu| {
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
}) })
} }

View file

@ -32,8 +32,10 @@ impl service::rooms::pdu_metadata::Data for KeyValueDatabase {
current.extend_from_slice(&count_raw.to_be_bytes()); current.extend_from_slice(&count_raw.to_be_bytes());
Ok(Box::new( Ok(Box::new(
self.tofrom_relation.iter_from(&current, true).take_while(move |(k, _)| k.starts_with(&prefix)).map( self.tofrom_relation
move |(tofrom, _data)| { .iter_from(&current, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(tofrom, _data)| {
let from = utils::u64_from_bytes(&tofrom[(mem::size_of::<u64>())..]) let from = utils::u64_from_bytes(&tofrom[(mem::size_of::<u64>())..])
.map_err(|_| Error::bad_database("Invalid count in tofrom_relation."))?; .map_err(|_| Error::bad_database("Invalid count in tofrom_relation."))?;
@ -49,8 +51,7 @@ impl service::rooms::pdu_metadata::Data for KeyValueDatabase {
pdu.remove_transaction_id()?; pdu.remove_transaction_id()?;
} }
Ok((PduCount::Normal(from), pdu)) Ok((PduCount::Normal(from), pdu))
}, }),
),
)) ))
} }
@ -75,6 +76,8 @@ impl service::rooms::pdu_metadata::Data for KeyValueDatabase {
} }
fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> { fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> {
self.softfailedeventids.get(event_id.as_bytes()).map(|o| o.is_some()) self.softfailedeventids
.get(event_id.as_bytes())
.map(|o| o.is_some())
} }
} }

View file

@ -23,7 +23,13 @@ impl service::rooms::search::Data for KeyValueDatabase {
} }
fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> { fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> {
let prefix = services().rooms.short.get_shortroomid(room_id)?.expect("room exists").to_be_bytes().to_vec(); let prefix = services()
.rooms
.short
.get_shortroomid(room_id)?
.expect("room exists")
.to_be_bytes()
.to_vec();
let words: Vec<_> = search_string let words: Vec<_> = search_string
.split_terminator(|c: char| !c.is_alphanumeric()) .split_terminator(|c: char| !c.is_alphanumeric())

View file

@ -17,20 +17,28 @@ impl service::rooms::short::Data for KeyValueDatabase {
}, },
None => { None => {
let shorteventid = services().globals.next_count()?; let shorteventid = services().globals.next_count()?;
self.eventid_shorteventid.insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?; self.eventid_shorteventid
self.shorteventid_eventid.insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?; .insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?;
self.shorteventid_eventid
.insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?;
shorteventid shorteventid
}, },
}; };
self.eventidshort_cache.lock().unwrap().insert(event_id.to_owned(), short); self.eventidshort_cache
.lock()
.unwrap()
.insert(event_id.to_owned(), short);
Ok(short) Ok(short)
} }
fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<Option<u64>> { fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<Option<u64>> {
if let Some(short) = if let Some(short) = self
self.statekeyshort_cache.lock().unwrap().get_mut(&(event_type.clone(), state_key.to_owned())) .statekeyshort_cache
.lock()
.unwrap()
.get_mut(&(event_type.clone(), state_key.to_owned()))
{ {
return Ok(Some(*short)); return Ok(Some(*short));
} }
@ -48,15 +56,21 @@ impl service::rooms::short::Data for KeyValueDatabase {
.transpose()?; .transpose()?;
if let Some(s) = short { if let Some(s) = short {
self.statekeyshort_cache.lock().unwrap().insert((event_type.clone(), state_key.to_owned()), s); self.statekeyshort_cache
.lock()
.unwrap()
.insert((event_type.clone(), state_key.to_owned()), s);
} }
Ok(short) Ok(short)
} }
fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<u64> { fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<u64> {
if let Some(short) = if let Some(short) = self
self.statekeyshort_cache.lock().unwrap().get_mut(&(event_type.clone(), state_key.to_owned())) .statekeyshort_cache
.lock()
.unwrap()
.get_mut(&(event_type.clone(), state_key.to_owned()))
{ {
return Ok(*short); return Ok(*short);
} }
@ -70,19 +84,29 @@ impl service::rooms::short::Data for KeyValueDatabase {
.map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?, .map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?,
None => { None => {
let shortstatekey = services().globals.next_count()?; let shortstatekey = services().globals.next_count()?;
self.statekey_shortstatekey.insert(&statekey_vec, &shortstatekey.to_be_bytes())?; self.statekey_shortstatekey
self.shortstatekey_statekey.insert(&shortstatekey.to_be_bytes(), &statekey_vec)?; .insert(&statekey_vec, &shortstatekey.to_be_bytes())?;
self.shortstatekey_statekey
.insert(&shortstatekey.to_be_bytes(), &statekey_vec)?;
shortstatekey shortstatekey
}, },
}; };
self.statekeyshort_cache.lock().unwrap().insert((event_type.clone(), state_key.to_owned()), short); self.statekeyshort_cache
.lock()
.unwrap()
.insert((event_type.clone(), state_key.to_owned()), short);
Ok(short) Ok(short)
} }
fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> { fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> {
if let Some(id) = self.shorteventid_cache.lock().unwrap().get_mut(&shorteventid) { if let Some(id) = self
.shorteventid_cache
.lock()
.unwrap()
.get_mut(&shorteventid)
{
return Ok(Arc::clone(id)); return Ok(Arc::clone(id));
} }
@ -97,13 +121,21 @@ impl service::rooms::short::Data for KeyValueDatabase {
) )
.map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?; .map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?;
self.shorteventid_cache.lock().unwrap().insert(shorteventid, Arc::clone(&event_id)); self.shorteventid_cache
.lock()
.unwrap()
.insert(shorteventid, Arc::clone(&event_id));
Ok(event_id) Ok(event_id)
} }
fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> {
if let Some(id) = self.shortstatekey_cache.lock().unwrap().get_mut(&shortstatekey) { if let Some(id) = self
.shortstatekey_cache
.lock()
.unwrap()
.get_mut(&shortstatekey)
{
return Ok(id.clone()); return Ok(id.clone());
} }
@ -114,8 +146,9 @@ impl service::rooms::short::Data for KeyValueDatabase {
let mut parts = bytes.splitn(2, |&b| b == 0xFF); let mut parts = bytes.splitn(2, |&b| b == 0xFF);
let eventtype_bytes = parts.next().expect("split always returns one entry"); let eventtype_bytes = parts.next().expect("split always returns one entry");
let statekey_bytes = let statekey_bytes = parts
parts.next().ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?; .next()
.ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?;
let event_type = StateEventType::from(utils::string_from_bytes(eventtype_bytes).map_err(|e| { let event_type = StateEventType::from(utils::string_from_bytes(eventtype_bytes).map_err(|e| {
warn!("Event type in shortstatekey_statekey is invalid: {}", e); warn!("Event type in shortstatekey_statekey is invalid: {}", e);
@ -127,7 +160,10 @@ impl service::rooms::short::Data for KeyValueDatabase {
let result = (event_type, state_key); let result = (event_type, state_key);
self.shortstatekey_cache.lock().unwrap().insert(shortstatekey, result.clone()); self.shortstatekey_cache
.lock()
.unwrap()
.insert(shortstatekey, result.clone());
Ok(result) Ok(result)
} }
@ -142,7 +178,8 @@ impl service::rooms::short::Data for KeyValueDatabase {
), ),
None => { None => {
let shortstatehash = services().globals.next_count()?; let shortstatehash = services().globals.next_count()?;
self.statehash_shortstatehash.insert(state_hash, &shortstatehash.to_be_bytes())?; self.statehash_shortstatehash
.insert(state_hash, &shortstatehash.to_be_bytes())?;
(shortstatehash, false) (shortstatehash, false)
}, },
}) })
@ -162,7 +199,8 @@ impl service::rooms::short::Data for KeyValueDatabase {
}, },
None => { None => {
let short = services().globals.next_count()?; let short = services().globals.next_count()?;
self.roomid_shortroomid.insert(room_id.as_bytes(), &short.to_be_bytes())?; self.roomid_shortroomid
.insert(room_id.as_bytes(), &short.to_be_bytes())?;
short short
}, },
}) })

View file

@ -7,7 +7,9 @@ use crate::{database::KeyValueDatabase, service, utils, Error, Result};
impl service::rooms::state::Data for KeyValueDatabase { impl service::rooms::state::Data for KeyValueDatabase {
fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> { fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_shortstatehash.get(room_id.as_bytes())?.map_or(Ok(None), |bytes| { self.roomid_shortstatehash
.get(room_id.as_bytes())?
.map_or(Ok(None), |bytes| {
Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| { Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Invalid shortstatehash in roomid_shortstatehash") Error::bad_database("Invalid shortstatehash in roomid_shortstatehash")
})?)) })?))
@ -20,12 +22,14 @@ impl service::rooms::state::Data for KeyValueDatabase {
new_shortstatehash: u64, new_shortstatehash: u64,
_mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
) -> Result<()> { ) -> Result<()> {
self.roomid_shortstatehash.insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?; self.roomid_shortstatehash
.insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?;
Ok(()) Ok(())
} }
fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()> { fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()> {
self.shorteventid_shortstatehash.insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?; self.shorteventid_shortstatehash
.insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?;
Ok(()) Ok(())
} }

View file

@ -19,7 +19,10 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
let mut result = HashMap::new(); let mut result = HashMap::new();
let mut i = 0; let mut i = 0;
for compressed in full_state.iter() { for compressed in full_state.iter() {
let parsed = services().rooms.state_compressor.parse_compressed_state_event(compressed)?; let parsed = services()
.rooms
.state_compressor
.parse_compressed_state_event(compressed)?;
result.insert(parsed.0, parsed.1); result.insert(parsed.0, parsed.1);
i += 1; i += 1;
@ -43,7 +46,10 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
let mut result = HashMap::new(); let mut result = HashMap::new();
let mut i = 0; let mut i = 0;
for compressed in full_state.iter() { for compressed in full_state.iter() {
let (_, eventid) = services().rooms.state_compressor.parse_compressed_state_event(compressed)?; let (_, eventid) = services()
.rooms
.state_compressor
.parse_compressed_state_event(compressed)?;
if let Some(pdu) = services().rooms.timeline.get_pdu(&eventid)? { if let Some(pdu) = services().rooms.timeline.get_pdu(&eventid)? {
result.insert( result.insert(
( (
@ -71,7 +77,11 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
fn state_get_id( fn state_get_id(
&self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<EventId>>> { ) -> Result<Option<Arc<EventId>>> {
let shortstatekey = match services().rooms.short.get_shortstatekey(event_type, state_key)? { let shortstatekey = match services()
.rooms
.short
.get_shortstatekey(event_type, state_key)?
{
Some(s) => s, Some(s) => s,
None => return Ok(None), None => return Ok(None),
}; };
@ -82,11 +92,17 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
.pop() .pop()
.expect("there is always one layer") .expect("there is always one layer")
.1; .1;
Ok( Ok(full_state
full_state.iter().find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())).and_then(|compressed| { .iter()
services().rooms.state_compressor.parse_compressed_state_event(compressed).ok().map(|(_, id)| id) .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))
}), .and_then(|compressed| {
) services()
.rooms
.state_compressor
.parse_compressed_state_event(compressed)
.ok()
.map(|(_, id)| id)
}))
} }
/// Returns a single PDU from `room_id` with key (`event_type`, /// Returns a single PDU from `room_id` with key (`event_type`,
@ -100,12 +116,15 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
/// Returns the state hash for this pdu. /// Returns the state hash for this pdu.
fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> { fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> {
self.eventid_shorteventid.get(event_id.as_bytes())?.map_or(Ok(None), |shorteventid| { self.eventid_shorteventid
.get(event_id.as_bytes())?
.map_or(Ok(None), |shorteventid| {
self.shorteventid_shortstatehash self.shorteventid_shortstatehash
.get(&shorteventid)? .get(&shorteventid)?
.map(|bytes| { .map(|bytes| {
utils::u64_from_bytes(&bytes) utils::u64_from_bytes(&bytes).map_err(|_| {
.map_err(|_| Error::bad_database("Invalid shortstatehash bytes in shorteventid_shortstatehash")) Error::bad_database("Invalid shortstatehash bytes in shorteventid_shortstatehash")
})
}) })
.transpose() .transpose()
}) })

View file

@ -58,7 +58,8 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
&userroom_id, &userroom_id,
&serde_json::to_vec(&last_state.unwrap_or_default()).expect("state to bytes always works"), &serde_json::to_vec(&last_state.unwrap_or_default()).expect("state to bytes always works"),
)?; )?;
self.roomuserid_invitecount.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?; self.roomuserid_invitecount
.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?;
self.userroomid_joined.remove(&userroom_id)?; self.userroomid_joined.remove(&userroom_id)?;
self.roomuserid_joined.remove(&roomuser_id)?; self.roomuserid_joined.remove(&roomuser_id)?;
self.userroomid_leftstate.remove(&userroom_id)?; self.userroomid_leftstate.remove(&userroom_id)?;
@ -80,7 +81,8 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
&userroom_id, &userroom_id,
&serde_json::to_vec(&Vec::<Raw<AnySyncStateEvent>>::new()).unwrap(), &serde_json::to_vec(&Vec::<Raw<AnySyncStateEvent>>::new()).unwrap(),
)?; // TODO )?; // TODO
self.roomuserid_leftcount.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?; self.roomuserid_leftcount
.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?;
self.userroomid_joined.remove(&userroom_id)?; self.userroomid_joined.remove(&userroom_id)?;
self.roomuserid_joined.remove(&roomuser_id)?; self.roomuserid_joined.remove(&roomuser_id)?;
self.userroomid_invitestate.remove(&userroom_id)?; self.userroomid_invitestate.remove(&userroom_id)?;
@ -109,11 +111,16 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
invitedcount += 1; invitedcount += 1;
} }
self.roomid_joinedcount.insert(room_id.as_bytes(), &joinedcount.to_be_bytes())?; self.roomid_joinedcount
.insert(room_id.as_bytes(), &joinedcount.to_be_bytes())?;
self.roomid_invitedcount.insert(room_id.as_bytes(), &invitedcount.to_be_bytes())?; self.roomid_invitedcount
.insert(room_id.as_bytes(), &invitedcount.to_be_bytes())?;
self.our_real_users_cache.write().unwrap().insert(room_id.to_owned(), Arc::new(real_users)); self.our_real_users_cache
.write()
.unwrap()
.insert(room_id.to_owned(), Arc::new(real_users));
for old_joined_server in self.room_servers(room_id).filter_map(Result::ok) { for old_joined_server in self.room_servers(room_id).filter_map(Result::ok) {
if !joined_servers.remove(&old_joined_server) { if !joined_servers.remove(&old_joined_server) {
@ -145,19 +152,33 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
self.serverroomids.insert(&serverroom_id, &[])?; self.serverroomids.insert(&serverroom_id, &[])?;
} }
self.appservice_in_room_cache.write().unwrap().remove(room_id); self.appservice_in_room_cache
.write()
.unwrap()
.remove(room_id);
Ok(()) Ok(())
} }
#[tracing::instrument(skip(self, room_id))] #[tracing::instrument(skip(self, room_id))]
fn get_our_real_users(&self, room_id: &RoomId) -> Result<Arc<HashSet<OwnedUserId>>> { fn get_our_real_users(&self, room_id: &RoomId) -> Result<Arc<HashSet<OwnedUserId>>> {
let maybe = self.our_real_users_cache.read().unwrap().get(room_id).cloned(); let maybe = self
.our_real_users_cache
.read()
.unwrap()
.get(room_id)
.cloned();
if let Some(users) = maybe { if let Some(users) = maybe {
Ok(users) Ok(users)
} else { } else {
self.update_joined_count(room_id)?; self.update_joined_count(room_id)?;
Ok(Arc::clone(self.our_real_users_cache.read().unwrap().get(room_id).unwrap())) Ok(Arc::clone(
self.our_real_users_cache
.read()
.unwrap()
.get(room_id)
.unwrap(),
))
} }
} }
@ -221,7 +242,11 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
Box::new(self.roomserverids.scan_prefix(prefix).map(|(key, _)| { Box::new(self.roomserverids.scan_prefix(prefix).map(|(key, _)| {
ServerName::parse( ServerName::parse(
utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element")) utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("Server name in roomserverids is invalid unicode."))?, .map_err(|_| Error::bad_database("Server name in roomserverids is invalid unicode."))?,
) )
.map_err(|_| Error::bad_database("Server name in roomserverids is invalid.")) .map_err(|_| Error::bad_database("Server name in roomserverids is invalid."))
@ -246,7 +271,11 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
Box::new(self.serverroomids.scan_prefix(prefix).map(|(key, _)| { Box::new(self.serverroomids.scan_prefix(prefix).map(|(key, _)| {
RoomId::parse( RoomId::parse(
utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element")) utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("RoomId in serverroomids is invalid unicode."))?, .map_err(|_| Error::bad_database("RoomId in serverroomids is invalid unicode."))?,
) )
.map_err(|_| Error::bad_database("RoomId in serverroomids is invalid.")) .map_err(|_| Error::bad_database("RoomId in serverroomids is invalid."))
@ -261,7 +290,11 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
Box::new(self.roomuserid_joined.scan_prefix(prefix).map(|(key, _)| { Box::new(self.roomuserid_joined.scan_prefix(prefix).map(|(key, _)| {
UserId::parse( UserId::parse(
utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element")) utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid unicode."))?, .map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid unicode."))?,
) )
.map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid.")) .map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid."))
@ -292,13 +325,21 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
let mut prefix = room_id.as_bytes().to_vec(); let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF); prefix.push(0xFF);
Box::new(self.roomuseroncejoinedids.scan_prefix(prefix).map(|(key, _)| { Box::new(
self.roomuseroncejoinedids
.scan_prefix(prefix)
.map(|(key, _)| {
UserId::parse( UserId::parse(
utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element")) utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid unicode."))?, .map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid unicode."))?,
) )
.map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid.")) .map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid."))
})) }),
)
} }
/// Returns an iterator over all invited members of a room. /// Returns an iterator over all invited members of a room.
@ -307,13 +348,21 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
let mut prefix = room_id.as_bytes().to_vec(); let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF); prefix.push(0xFF);
Box::new(self.roomuserid_invitecount.scan_prefix(prefix).map(|(key, _)| { Box::new(
self.roomuserid_invitecount
.scan_prefix(prefix)
.map(|(key, _)| {
UserId::parse( UserId::parse(
utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element")) utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid unicode."))?, .map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid unicode."))?,
) )
.map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid.")) .map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid."))
})) }),
)
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
@ -322,7 +371,9 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
key.push(0xFF); key.push(0xFF);
key.extend_from_slice(user_id.as_bytes()); key.extend_from_slice(user_id.as_bytes());
self.roomuserid_invitecount.get(&key)?.map_or(Ok(None), |bytes| { self.roomuserid_invitecount
.get(&key)?
.map_or(Ok(None), |bytes| {
Ok(Some( Ok(Some(
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid invitecount in db."))?, utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid invitecount in db."))?,
)) ))
@ -344,13 +395,21 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
/// Returns an iterator over all rooms this user joined. /// Returns an iterator over all rooms this user joined.
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn rooms_joined<'a>(&'a self, user_id: &UserId) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> { fn rooms_joined<'a>(&'a self, user_id: &UserId) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
Box::new(self.userroomid_joined.scan_prefix(user_id.as_bytes().to_vec()).map(|(key, _)| { Box::new(
self.userroomid_joined
.scan_prefix(user_id.as_bytes().to_vec())
.map(|(key, _)| {
RoomId::parse( RoomId::parse(
utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element")) utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid unicode."))?, .map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid unicode."))?,
) )
.map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid.")) .map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid."))
})) }),
)
} }
/// Returns an iterator over all rooms a user was invited to. /// Returns an iterator over all rooms a user was invited to.
@ -359,9 +418,16 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF); prefix.push(0xFF);
Box::new(self.userroomid_invitestate.scan_prefix(prefix).map(|(key, state)| { Box::new(
self.userroomid_invitestate
.scan_prefix(prefix)
.map(|(key, state)| {
let room_id = RoomId::parse( let room_id = RoomId::parse(
utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element")) utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?, .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?,
) )
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?; .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?;
@ -370,7 +436,8 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
.map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?; .map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?;
Ok((room_id, state)) Ok((room_id, state))
})) }),
)
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
@ -413,9 +480,16 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF); prefix.push(0xFF);
Box::new(self.userroomid_leftstate.scan_prefix(prefix).map(|(key, state)| { Box::new(
self.userroomid_leftstate
.scan_prefix(prefix)
.map(|(key, state)| {
let room_id = RoomId::parse( let room_id = RoomId::parse(
utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element")) utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?, .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?,
) )
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?; .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?;
@ -424,7 +498,8 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
.map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?; .map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?;
Ok((room_id, state)) Ok((room_id, state))
})) }),
)
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]

View file

@ -58,6 +58,7 @@ impl service::rooms::state_compressor::Data for KeyValueDatabase {
} }
} }
self.shortstatehash_statediff.insert(&shortstatehash.to_be_bytes(), &value) self.shortstatehash_statediff
.insert(&shortstatehash.to_be_bytes(), &value)
} }
} }

View file

@ -10,14 +10,22 @@ impl service::rooms::threads::Data for KeyValueDatabase {
fn threads_until<'a>( fn threads_until<'a>(
&'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, _include: &'a IncludeThreads, &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, _include: &'a IncludeThreads,
) -> PduEventIterResult<'a> { ) -> PduEventIterResult<'a> {
let prefix = services().rooms.short.get_shortroomid(room_id)?.expect("room exists").to_be_bytes().to_vec(); let prefix = services()
.rooms
.short
.get_shortroomid(room_id)?
.expect("room exists")
.to_be_bytes()
.to_vec();
let mut current = prefix.clone(); let mut current = prefix.clone();
current.extend_from_slice(&(until - 1).to_be_bytes()); current.extend_from_slice(&(until - 1).to_be_bytes());
Ok(Box::new( Ok(Box::new(
self.threadid_userids.iter_from(&current, true).take_while(move |(k, _)| k.starts_with(&prefix)).map( self.threadid_userids
move |(pduid, _users)| { .iter_from(&current, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pduid, _users)| {
let count = utils::u64_from_bytes(&pduid[(mem::size_of::<u64>())..]) let count = utils::u64_from_bytes(&pduid[(mem::size_of::<u64>())..])
.map_err(|_| Error::bad_database("Invalid pduid in threadid_userids."))?; .map_err(|_| Error::bad_database("Invalid pduid in threadid_userids."))?;
let mut pdu = services() let mut pdu = services()
@ -29,13 +37,16 @@ impl service::rooms::threads::Data for KeyValueDatabase {
pdu.remove_transaction_id()?; pdu.remove_transaction_id()?;
} }
Ok((count, pdu)) Ok((count, pdu))
}, }),
),
)) ))
} }
fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()> { fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()> {
let users = participants.iter().map(|user| user.as_bytes()).collect::<Vec<_>>().join(&[0xFF][..]); let users = participants
.iter()
.map(|user| user.as_bytes())
.collect::<Vec<_>>()
.join(&[0xFF][..]);
self.threadid_userids.insert(root_id, &users)?; self.threadid_userids.insert(root_id, &users)?;

View file

@ -8,9 +8,16 @@ use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEven
impl service::rooms::timeline::Data for KeyValueDatabase { impl service::rooms::timeline::Data for KeyValueDatabase {
fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<PduCount> { fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<PduCount> {
match self.lasttimelinecount_cache.lock().unwrap().entry(room_id.to_owned()) { match self
.lasttimelinecount_cache
.lock()
.unwrap()
.entry(room_id.to_owned())
{
hash_map::Entry::Vacant(v) => { hash_map::Entry::Vacant(v) => {
if let Some(last_count) = self.pdus_until(sender_user, room_id, PduCount::max())?.find_map(|r| { if let Some(last_count) = self
.pdus_until(sender_user, room_id, PduCount::max())?
.find_map(|r| {
// Filter out buggy events // Filter out buggy events
if r.is_err() { if r.is_err() {
error!("Bad pdu in pdus_since: {:?}", r); error!("Bad pdu in pdus_since: {:?}", r);
@ -28,7 +35,10 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
/// Returns the `count` of this pdu's id. /// Returns the `count` of this pdu's id.
fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<PduCount>> { fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<PduCount>> {
self.eventid_pduid.get(event_id.as_bytes())?.map(|pdu_id| pdu_count(&pdu_id)).transpose() self.eventid_pduid
.get(event_id.as_bytes())?
.map(|pdu_id| pdu_count(&pdu_id))
.transpose()
} }
/// Returns the json of a pdu. /// Returns the json of a pdu.
@ -49,7 +59,9 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
self.eventid_pduid self.eventid_pduid
.get(event_id.as_bytes())? .get(event_id.as_bytes())?
.map(|pduid| { .map(|pduid| {
self.pduid_pdu.get(&pduid)?.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) self.pduid_pdu
.get(&pduid)?
.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid."))
}) })
.transpose()? .transpose()?
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))) .map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
@ -64,7 +76,9 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
self.eventid_pduid self.eventid_pduid
.get(event_id.as_bytes())? .get(event_id.as_bytes())?
.map(|pduid| { .map(|pduid| {
self.pduid_pdu.get(&pduid)?.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) self.pduid_pdu
.get(&pduid)?
.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid."))
}) })
.transpose()? .transpose()?
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))) .map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
@ -92,7 +106,10 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
)? )?
.map(Arc::new) .map(Arc::new)
{ {
self.pdu_cache.lock().unwrap().insert(event_id.to_owned(), Arc::clone(&pdu)); self.pdu_cache
.lock()
.unwrap()
.insert(event_id.to_owned(), Arc::clone(&pdu));
Ok(Some(pdu)) Ok(Some(pdu))
} else { } else {
Ok(None) Ok(None)
@ -125,7 +142,10 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"), &serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
)?; )?;
self.lasttimelinecount_cache.lock().unwrap().insert(pdu.room_id.clone(), PduCount::Normal(count)); self.lasttimelinecount_cache
.lock()
.unwrap()
.insert(pdu.room_id.clone(), PduCount::Normal(count));
self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id)?; self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id)?;
self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?; self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?;
@ -156,7 +176,10 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
return Err(Error::BadRequest(ErrorKind::NotFound, "PDU does not exist.")); return Err(Error::BadRequest(ErrorKind::NotFound, "PDU does not exist."));
} }
self.pdu_cache.lock().unwrap().remove(&(*pdu.event_id).to_owned()); self.pdu_cache
.lock()
.unwrap()
.remove(&(*pdu.event_id).to_owned());
Ok(()) Ok(())
} }
@ -172,8 +195,10 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
let user_id = user_id.to_owned(); let user_id = user_id.to_owned();
Ok(Box::new( Ok(Box::new(
self.pduid_pdu.iter_from(&current, true).take_while(move |(k, _)| k.starts_with(&prefix)).map( self.pduid_pdu
move |(pdu_id, v)| { .iter_from(&current, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pdu_id, v)| {
let mut pdu = serde_json::from_slice::<PduEvent>(&v) let mut pdu = serde_json::from_slice::<PduEvent>(&v)
.map_err(|_| Error::bad_database("PDU in db is invalid."))?; .map_err(|_| Error::bad_database("PDU in db is invalid."))?;
if pdu.sender != user_id { if pdu.sender != user_id {
@ -182,8 +207,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
pdu.add_age()?; pdu.add_age()?;
let count = pdu_count(&pdu_id)?; let count = pdu_count(&pdu_id)?;
Ok((count, pdu)) Ok((count, pdu))
}, }),
),
)) ))
} }
@ -195,8 +219,10 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
let user_id = user_id.to_owned(); let user_id = user_id.to_owned();
Ok(Box::new( Ok(Box::new(
self.pduid_pdu.iter_from(&current, false).take_while(move |(k, _)| k.starts_with(&prefix)).map( self.pduid_pdu
move |(pdu_id, v)| { .iter_from(&current, false)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pdu_id, v)| {
let mut pdu = serde_json::from_slice::<PduEvent>(&v) let mut pdu = serde_json::from_slice::<PduEvent>(&v)
.map_err(|_| Error::bad_database("PDU in db is invalid."))?; .map_err(|_| Error::bad_database("PDU in db is invalid."))?;
if pdu.sender != user_id { if pdu.sender != user_id {
@ -205,8 +231,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
pdu.add_age()?; pdu.add_age()?;
let count = pdu_count(&pdu_id)?; let count = pdu_count(&pdu_id)?;
Ok((count, pdu)) Ok((count, pdu))
}, }),
),
)) ))
} }
@ -228,8 +253,10 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
highlights_batch.push(userroom_id); highlights_batch.push(userroom_id);
} }
self.userroomid_notificationcount.increment_batch(&mut notifies_batch.into_iter())?; self.userroomid_notificationcount
self.userroomid_highlightcount.increment_batch(&mut highlights_batch.into_iter())?; .increment_batch(&mut notifies_batch.into_iter())?;
self.userroomid_highlightcount
.increment_batch(&mut highlights_batch.into_iter())?;
Ok(()) Ok(())
} }
} }

View file

@ -11,10 +11,13 @@ impl service::rooms::user::Data for KeyValueDatabase {
roomuser_id.push(0xFF); roomuser_id.push(0xFF);
roomuser_id.extend_from_slice(user_id.as_bytes()); roomuser_id.extend_from_slice(user_id.as_bytes());
self.userroomid_notificationcount.insert(&userroom_id, &0_u64.to_be_bytes())?; self.userroomid_notificationcount
self.userroomid_highlightcount.insert(&userroom_id, &0_u64.to_be_bytes())?; .insert(&userroom_id, &0_u64.to_be_bytes())?;
self.userroomid_highlightcount
.insert(&userroom_id, &0_u64.to_be_bytes())?;
self.roomuserid_lastnotificationread.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?; self.roomuserid_lastnotificationread
.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?;
Ok(()) Ok(())
} }
@ -24,7 +27,9 @@ impl service::rooms::user::Data for KeyValueDatabase {
userroom_id.push(0xFF); userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes()); userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_notificationcount.get(&userroom_id)?.map_or(Ok(0), |bytes| { self.userroomid_notificationcount
.get(&userroom_id)?
.map_or(Ok(0), |bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid notification count in db.")) utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid notification count in db."))
}) })
} }
@ -34,7 +39,9 @@ impl service::rooms::user::Data for KeyValueDatabase {
userroom_id.push(0xFF); userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes()); userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_highlightcount.get(&userroom_id)?.map_or(Ok(0), |bytes| { self.userroomid_highlightcount
.get(&userroom_id)?
.map_or(Ok(0), |bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid highlight count in db.")) utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid highlight count in db."))
}) })
} }
@ -56,16 +63,25 @@ impl service::rooms::user::Data for KeyValueDatabase {
} }
fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) -> Result<()> { fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) -> Result<()> {
let shortroomid = services().rooms.short.get_shortroomid(room_id)?.expect("room exists"); let shortroomid = services()
.rooms
.short
.get_shortroomid(room_id)?
.expect("room exists");
let mut key = shortroomid.to_be_bytes().to_vec(); let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(&token.to_be_bytes()); key.extend_from_slice(&token.to_be_bytes());
self.roomsynctoken_shortstatehash.insert(&key, &shortstatehash.to_be_bytes()) self.roomsynctoken_shortstatehash
.insert(&key, &shortstatehash.to_be_bytes())
} }
fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> { fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> {
let shortroomid = services().rooms.short.get_shortroomid(room_id)?.expect("room exists"); let shortroomid = services()
.rooms
.short
.get_shortroomid(room_id)?
.expect("room exists");
let mut key = shortroomid.to_be_bytes().to_vec(); let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(&token.to_be_bytes()); key.extend_from_slice(&token.to_be_bytes());
@ -106,7 +122,9 @@ impl service::rooms::user::Data for KeyValueDatabase {
// We use the default compare function because keys are sorted correctly (not // We use the default compare function because keys are sorted correctly (not
// reversed) // reversed)
Ok(Box::new( Ok(Box::new(
utils::common_elements(iterators, Ord::cmp).expect("users is not empty").map(|bytes| { utils::common_elements(iterators, Ord::cmp)
.expect("users is not empty")
.map(|bytes| {
RoomId::parse( RoomId::parse(
utils::string_from_bytes(&bytes) utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid RoomId bytes in userroomid_joined"))?, .map_err(|_| Error::bad_database("Invalid RoomId bytes in userroomid_joined"))?,

View file

@ -73,7 +73,8 @@ impl service::sending::Data for KeyValueDatabase {
batch.push((key.clone(), value.to_owned())); batch.push((key.clone(), value.to_owned()));
keys.push(key); keys.push(key);
} }
self.servernameevent_data.insert_batch(&mut batch.into_iter())?; self.servernameevent_data
.insert_batch(&mut batch.into_iter())?;
Ok(keys) Ok(keys)
} }
@ -107,11 +108,14 @@ impl service::sending::Data for KeyValueDatabase {
} }
fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) -> Result<()> { fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) -> Result<()> {
self.servername_educount.insert(server_name.as_bytes(), &last_count.to_be_bytes()) self.servername_educount
.insert(server_name.as_bytes(), &last_count.to_be_bytes())
} }
fn get_latest_educount(&self, server_name: &ServerName) -> Result<u64> { fn get_latest_educount(&self, server_name: &ServerName) -> Result<u64> {
self.servername_educount.get(server_name.as_bytes())?.map_or(Ok(0), |bytes| { self.servername_educount
.get(server_name.as_bytes())?
.map_or(Ok(0), |bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid u64 in servername_educount.")) utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid u64 in servername_educount."))
}) })
} }
@ -124,7 +128,9 @@ fn parse_servercurrentevent(key: &[u8], value: Vec<u8>) -> Result<(OutgoingKind,
let mut parts = key[1..].splitn(2, |&b| b == 0xFF); let mut parts = key[1..].splitn(2, |&b| b == 0xFF);
let server = parts.next().expect("splitn always returns one element"); let server = parts.next().expect("splitn always returns one element");
let event = parts.next().ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; let event = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
let server = utils::string_from_bytes(server) let server = utils::string_from_bytes(server)
.map_err(|_| Error::bad_database("Invalid server bytes in server_currenttransaction"))?; .map_err(|_| Error::bad_database("Invalid server bytes in server_currenttransaction"))?;
@ -146,11 +152,15 @@ fn parse_servercurrentevent(key: &[u8], value: Vec<u8>) -> Result<(OutgoingKind,
let user_id = let user_id =
UserId::parse(user_string).map_err(|_| Error::bad_database("Invalid user id in servercurrentevent"))?; UserId::parse(user_string).map_err(|_| Error::bad_database("Invalid user id in servercurrentevent"))?;
let pushkey = parts.next().ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; let pushkey = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
let pushkey_string = utils::string_from_bytes(pushkey) let pushkey_string = utils::string_from_bytes(pushkey)
.map_err(|_| Error::bad_database("Invalid pushkey in servercurrentevent"))?; .map_err(|_| Error::bad_database("Invalid pushkey in servercurrentevent"))?;
let event = parts.next().ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; let event = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
( (
OutgoingKind::Push(user_id, pushkey_string), OutgoingKind::Push(user_id, pushkey_string),
@ -165,7 +175,9 @@ fn parse_servercurrentevent(key: &[u8], value: Vec<u8>) -> Result<(OutgoingKind,
let mut parts = key.splitn(2, |&b| b == 0xFF); let mut parts = key.splitn(2, |&b| b == 0xFF);
let server = parts.next().expect("splitn always returns one element"); let server = parts.next().expect("splitn always returns one element");
let event = parts.next().ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; let event = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
let server = utils::string_from_bytes(server) let server = utils::string_from_bytes(server)
.map_err(|_| Error::bad_database("Invalid server bytes in server_currenttransaction"))?; .map_err(|_| Error::bad_database("Invalid server bytes in server_currenttransaction"))?;

View file

@ -9,7 +9,10 @@ impl service::uiaa::Data for KeyValueDatabase {
fn set_uiaa_request( fn set_uiaa_request(
&self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue, &self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue,
) -> Result<()> { ) -> Result<()> {
self.userdevicesessionid_uiaarequest.write().unwrap().insert( self.userdevicesessionid_uiaarequest
.write()
.unwrap()
.insert(
(user_id.to_owned(), device_id.to_owned(), session.to_owned()), (user_id.to_owned(), device_id.to_owned(), session.to_owned()),
request.to_owned(), request.to_owned(),
); );
@ -40,7 +43,8 @@ impl service::uiaa::Data for KeyValueDatabase {
&serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"), &serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"),
)?; )?;
} else { } else {
self.userdevicesessionid_uiaainfo.remove(&userdevicesessionid)?; self.userdevicesessionid_uiaainfo
.remove(&userdevicesessionid)?;
} }
Ok(()) Ok(())

View file

@ -34,12 +34,16 @@ impl service::users::Data for KeyValueDatabase {
/// Find out which user an access token belongs to. /// Find out which user an access token belongs to.
fn find_from_token(&self, token: &str) -> Result<Option<(OwnedUserId, String)>> { fn find_from_token(&self, token: &str) -> Result<Option<(OwnedUserId, String)>> {
self.token_userdeviceid.get(token.as_bytes())?.map_or(Ok(None), |bytes| { self.token_userdeviceid
.get(token.as_bytes())?
.map_or(Ok(None), |bytes| {
let mut parts = bytes.split(|&b| b == 0xFF); let mut parts = bytes.split(|&b| b == 0xFF);
let user_bytes = let user_bytes = parts
parts.next().ok_or_else(|| Error::bad_database("User ID in token_userdeviceid is invalid."))?; .next()
let device_bytes = .ok_or_else(|| Error::bad_database("User ID in token_userdeviceid is invalid."))?;
parts.next().ok_or_else(|| Error::bad_database("Device ID in token_userdeviceid is invalid."))?; let device_bytes = parts
.next()
.ok_or_else(|| Error::bad_database("Device ID in token_userdeviceid is invalid."))?;
Ok(Some(( Ok(Some((
UserId::parse( UserId::parse(
@ -79,7 +83,9 @@ impl service::users::Data for KeyValueDatabase {
/// Returns the password hash for the given user. /// Returns the password hash for the given user.
fn password_hash(&self, user_id: &UserId) -> Result<Option<String>> { fn password_hash(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_password.get(user_id.as_bytes())?.map_or(Ok(None), |bytes| { self.userid_password
.get(user_id.as_bytes())?
.map_or(Ok(None), |bytes| {
Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| { Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Password hash in db is not valid string.") Error::bad_database("Password hash in db is not valid string.")
})?)) })?))
@ -90,7 +96,8 @@ impl service::users::Data for KeyValueDatabase {
fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> {
if let Some(password) = password { if let Some(password) = password {
if let Ok(hash) = utils::calculate_password_hash(password) { if let Ok(hash) = utils::calculate_password_hash(password) {
self.userid_password.insert(user_id.as_bytes(), hash.as_bytes())?; self.userid_password
.insert(user_id.as_bytes(), hash.as_bytes())?;
Ok(()) Ok(())
} else { } else {
Err(Error::BadRequest( Err(Error::BadRequest(
@ -106,9 +113,12 @@ impl service::users::Data for KeyValueDatabase {
/// Returns the displayname of a user on this homeserver. /// Returns the displayname of a user on this homeserver.
fn displayname(&self, user_id: &UserId) -> Result<Option<String>> { fn displayname(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_displayname.get(user_id.as_bytes())?.map_or(Ok(None), |bytes| { self.userid_displayname
.get(user_id.as_bytes())?
.map_or(Ok(None), |bytes| {
Ok(Some( Ok(Some(
utils::string_from_bytes(&bytes).map_err(|_| Error::bad_database("Displayname in db is invalid."))?, utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Displayname in db is invalid."))?,
)) ))
}) })
} }
@ -117,7 +127,8 @@ impl service::users::Data for KeyValueDatabase {
/// need to nofify all rooms of this change. /// need to nofify all rooms of this change.
fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) -> Result<()> { fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) -> Result<()> {
if let Some(displayname) = displayname { if let Some(displayname) = displayname {
self.userid_displayname.insert(user_id.as_bytes(), displayname.as_bytes())?; self.userid_displayname
.insert(user_id.as_bytes(), displayname.as_bytes())?;
} else { } else {
self.userid_displayname.remove(user_id.as_bytes())?; self.userid_displayname.remove(user_id.as_bytes())?;
} }
@ -143,7 +154,8 @@ impl service::users::Data for KeyValueDatabase {
/// Sets a new avatar_url or removes it if avatar_url is None. /// Sets a new avatar_url or removes it if avatar_url is None.
fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<OwnedMxcUri>) -> Result<()> { fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<OwnedMxcUri>) -> Result<()> {
if let Some(avatar_url) = avatar_url { if let Some(avatar_url) = avatar_url {
self.userid_avatarurl.insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?; self.userid_avatarurl
.insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?;
} else { } else {
self.userid_avatarurl.remove(user_id.as_bytes())?; self.userid_avatarurl.remove(user_id.as_bytes())?;
} }
@ -167,7 +179,8 @@ impl service::users::Data for KeyValueDatabase {
/// Sets a new avatar_url or removes it if avatar_url is None. /// Sets a new avatar_url or removes it if avatar_url is None.
fn set_blurhash(&self, user_id: &UserId, blurhash: Option<String>) -> Result<()> { fn set_blurhash(&self, user_id: &UserId, blurhash: Option<String>) -> Result<()> {
if let Some(blurhash) = blurhash { if let Some(blurhash) = blurhash {
self.userid_blurhash.insert(user_id.as_bytes(), blurhash.as_bytes())?; self.userid_blurhash
.insert(user_id.as_bytes(), blurhash.as_bytes())?;
} else { } else {
self.userid_blurhash.remove(user_id.as_bytes())?; self.userid_blurhash.remove(user_id.as_bytes())?;
} }
@ -190,7 +203,8 @@ impl service::users::Data for KeyValueDatabase {
userdeviceid.push(0xFF); userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes()); userdeviceid.extend_from_slice(device_id.as_bytes());
self.userid_devicelistversion.increment(user_id.as_bytes())?; self.userid_devicelistversion
.increment(user_id.as_bytes())?;
self.userdeviceid_metadata.insert( self.userdeviceid_metadata.insert(
&userdeviceid, &userdeviceid,
@ -230,7 +244,8 @@ impl service::users::Data for KeyValueDatabase {
// TODO: Remove onetimekeys // TODO: Remove onetimekeys
self.userid_devicelistversion.increment(user_id.as_bytes())?; self.userid_devicelistversion
.increment(user_id.as_bytes())?;
self.userdeviceid_metadata.remove(&userdeviceid)?; self.userdeviceid_metadata.remove(&userdeviceid)?;
@ -242,7 +257,10 @@ impl service::users::Data for KeyValueDatabase {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF); prefix.push(0xFF);
// All devices have metadata // All devices have metadata
Box::new(self.userdeviceid_metadata.scan_prefix(prefix).map(|(bytes, _)| { Box::new(
self.userdeviceid_metadata
.scan_prefix(prefix)
.map(|(bytes, _)| {
Ok(utils::string_from_bytes( Ok(utils::string_from_bytes(
bytes bytes
.rsplit(|&b| b == 0xFF) .rsplit(|&b| b == 0xFF)
@ -251,7 +269,8 @@ impl service::users::Data for KeyValueDatabase {
) )
.map_err(|_| Error::bad_database("Device ID in userdeviceid_metadata is invalid."))? .map_err(|_| Error::bad_database("Device ID in userdeviceid_metadata is invalid."))?
.into()) .into())
})) }),
)
} }
/// Replaces the access token of one device. /// Replaces the access token of one device.
@ -278,8 +297,10 @@ impl service::users::Data for KeyValueDatabase {
} }
// Assign token to user device combination // Assign token to user device combination
self.userdeviceid_token.insert(&userdeviceid, token.as_bytes())?; self.userdeviceid_token
self.token_userdeviceid.insert(token.as_bytes(), &userdeviceid)?; .insert(&userdeviceid, token.as_bytes())?;
self.token_userdeviceid
.insert(token.as_bytes(), &userdeviceid)?;
Ok(()) Ok(())
} }
@ -310,7 +331,9 @@ impl service::users::Data for KeyValueDatabase {
// TODO: Use DeviceKeyId::to_string when it's available (and update everything, // TODO: Use DeviceKeyId::to_string when it's available (and update everything,
// because there are no wrapping quotation marks anymore) // because there are no wrapping quotation marks anymore)
key.extend_from_slice( key.extend_from_slice(
serde_json::to_string(one_time_key_key).expect("DeviceKeyId::to_string always works").as_bytes(), serde_json::to_string(one_time_key_key)
.expect("DeviceKeyId::to_string always works")
.as_bytes(),
); );
self.onetimekeyid_onetimekeys.insert( self.onetimekeyid_onetimekeys.insert(
@ -318,13 +341,16 @@ impl service::users::Data for KeyValueDatabase {
&serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"), &serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"),
)?; )?;
self.userid_lastonetimekeyupdate.insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?; self.userid_lastonetimekeyupdate
.insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?;
Ok(()) Ok(())
} }
fn last_one_time_keys_update(&self, user_id: &UserId) -> Result<u64> { fn last_one_time_keys_update(&self, user_id: &UserId) -> Result<u64> {
self.userid_lastonetimekeyupdate.get(user_id.as_bytes())?.map_or(Ok(0), |bytes| { self.userid_lastonetimekeyupdate
.get(user_id.as_bytes())?
.map_or(Ok(0), |bytes| {
utils::u64_from_bytes(&bytes) utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Count in roomid_lastroomactiveupdate is invalid.")) .map_err(|_| Error::bad_database("Count in roomid_lastroomactiveupdate is invalid."))
}) })
@ -341,7 +367,8 @@ impl service::users::Data for KeyValueDatabase {
prefix.extend_from_slice(key_algorithm.as_ref().as_bytes()); prefix.extend_from_slice(key_algorithm.as_ref().as_bytes());
prefix.push(b':'); prefix.push(b':');
self.userid_lastonetimekeyupdate.insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?; self.userid_lastonetimekeyupdate
.insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?;
self.onetimekeyid_onetimekeys self.onetimekeyid_onetimekeys
.scan_prefix(prefix) .scan_prefix(prefix)
@ -372,7 +399,10 @@ impl service::users::Data for KeyValueDatabase {
let mut counts = BTreeMap::new(); let mut counts = BTreeMap::new();
for algorithm in self.onetimekeyid_onetimekeys.scan_prefix(userdeviceid).map(|(bytes, _)| { for algorithm in self
.onetimekeyid_onetimekeys
.scan_prefix(userdeviceid)
.map(|(bytes, _)| {
Ok::<_, Error>( Ok::<_, Error>(
serde_json::from_slice::<OwnedDeviceKeyId>( serde_json::from_slice::<OwnedDeviceKeyId>(
bytes bytes
@ -415,9 +445,11 @@ impl service::users::Data for KeyValueDatabase {
let (master_key_key, _) = self.parse_master_key(user_id, master_key)?; let (master_key_key, _) = self.parse_master_key(user_id, master_key)?;
self.keyid_key.insert(&master_key_key, master_key.json().get().as_bytes())?; self.keyid_key
.insert(&master_key_key, master_key.json().get().as_bytes())?;
self.userid_masterkeyid.insert(user_id.as_bytes(), &master_key_key)?; self.userid_masterkeyid
.insert(user_id.as_bytes(), &master_key_key)?;
// Self-signing key // Self-signing key
if let Some(self_signing_key) = self_signing_key { if let Some(self_signing_key) = self_signing_key {
@ -441,9 +473,11 @@ impl service::users::Data for KeyValueDatabase {
let mut self_signing_key_key = prefix.clone(); let mut self_signing_key_key = prefix.clone();
self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes()); self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes());
self.keyid_key.insert(&self_signing_key_key, self_signing_key.json().get().as_bytes())?; self.keyid_key
.insert(&self_signing_key_key, self_signing_key.json().get().as_bytes())?;
self.userid_selfsigningkeyid.insert(user_id.as_bytes(), &self_signing_key_key)?; self.userid_selfsigningkeyid
.insert(user_id.as_bytes(), &self_signing_key_key)?;
} }
// User-signing key // User-signing key
@ -468,9 +502,11 @@ impl service::users::Data for KeyValueDatabase {
let mut user_signing_key_key = prefix; let mut user_signing_key_key = prefix;
user_signing_key_key.extend_from_slice(user_signing_key_id.as_bytes()); user_signing_key_key.extend_from_slice(user_signing_key_id.as_bytes());
self.keyid_key.insert(&user_signing_key_key, user_signing_key.json().get().as_bytes())?; self.keyid_key
.insert(&user_signing_key_key, user_signing_key.json().get().as_bytes())?;
self.userid_usersigningkeyid.insert(user_id.as_bytes(), &user_signing_key_key)?; self.userid_usersigningkeyid
.insert(user_id.as_bytes(), &user_signing_key_key)?;
} }
if notify { if notify {
@ -559,9 +595,18 @@ impl service::users::Data for KeyValueDatabase {
fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> {
let count = services().globals.next_count()?.to_be_bytes(); let count = services().globals.next_count()?.to_be_bytes();
for room_id in services().rooms.state_cache.rooms_joined(user_id).filter_map(Result::ok) { for room_id in services()
.rooms
.state_cache
.rooms_joined(user_id)
.filter_map(Result::ok)
{
// Don't send key updates to unencrypted rooms // Don't send key updates to unencrypted rooms
if services().rooms.state_accessor.room_state_get(&room_id, &StateEventType::RoomEncryption, "")?.is_none() if services()
.rooms
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomEncryption, "")?
.is_none()
{ {
continue; continue;
} }
@ -599,11 +644,13 @@ impl service::users::Data for KeyValueDatabase {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF); prefix.push(0xFF);
let master_key = let master_key = master_key
master_key.deserialize().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))?; .deserialize()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))?;
let mut master_key_ids = master_key.keys.values(); let mut master_key_ids = master_key.keys.values();
let master_key_id = let master_key_id = master_key_ids
master_key_ids.next().ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Master key contained no key."))?; .next()
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Master key contained no key."))?;
if master_key_ids.next().is_some() { if master_key_ids.next().is_some() {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
@ -646,7 +693,9 @@ impl service::users::Data for KeyValueDatabase {
} }
fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<Raw<CrossSigningKey>>> { fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<Raw<CrossSigningKey>>> {
self.userid_usersigningkeyid.get(user_id.as_bytes())?.map_or(Ok(None), |key| { self.userid_usersigningkeyid
.get(user_id.as_bytes())?
.map_or(Ok(None), |key| {
self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| {
Ok(Some( Ok(Some(
serde_json::from_slice(&bytes) serde_json::from_slice(&bytes)
@ -743,7 +792,8 @@ impl service::users::Data for KeyValueDatabase {
)); ));
} }
self.userid_devicelistversion.increment(user_id.as_bytes())?; self.userid_devicelistversion
.increment(user_id.as_bytes())?;
self.userdeviceid_metadata.insert( self.userdeviceid_metadata.insert(
&userdeviceid, &userdeviceid,
@ -759,7 +809,9 @@ impl service::users::Data for KeyValueDatabase {
userdeviceid.push(0xFF); userdeviceid.push(0xFF);
userdeviceid.extend_from_slice(device_id.as_bytes()); userdeviceid.extend_from_slice(device_id.as_bytes());
self.userdeviceid_metadata.get(&userdeviceid)?.map_or(Ok(None), |bytes| { self.userdeviceid_metadata
.get(&userdeviceid)?
.map_or(Ok(None), |bytes| {
Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { Ok(Some(serde_json::from_slice(&bytes).map_err(|_| {
Error::bad_database("Metadata in userdeviceid_metadata is invalid.") Error::bad_database("Metadata in userdeviceid_metadata is invalid.")
})?)) })?))
@ -767,8 +819,12 @@ impl service::users::Data for KeyValueDatabase {
} }
fn get_devicelist_version(&self, user_id: &UserId) -> Result<Option<u64>> { fn get_devicelist_version(&self, user_id: &UserId) -> Result<Option<u64>> {
self.userid_devicelistversion.get(user_id.as_bytes())?.map_or(Ok(None), |bytes| { self.userid_devicelistversion
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid devicelistversion in db.")).map(Some) .get(user_id.as_bytes())?
.map_or(Ok(None), |bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid devicelistversion in db."))
.map(Some)
}) })
} }
@ -776,10 +832,14 @@ impl service::users::Data for KeyValueDatabase {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xFF); key.push(0xFF);
Box::new(self.userdeviceid_metadata.scan_prefix(key).map(|(_, bytes)| { Box::new(
self.userdeviceid_metadata
.scan_prefix(key)
.map(|(_, bytes)| {
serde_json::from_slice::<Device>(&bytes) serde_json::from_slice::<Device>(&bytes)
.map_err(|_| Error::bad_database("Device in userdeviceid_metadata is invalid.")) .map_err(|_| Error::bad_database("Device in userdeviceid_metadata is invalid."))
})) }),
)
} }
/// Creates a new sync filter. Returns the filter id. /// Creates a new sync filter. Returns the filter id.
@ -790,7 +850,8 @@ impl service::users::Data for KeyValueDatabase {
key.push(0xFF); key.push(0xFF);
key.extend_from_slice(filter_id.as_bytes()); key.extend_from_slice(filter_id.as_bytes());
self.userfilterid_filter.insert(&key, &serde_json::to_vec(&filter).expect("filter is valid json"))?; self.userfilterid_filter
.insert(&key, &serde_json::to_vec(&filter).expect("filter is valid json"))?;
Ok(filter_id) Ok(filter_id)
} }

View file

@ -375,7 +375,10 @@ impl KeyValueDatabase {
server_signingkeys: builder.open_tree("server_signingkeys")?, server_signingkeys: builder.open_tree("server_signingkeys")?,
pdu_cache: Mutex::new(LruCache::new( pdu_cache: Mutex::new(LruCache::new(
config.pdu_cache_capacity.try_into().expect("pdu cache capacity fits into usize"), config
.pdu_cache_capacity
.try_into()
.expect("pdu cache capacity fits into usize"),
)), )),
auth_chain_cache: Mutex::new(LruCache::new((100_000.0 * config.conduit_cache_capacity_modifier) as usize)), auth_chain_cache: Mutex::new(LruCache::new((100_000.0 * config.conduit_cache_capacity_modifier) as usize)),
shorteventid_cache: Mutex::new(LruCache::new( shorteventid_cache: Mutex::new(LruCache::new(
@ -457,8 +460,11 @@ impl KeyValueDatabase {
.argon .argon
.hash_password(b"", &salt) .hash_password(b"", &salt)
.expect("our own password to be properly hashed"); .expect("our own password to be properly hashed");
let empty_hashed_password = let empty_hashed_password = services()
services().globals.argon.verify_password(&password, &empty_pass).is_ok(); .globals
.argon
.verify_password(&password, &empty_pass)
.is_ok();
if empty_hashed_password { if empty_hashed_password {
db.userid_password.insert(&userid, b"")?; db.userid_password.insert(&userid, b"")?;
@ -526,7 +532,8 @@ impl KeyValueDatabase {
key.push(0xFF); key.push(0xFF);
key.extend_from_slice(event_type); key.extend_from_slice(event_type);
db.roomusertype_roomuserdataid.insert(&key, &roomuserdataid)?; db.roomusertype_roomuserdataid
.insert(&key, &roomuserdataid)?;
} }
services().globals.bump_database_version(5)?; services().globals.bump_database_version(5)?;
@ -565,16 +572,24 @@ impl KeyValueDatabase {
let states_parents = last_roomsstatehash.map_or_else( let states_parents = last_roomsstatehash.map_or_else(
|| Ok(Vec::new()), || Ok(Vec::new()),
|&last_roomsstatehash| { |&last_roomsstatehash| {
services().rooms.state_compressor.load_shortstatehash_info(last_roomsstatehash) services()
.rooms
.state_compressor
.load_shortstatehash_info(last_roomsstatehash)
}, },
)?; )?;
let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() {
let statediffnew = let statediffnew = current_state
current_state.difference(&parent_stateinfo.1).copied().collect::<HashSet<_>>(); .difference(&parent_stateinfo.1)
.copied()
.collect::<HashSet<_>>();
let statediffremoved = let statediffremoved = parent_stateinfo
parent_stateinfo.1.difference(&current_state).copied().collect::<HashSet<_>>(); .1
.difference(&current_state)
.copied()
.collect::<HashSet<_>>();
(statediffnew, statediffremoved) (statediffnew, statediffremoved)
} else { } else {
@ -634,7 +649,12 @@ impl KeyValueDatabase {
let event_id = db.shorteventid_eventid.get(&seventid).unwrap().unwrap(); let event_id = db.shorteventid_eventid.get(&seventid).unwrap().unwrap();
let string = utils::string_from_bytes(&event_id).unwrap(); let string = utils::string_from_bytes(&event_id).unwrap();
let event_id = <&EventId>::try_from(string.as_str()).unwrap(); let event_id = <&EventId>::try_from(string.as_str()).unwrap();
let pdu = services().rooms.timeline.get_pdu(event_id).unwrap().unwrap(); let pdu = services()
.rooms
.timeline
.get_pdu(event_id)
.unwrap()
.unwrap();
if Some(&pdu.room_id) != current_room.as_ref() { if Some(&pdu.room_id) != current_room.as_ref() {
current_room = Some(pdu.room_id.clone()); current_room = Some(pdu.room_id.clone());
@ -676,7 +696,11 @@ impl KeyValueDatabase {
let room_id = parts.next().unwrap(); let room_id = parts.next().unwrap();
let count = parts.next().unwrap(); let count = parts.next().unwrap();
let short_room_id = db.roomid_shortroomid.get(room_id).unwrap().expect("shortroomid should exist"); let short_room_id = db
.roomid_shortroomid
.get(room_id)
.unwrap()
.expect("shortroomid should exist");
let mut new_key = short_room_id; let mut new_key = short_room_id;
new_key.extend_from_slice(count); new_key.extend_from_slice(count);
@ -694,7 +718,11 @@ impl KeyValueDatabase {
let room_id = parts.next().unwrap(); let room_id = parts.next().unwrap();
let count = parts.next().unwrap(); let count = parts.next().unwrap();
let short_room_id = db.roomid_shortroomid.get(room_id).unwrap().expect("shortroomid should exist"); let short_room_id = db
.roomid_shortroomid
.get(room_id)
.unwrap()
.expect("shortroomid should exist");
let mut new_value = short_room_id; let mut new_value = short_room_id;
new_value.extend_from_slice(count); new_value.extend_from_slice(count);
@ -724,8 +752,11 @@ impl KeyValueDatabase {
let _pdu_id_room = parts.next().unwrap(); let _pdu_id_room = parts.next().unwrap();
let pdu_id_count = parts.next().unwrap(); let pdu_id_count = parts.next().unwrap();
let short_room_id = let short_room_id = db
db.roomid_shortroomid.get(room_id).unwrap().expect("shortroomid should exist"); .roomid_shortroomid
.get(room_id)
.unwrap()
.expect("shortroomid should exist");
let mut new_key = short_room_id; let mut new_key = short_room_id;
new_key.extend_from_slice(word); new_key.extend_from_slice(word);
new_key.push(0xFF); new_key.push(0xFF);
@ -765,7 +796,8 @@ impl KeyValueDatabase {
if services().globals.database_version()? < 10 { if services().globals.database_version()? < 10 {
// Add other direction for shortstatekeys // Add other direction for shortstatekeys
for (statekey, shortstatekey) in db.statekey_shortstatekey.iter() { for (statekey, shortstatekey) in db.statekey_shortstatekey.iter() {
db.shortstatekey_statekey.insert(&shortstatekey, &statekey)?; db.shortstatekey_statekey
.insert(&shortstatekey, &statekey)?;
} }
// Force E2EE device list updates so we can send them over federation // Force E2EE device list updates so we can send them over federation
@ -779,7 +811,9 @@ impl KeyValueDatabase {
} }
if services().globals.database_version()? < 11 { if services().globals.database_version()? < 11 {
db.db.open_tree("userdevicesessionid_uiaarequest")?.clear()?; db.db
.open_tree("userdevicesessionid_uiaarequest")?
.clear()?;
services().globals.bump_database_version(11)?; services().globals.bump_database_version(11)?;
warn!("Migration: 10 -> 11 finished"); warn!("Migration: 10 -> 11 finished");
@ -813,7 +847,9 @@ impl KeyValueDatabase {
if rule.is_some() { if rule.is_some() {
let mut rule = rule.unwrap().clone(); let mut rule = rule.unwrap().clone();
content_rule_transformation[1].clone_into(&mut rule.rule_id); content_rule_transformation[1].clone_into(&mut rule.rule_id);
rules_list.content.shift_remove(content_rule_transformation[0]); rules_list
.content
.shift_remove(content_rule_transformation[0]);
rules_list.content.insert(rule); rules_list.content.insert(rule);
} }
} }
@ -874,7 +910,10 @@ impl KeyValueDatabase {
let mut account_data = serde_json::from_str::<PushRulesEvent>(raw_rules_list.get()).unwrap(); let mut account_data = serde_json::from_str::<PushRulesEvent>(raw_rules_list.get()).unwrap();
let user_default_rules = Ruleset::server_default(&user); let user_default_rules = Ruleset::server_default(&user);
account_data.content.global.update_with_server_default(user_default_rules); account_data
.content
.global
.update_with_server_default(user_default_rules);
services().account_data.update( services().account_data.update(
None, None,
@ -936,7 +975,10 @@ impl KeyValueDatabase {
warn!( warn!(
"User {} matches the following forbidden username patterns: {}", "User {} matches the following forbidden username patterns: {}",
user_id.to_string(), user_id.to_string(),
matches.into_iter().map(|x| &patterns.patterns()[x]).join(", ") matches
.into_iter()
.map(|x| &patterns.patterns()[x])
.join(", ")
); );
} }
} }
@ -957,7 +999,10 @@ impl KeyValueDatabase {
"Room with alias {} ({}) matches the following forbidden room name patterns: {}", "Room with alias {} ({}) matches the following forbidden room name patterns: {}",
room_alias, room_alias,
&room_id, &room_id,
matches.into_iter().map(|x| &patterns.patterns()[x]).join(", ") matches
.into_iter()
.map(|x| &patterns.patterns()[x])
.join(", ")
); );
} }
} }
@ -971,7 +1016,9 @@ impl KeyValueDatabase {
latest_database_version latest_database_version
); );
} else { } else {
services().globals.bump_database_version(latest_database_version)?; services()
.globals
.bump_database_version(latest_database_version)?;
// Create the admin room and server user on first run // Create the admin room and server user on first run
services().admin.create_admin_room().await?; services().admin.create_admin_room().await?;
@ -993,9 +1040,11 @@ impl KeyValueDatabase {
"The Conduit account emergency password is set! Please unset it as soon as you finish admin \ "The Conduit account emergency password is set! Please unset it as soon as you finish admin \
account recovery!" account recovery!"
); );
services().admin.send_message(RoomMessageEventContent::text_plain( services()
"The Conduit account emergency password is set! Please unset it as soon as you finish admin \ .admin
account recovery!", .send_message(RoomMessageEventContent::text_plain(
"The Conduit account emergency password is set! Please unset it as soon as you finish \
admin account recovery!",
)); ));
} }
}, },
@ -1047,8 +1096,13 @@ impl KeyValueDatabase {
} }
async fn try_handle_updates() -> Result<()> { async fn try_handle_updates() -> Result<()> {
let response = let response = services()
services().globals.client.default.get("https://pupbrain.dev/check-for-updates/stable").send().await?; .globals
.client
.default
.get("https://pupbrain.dev/check-for-updates/stable")
.send()
.await?;
let response = serde_json::from_str::<CheckForUpdatesResponse>(&response.text().await?).map_err(|e| { let response = serde_json::from_str::<CheckForUpdatesResponse>(&response.text().await?).map_err(|e| {
error!("Bad check for updates response: {e}"); error!("Bad check for updates response: {e}");
@ -1060,13 +1114,17 @@ impl KeyValueDatabase {
last_update_id = last_update_id.max(update.id); last_update_id = last_update_id.max(update.id);
if update.id > services().globals.last_check_for_updates_id()? { if update.id > services().globals.last_check_for_updates_id()? {
error!("{}", update.message); error!("{}", update.message);
services().admin.send_message(RoomMessageEventContent::text_plain(format!( services()
.admin
.send_message(RoomMessageEventContent::text_plain(format!(
"@room: the following is a message from the conduwuit puppy. it was sent on '{}':\n\n{}", "@room: the following is a message from the conduwuit puppy. it was sent on '{}':\n\n{}",
update.date, update.message update.date, update.message
))); )));
} }
} }
services().globals.update_check_for_updates_id(last_update_id)?; services()
.globals
.update_check_for_updates_id(last_update_id)?;
Ok(()) Ok(())
} }
@ -1138,7 +1196,9 @@ fn set_emergency_access() -> Result<bool> {
let conduit_user = UserId::parse_with_server_name("conduit", services().globals.server_name()) let conduit_user = UserId::parse_with_server_name("conduit", services().globals.server_name())
.expect("@conduit:server_name is a valid UserId"); .expect("@conduit:server_name is a valid UserId");
services().users.set_password(&conduit_user, services().globals.emergency_password().as_deref())?; services()
.users
.set_password(&conduit_user, services().globals.emergency_password().as_deref())?;
let (ruleset, res) = match services().globals.emergency_password() { let (ruleset, res) = match services().globals.emergency_password() {
Some(_) => (Ruleset::server_default(&conduit_user), Ok(true)), Some(_) => (Ruleset::server_default(&conduit_user), Ok(true)),

View file

@ -20,5 +20,8 @@ pub use utils::error::{Error, Result};
pub static SERVICES: RwLock<Option<&'static Services<'static>>> = RwLock::new(None); pub static SERVICES: RwLock<Option<&'static Services<'static>>> = RwLock::new(None);
pub fn services() -> &'static Services<'static> { pub fn services() -> &'static Services<'static> {
SERVICES.read().unwrap().expect("SERVICES should be initialized when this is called") SERVICES
.read()
.unwrap()
.expect("SERVICES should be initialized when this is called")
} }

View file

@ -111,7 +111,9 @@ async fn main() {
}, },
}; };
let subscriber = tracing_subscriber::Registry::default().with(filter_layer).with(telemetry); let subscriber = tracing_subscriber::Registry::default()
.with(filter_layer)
.with(telemetry);
tracing::subscriber::set_global_default(subscriber).unwrap(); tracing::subscriber::set_global_default(subscriber).unwrap();
} }
} else if config.tracing_flame { } else if config.tracing_flame {
@ -292,21 +294,30 @@ async fn main() {
); );
} }
if config.url_preview_domain_contains_allowlist.contains(&"*".to_owned()) { if config
.url_preview_domain_contains_allowlist
.contains(&"*".to_owned())
{
warn!( warn!(
"All URLs are allowed for URL previews via setting \"url_preview_domain_contains_allowlist\" to \"*\". \ "All URLs are allowed for URL previews via setting \"url_preview_domain_contains_allowlist\" to \"*\". \
This opens up significant attack surface to your server. You are expected to be aware of the risks by \ This opens up significant attack surface to your server. You are expected to be aware of the risks by \
doing this." doing this."
); );
} }
if config.url_preview_domain_explicit_allowlist.contains(&"*".to_owned()) { if config
.url_preview_domain_explicit_allowlist
.contains(&"*".to_owned())
{
warn!( warn!(
"All URLs are allowed for URL previews via setting \"url_preview_domain_explicit_allowlist\" to \"*\". \ "All URLs are allowed for URL previews via setting \"url_preview_domain_explicit_allowlist\" to \"*\". \
This opens up significant attack surface to your server. You are expected to be aware of the risks by \ This opens up significant attack surface to your server. You are expected to be aware of the risks by \
doing this." doing this."
); );
} }
if config.url_preview_url_contains_allowlist.contains(&"*".to_owned()) { if config
.url_preview_url_contains_allowlist
.contains(&"*".to_owned())
{
warn!( warn!(
"All URLs are allowed for URL previews via setting \"url_preview_url_contains_allowlist\" to \"*\". This \ "All URLs are allowed for URL previews via setting \"url_preview_url_contains_allowlist\" to \"*\". This \
opens up significant attack surface to your server. You are expected to be aware of the risks by doing \ opens up significant attack surface to your server. You are expected to be aware of the risks by doing \
@ -338,9 +349,17 @@ async fn run_server() -> io::Result<()> {
// Left is only 1 value, so make a vec with 1 value only // Left is only 1 value, so make a vec with 1 value only
let port_vec = [port]; let port_vec = [port];
port_vec.iter().copied().map(|port| SocketAddr::from((config.address, *port))).collect::<Vec<_>>() port_vec
.iter()
.copied()
.map(|port| SocketAddr::from((config.address, *port)))
.collect::<Vec<_>>()
}, },
Right(ports) => ports.iter().copied().map(|port| SocketAddr::from((config.address, port))).collect::<Vec<_>>(), Right(ports) => ports
.iter()
.copied()
.map(|port| SocketAddr::from((config.address, port)))
.collect::<Vec<_>>(),
}; };
let x_requested_with = HeaderName::from_static("x-requested-with"); let x_requested_with = HeaderName::from_static("x-requested-with");
@ -385,7 +404,10 @@ async fn run_server() -> io::Result<()> {
.max_age(Duration::from_secs(86400)), .max_age(Duration::from_secs(86400)),
) )
.layer(DefaultBodyLimit::max( .layer(DefaultBodyLimit::max(
config.max_request_size.try_into().expect("failed to convert max request size"), config
.max_request_size
.try_into()
.expect("failed to convert max request size"),
)); ));
let app; let app;
@ -395,7 +417,9 @@ async fn run_server() -> io::Result<()> {
{ {
app = if cfg!(feature = "zstd_compression") && config.zstd_compression { app = if cfg!(feature = "zstd_compression") && config.zstd_compression {
debug!("zstd body compression is enabled"); debug!("zstd body compression is enabled");
routes().layer(middlewares.compression()).into_make_service() routes()
.layer(middlewares.compression())
.into_make_service()
} else { } else {
routes().layer(middlewares).into_make_service() routes().layer(middlewares).into_make_service()
} }
@ -432,7 +456,9 @@ async fn run_server() -> io::Result<()> {
let octal_perms = u32::from_str_radix(&socket_perms, 8).unwrap(); let octal_perms = u32::from_str_radix(&socket_perms, 8).unwrap();
let listener = tokio::net::UnixListener::bind(path.clone())?; let listener = tokio::net::UnixListener::bind(path.clone())?;
tokio::fs::set_permissions(path, Permissions::from_mode(octal_perms)).await.unwrap(); tokio::fs::set_permissions(path, Permissions::from_mode(octal_perms))
.await
.unwrap();
let socket = SocketIncoming::from_listener(listener); let socket = SocketIncoming::from_listener(listener);
#[allow(clippy::let_underscore_untyped)] // error[E0658]: attributes on expressions are experimental #[allow(clippy::let_underscore_untyped)] // error[E0658]: attributes on expressions are experimental
@ -483,7 +509,11 @@ async fn run_server() -> io::Result<()> {
} }
} else { } else {
for addr in &addrs { for addr in &addrs {
join_set.spawn(bind_rustls(*addr, conf.clone()).handle(handle.clone()).serve(app.clone())); join_set.spawn(
bind_rustls(*addr, conf.clone())
.handle(handle.clone())
.serve(app.clone()),
);
} }
} }
@ -528,7 +558,9 @@ async fn spawn_task<B: Send + 'static>(
if services().globals.shutdown.load(atomic::Ordering::Relaxed) { if services().globals.shutdown.load(atomic::Ordering::Relaxed) {
return Err(StatusCode::SERVICE_UNAVAILABLE); return Err(StatusCode::SERVICE_UNAVAILABLE);
} }
tokio::spawn(next.run(req)).await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) tokio::spawn(next.run(req))
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
} }
async fn unrecognized_method<B: Send + 'static>( async fn unrecognized_method<B: Send + 'static>(
@ -758,7 +790,9 @@ fn routes() -> Router {
async fn shutdown_signal(handle: ServerHandle, tx: Sender<()>) -> Result<()> { async fn shutdown_signal(handle: ServerHandle, tx: Sender<()>) -> Result<()> {
let ctrl_c = async { let ctrl_c = async {
signal::ctrl_c().await.expect("failed to install Ctrl+C handler"); signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
}; };
#[cfg(unix)] #[cfg(unix)]

View file

@ -62,7 +62,11 @@ pub(crate) async fn process(command: AppserviceCommand, body: Vec<&str>) -> Resu
}, },
AppserviceCommand::Unregister { AppserviceCommand::Unregister {
appservice_identifier, appservice_identifier,
} => match services().appservice.unregister_appservice(&appservice_identifier).await { } => match services()
.appservice
.unregister_appservice(&appservice_identifier)
.await
{
Ok(()) => Ok(RoomMessageEventContent::text_plain("Appservice unregistered.")), Ok(()) => Ok(RoomMessageEventContent::text_plain("Appservice unregistered.")),
Err(e) => Ok(RoomMessageEventContent::text_plain(format!( Err(e) => Ok(RoomMessageEventContent::text_plain(format!(
"Failed to unregister appservice: {e}" "Failed to unregister appservice: {e}"
@ -70,7 +74,11 @@ pub(crate) async fn process(command: AppserviceCommand, body: Vec<&str>) -> Resu
}, },
AppserviceCommand::Show { AppserviceCommand::Show {
appservice_identifier, appservice_identifier,
} => match services().appservice.get_registration(&appservice_identifier).await { } => match services()
.appservice
.get_registration(&appservice_identifier)
.await
{
Some(config) => { Some(config) => {
let config_str = serde_yaml::to_string(&config).expect("config should've been validated on register"); let config_str = serde_yaml::to_string(&config).expect("config should've been validated on register");
let output = format!("Config for {}:\n\n```yaml\n{}\n```", appservice_identifier, config_str,); let output = format!("Config for {}:\n\n```yaml\n{}\n```", appservice_identifier, config_str,);

View file

@ -81,7 +81,12 @@ pub(crate) async fn process(command: DebugCommand, body: Vec<&str>) -> Result<Ro
let room_id = <&RoomId>::try_from(room_id_str) let room_id = <&RoomId>::try_from(room_id_str)
.map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?;
let start = Instant::now(); let start = Instant::now();
let count = services().rooms.auth_chain.get_auth_chain(room_id, vec![event_id]).await?.count(); let count = services()
.rooms
.auth_chain
.get_auth_chain(room_id, vec![event_id])
.await?
.count();
let elapsed = start.elapsed(); let elapsed = start.elapsed();
RoomMessageEventContent::text_plain(format!("Loaded auth chain with length {count} in {elapsed:?}")) RoomMessageEventContent::text_plain(format!("Loaded auth chain with length {count} in {elapsed:?}"))
} else { } else {
@ -119,7 +124,10 @@ pub(crate) async fn process(command: DebugCommand, body: Vec<&str>) -> Result<Ro
event_id, event_id,
} => { } => {
let mut outlier = false; let mut outlier = false;
let mut pdu_json = services().rooms.timeline.get_non_outlier_pdu_json(&event_id)?; let mut pdu_json = services()
.rooms
.timeline
.get_non_outlier_pdu_json(&event_id)?;
if pdu_json.is_none() { if pdu_json.is_none() {
outlier = true; outlier = true;
pdu_json = services().rooms.timeline.get_pdu_json(&event_id)?; pdu_json = services().rooms.timeline.get_pdu_json(&event_id)?;
@ -224,7 +232,11 @@ pub(crate) async fn process(command: DebugCommand, body: Vec<&str>) -> Result<Ro
}); });
info!("Attempting to handle event ID {event_id} as backfilled PDU"); info!("Attempting to handle event ID {event_id} as backfilled PDU");
services().rooms.timeline.backfill_pdu(&server, response.pdu, &pub_key_map).await?; services()
.rooms
.timeline
.backfill_pdu(&server, response.pdu, &pub_key_map)
.await?;
let json_text = serde_json::to_string_pretty(&json).expect("canonical json is valid json"); let json_text = serde_json::to_string_pretty(&json).expect("canonical json is valid json");

View file

@ -88,7 +88,11 @@ pub(crate) async fn process(command: FederationCommand, body: Vec<&str>) -> Resu
Ok(value) => { Ok(value) => {
let pub_key_map = RwLock::new(BTreeMap::new()); let pub_key_map = RwLock::new(BTreeMap::new());
services().rooms.event_handler.fetch_required_signing_keys([&value], &pub_key_map).await?; services()
.rooms
.event_handler
.fetch_required_signing_keys([&value], &pub_key_map)
.await?;
let pub_key_map = pub_key_map.read().await; let pub_key_map = pub_key_map.read().await;
match ruma::signatures::verify_json(&pub_key_map, &value) { match ruma::signatures::verify_json(&pub_key_map, &value) {

View file

@ -202,7 +202,10 @@ pub(crate) async fn process(command: MediaCommand, body: Vec<&str>) -> Result<Ro
MediaCommand::DeletePastRemoteMedia { MediaCommand::DeletePastRemoteMedia {
duration, duration,
} => { } => {
let deleted_count = services().media.delete_all_remote_media_at_after_time(duration).await?; let deleted_count = services()
.media
.delete_all_remote_media_at_after_time(duration)
.await?;
Ok(RoomMessageEventContent::text_plain(format!( Ok(RoomMessageEventContent::text_plain(format!(
"Deleted {} total files.", "Deleted {} total files.",

View file

@ -169,11 +169,15 @@ impl Service {
} }
pub fn process_message(&self, room_message: String, event_id: Arc<EventId>) { pub fn process_message(&self, room_message: String, event_id: Arc<EventId>) {
self.sender.send(AdminRoomEvent::ProcessMessage(room_message, event_id)).unwrap(); self.sender
.send(AdminRoomEvent::ProcessMessage(room_message, event_id))
.unwrap();
} }
pub fn send_message(&self, message_content: RoomMessageEventContent) { pub fn send_message(&self, message_content: RoomMessageEventContent) {
self.sender.send(AdminRoomEvent::SendMessage(message_content)).unwrap(); self.sender
.send(AdminRoomEvent::SendMessage(message_content))
.unwrap();
} }
// Parse and process a message from the admin room // Parse and process a message from the admin room
@ -276,7 +280,10 @@ impl Service {
if let Some(line_index) = text_lines.iter().position(|line| *line == "[commandbody]") { if let Some(line_index) = text_lines.iter().position(|line| *line == "[commandbody]") {
text_lines.remove(line_index); text_lines.remove(line_index);
while text_lines.get(line_index).is_some_and(|line| line.starts_with('#')) { while text_lines
.get(line_index)
.is_some_and(|line| line.starts_with('#'))
{
command_body += if text_lines[line_index].starts_with("# ") { command_body += if text_lines[line_index].starts_with("# ") {
&text_lines[line_index][2..] &text_lines[line_index][2..]
} else { } else {
@ -304,7 +311,9 @@ impl Service {
// Add HTML line-breaks // Add HTML line-breaks
text.replace("\n\n\n", "\n\n").replace('\n', "<br>\n").replace("[nobr]<br>", "") text.replace("\n\n\n", "\n\n")
.replace('\n', "<br>\n")
.replace("[nobr]<br>", "")
} }
/// Create the admin room. /// Create the admin room.
@ -316,8 +325,15 @@ impl Service {
services().rooms.short.get_or_create_shortroomid(&room_id)?; services().rooms.short.get_or_create_shortroomid(&room_id)?;
let mutex_state = let mutex_state = Arc::clone(
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.clone()).or_default()); services()
.globals
.roomid_mutex_state
.write()
.await
.entry(room_id.clone())
.or_default(),
);
let state_lock = mutex_state.lock().await; let state_lock = mutex_state.lock().await;
// Create a user for the server // Create a user for the server
@ -560,7 +576,10 @@ impl Service {
.try_into() .try_into()
.expect("#admins:server_name is a valid alias name"); .expect("#admins:server_name is a valid alias name");
services().rooms.alias.resolve_local_alias(&admin_room_alias) services()
.rooms
.alias
.resolve_local_alias(&admin_room_alias)
} }
/// Invite the user to the conduit admin room. /// Invite the user to the conduit admin room.
@ -568,8 +587,15 @@ impl Service {
/// In conduit, this is equivalent to granting admin privileges. /// In conduit, this is equivalent to granting admin privileges.
pub(crate) async fn make_user_admin(&self, user_id: &UserId, displayname: String) -> Result<()> { pub(crate) async fn make_user_admin(&self, user_id: &UserId, displayname: String) -> Result<()> {
if let Some(room_id) = Self::get_admin_room()? { if let Some(room_id) = Self::get_admin_room()? {
let mutex_state = let mutex_state = Arc::clone(
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.clone()).or_default()); services()
.globals
.roomid_mutex_state
.write()
.await
.entry(room_id.clone())
.or_default(),
);
let state_lock = mutex_state.lock().await; let state_lock = mutex_state.lock().await;
// Use the server user to grant the new admin's power level // Use the server user to grant the new admin's power level
@ -681,13 +707,29 @@ impl Service {
} }
} }
fn escape_html(s: &str) -> String { s.replace('&', "&amp;").replace('<', "&lt;").replace('>', "&gt;") } fn escape_html(s: &str) -> String {
s.replace('&', "&amp;")
.replace('<', "&lt;")
.replace('>', "&gt;")
}
fn get_room_info(id: &OwnedRoomId) -> (OwnedRoomId, u64, String) { fn get_room_info(id: &OwnedRoomId) -> (OwnedRoomId, u64, String) {
( (
id.clone(), id.clone(),
services().rooms.state_cache.room_joined_count(id).ok().flatten().unwrap_or(0), services()
services().rooms.state_accessor.get_name(id).ok().flatten().unwrap_or_else(|| id.to_string()), .rooms
.state_cache
.room_joined_count(id)
.ok()
.flatten()
.unwrap_or(0),
services()
.rooms
.state_accessor
.get_name(id)
.ok()
.flatten()
.unwrap_or_else(|| id.to_string()),
) )
} }
@ -705,7 +747,9 @@ mod test {
fn get_help_subcommand() { get_help_inner("help"); } fn get_help_subcommand() { get_help_inner("help"); }
fn get_help_inner(input: &str) { fn get_help_inner(input: &str) {
let error = AdminCommand::try_parse_from(["argv[0] doesn't matter", input]).unwrap_err().to_string(); let error = AdminCommand::try_parse_from(["argv[0] doesn't matter", input])
.unwrap_err()
.to_string();
// Search for a handful of keywords that suggest the help printed properly // Search for a handful of keywords that suggest the help printed properly
assert!(error.contains("Usage:")); assert!(error.contains("Usage:"));

View file

@ -55,7 +55,11 @@ pub(crate) async fn process(command: RoomCommand, body: Vec<&str>) -> Result<Roo
rooms.sort_by_key(|r| r.1); rooms.sort_by_key(|r| r.1);
rooms.reverse(); rooms.reverse();
let rooms = rooms.into_iter().skip(page.saturating_sub(1) * PAGE_SIZE).take(PAGE_SIZE).collect::<Vec<_>>(); let rooms = rooms
.into_iter()
.skip(page.saturating_sub(1) * PAGE_SIZE)
.take(PAGE_SIZE)
.collect::<Vec<_>>();
if rooms.is_empty() { if rooms.is_empty() {
return Ok(RoomMessageEventContent::text_plain("No more rooms.")); return Ok(RoomMessageEventContent::text_plain("No more rooms."));
@ -72,7 +76,9 @@ pub(crate) async fn process(command: RoomCommand, body: Vec<&str>) -> Result<Roo
let output_html = format!( let output_html = format!(
"<table><caption>Room list - page \ "<table><caption>Room list - page \
{page}</caption>\n<tr><th>id</th>\t<th>members</th>\t<th>name</th></tr>\n{}</table>", {page}</caption>\n<tr><th>id</th>\t<th>members</th>\t<th>name</th></tr>\n{}</table>",
rooms.iter().fold(String::new(), |mut output, (id, members, name)| { rooms
.iter()
.fold(String::new(), |mut output, (id, members, name)| {
writeln!( writeln!(
output, output,
"<tr><td>{}</td>\t<td>{}</td>\t<td>{}</td></tr>", "<tr><td>{}</td>\t<td>{}</td>\t<td>{}</td></tr>",

View file

@ -107,7 +107,11 @@ pub(crate) async fn process(command: RoomAliasCommand, _body: Vec<&str>) -> Resu
room_id, room_id,
} => { } => {
if let Some(room_id) = room_id { if let Some(room_id) = room_id {
let aliases = services().rooms.alias.local_aliases_for_room(&room_id).collect::<Result<Vec<_>, _>>(); let aliases = services()
.rooms
.alias
.local_aliases_for_room(&room_id)
.collect::<Result<Vec<_>, _>>();
match aliases { match aliases {
Ok(aliases) => { Ok(aliases) => {
let plain_list = aliases.iter().fold(String::new(), |mut output, alias| { let plain_list = aliases.iter().fold(String::new(), |mut output, alias| {
@ -127,16 +131,24 @@ pub(crate) async fn process(command: RoomAliasCommand, _body: Vec<&str>) -> Resu
Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to list aliases: {}", err))), Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to list aliases: {}", err))),
} }
} else { } else {
let aliases = services().rooms.alias.all_local_aliases().collect::<Result<Vec<_>, _>>(); let aliases = services()
.rooms
.alias
.all_local_aliases()
.collect::<Result<Vec<_>, _>>();
match aliases { match aliases {
Ok(aliases) => { Ok(aliases) => {
let server_name = services().globals.server_name(); let server_name = services().globals.server_name();
let plain_list = aliases.iter().fold(String::new(), |mut output, (alias, id)| { let plain_list = aliases
.iter()
.fold(String::new(), |mut output, (alias, id)| {
writeln!(output, "- `{alias}` -> #{id}:{server_name}").unwrap(); writeln!(output, "- `{alias}` -> #{id}:{server_name}").unwrap();
output output
}); });
let html_list = aliases.iter().fold(String::new(), |mut output, (alias, id)| { let html_list = aliases
.iter()
.fold(String::new(), |mut output, (alias, id)| {
writeln!( writeln!(
output, output,
"<li><code>{}</code> -> #{}:{}</li>", "<li><code>{}</code> -> #{}:{}</li>",

View file

@ -58,7 +58,11 @@ pub(crate) async fn process(command: RoomDirectoryCommand, _body: Vec<&str>) ->
rooms.sort_by_key(|r| r.1); rooms.sort_by_key(|r| r.1);
rooms.reverse(); rooms.reverse();
let rooms = rooms.into_iter().skip(page.saturating_sub(1) * PAGE_SIZE).take(PAGE_SIZE).collect::<Vec<_>>(); let rooms = rooms
.into_iter()
.skip(page.saturating_sub(1) * PAGE_SIZE)
.take(PAGE_SIZE)
.collect::<Vec<_>>();
if rooms.is_empty() { if rooms.is_empty() {
return Ok(RoomMessageEventContent::text_plain("No more rooms.")); return Ok(RoomMessageEventContent::text_plain("No more rooms."));
@ -75,7 +79,9 @@ pub(crate) async fn process(command: RoomDirectoryCommand, _body: Vec<&str>) ->
let output_html = format!( let output_html = format!(
"<table><caption>Room directory - page \ "<table><caption>Room directory - page \
{page}</caption>\n<tr><th>id</th>\t<th>members</th>\t<th>name</th></tr>\n{}</table>", {page}</caption>\n<tr><th>id</th>\t<th>members</th>\t<th>name</th></tr>\n{}</table>",
rooms.iter().fold(String::new(), |mut output, (id, members, name)| { rooms
.iter()
.fold(String::new(), |mut output, (id, members, name)| {
writeln!( writeln!(
output, output,
"<tr><td>{}</td>\t<td>{}</td>\t<td>{}</td></tr>", "<tr><td>{}</td>\t<td>{}</td>\t<td>{}</td></tr>",

View file

@ -446,7 +446,11 @@ pub(crate) async fn process(command: RoomModerationCommand, body: Vec<&str>) ->
)) ))
}, },
RoomModerationCommand::ListBannedRooms => { RoomModerationCommand::ListBannedRooms => {
let rooms = services().rooms.metadata.list_banned_rooms().collect::<Result<Vec<_>, _>>(); let rooms = services()
.rooms
.metadata
.list_banned_rooms()
.collect::<Result<Vec<_>, _>>();
match rooms { match rooms {
Ok(room_ids) => { Ok(room_ids) => {

View file

@ -115,13 +115,18 @@ pub(crate) async fn process(command: UserCommand, body: Vec<&str>) -> Result<Roo
displayname.push_str(&(" ".to_owned() + services().globals.new_user_displayname_suffix())); displayname.push_str(&(" ".to_owned() + services().globals.new_user_displayname_suffix()));
} }
services().users.set_displayname(&user_id, Some(displayname)).await?; services()
.users
.set_displayname(&user_id, Some(displayname))
.await?;
// Initial account data // Initial account data
services().account_data.update( services().account_data.update(
None, None,
&user_id, &user_id,
ruma::events::GlobalAccountDataEventType::PushRules.to_string().into(), ruma::events::GlobalAccountDataEventType::PushRules
.to_string()
.into(),
&serde_json::to_value(ruma::events::push_rules::PushRulesEvent { &serde_json::to_value(ruma::events::push_rules::PushRulesEvent {
content: ruma::events::push_rules::PushRulesEventContent { content: ruma::events::push_rules::PushRulesEventContent {
global: ruma::push::Ruleset::server_default(&user_id), global: ruma::push::Ruleset::server_default(&user_id),
@ -132,7 +137,11 @@ pub(crate) async fn process(command: UserCommand, body: Vec<&str>) -> Result<Roo
if !services().globals.config.auto_join_rooms.is_empty() { if !services().globals.config.auto_join_rooms.is_empty() {
for room in &services().globals.config.auto_join_rooms { for room in &services().globals.config.auto_join_rooms {
if !services().rooms.state_cache.server_in_room(services().globals.server_name(), room)? { if !services()
.rooms
.state_cache
.server_in_room(services().globals.server_name(), room)?
{
warn!("Skipping room {room} to automatically join as we have never joined before."); warn!("Skipping room {room} to automatically join as we have never joined before.");
continue; continue;
} }
@ -230,7 +239,10 @@ pub(crate) async fn process(command: UserCommand, body: Vec<&str>) -> Result<Roo
let new_password = utils::random_string(AUTO_GEN_PASSWORD_LENGTH); let new_password = utils::random_string(AUTO_GEN_PASSWORD_LENGTH);
match services().users.set_password(&user_id, Some(new_password.as_str())) { match services()
.users
.set_password(&user_id, Some(new_password.as_str()))
{
Ok(()) => Ok(RoomMessageEventContent::text_plain(format!( Ok(()) => Ok(RoomMessageEventContent::text_plain(format!(
"Successfully reset the password for user {user_id}: `{new_password}`" "Successfully reset the password for user {user_id}: `{new_password}`"
))), ))),
@ -343,7 +355,9 @@ pub(crate) async fn process(command: UserCommand, body: Vec<&str>) -> Result<Roo
let output_html = format!( let output_html = format!(
"<table><caption>Rooms {user_id} \ "<table><caption>Rooms {user_id} \
Joined</caption>\n<tr><th>id</th>\t<th>members</th>\t<th>name</th></tr>\n{}</table>", Joined</caption>\n<tr><th>id</th>\t<th>members</th>\t<th>name</th></tr>\n{}</table>",
rooms.iter().fold(String::new(), |mut output, (id, members, name)| { rooms
.iter()
.fold(String::new(), |mut output, (id, members, name)| {
writeln!( writeln!(
output, output,
"<tr><td>{}</td>\t<td>{}</td>\t<td>{}</td></tr>", "<tr><td>{}</td>\t<td>{}</td>\t<td>{}</td></tr>",

View file

@ -107,7 +107,10 @@ impl Service {
for appservice in db.all()? { for appservice in db.all()? {
registration_info.insert( registration_info.insert(
appservice.0, appservice.0,
appservice.1.try_into().expect("Should be validated on registration"), appservice
.1
.try_into()
.expect("Should be validated on registration"),
); );
} }
@ -119,7 +122,12 @@ impl Service {
/// Registers an appservice and returns the ID to the caller /// Registers an appservice and returns the ID to the caller
pub async fn register_appservice(&self, yaml: Registration) -> Result<String> { pub async fn register_appservice(&self, yaml: Registration) -> Result<String> {
services().appservice.registration_info.write().await.insert(yaml.id.clone(), yaml.clone().try_into()?); services()
.appservice
.registration_info
.write()
.await
.insert(yaml.id.clone(), yaml.clone().try_into()?);
self.db.register_appservice(yaml) self.db.register_appservice(yaml)
} }
@ -130,19 +138,40 @@ impl Service {
/// ///
/// * `service_name` - the name you send to register the service previously /// * `service_name` - the name you send to register the service previously
pub async fn unregister_appservice(&self, service_name: &str) -> Result<()> { pub async fn unregister_appservice(&self, service_name: &str) -> Result<()> {
services().appservice.registration_info.write().await.remove(service_name); services()
.appservice
.registration_info
.write()
.await
.remove(service_name);
self.db.unregister_appservice(service_name) self.db.unregister_appservice(service_name)
} }
pub async fn get_registration(&self, id: &str) -> Option<Registration> { pub async fn get_registration(&self, id: &str) -> Option<Registration> {
self.registration_info.read().await.get(id).cloned().map(|info| info.registration) self.registration_info
.read()
.await
.get(id)
.cloned()
.map(|info| info.registration)
} }
pub async fn iter_ids(&self) -> Vec<String> { self.registration_info.read().await.keys().cloned().collect() } pub async fn iter_ids(&self) -> Vec<String> {
self.registration_info
.read()
.await
.keys()
.cloned()
.collect()
}
pub async fn find_from_token(&self, token: &str) -> Option<RegistrationInfo> { pub async fn find_from_token(&self, token: &str) -> Option<RegistrationInfo> {
self.read().await.values().find(|info| info.registration.as_token == token).cloned() self.read()
.await
.values()
.find(|info| info.registration.as_token == token)
.cloned()
} }
pub fn read(&self) -> impl Future<Output = tokio::sync::RwLockReadGuard<'_, BTreeMap<String, RegistrationInfo>>> { pub fn read(&self) -> impl Future<Output = tokio::sync::RwLockReadGuard<'_, BTreeMap<String, RegistrationInfo>>> {

View file

@ -106,8 +106,10 @@ impl Service<'_> {
}, },
}; };
let jwt_decoding_key = let jwt_decoding_key = config
config.jwt_secret.as_ref().map(|secret| jsonwebtoken::DecodingKey::from_secret(secret.as_bytes())); .jwt_secret
.as_ref()
.map(|secret| jsonwebtoken::DecodingKey::from_secret(secret.as_bytes()));
let resolver = Arc::new(resolver::Resolver::new(config)); let resolver = Arc::new(resolver::Resolver::new(config));
@ -159,7 +161,10 @@ impl Service<'_> {
fs::create_dir_all(s.get_media_folder())?; fs::create_dir_all(s.get_media_folder())?;
if !s.supported_room_versions().contains(&s.config.default_room_version) { if !s
.supported_room_versions()
.contains(&s.config.default_room_version)
{
error!(config=?s.config.default_room_version, fallback=?crate::config::default_default_room_version(), "Room version in config isn't supported, falling back to default version"); error!(config=?s.config.default_room_version, fallback=?crate::config::default_default_room_version(), "Room version in config isn't supported, falling back to default version");
s.config.default_room_version = crate::config::default_default_room_version(); s.config.default_room_version = crate::config::default_default_room_version();
}; };

View file

@ -65,9 +65,7 @@ impl Resolver {
} }
impl Resolve for Resolver { impl Resolve for Resolver {
fn resolve(&self, name: Name) -> Resolving { fn resolve(&self, name: Name) -> Resolving { resolve_to_reqwest(self.resolver.clone(), name) }
resolve_to_reqwest(self.resolver.clone(), name)
}
} }
impl Resolve for Hooked { impl Resolve for Hooked {
@ -78,7 +76,7 @@ impl Resolve for Hooked {
.get(name.as_str()) .get(name.as_str())
.map_or_else( .map_or_else(
|| resolve_to_reqwest(self.resolver.clone(), name), || resolve_to_reqwest(self.resolver.clone(), name),
|(override_name, port)| cached_to_reqwest(override_name, *port) |(override_name, port)| cached_to_reqwest(override_name, *port),
) )
} }
} }

View file

@ -44,7 +44,8 @@ impl Service {
pub fn add_key( pub fn add_key(
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw<KeyBackupData>, &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw<KeyBackupData>,
) -> Result<()> { ) -> Result<()> {
self.db.add_key(user_id, version, room_id, session_id, key_data) self.db
.add_key(user_id, version, room_id, session_id, key_data)
} }
pub fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> { self.db.count_keys(user_id, version) } pub fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> { self.db.count_keys(user_id, version) }
@ -76,6 +77,7 @@ impl Service {
} }
pub fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()> { pub fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()> {
self.db.delete_room_key(user_id, version, room_id, session_id) self.db
.delete_room_key(user_id, version, room_id, session_id)
} }
} }

View file

@ -50,9 +50,11 @@ impl Service {
) -> Result<()> { ) -> Result<()> {
// Width, Height = 0 if it's not a thumbnail // Width, Height = 0 if it's not a thumbnail
let key = if let Some(user) = sender_user { let key = if let Some(user) = sender_user {
self.db.create_file_metadata(Some(user.as_str()), mxc, 0, 0, content_disposition, content_type)? self.db
.create_file_metadata(Some(user.as_str()), mxc, 0, 0, content_disposition, content_type)?
} else { } else {
self.db.create_file_metadata(None, mxc, 0, 0, content_disposition, content_type)? self.db
.create_file_metadata(None, mxc, 0, 0, content_disposition, content_type)?
}; };
let path; let path;
@ -118,9 +120,11 @@ impl Service {
content_type: Option<&str>, width: u32, height: u32, file: &[u8], content_type: Option<&str>, width: u32, height: u32, file: &[u8],
) -> Result<()> { ) -> Result<()> {
let key = if let Some(user) = sender_user { let key = if let Some(user) = sender_user {
self.db.create_file_metadata(Some(user.as_str()), mxc, width, height, content_disposition, content_type)? self.db
.create_file_metadata(Some(user.as_str()), mxc, width, height, content_disposition, content_type)?
} else { } else {
self.db.create_file_metadata(None, mxc, width, height, content_disposition, content_type)? self.db
.create_file_metadata(None, mxc, width, height, content_disposition, content_type)?
}; };
let path; let path;
@ -161,7 +165,9 @@ impl Service {
}; };
let mut file = Vec::new(); let mut file = Vec::new();
BufReader::new(File::open(path).await?).read_to_end(&mut file).await?; BufReader::new(File::open(path).await?)
.read_to_end(&mut file)
.await?;
Ok(Some(FileMeta { Ok(Some(FileMeta {
content_disposition, content_disposition,
@ -308,7 +314,9 @@ impl Service {
/// For width,height <= 96 the server uses another thumbnailing algorithm /// For width,height <= 96 the server uses another thumbnailing algorithm
/// which crops the image afterwards. /// which crops the image afterwards.
pub async fn get_thumbnail(&self, mxc: String, width: u32, height: u32) -> Result<Option<FileMeta>> { pub async fn get_thumbnail(&self, mxc: String, width: u32, height: u32) -> Result<Option<FileMeta>> {
let (width, height, crop) = self.thumbnail_properties(width, height).unwrap_or((0, 0, false)); // 0, 0 because that's the original file let (width, height, crop) = self
.thumbnail_properties(width, height)
.unwrap_or((0, 0, false)); // 0, 0 because that's the original file
if let Ok((content_disposition, content_type, key)) = self.db.search_file_metadata(mxc.clone(), width, height) { if let Ok((content_disposition, content_type, key)) = self.db.search_file_metadata(mxc.clone(), width, height) {
// Using saved thumbnail // Using saved thumbnail
@ -466,7 +474,9 @@ impl Service {
} }
pub async fn set_url_preview(&self, url: &str, data: &UrlPreviewData) -> Result<()> { pub async fn set_url_preview(&self, url: &str, data: &UrlPreviewData) -> Result<()> {
let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).expect("valid system time"); let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.expect("valid system time");
self.db.set_url_preview(url, data, now) self.db.set_url_preview(url, data, now)
} }
} }
@ -492,9 +502,19 @@ mod tests {
key.extend_from_slice(&width.to_be_bytes()); key.extend_from_slice(&width.to_be_bytes());
key.extend_from_slice(&height.to_be_bytes()); key.extend_from_slice(&height.to_be_bytes());
key.push(0xFF); key.push(0xFF);
key.extend_from_slice(content_disposition.as_ref().map(|f| f.as_bytes()).unwrap_or_default()); key.extend_from_slice(
content_disposition
.as_ref()
.map(|f| f.as_bytes())
.unwrap_or_default(),
);
key.push(0xFF); key.push(0xFF);
key.extend_from_slice(content_type.as_ref().map(|c| c.as_bytes()).unwrap_or_default()); key.extend_from_slice(
content_type
.as_ref()
.map(|c| c.as_bytes())
.unwrap_or_default(),
);
Ok(key) Ok(key)
} }

View file

@ -168,11 +168,41 @@ impl Services<'_> {
async fn memory_usage(&self) -> String { async fn memory_usage(&self) -> String {
let lazy_load_waiting = self.rooms.lazy_loading.lazy_load_waiting.lock().await.len(); let lazy_load_waiting = self.rooms.lazy_loading.lazy_load_waiting.lock().await.len();
let server_visibility_cache = self.rooms.state_accessor.server_visibility_cache.lock().unwrap().len(); let server_visibility_cache = self
let user_visibility_cache = self.rooms.state_accessor.user_visibility_cache.lock().unwrap().len(); .rooms
let stateinfo_cache = self.rooms.state_compressor.stateinfo_cache.lock().unwrap().len(); .state_accessor
let lasttimelinecount_cache = self.rooms.timeline.lasttimelinecount_cache.lock().await.len(); .server_visibility_cache
let roomid_spacehierarchy_cache = self.rooms.spaces.roomid_spacehierarchy_cache.lock().await.len(); .lock()
.unwrap()
.len();
let user_visibility_cache = self
.rooms
.state_accessor
.user_visibility_cache
.lock()
.unwrap()
.len();
let stateinfo_cache = self
.rooms
.state_compressor
.stateinfo_cache
.lock()
.unwrap()
.len();
let lasttimelinecount_cache = self
.rooms
.timeline
.lasttimelinecount_cache
.lock()
.await
.len();
let roomid_spacehierarchy_cache = self
.rooms
.spaces
.roomid_spacehierarchy_cache
.lock()
.await
.len();
format!( format!(
"\ "\
@ -187,22 +217,52 @@ roomid_spacehierarchy_cache: {roomid_spacehierarchy_cache}"
async fn clear_caches(&self, amount: u32) { async fn clear_caches(&self, amount: u32) {
if amount > 0 { if amount > 0 {
self.rooms.lazy_loading.lazy_load_waiting.lock().await.clear(); self.rooms
.lazy_loading
.lazy_load_waiting
.lock()
.await
.clear();
} }
if amount > 1 { if amount > 1 {
self.rooms.state_accessor.server_visibility_cache.lock().unwrap().clear(); self.rooms
.state_accessor
.server_visibility_cache
.lock()
.unwrap()
.clear();
} }
if amount > 2 { if amount > 2 {
self.rooms.state_accessor.user_visibility_cache.lock().unwrap().clear(); self.rooms
.state_accessor
.user_visibility_cache
.lock()
.unwrap()
.clear();
} }
if amount > 3 { if amount > 3 {
self.rooms.state_compressor.stateinfo_cache.lock().unwrap().clear(); self.rooms
.state_compressor
.stateinfo_cache
.lock()
.unwrap()
.clear();
} }
if amount > 4 { if amount > 4 {
self.rooms.timeline.lasttimelinecount_cache.lock().await.clear(); self.rooms
.timeline
.lasttimelinecount_cache
.lock()
.await
.clear();
} }
if amount > 5 { if amount > 5 {
self.rooms.spaces.roomid_spacehierarchy_cache.lock().await.clear(); self.rooms
.spaces
.roomid_spacehierarchy_cache
.lock()
.await
.clear();
} }
if amount > 6 { if amount > 6 {
self.globals.resolver.overrides.write().unwrap().clear(); self.globals.resolver.overrides.write().unwrap().clear();

View file

@ -277,11 +277,17 @@ impl PduEvent {
/// This does not return a full `Pdu` it is only to satisfy ruma's types. /// This does not return a full `Pdu` it is only to satisfy ruma's types.
#[tracing::instrument] #[tracing::instrument]
pub fn convert_to_outgoing_federation_event(mut pdu_json: CanonicalJsonObject) -> Box<RawJsonValue> { pub fn convert_to_outgoing_federation_event(mut pdu_json: CanonicalJsonObject) -> Box<RawJsonValue> {
if let Some(unsigned) = pdu_json.get_mut("unsigned").and_then(|val| val.as_object_mut()) { if let Some(unsigned) = pdu_json
.get_mut("unsigned")
.and_then(|val| val.as_object_mut())
{
unsigned.remove("transaction_id"); unsigned.remove("transaction_id");
} }
if let Some(room_id) = pdu_json.get("room_id").and_then(|val| RoomId::parse(val.as_str()?).ok()) { if let Some(room_id) = pdu_json
.get("room_id")
.and_then(|val| RoomId::parse(val.as_str()?).ok())
{
if let Ok(room_version_id) = services().rooms.state.get_room_version(&room_id) { if let Ok(room_version_id) = services().rooms.state.get_room_version(&room_id) {
// room v3 and above removed the "event_id" field from remote PDU format // room v3 and above removed the "event_id" field from remote PDU format
match room_version_id { match room_version_id {

View file

@ -83,7 +83,12 @@ impl Service {
} }
} }
let response = services().globals.client.pusher.execute(reqwest_request).await; let response = services()
.globals
.client
.pusher
.execute(reqwest_request)
.await;
match response { match response {
Ok(mut response) => { Ok(mut response) => {
@ -108,10 +113,14 @@ impl Service {
} }
let status = response.status(); let status = response.status();
let mut http_response_builder = http::Response::builder().status(status).version(response.version()); let mut http_response_builder = http::Response::builder()
.status(status)
.version(response.version());
mem::swap( mem::swap(
response.headers_mut(), response.headers_mut(),
http_response_builder.headers_mut().expect("http::response::Builder is usable"), http_response_builder
.headers_mut()
.expect("http::response::Builder is usable"),
); );
let body = response.bytes().await.unwrap_or_else(|e| { let body = response.bytes().await.unwrap_or_else(|e| {
@ -130,7 +139,9 @@ impl Service {
} }
let response = T::IncomingResponse::try_from_http_response( let response = T::IncomingResponse::try_from_http_response(
http_response_builder.body(body).expect("reqwest body is valid http body"), http_response_builder
.body(body)
.expect("reqwest body is valid http body"),
); );
response.map_err(|_| { response.map_err(|_| {
info!("Push gateway returned invalid response bytes {}\n{}", destination, url); info!("Push gateway returned invalid response bytes {}\n{}", destination, url);
@ -202,9 +213,18 @@ impl Service {
let ctx = PushConditionRoomCtx { let ctx = PushConditionRoomCtx {
room_id: room_id.to_owned(), room_id: room_id.to_owned(),
member_count: UInt::from(services().rooms.state_cache.room_joined_count(room_id)?.unwrap_or(1) as u32), member_count: UInt::from(
services()
.rooms
.state_cache
.room_joined_count(room_id)?
.unwrap_or(1) as u32,
),
user_id: user.to_owned(), user_id: user.to_owned(),
user_display_name: services().users.displayname(user)?.unwrap_or_else(|| user.localpart().to_owned()), user_display_name: services()
.users
.displayname(user)?
.unwrap_or_else(|| user.localpart().to_owned()),
power_levels: Some(power_levels), power_levels: Some(power_levels),
}; };
@ -242,13 +262,16 @@ impl Service {
notifi.counts = NotificationCounts::new(unread, uint!(0)); notifi.counts = NotificationCounts::new(unread, uint!(0));
if event.kind == TimelineEventType::RoomEncrypted if event.kind == TimelineEventType::RoomEncrypted
|| tweaks.iter().any(|t| matches!(t, Tweak::Highlight(true) | Tweak::Sound(_))) || tweaks
.iter()
.any(|t| matches!(t, Tweak::Highlight(true) | Tweak::Sound(_)))
{ {
notifi.prio = NotificationPriority::High; notifi.prio = NotificationPriority::High;
} }
if event_id_only { if event_id_only {
self.send_request(&http.url, send_event_notification::v1::Request::new(notifi)).await?; self.send_request(&http.url, send_event_notification::v1::Request::new(notifi))
.await?;
} else { } else {
notifi.sender = Some(event.sender.clone()); notifi.sender = Some(event.sender.clone());
notifi.event_type = Some(event.kind.clone()); notifi.event_type = Some(event.kind.clone());
@ -262,7 +285,8 @@ impl Service {
notifi.room_name = services().rooms.state_accessor.get_name(&event.room_id)?; notifi.room_name = services().rooms.state_accessor.get_name(&event.room_id)?;
self.send_request(&http.url, send_event_notification::v1::Request::new(notifi)).await?; self.send_request(&http.url, send_event_notification::v1::Request::new(notifi))
.await?;
} }
Ok(()) Ok(())

View file

@ -53,7 +53,11 @@ impl Service {
} }
let chunk_key: Vec<u64> = chunk.iter().map(|(short, _)| short).copied().collect(); let chunk_key: Vec<u64> = chunk.iter().map(|(short, _)| short).copied().collect();
if let Some(cached) = services().rooms.auth_chain.get_cached_eventid_authchain(&chunk_key)? { if let Some(cached) = services()
.rooms
.auth_chain
.get_cached_eventid_authchain(&chunk_key)?
{
hits += 1; hits += 1;
full_auth_chain.extend(cached.iter().copied()); full_auth_chain.extend(cached.iter().copied());
continue; continue;
@ -65,13 +69,20 @@ impl Service {
let mut misses2 = 0; let mut misses2 = 0;
let mut i = 0; let mut i = 0;
for (sevent_id, event_id) in chunk { for (sevent_id, event_id) in chunk {
if let Some(cached) = services().rooms.auth_chain.get_cached_eventid_authchain(&[sevent_id])? { if let Some(cached) = services()
.rooms
.auth_chain
.get_cached_eventid_authchain(&[sevent_id])?
{
hits2 += 1; hits2 += 1;
chunk_cache.extend(cached.iter().copied()); chunk_cache.extend(cached.iter().copied());
} else { } else {
misses2 += 1; misses2 += 1;
let auth_chain = Arc::new(self.get_auth_chain_inner(room_id, &event_id)?); let auth_chain = Arc::new(self.get_auth_chain_inner(room_id, &event_id)?);
services().rooms.auth_chain.cache_auth_chain(vec![sevent_id], Arc::clone(&auth_chain))?; services()
.rooms
.auth_chain
.cache_auth_chain(vec![sevent_id], Arc::clone(&auth_chain))?;
debug!( debug!(
event_id = ?event_id, event_id = ?event_id,
chain_length = ?auth_chain.len(), chain_length = ?auth_chain.len(),
@ -92,7 +103,10 @@ impl Service {
"Chunk missed", "Chunk missed",
); );
let chunk_cache = Arc::new(chunk_cache); let chunk_cache = Arc::new(chunk_cache);
services().rooms.auth_chain.cache_auth_chain(chunk_key, Arc::clone(&chunk_cache))?; services()
.rooms
.auth_chain
.cache_auth_chain(chunk_key, Arc::clone(&chunk_cache))?;
full_auth_chain.extend(chunk_cache.iter()); full_auth_chain.extend(chunk_cache.iter());
} }
@ -103,7 +117,9 @@ impl Service {
"Auth chain stats", "Auth chain stats",
); );
Ok(full_auth_chain.into_iter().filter_map(move |sid| services().rooms.short.get_eventid_from_short(sid).ok())) Ok(full_auth_chain
.into_iter()
.filter_map(move |sid| services().rooms.short.get_eventid_from_short(sid).ok()))
} }
#[tracing::instrument(skip(self, event_id))] #[tracing::instrument(skip(self, event_id))]
@ -118,7 +134,10 @@ impl Service {
return Err(Error::BadRequest(ErrorKind::Forbidden, "Evil event in db")); return Err(Error::BadRequest(ErrorKind::Forbidden, "Evil event in db"));
} }
for auth_event in &pdu.auth_events { for auth_event in &pdu.auth_events {
let sauthevent = services().rooms.short.get_or_create_shorteventid(auth_event)?; let sauthevent = services()
.rooms
.short
.get_or_create_shorteventid(auth_event)?;
if !found.contains(&sauthevent) { if !found.contains(&sauthevent) {
found.insert(sauthevent); found.insert(sauthevent);

View file

@ -91,7 +91,8 @@ impl Service {
&self, room_id: &RoomId, user_id: &UserId, presence_state: PresenceState, currently_active: Option<bool>, &self, room_id: &RoomId, user_id: &UserId, presence_state: PresenceState, currently_active: Option<bool>,
last_active_ago: Option<UInt>, status_msg: Option<String>, last_active_ago: Option<UInt>, status_msg: Option<String>,
) -> Result<()> { ) -> Result<()> {
self.db.set_presence(room_id, user_id, presence_state, currently_active, last_active_ago, status_msg) self.db
.set_presence(room_id, user_id, presence_state, currently_active, last_active_ago, status_msg)
} }
/// Removes the presence record for the given user from the database. /// Removes the presence record for the given user from the database.
@ -142,7 +143,11 @@ fn process_presence_timer(user_id: &OwnedUserId) -> Result<()> {
let mut status_msg = None; let mut status_msg = None;
for room_id in services().rooms.state_cache.rooms_joined(user_id) { for room_id in services().rooms.state_cache.rooms_joined(user_id) {
let presence_event = services().rooms.edus.presence.get_presence(&room_id?, user_id)?; let presence_event = services()
.rooms
.edus
.presence
.get_presence(&room_id?, user_id)?;
if let Some(presence_event) = presence_event { if let Some(presence_event) = presence_event {
presence_state = presence_event.content.presence; presence_state = presence_event.content.presence;

View file

@ -16,16 +16,32 @@ impl Service {
/// Sets a user as typing until the timeout timestamp is reached or /// Sets a user as typing until the timeout timestamp is reached or
/// roomtyping_remove is called. /// roomtyping_remove is called.
pub async fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> { pub async fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> {
self.typing.write().await.entry(room_id.to_owned()).or_default().insert(user_id.to_owned(), timeout); self.typing
self.last_typing_update.write().await.insert(room_id.to_owned(), services().globals.next_count()?); .write()
.await
.entry(room_id.to_owned())
.or_default()
.insert(user_id.to_owned(), timeout);
self.last_typing_update
.write()
.await
.insert(room_id.to_owned(), services().globals.next_count()?);
_ = self.typing_update_sender.send(room_id.to_owned()); _ = self.typing_update_sender.send(room_id.to_owned());
Ok(()) Ok(())
} }
/// Removes a user from typing before the timeout is reached. /// Removes a user from typing before the timeout is reached.
pub async fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { pub async fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
self.typing.write().await.entry(room_id.to_owned()).or_default().remove(user_id); self.typing
self.last_typing_update.write().await.insert(room_id.to_owned(), services().globals.next_count()?); .write()
.await
.entry(room_id.to_owned())
.or_default()
.remove(user_id);
self.last_typing_update
.write()
.await
.insert(room_id.to_owned(), services().globals.next_count()?);
_ = self.typing_update_sender.send(room_id.to_owned()); _ = self.typing_update_sender.send(room_id.to_owned());
Ok(()) Ok(())
} }
@ -67,7 +83,10 @@ impl Service {
for user in removable { for user in removable {
room.remove(&user); room.remove(&user);
} }
self.last_typing_update.write().await.insert(room_id.to_owned(), services().globals.next_count()?); self.last_typing_update
.write()
.await
.insert(room_id.to_owned(), services().globals.next_count()?);
_ = self.typing_update_sender.send(room_id.to_owned()); _ = self.typing_update_sender.send(room_id.to_owned());
} }
@ -77,7 +96,13 @@ impl Service {
/// Returns the count of the last typing update in this room. /// Returns the count of the last typing update in this room.
pub async fn last_typing_update(&self, room_id: &RoomId) -> Result<u64> { pub async fn last_typing_update(&self, room_id: &RoomId) -> Result<u64> {
self.typings_maintain(room_id).await?; self.typings_maintain(room_id).await?;
Ok(self.last_typing_update.read().await.get(room_id).copied().unwrap_or(0)) Ok(self
.last_typing_update
.read()
.await
.get(room_id)
.copied()
.unwrap_or(0))
} }
/// Returns a new typing EDU. /// Returns a new typing EDU.

View file

@ -125,8 +125,9 @@ impl Service {
.first_pdu_in_room(room_id)? .first_pdu_in_room(room_id)?
.ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?; .ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?;
let (incoming_pdu, val) = let (incoming_pdu, val) = self
self.handle_outlier_pdu(origin, &create_event, event_id, room_id, value, false, pub_key_map).await?; .handle_outlier_pdu(origin, &create_event, event_id, room_id, value, false, pub_key_map)
.await?;
self.check_room_id(room_id, &incoming_pdu)?; self.check_room_id(room_id, &incoming_pdu)?;
// 8. if not timeline event: stop // 8. if not timeline event: stop
@ -167,7 +168,13 @@ impl Service {
)); ));
} }
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().await.get(&*prev_id) { if let Some((time, tries)) = services()
.globals
.bad_event_ratelimiter
.read()
.await
.get(&*prev_id)
{
// Exponential backoff // Exponential backoff
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
@ -182,7 +189,13 @@ impl Service {
if errors >= 5 { if errors >= 5 {
// Timeout other events // Timeout other events
match services().globals.bad_event_ratelimiter.write().await.entry((*prev_id).to_owned()) { match services()
.globals
.bad_event_ratelimiter
.write()
.await
.entry((*prev_id).to_owned())
{
hash_map::Entry::Vacant(e) => { hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1)); e.insert((Instant::now(), 1));
}, },
@ -207,12 +220,19 @@ impl Service {
.await .await
.insert(room_id.to_owned(), ((*prev_id).to_owned(), start_time)); .insert(room_id.to_owned(), ((*prev_id).to_owned(), start_time));
if let Err(e) = if let Err(e) = self
self.upgrade_outlier_to_timeline_pdu(pdu, json, &create_event, origin, room_id, pub_key_map).await .upgrade_outlier_to_timeline_pdu(pdu, json, &create_event, origin, room_id, pub_key_map)
.await
{ {
errors += 1; errors += 1;
warn!("Prev event {} failed: {}", prev_id, e); warn!("Prev event {} failed: {}", prev_id, e);
match services().globals.bad_event_ratelimiter.write().await.entry((*prev_id).to_owned()) { match services()
.globals
.bad_event_ratelimiter
.write()
.await
.entry((*prev_id).to_owned())
{
hash_map::Entry::Vacant(e) => { hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1)); e.insert((Instant::now(), 1));
}, },
@ -222,7 +242,12 @@ impl Service {
} }
} }
let elapsed = start_time.elapsed(); let elapsed = start_time.elapsed();
services().globals.roomid_federationhandletime.write().await.remove(&room_id.to_owned()); services()
.globals
.roomid_federationhandletime
.write()
.await
.remove(&room_id.to_owned());
debug!( debug!(
"Handling prev event {} took {}m{}s", "Handling prev event {} took {}m{}s",
prev_id, prev_id,
@ -246,7 +271,12 @@ impl Service {
.event_handler .event_handler
.upgrade_outlier_to_timeline_pdu(incoming_pdu, val, &create_event, origin, room_id, pub_key_map) .upgrade_outlier_to_timeline_pdu(incoming_pdu, val, &create_event, origin, room_id, pub_key_map)
.await; .await;
services().globals.roomid_federationhandletime.write().await.remove(&room_id.to_owned()); services()
.globals
.roomid_federationhandletime
.write()
.await
.remove(&room_id.to_owned());
r r
} }
@ -323,7 +353,11 @@ impl Service {
debug!(event_id = ?incoming_pdu.event_id, "Fetching auth events"); debug!(event_id = ?incoming_pdu.event_id, "Fetching auth events");
self.fetch_and_handle_outliers( self.fetch_and_handle_outliers(
origin, origin,
&incoming_pdu.auth_events.iter().map(|x| Arc::from(&**x)).collect::<Vec<_>>(), &incoming_pdu
.auth_events
.iter()
.map(|x| Arc::from(&**x))
.collect::<Vec<_>>(),
create_event, create_event,
room_id, room_id,
room_version_id, room_version_id,
@ -351,7 +385,10 @@ impl Service {
match auth_events.entry(( match auth_events.entry((
auth_event.kind.to_string().into(), auth_event.kind.to_string().into(),
auth_event.state_key.clone().expect("all auth events have state keys"), auth_event
.state_key
.clone()
.expect("all auth events have state keys"),
)) { )) {
hash_map::Entry::Vacant(v) => { hash_map::Entry::Vacant(v) => {
v.insert(auth_event); v.insert(auth_event);
@ -367,7 +404,9 @@ impl Service {
// The original create event must be in the auth events // The original create event must be in the auth events
if !matches!( if !matches!(
auth_events.get(&(StateEventType::RoomCreate, String::new())).map(AsRef::as_ref), auth_events
.get(&(StateEventType::RoomCreate, String::new()))
.map(AsRef::as_ref),
Some(_) | None Some(_) | None
) { ) {
return Err(Error::BadRequest( return Err(Error::BadRequest(
@ -390,7 +429,10 @@ impl Service {
debug!("Validation successful."); debug!("Validation successful.");
// 7. Persist the event as an outlier. // 7. Persist the event as an outlier.
services().rooms.outlier.add_pdu_outlier(&incoming_pdu.event_id, &val)?; services()
.rooms
.outlier
.add_pdu_outlier(&incoming_pdu.event_id, &val)?;
debug!("Added pdu as outlier."); debug!("Added pdu as outlier.");
@ -407,7 +449,11 @@ impl Service {
return Ok(Some(pduid)); return Ok(Some(pduid));
} }
if services().rooms.pdu_metadata.is_event_soft_failed(&incoming_pdu.event_id)? { if services()
.rooms
.pdu_metadata
.is_event_soft_failed(&incoming_pdu.event_id)?
{
return Err(Error::BadRequest(ErrorKind::InvalidParam, "Event has been soft failed")); return Err(Error::BadRequest(ErrorKind::InvalidParam, "Event has been soft failed"));
} }
@ -434,10 +480,19 @@ impl Service {
if incoming_pdu.prev_events.len() == 1 { if incoming_pdu.prev_events.len() == 1 {
let prev_event = &*incoming_pdu.prev_events[0]; let prev_event = &*incoming_pdu.prev_events[0];
let prev_event_sstatehash = services().rooms.state_accessor.pdu_shortstatehash(prev_event)?; let prev_event_sstatehash = services()
.rooms
.state_accessor
.pdu_shortstatehash(prev_event)?;
let state = if let Some(shortstatehash) = prev_event_sstatehash { let state = if let Some(shortstatehash) = prev_event_sstatehash {
Some(services().rooms.state_accessor.state_full_ids(shortstatehash).await) Some(
services()
.rooms
.state_accessor
.state_full_ids(shortstatehash)
.await,
)
} else { } else {
None None
}; };
@ -477,7 +532,11 @@ impl Service {
break; break;
}; };
let sstatehash = if let Ok(Some(s)) = services().rooms.state_accessor.pdu_shortstatehash(prev_eventid) { let sstatehash = if let Ok(Some(s)) = services()
.rooms
.state_accessor
.pdu_shortstatehash(prev_eventid)
{
s s
} else { } else {
okay = false; okay = false;
@ -492,8 +551,11 @@ impl Service {
let mut auth_chain_sets = Vec::with_capacity(extremity_sstatehashes.len()); let mut auth_chain_sets = Vec::with_capacity(extremity_sstatehashes.len());
for (sstatehash, prev_event) in extremity_sstatehashes { for (sstatehash, prev_event) in extremity_sstatehashes {
let mut leaf_state: HashMap<_, _> = let mut leaf_state: HashMap<_, _> = services()
services().rooms.state_accessor.state_full_ids(sstatehash).await?; .rooms
.state_accessor
.state_full_ids(sstatehash)
.await?;
if let Some(state_key) = &prev_event.state_key { if let Some(state_key) = &prev_event.state_key {
let shortstatekey = services() let shortstatekey = services()
@ -518,8 +580,14 @@ impl Service {
starting_events.push(id); starting_events.push(id);
} }
auth_chain_sets auth_chain_sets.push(
.push(services().rooms.auth_chain.get_auth_chain(room_id, starting_events).await?.collect()); services()
.rooms
.auth_chain
.get_auth_chain(room_id, starting_events)
.await?
.collect(),
);
fork_states.push(state); fork_states.push(state);
} }
@ -579,7 +647,11 @@ impl Service {
Ok(res) => { Ok(res) => {
debug!("Fetching state events at event."); debug!("Fetching state events at event.");
let collect = res.pdu_ids.iter().map(|x| Arc::from(&**x)).collect::<Vec<_>>(); let collect = res
.pdu_ids
.iter()
.map(|x| Arc::from(&**x))
.collect::<Vec<_>>();
let state_vec = self let state_vec = self
.fetch_and_handle_outliers( .fetch_and_handle_outliers(
@ -679,8 +751,15 @@ impl Service {
// 13. Use state resolution to find new room state // 13. Use state resolution to find new room state
// We start looking at current room state now, so lets lock the room // We start looking at current room state now, so lets lock the room
let mutex_state = let mutex_state = Arc::clone(
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.to_owned()).or_default()); services()
.globals
.roomid_mutex_state
.write()
.await
.entry(room_id.to_owned())
.or_default(),
);
let state_lock = mutex_state.lock().await; let state_lock = mutex_state.lock().await;
// Now we calculate the set of extremities this room has after the incoming // Now we calculate the set of extremities this room has after the incoming
@ -698,13 +777,26 @@ impl Service {
} }
// Only keep those extremities were not referenced yet // Only keep those extremities were not referenced yet
extremities.retain(|id| !matches!(services().rooms.pdu_metadata.is_event_referenced(room_id, id), Ok(true))); extremities.retain(|id| {
!matches!(
services()
.rooms
.pdu_metadata
.is_event_referenced(room_id, id),
Ok(true)
)
});
debug!("Compressing state at event"); debug!("Compressing state at event");
let state_ids_compressed = Arc::new( let state_ids_compressed = Arc::new(
state_at_incoming_event state_at_incoming_event
.iter() .iter()
.map(|(shortstatekey, id)| services().rooms.state_compressor.compress_state_event(*shortstatekey, id)) .map(|(shortstatekey, id)| {
services()
.rooms
.state_compressor
.compress_state_event(*shortstatekey, id)
})
.collect::<Result<_>>()?, .collect::<Result<_>>()?,
); );
@ -722,14 +814,23 @@ impl Service {
state_after.insert(shortstatekey, Arc::from(&*incoming_pdu.event_id)); state_after.insert(shortstatekey, Arc::from(&*incoming_pdu.event_id));
} }
let new_room_state = self.resolve_state(room_id, room_version_id, state_after).await?; let new_room_state = self
.resolve_state(room_id, room_version_id, state_after)
.await?;
// Set the new room state to the resolved state // Set the new room state to the resolved state
debug!("Forcing new room state"); debug!("Forcing new room state");
let (sstatehash, new, removed) = services().rooms.state_compressor.save_state(room_id, new_room_state)?; let (sstatehash, new, removed) = services()
.rooms
.state_compressor
.save_state(room_id, new_room_state)?;
services().rooms.state.force_state(room_id, sstatehash, new, removed, &state_lock).await?; services()
.rooms
.state
.force_state(room_id, sstatehash, new, removed, &state_lock)
.await?;
} }
// 14. Check if the event passes auth based on the "current state" of the room, // 14. Check if the event passes auth based on the "current state" of the room,
@ -752,7 +853,10 @@ impl Service {
// Soft fail, we keep the event as an outlier but don't add it to the timeline // Soft fail, we keep the event as an outlier but don't add it to the timeline
warn!("Event was soft failed: {:?}", incoming_pdu); warn!("Event was soft failed: {:?}", incoming_pdu);
services().rooms.pdu_metadata.mark_event_soft_failed(&incoming_pdu.event_id)?; services()
.rooms
.pdu_metadata
.mark_event_soft_failed(&incoming_pdu.event_id)?;
return Err(Error::BadRequest(ErrorKind::InvalidParam, "Event has been soft failed")); return Err(Error::BadRequest(ErrorKind::InvalidParam, "Event has been soft failed"));
} }
@ -787,10 +891,17 @@ impl Service {
&self, room_id: &RoomId, room_version_id: &RoomVersionId, incoming_state: HashMap<u64, Arc<EventId>>, &self, room_id: &RoomId, room_version_id: &RoomVersionId, incoming_state: HashMap<u64, Arc<EventId>>,
) -> Result<Arc<HashSet<CompressedStateEvent>>> { ) -> Result<Arc<HashSet<CompressedStateEvent>>> {
debug!("Loading current room state ids"); debug!("Loading current room state ids");
let current_sstatehash = let current_sstatehash = services()
services().rooms.state.get_room_shortstatehash(room_id)?.expect("every room has state"); .rooms
.state
.get_room_shortstatehash(room_id)?
.expect("every room has state");
let current_state_ids = services().rooms.state_accessor.state_full_ids(current_sstatehash).await?; let current_state_ids = services()
.rooms
.state_accessor
.state_full_ids(current_sstatehash)
.await?;
let fork_states = [current_state_ids, incoming_state]; let fork_states = [current_state_ids, incoming_state];
@ -852,9 +963,14 @@ impl Service {
let new_room_state = state let new_room_state = state
.into_iter() .into_iter()
.map(|((event_type, state_key), event_id)| { .map(|((event_type, state_key), event_id)| {
let shortstatekey = let shortstatekey = services()
services().rooms.short.get_or_create_shortstatekey(&event_type.to_string().into(), &state_key)?; .rooms
services().rooms.state_compressor.compress_state_event(shortstatekey, &event_id) .short
.get_or_create_shortstatekey(&event_type.to_string().into(), &state_key)?;
services()
.rooms
.state_compressor
.compress_state_event(shortstatekey, &event_id)
}) })
.collect::<Result<_>>()?; .collect::<Result<_>>()?;
@ -877,7 +993,13 @@ impl Service {
) -> AsyncRecursiveCanonicalJsonVec<'a> { ) -> AsyncRecursiveCanonicalJsonVec<'a> {
Box::pin(async move { Box::pin(async move {
let back_off = |id| async { let back_off = |id| async {
match services().globals.bad_event_ratelimiter.write().await.entry(id) { match services()
.globals
.bad_event_ratelimiter
.write()
.await
.entry(id)
{
hash_map::Entry::Vacant(e) => { hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1)); e.insert((Instant::now(), 1));
}, },
@ -904,7 +1026,13 @@ impl Service {
let mut events_all = HashSet::new(); let mut events_all = HashSet::new();
let mut i = 0; let mut i = 0;
while let Some(next_id) = todo_auth_events.pop() { while let Some(next_id) = todo_auth_events.pop() {
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().await.get(&*next_id) { if let Some((time, tries)) = services()
.globals
.bad_event_ratelimiter
.read()
.await
.get(&*next_id)
{
// Exponential backoff // Exponential backoff
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
@ -1010,7 +1138,13 @@ impl Service {
pdus.push((local_pdu, None)); pdus.push((local_pdu, None));
} }
for (next_id, value) in events_in_reverse_order.iter().rev() { for (next_id, value) in events_in_reverse_order.iter().rev() {
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().await.get(&**next_id) { if let Some((time, tries)) = services()
.globals
.bad_event_ratelimiter
.read()
.await
.get(&**next_id)
{
// Exponential backoff // Exponential backoff
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
@ -1087,9 +1221,14 @@ impl Service {
continue; continue;
} }
if let Some(json) = if let Some(json) = json_opt.or_else(|| {
json_opt.or_else(|| services().rooms.outlier.get_outlier_pdu_json(&prev_event_id).ok().flatten()) services()
{ .rooms
.outlier
.get_outlier_pdu_json(&prev_event_id)
.ok()
.flatten()
}) {
if pdu.origin_server_ts > first_pdu_in_room.origin_server_ts { if pdu.origin_server_ts > first_pdu_in_room.origin_server_ts {
amount += 1; amount += 1;
for prev_prev in &pdu.prev_events { for prev_prev in &pdu.prev_events {
@ -1122,7 +1261,9 @@ impl Service {
Ok(( Ok((
int!(0), int!(0),
MilliSecondsSinceUnixEpoch( MilliSecondsSinceUnixEpoch(
eventid_info.get(event_id).map_or_else(|| uint!(0), |info| info.0.origin_server_ts), eventid_info
.get(event_id)
.map_or_else(|| uint!(0), |info| info.0.origin_server_ts),
), ),
)) ))
}) })
@ -1172,7 +1313,11 @@ impl Service {
info!( info!(
"Fetch keys for {}", "Fetch keys for {}",
server_key_ids.keys().cloned().collect::<Vec<_>>().join(", ") server_key_ids
.keys()
.cloned()
.collect::<Vec<_>>()
.join(", ")
); );
let mut server_keys: FuturesUnordered<_> = server_key_ids let mut server_keys: FuturesUnordered<_> = server_key_ids
@ -1204,7 +1349,10 @@ impl Service {
while let Some(fetch_res) = server_keys.next().await { while let Some(fetch_res) = server_keys.next().await {
match fetch_res { match fetch_res {
Ok((signature_server, keys)) => { Ok((signature_server, keys)) => {
pub_key_map.write().await.insert(signature_server.clone(), keys); pub_key_map
.write()
.await
.insert(signature_server.clone(), keys);
}, },
Err((signature_server, e)) => { Err((signature_server, e)) => {
warn!("Failed to fetch keys for {}: {:?}", signature_server, e); warn!("Failed to fetch keys for {}: {:?}", signature_server, e);
@ -1235,7 +1383,13 @@ impl Service {
); );
let event_id = <&EventId>::try_from(event_id.as_str()).expect("ruma's reference hashes are valid event ids"); let event_id = <&EventId>::try_from(event_id.as_str()).expect("ruma's reference hashes are valid event ids");
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().await.get(event_id) { if let Some((time, tries)) = services()
.globals
.bad_event_ratelimiter
.read()
.await
.get(event_id)
{
// Exponential backoff // Exponential backoff
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
@ -1275,8 +1429,12 @@ impl Service {
debug!("Loading signing keys for {}", origin); debug!("Loading signing keys for {}", origin);
let result: BTreeMap<_, _> = let result: BTreeMap<_, _> = services()
services().globals.signing_keys_for(origin)?.into_iter().map(|(k, v)| (k.to_string(), v.key)).collect(); .globals
.signing_keys_for(origin)?
.into_iter()
.map(|(k, v)| (k.to_string(), v.key))
.collect();
if !contains_all_ids(&result) { if !contains_all_ids(&result) {
debug!("Signing key not loaded for {}", origin); debug!("Signing key not loaded for {}", origin);
@ -1353,7 +1511,10 @@ impl Service {
.into_keys() .into_keys()
.map(|server| async move { .map(|server| async move {
( (
services().sending.send_federation_request(&server, get_server_keys::v2::Request::new()).await, services()
.sending
.send_federation_request(&server, get_server_keys::v2::Request::new())
.await,
server, server,
) )
}) })
@ -1391,10 +1552,14 @@ impl Service {
// Try to fetch keys, failure is okay // Try to fetch keys, failure is okay
// Servers we couldn't find in the cache will be added to `servers` // Servers we couldn't find in the cache will be added to `servers`
for pdu in &event.room_state.state { for pdu in &event.room_state.state {
_ = self.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm).await; _ = self
.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm)
.await;
} }
for pdu in &event.room_state.auth_chain { for pdu in &event.room_state.auth_chain {
_ = self.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm).await; _ = self
.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm)
.await;
} }
drop(pkm); drop(pkm);
@ -1411,7 +1576,8 @@ impl Service {
homeserver signing keys." homeserver signing keys."
); );
self.batch_request_signing_keys(servers.clone(), pub_key_map).await?; self.batch_request_signing_keys(servers.clone(), pub_key_map)
.await?;
if servers.is_empty() { if servers.is_empty() {
info!("Trusted server supplied all signing keys, no more keys to fetch"); info!("Trusted server supplied all signing keys, no more keys to fetch");
@ -1420,11 +1586,13 @@ impl Service {
info!("Remaining servers left that the notary/trusted servers did not provide: {servers:?}"); info!("Remaining servers left that the notary/trusted servers did not provide: {servers:?}");
self.request_signing_keys(servers.clone(), pub_key_map).await?; self.request_signing_keys(servers.clone(), pub_key_map)
.await?;
} else { } else {
info!("query_trusted_key_servers_first is set to false, querying individual homeservers first"); info!("query_trusted_key_servers_first is set to false, querying individual homeservers first");
self.request_signing_keys(servers.clone(), pub_key_map).await?; self.request_signing_keys(servers.clone(), pub_key_map)
.await?;
if servers.is_empty() { if servers.is_empty() {
info!("Individual homeservers supplied all signing keys, no more keys to fetch"); info!("Individual homeservers supplied all signing keys, no more keys to fetch");
@ -1433,7 +1601,8 @@ impl Service {
info!("Remaining servers left the individual homeservers did not provide: {servers:?}"); info!("Remaining servers left the individual homeservers did not provide: {servers:?}");
self.batch_request_signing_keys(servers.clone(), pub_key_map).await?; self.batch_request_signing_keys(servers.clone(), pub_key_map)
.await?;
} }
info!("Search for signing keys done"); info!("Search for signing keys done");
@ -1448,7 +1617,11 @@ impl Service {
/// Returns Ok if the acl allows the server /// Returns Ok if the acl allows the server
pub fn acl_check(&self, server_name: &ServerName, room_id: &RoomId) -> Result<()> { pub fn acl_check(&self, server_name: &ServerName, room_id: &RoomId) -> Result<()> {
let acl_event = let acl_event =
match services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomServerAcl, "")? { match services()
.rooms
.state_accessor
.room_state_get(room_id, &StateEventType::RoomServerAcl, "")?
{
Some(acl) => { Some(acl) => {
debug!("ACL event found: {acl:?}"); debug!("ACL event found: {acl:?}");
acl acl
@ -1493,21 +1666,36 @@ impl Service {
) -> Result<BTreeMap<String, Base64>> { ) -> Result<BTreeMap<String, Base64>> {
let contains_all_ids = |keys: &BTreeMap<String, Base64>| signature_ids.iter().all(|id| keys.contains_key(id)); let contains_all_ids = |keys: &BTreeMap<String, Base64>| signature_ids.iter().all(|id| keys.contains_key(id));
let permit = let permit = services()
services().globals.servername_ratelimiter.read().await.get(origin).map(|s| Arc::clone(s).acquire_owned()); .globals
.servername_ratelimiter
.read()
.await
.get(origin)
.map(|s| Arc::clone(s).acquire_owned());
let permit = if let Some(p) = permit { let permit = if let Some(p) = permit {
p p
} else { } else {
let mut write = services().globals.servername_ratelimiter.write().await; let mut write = services().globals.servername_ratelimiter.write().await;
let s = Arc::clone(write.entry(origin.to_owned()).or_insert_with(|| Arc::new(Semaphore::new(1)))); let s = Arc::clone(
write
.entry(origin.to_owned())
.or_insert_with(|| Arc::new(Semaphore::new(1))),
);
s.acquire_owned() s.acquire_owned()
} }
.await; .await;
let back_off = |id| async { let back_off = |id| async {
match services().globals.bad_signature_ratelimiter.write().await.entry(id) { match services()
.globals
.bad_signature_ratelimiter
.write()
.await
.entry(id)
{
hash_map::Entry::Vacant(e) => { hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1)); e.insert((Instant::now(), 1));
}, },
@ -1515,7 +1703,13 @@ impl Service {
} }
}; };
if let Some((time, tries)) = services().globals.bad_signature_ratelimiter.read().await.get(&signature_ids) { if let Some((time, tries)) = services()
.globals
.bad_signature_ratelimiter
.read()
.await
.get(&signature_ids)
{
// Exponential backoff // Exponential backoff
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
@ -1530,8 +1724,12 @@ impl Service {
debug!("Loading signing keys for {origin} from our database if available"); debug!("Loading signing keys for {origin} from our database if available");
let mut result: BTreeMap<_, _> = let mut result: BTreeMap<_, _> = services()
services().globals.signing_keys_for(origin)?.into_iter().map(|(k, v)| (k.to_string(), v.key)).collect(); .globals
.signing_keys_for(origin)?
.into_iter()
.map(|(k, v)| (k.to_string(), v.key))
.collect();
if contains_all_ids(&result) { if contains_all_ids(&result) {
debug!("We have all homeserver signing keys locally for {origin}, not fetching any remotely"); debug!("We have all homeserver signing keys locally for {origin}, not fetching any remotely");
@ -1554,20 +1752,34 @@ impl Service {
get_remote_server_keys::v2::Request::new( get_remote_server_keys::v2::Request::new(
origin.to_owned(), origin.to_owned(),
MilliSecondsSinceUnixEpoch::from_system_time( MilliSecondsSinceUnixEpoch::from_system_time(
SystemTime::now().checked_add(Duration::from_secs(3600)).expect("SystemTime too large"), SystemTime::now()
.checked_add(Duration::from_secs(3600))
.expect("SystemTime too large"),
) )
.expect("time is valid"), .expect("time is valid"),
), ),
) )
.await .await
.ok() .ok()
.map(|resp| resp.server_keys.into_iter().filter_map(|e| e.deserialize().ok()).collect::<Vec<_>>()) .map(|resp| {
{ resp.server_keys
.into_iter()
.filter_map(|e| e.deserialize().ok())
.collect::<Vec<_>>()
}) {
debug!("Got signing keys: {:?}", server_keys); debug!("Got signing keys: {:?}", server_keys);
for k in server_keys { for k in server_keys {
services().globals.add_signing_key(origin, k.clone())?; services().globals.add_signing_key(origin, k.clone())?;
result.extend(k.verify_keys.into_iter().map(|(k, v)| (k.to_string(), v.key))); result.extend(
result.extend(k.old_verify_keys.into_iter().map(|(k, v)| (k.to_string(), v.key))); k.verify_keys
.into_iter()
.map(|(k, v)| (k.to_string(), v.key)),
);
result.extend(
k.old_verify_keys
.into_iter()
.map(|(k, v)| (k.to_string(), v.key)),
);
} }
if contains_all_ids(&result) { if contains_all_ids(&result) {
@ -1584,10 +1796,22 @@ impl Service {
.ok() .ok()
.and_then(|resp| resp.server_key.deserialize().ok()) .and_then(|resp| resp.server_key.deserialize().ok())
{ {
services().globals.add_signing_key(origin, server_key.clone())?; services()
.globals
.add_signing_key(origin, server_key.clone())?;
result.extend(server_key.verify_keys.into_iter().map(|(k, v)| (k.to_string(), v.key))); result.extend(
result.extend(server_key.old_verify_keys.into_iter().map(|(k, v)| (k.to_string(), v.key))); server_key
.verify_keys
.into_iter()
.map(|(k, v)| (k.to_string(), v.key)),
);
result.extend(
server_key
.old_verify_keys
.into_iter()
.map(|(k, v)| (k.to_string(), v.key)),
);
if contains_all_ids(&result) { if contains_all_ids(&result) {
return Ok(result); return Ok(result);
@ -1604,10 +1828,22 @@ impl Service {
.ok() .ok()
.and_then(|resp| resp.server_key.deserialize().ok()) .and_then(|resp| resp.server_key.deserialize().ok())
{ {
services().globals.add_signing_key(origin, server_key.clone())?; services()
.globals
.add_signing_key(origin, server_key.clone())?;
result.extend(server_key.verify_keys.into_iter().map(|(k, v)| (k.to_string(), v.key))); result.extend(
result.extend(server_key.old_verify_keys.into_iter().map(|(k, v)| (k.to_string(), v.key))); server_key
.verify_keys
.into_iter()
.map(|(k, v)| (k.to_string(), v.key)),
);
result.extend(
server_key
.old_verify_keys
.into_iter()
.map(|(k, v)| (k.to_string(), v.key)),
);
if contains_all_ids(&result) { if contains_all_ids(&result) {
return Ok(result); return Ok(result);
@ -1623,20 +1859,34 @@ impl Service {
get_remote_server_keys::v2::Request::new( get_remote_server_keys::v2::Request::new(
origin.to_owned(), origin.to_owned(),
MilliSecondsSinceUnixEpoch::from_system_time( MilliSecondsSinceUnixEpoch::from_system_time(
SystemTime::now().checked_add(Duration::from_secs(3600)).expect("SystemTime too large"), SystemTime::now()
.checked_add(Duration::from_secs(3600))
.expect("SystemTime too large"),
) )
.expect("time is valid"), .expect("time is valid"),
), ),
) )
.await .await
.ok() .ok()
.map(|resp| resp.server_keys.into_iter().filter_map(|e| e.deserialize().ok()).collect::<Vec<_>>()) .map(|resp| {
{ resp.server_keys
.into_iter()
.filter_map(|e| e.deserialize().ok())
.collect::<Vec<_>>()
}) {
debug!("Got signing keys: {:?}", server_keys); debug!("Got signing keys: {:?}", server_keys);
for k in server_keys { for k in server_keys {
services().globals.add_signing_key(origin, k.clone())?; services().globals.add_signing_key(origin, k.clone())?;
result.extend(k.verify_keys.into_iter().map(|(k, v)| (k.to_string(), v.key))); result.extend(
result.extend(k.old_verify_keys.into_iter().map(|(k, v)| (k.to_string(), v.key))); k.verify_keys
.into_iter()
.map(|(k, v)| (k.to_string(), v.key)),
);
result.extend(
k.old_verify_keys
.into_iter()
.map(|(k, v)| (k.to_string(), v.key)),
);
} }
if contains_all_ids(&result) { if contains_all_ids(&result) {

View file

@ -20,7 +20,8 @@ impl Service {
pub fn lazy_load_was_sent_before( pub fn lazy_load_was_sent_before(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId, &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId,
) -> Result<bool> { ) -> Result<bool> {
self.db.lazy_load_was_sent_before(user_id, device_id, room_id, ll_user) self.db
.lazy_load_was_sent_before(user_id, device_id, room_id, ll_user)
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
@ -44,7 +45,8 @@ impl Service {
room_id.to_owned(), room_id.to_owned(),
since, since,
)) { )) {
self.db.lazy_load_confirm_delivery(user_id, device_id, room_id, &mut user_ids.iter().map(|u| &**u))?; self.db
.lazy_load_confirm_delivery(user_id, device_id, room_id, &mut user_ids.iter().map(|u| &**u))?;
} else { } else {
// Ignore // Ignore
} }

View file

@ -126,8 +126,10 @@ impl Service {
next_token = events_before.last().map(|(count, _)| count).copied(); next_token = events_before.last().map(|(count, _)| count).copied();
let events_before: Vec<_> = let events_before: Vec<_> = events_before
events_before.into_iter().map(|(_, pdu)| pdu.to_message_like_event()).collect(); .into_iter()
.map(|(_, pdu)| pdu.to_message_like_event())
.collect();
Ok(get_relating_events::v1::Response { Ok(get_relating_events::v1::Response {
chunk: events_before, chunk: events_before,

View file

@ -148,7 +148,13 @@ impl Arena {
)]; )];
while let Some(parent) = self.parent(parents.last().expect("Has at least one value, as above").0) { while let Some(parent) = self.parent(parents.last().expect("Has at least one value, as above").0) {
parents.push((parent, self.get(parent).expect("It is some, as above").room_id.clone())); parents.push((
parent,
self.get(parent)
.expect("It is some, as above")
.room_id
.clone(),
));
} }
// If at max_depth, don't add new rooms // If at max_depth, don't add new rooms
@ -178,7 +184,10 @@ impl Arena {
} }
if self.first_untraversed.is_none() if self.first_untraversed.is_none()
|| parent >= self.first_untraversed.expect("Should have already continued if none") || parent
>= self
.first_untraversed
.expect("Should have already continued if none")
{ {
self.first_untraversed = next_id; self.first_untraversed = next_id;
} }
@ -187,7 +196,9 @@ impl Arena {
// This is done as if we use an if-let above, we cannot reference self.nodes // This is done as if we use an if-let above, we cannot reference self.nodes
// above as then we would have multiple mutable references // above as then we would have multiple mutable references
let node = self.get_mut(parent).expect("Must be some, as inside this block"); let node = self
.get_mut(parent)
.expect("Must be some, as inside this block");
node.first_child = next_id; node.first_child = next_id;
} }
@ -324,7 +335,10 @@ impl Service {
pub async fn get_federation_hierarchy( pub async fn get_federation_hierarchy(
&self, room_id: &RoomId, server_name: &ServerName, suggested_only: bool, &self, room_id: &RoomId, server_name: &ServerName, suggested_only: bool,
) -> Result<federation::space::get_hierarchy::v1::Response> { ) -> Result<federation::space::get_hierarchy::v1::Response> {
match self.get_summary_and_children(&room_id.to_owned(), suggested_only, Identifier::None).await? { match self
.get_summary_and_children(&room_id.to_owned(), suggested_only, Identifier::None)
.await?
{
Some(SummaryAccessibility::Accessible(room)) => { Some(SummaryAccessibility::Accessible(room)) => {
let mut children = Vec::new(); let mut children = Vec::new();
let mut inaccessible_children = Vec::new(); let mut inaccessible_children = Vec::new();
@ -360,7 +374,13 @@ impl Service {
async fn get_summary_and_children( async fn get_summary_and_children(
&self, current_room: &OwnedRoomId, suggested_only: bool, identifier: Identifier<'_>, &self, current_room: &OwnedRoomId, suggested_only: bool, identifier: Identifier<'_>,
) -> Result<Option<SummaryAccessibility>> { ) -> Result<Option<SummaryAccessibility>> {
if let Some(cached) = self.roomid_spacehierarchy_cache.lock().await.get_mut(&current_room.to_owned()).as_ref() { if let Some(cached) = self
.roomid_spacehierarchy_cache
.lock()
.await
.get_mut(&current_room.to_owned())
.as_ref()
{
return Ok(if let Some(cached) = cached { return Ok(if let Some(cached) = cached {
if is_accessable_child( if is_accessable_child(
current_room, current_room,
@ -395,7 +415,9 @@ impl Service {
// Federation requests should not request information from other // Federation requests should not request information from other
// servers // servers
} else if let Identifier::UserId(_) = identifier { } else if let Identifier::UserId(_) = identifier {
let server = current_room.server_name().expect("Room IDs should always have a server name"); let server = current_room
.server_name()
.expect("Room IDs should always have a server name");
if server == services().globals.server_name() { if server == services().globals.server_name() {
return Ok(None); return Ok(None);
} }
@ -473,7 +495,10 @@ impl Service {
Some(SummaryAccessibility::Inaccessible) Some(SummaryAccessibility::Inaccessible)
} }
} else { } else {
self.roomid_spacehierarchy_cache.lock().await.insert(current_room.clone(), None); self.roomid_spacehierarchy_cache
.lock()
.await
.insert(current_room.clone(), None);
None None
} }
@ -494,7 +519,9 @@ impl Service {
.state_accessor .state_accessor
.room_state_get(room_id, &StateEventType::RoomJoinRules, "")? .room_state_get(room_id, &StateEventType::RoomJoinRules, "")?
.map(|s| { .map(|s| {
serde_json::from_str(s.content.get()).map(|c: RoomJoinRulesEventContent| c.join_rule).map_err(|e| { serde_json::from_str(s.content.get())
.map(|c: RoomJoinRulesEventContent| c.join_rule)
.map_err(|e| {
error!("Invalid room join rule event in database: {}", e); error!("Invalid room join rule event in database: {}", e);
Error::BadDatabase("Invalid room join rule event in database.") Error::BadDatabase("Invalid room join rule event in database.")
}) })
@ -534,15 +561,18 @@ impl Service {
.try_into() .try_into()
.expect("user count should not be that big"), .expect("user count should not be that big"),
room_id: room_id.to_owned(), room_id: room_id.to_owned(),
topic: services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomTopic, "")?.map_or( topic: services()
Ok(None), .rooms
|s| { .state_accessor
serde_json::from_str(s.content.get()).map(|c: RoomTopicEventContent| Some(c.topic)).map_err(|_| { .room_state_get(room_id, &StateEventType::RoomTopic, "")?
.map_or(Ok(None), |s| {
serde_json::from_str(s.content.get())
.map(|c: RoomTopicEventContent| Some(c.topic))
.map_err(|_| {
error!("Invalid room topic event in database for room {}", room_id); error!("Invalid room topic event in database for room {}", room_id);
Error::bad_database("Invalid room topic event in database.") Error::bad_database("Invalid room topic event in database.")
}) })
}, })?,
)?,
world_readable: world_readable(room_id)?, world_readable: world_readable(room_id)?,
guest_can_join: guest_can_join(room_id)?, guest_can_join: guest_can_join(room_id)?,
avatar_url: services() avatar_url: services()
@ -588,7 +618,9 @@ impl Service {
let mut arena = Arena::new(summary.room_id.clone(), max_depth); let mut arena = Arena::new(summary.room_id.clone(), max_depth);
let mut results = Vec::new(); let mut results = Vec::new();
let root = arena.first_untraversed().expect("The node just added is not traversed"); let root = arena
.first_untraversed()
.expect("The node just added is not traversed");
arena.push(root, get_parent_children(&summary.clone(), suggested_only)); arena.push(root, get_parent_children(&summary.clone(), suggested_only));
results.push(summary_to_chunk(*summary.clone())); results.push(summary_to_chunk(*summary.clone()));
@ -597,7 +629,10 @@ impl Service {
if limit > results.len() { if limit > results.len() {
if let Some(SummaryAccessibility::Accessible(summary)) = self if let Some(SummaryAccessibility::Accessible(summary)) = self
.get_summary_and_children( .get_summary_and_children(
&arena.get(current_room).expect("We added this node, it must exist").room_id, &arena
.get(current_room)
.expect("We added this node, it must exist")
.room_id,
suggested_only, suggested_only,
Identifier::UserId(sender_user), Identifier::UserId(sender_user),
) )
@ -651,7 +686,11 @@ async fn get_stripped_space_child_events(
room_id: &RoomId, room_id: &RoomId,
) -> Result<Option<Vec<Raw<HierarchySpaceChildEvent>>>, Error> { ) -> Result<Option<Vec<Raw<HierarchySpaceChildEvent>>>, Error> {
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? { if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
let state = services().rooms.state_accessor.state_full_ids(current_shortstatehash).await?; let state = services()
.rooms
.state_accessor
.state_full_ids(current_shortstatehash)
.await?;
let mut children_pdus = Vec::new(); let mut children_pdus = Vec::new();
for (key, id) in state { for (key, id) in state {
let (event_type, state_key) = services().rooms.short.get_statekey_from_short(key)?; let (event_type, state_key) = services().rooms.short.get_statekey_from_short(key)?;
@ -702,13 +741,24 @@ fn is_accessable_child_recurse(
let room_id: &RoomId = current_room; let room_id: &RoomId = current_room;
// Checks if ACLs allow for the server to participate // Checks if ACLs allow for the server to participate
if services().rooms.event_handler.acl_check(server_name, room_id).is_err() { if services()
.rooms
.event_handler
.acl_check(server_name, room_id)
.is_err()
{
return Ok(false); return Ok(false);
} }
}, },
Identifier::UserId(user_id) => { Identifier::UserId(user_id) => {
if services().rooms.state_cache.is_joined(user_id, current_room)? if services()
|| services().rooms.state_cache.is_invited(user_id, current_room)? .rooms
.state_cache
.is_joined(user_id, current_room)?
|| services()
.rooms
.state_cache
.is_invited(user_id, current_room)?
{ {
return Ok(true); return Ok(true);
} }
@ -746,26 +796,28 @@ fn is_accessable_child_recurse(
/// Checks if guests are able to join a given room /// Checks if guests are able to join a given room
fn guest_can_join(room_id: &RoomId) -> Result<bool, Error> { fn guest_can_join(room_id: &RoomId) -> Result<bool, Error> {
services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomGuestAccess, "")?.map_or( services()
Ok(false), .rooms
|s| { .state_accessor
.room_state_get(room_id, &StateEventType::RoomGuestAccess, "")?
.map_or(Ok(false), |s| {
serde_json::from_str(s.content.get()) serde_json::from_str(s.content.get())
.map(|c: RoomGuestAccessEventContent| c.guest_access == GuestAccess::CanJoin) .map(|c: RoomGuestAccessEventContent| c.guest_access == GuestAccess::CanJoin)
.map_err(|_| Error::bad_database("Invalid room guest access event in database.")) .map_err(|_| Error::bad_database("Invalid room guest access event in database."))
}, })
)
} }
/// Checks if guests are able to view room content without joining /// Checks if guests are able to view room content without joining
fn world_readable(room_id: &RoomId) -> Result<bool, Error> { fn world_readable(room_id: &RoomId) -> Result<bool, Error> {
services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "")?.map_or( services()
Ok(false), .rooms
|s| { .state_accessor
.room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "")?
.map_or(Ok(false), |s| {
serde_json::from_str(s.content.get()) serde_json::from_str(s.content.get())
.map(|c: RoomHistoryVisibilityEventContent| c.history_visibility == HistoryVisibility::WorldReadable) .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility == HistoryVisibility::WorldReadable)
.map_err(|_| Error::bad_database("Invalid room history visibility event in database.")) .map_err(|_| Error::bad_database("Invalid room history visibility event in database."))
}, })
)
} }
/// Returns the join rule for a given room /// Returns the join rule for a given room

View file

@ -36,7 +36,12 @@ impl Service {
state_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex state_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
) -> Result<()> { ) -> Result<()> {
for event_id in statediffnew.iter().filter_map(|new| { for event_id in statediffnew.iter().filter_map(|new| {
services().rooms.state_compressor.parse_compressed_state_event(new).ok().map(|(_, id)| id) services()
.rooms
.state_compressor
.parse_compressed_state_event(new)
.ok()
.map(|(_, id)| id)
}) { }) {
let pdu = match services().rooms.timeline.get_pdu_json(&event_id)? { let pdu = match services().rooms.timeline.get_pdu_json(&event_id)? {
Some(pdu) => pdu, Some(pdu) => pdu,
@ -74,7 +79,13 @@ impl Service {
.await?; .await?;
}, },
TimelineEventType::SpaceChild => { TimelineEventType::SpaceChild => {
services().rooms.spaces.roomid_spacehierarchy_cache.lock().await.remove(&pdu.room_id); services()
.rooms
.spaces
.roomid_spacehierarchy_cache
.lock()
.await
.remove(&pdu.room_id);
}, },
_ => continue, _ => continue,
} }
@ -82,7 +93,8 @@ impl Service {
services().rooms.state_cache.update_joined_count(room_id)?; services().rooms.state_cache.update_joined_count(room_id)?;
self.db.set_room_state(room_id, shortstatehash, state_lock)?; self.db
.set_room_state(room_id, shortstatehash, state_lock)?;
Ok(()) Ok(())
} }
@ -95,25 +107,47 @@ impl Service {
pub fn set_event_state( pub fn set_event_state(
&self, event_id: &EventId, room_id: &RoomId, state_ids_compressed: Arc<HashSet<CompressedStateEvent>>, &self, event_id: &EventId, room_id: &RoomId, state_ids_compressed: Arc<HashSet<CompressedStateEvent>>,
) -> Result<u64> { ) -> Result<u64> {
let shorteventid = services().rooms.short.get_or_create_shorteventid(event_id)?; let shorteventid = services()
.rooms
.short
.get_or_create_shorteventid(event_id)?;
let previous_shortstatehash = self.db.get_room_shortstatehash(room_id)?; let previous_shortstatehash = self.db.get_room_shortstatehash(room_id)?;
let state_hash = calculate_hash(&state_ids_compressed.iter().map(|s| &s[..]).collect::<Vec<_>>()); let state_hash = calculate_hash(
&state_ids_compressed
.iter()
.map(|s| &s[..])
.collect::<Vec<_>>(),
);
let (shortstatehash, already_existed) = services().rooms.short.get_or_create_shortstatehash(&state_hash)?; let (shortstatehash, already_existed) = services()
.rooms
.short
.get_or_create_shortstatehash(&state_hash)?;
if !already_existed { if !already_existed {
let states_parents = previous_shortstatehash.map_or_else( let states_parents = previous_shortstatehash.map_or_else(
|| Ok(Vec::new()), || Ok(Vec::new()),
|p| services().rooms.state_compressor.load_shortstatehash_info(p), |p| {
services()
.rooms
.state_compressor
.load_shortstatehash_info(p)
},
)?; )?;
let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() {
let statediffnew: HashSet<_> = state_ids_compressed.difference(&parent_stateinfo.1).copied().collect(); let statediffnew: HashSet<_> = state_ids_compressed
.difference(&parent_stateinfo.1)
.copied()
.collect();
let statediffremoved: HashSet<_> = let statediffremoved: HashSet<_> = parent_stateinfo
parent_stateinfo.1.difference(&state_ids_compressed).copied().collect(); .1
.difference(&state_ids_compressed)
.copied()
.collect();
(Arc::new(statediffnew), Arc::new(statediffremoved)) (Arc::new(statediffnew), Arc::new(statediffremoved))
} else { } else {
@ -139,7 +173,10 @@ impl Service {
/// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`.
#[tracing::instrument(skip(self, new_pdu))] #[tracing::instrument(skip(self, new_pdu))]
pub fn append_to_state(&self, new_pdu: &PduEvent) -> Result<u64> { pub fn append_to_state(&self, new_pdu: &PduEvent) -> Result<u64> {
let shorteventid = services().rooms.short.get_or_create_shorteventid(&new_pdu.event_id)?; let shorteventid = services()
.rooms
.short
.get_or_create_shorteventid(&new_pdu.event_id)?;
let previous_shortstatehash = self.get_room_shortstatehash(&new_pdu.room_id)?; let previous_shortstatehash = self.get_room_shortstatehash(&new_pdu.room_id)?;
@ -150,17 +187,31 @@ impl Service {
if let Some(state_key) = &new_pdu.state_key { if let Some(state_key) = &new_pdu.state_key {
let states_parents = previous_shortstatehash.map_or_else( let states_parents = previous_shortstatehash.map_or_else(
|| Ok(Vec::new()), || Ok(Vec::new()),
|p| services().rooms.state_compressor.load_shortstatehash_info(p), |p| {
services()
.rooms
.state_compressor
.load_shortstatehash_info(p)
},
)?; )?;
let shortstatekey = let shortstatekey = services()
services().rooms.short.get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key)?; .rooms
.short
.get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key)?;
let new = services().rooms.state_compressor.compress_state_event(shortstatekey, &new_pdu.event_id)?; let new = services()
.rooms
.state_compressor
.compress_state_event(shortstatekey, &new_pdu.event_id)?;
let replaces = states_parents let replaces = states_parents
.last() .last()
.map(|info| info.1.iter().find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))) .map(|info| {
info.1
.iter()
.find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))
})
.unwrap_or_default(); .unwrap_or_default();
if Some(&new) == replaces { if Some(&new) == replaces {
@ -197,12 +248,18 @@ impl Service {
let mut state = Vec::new(); let mut state = Vec::new();
// Add recommended events // Add recommended events
if let Some(e) = if let Some(e) =
services().rooms.state_accessor.room_state_get(&invite_event.room_id, &StateEventType::RoomCreate, "")? services()
.rooms
.state_accessor
.room_state_get(&invite_event.room_id, &StateEventType::RoomCreate, "")?
{ {
state.push(e.to_stripped_state_event()); state.push(e.to_stripped_state_event());
} }
if let Some(e) = if let Some(e) =
services().rooms.state_accessor.room_state_get(&invite_event.room_id, &StateEventType::RoomJoinRules, "")? services()
.rooms
.state_accessor
.room_state_get(&invite_event.room_id, &StateEventType::RoomJoinRules, "")?
{ {
state.push(e.to_stripped_state_event()); state.push(e.to_stripped_state_event());
} }
@ -214,12 +271,18 @@ impl Service {
state.push(e.to_stripped_state_event()); state.push(e.to_stripped_state_event());
} }
if let Some(e) = if let Some(e) =
services().rooms.state_accessor.room_state_get(&invite_event.room_id, &StateEventType::RoomAvatar, "")? services()
.rooms
.state_accessor
.room_state_get(&invite_event.room_id, &StateEventType::RoomAvatar, "")?
{ {
state.push(e.to_stripped_state_event()); state.push(e.to_stripped_state_event());
} }
if let Some(e) = if let Some(e) =
services().rooms.state_accessor.room_state_get(&invite_event.room_id, &StateEventType::RoomName, "")? services()
.rooms
.state_accessor
.room_state_get(&invite_event.room_id, &StateEventType::RoomName, "")?
{ {
state.push(e.to_stripped_state_event()); state.push(e.to_stripped_state_event());
} }
@ -249,7 +312,10 @@ impl Service {
/// Returns the room's version. /// Returns the room's version.
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn get_room_version(&self, room_id: &RoomId) -> Result<RoomVersionId> { pub fn get_room_version(&self, room_id: &RoomId) -> Result<RoomVersionId> {
let create_event = services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomCreate, "")?; let create_event = services()
.rooms
.state_accessor
.room_state_get(room_id, &StateEventType::RoomCreate, "")?;
let create_event_content: RoomCreateEventContent = create_event let create_event_content: RoomCreateEventContent = create_event
.as_ref() .as_ref()
@ -279,7 +345,8 @@ impl Service {
event_ids: Vec<OwnedEventId>, event_ids: Vec<OwnedEventId>,
state_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex state_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
) -> Result<()> { ) -> Result<()> {
self.db.set_forward_extremities(room_id, event_ids, state_lock) self.db
.set_forward_extremities(room_id, event_ids, state_lock)
} }
/// This fetches auth events from the current state. /// This fetches auth events from the current state.
@ -321,9 +388,23 @@ impl Service {
Ok(full_state Ok(full_state
.iter() .iter()
.filter_map(|compressed| services().rooms.state_compressor.parse_compressed_state_event(compressed).ok()) .filter_map(|compressed| {
services()
.rooms
.state_compressor
.parse_compressed_state_event(compressed)
.ok()
})
.filter_map(|(shortstatekey, event_id)| sauthevents.remove(&shortstatekey).map(|k| (k, event_id))) .filter_map(|(shortstatekey, event_id)| sauthevents.remove(&shortstatekey).map(|k| (k, event_id)))
.filter_map(|(k, event_id)| services().rooms.timeline.get_pdu(&event_id).ok().flatten().map(|pdu| (k, pdu))) .filter_map(|(k, event_id)| {
services()
.rooms
.timeline
.get_pdu(&event_id)
.ok()
.flatten()
.map(|pdu| (k, pdu))
})
.collect()) .collect())
} }
} }

View file

@ -59,19 +59,18 @@ impl Service {
/// Get membership for given user in state /// Get membership for given user in state
fn user_membership(&self, shortstatehash: u64, user_id: &UserId) -> Result<MembershipState> { fn user_membership(&self, shortstatehash: u64, user_id: &UserId) -> Result<MembershipState> {
self.state_get(shortstatehash, &StateEventType::RoomMember, user_id.as_str())?.map_or( self.state_get(shortstatehash, &StateEventType::RoomMember, user_id.as_str())?
Ok(MembershipState::Leave), .map_or(Ok(MembershipState::Leave), |s| {
|s| {
serde_json::from_str(s.content.get()) serde_json::from_str(s.content.get())
.map(|c: RoomMemberEventContent| c.membership) .map(|c: RoomMemberEventContent| c.membership)
.map_err(|_| Error::bad_database("Invalid room membership event in database.")) .map_err(|_| Error::bad_database("Invalid room membership event in database."))
}, })
)
} }
/// The user was a joined member at this state (potentially in the past) /// The user was a joined member at this state (potentially in the past)
fn user_was_joined(&self, shortstatehash: u64, user_id: &UserId) -> bool { fn user_was_joined(&self, shortstatehash: u64, user_id: &UserId) -> bool {
self.user_membership(shortstatehash, user_id).is_ok_and(|s| s == MembershipState::Join) self.user_membership(shortstatehash, user_id)
.is_ok_and(|s| s == MembershipState::Join)
// Return sensible default, i.e. // Return sensible default, i.e.
// false // false
} }
@ -92,20 +91,22 @@ impl Service {
return Ok(true); return Ok(true);
}; };
if let Some(visibility) = if let Some(visibility) = self
self.server_visibility_cache.lock().unwrap().get_mut(&(origin.to_owned(), shortstatehash)) .server_visibility_cache
.lock()
.unwrap()
.get_mut(&(origin.to_owned(), shortstatehash))
{ {
return Ok(*visibility); return Ok(*visibility);
} }
let history_visibility = self.state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")?.map_or( let history_visibility = self
Ok(HistoryVisibility::Shared), .state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")?
|s| { .map_or(Ok(HistoryVisibility::Shared), |s| {
serde_json::from_str(s.content.get()) serde_json::from_str(s.content.get())
.map(|c: RoomHistoryVisibilityEventContent| c.history_visibility) .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility)
.map_err(|_| Error::bad_database("Invalid history visibility event in database.")) .map_err(|_| Error::bad_database("Invalid history visibility event in database."))
}, })?;
)?;
let mut current_server_members = services() let mut current_server_members = services()
.rooms .rooms
@ -130,7 +131,10 @@ impl Service {
}, },
}; };
self.server_visibility_cache.lock().unwrap().insert((origin.to_owned(), shortstatehash), visibility); self.server_visibility_cache
.lock()
.unwrap()
.insert((origin.to_owned(), shortstatehash), visibility);
Ok(visibility) Ok(visibility)
} }
@ -144,22 +148,24 @@ impl Service {
None => return Ok(true), None => return Ok(true),
}; };
if let Some(visibility) = if let Some(visibility) = self
self.user_visibility_cache.lock().unwrap().get_mut(&(user_id.to_owned(), shortstatehash)) .user_visibility_cache
.lock()
.unwrap()
.get_mut(&(user_id.to_owned(), shortstatehash))
{ {
return Ok(*visibility); return Ok(*visibility);
} }
let currently_member = services().rooms.state_cache.is_joined(user_id, room_id)?; let currently_member = services().rooms.state_cache.is_joined(user_id, room_id)?;
let history_visibility = self.state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")?.map_or( let history_visibility = self
Ok(HistoryVisibility::Shared), .state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")?
|s| { .map_or(Ok(HistoryVisibility::Shared), |s| {
serde_json::from_str(s.content.get()) serde_json::from_str(s.content.get())
.map(|c: RoomHistoryVisibilityEventContent| c.history_visibility) .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility)
.map_err(|_| Error::bad_database("Invalid history visibility event in database.")) .map_err(|_| Error::bad_database("Invalid history visibility event in database."))
}, })?;
)?;
let visibility = match history_visibility { let visibility = match history_visibility {
HistoryVisibility::WorldReadable => true, HistoryVisibility::WorldReadable => true,
@ -178,7 +184,10 @@ impl Service {
}, },
}; };
self.user_visibility_cache.lock().unwrap().insert((user_id.to_owned(), shortstatehash), visibility); self.user_visibility_cache
.lock()
.unwrap()
.insert((user_id.to_owned(), shortstatehash), visibility);
Ok(visibility) Ok(visibility)
} }
@ -189,17 +198,16 @@ impl Service {
pub fn user_can_see_state_events(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { pub fn user_can_see_state_events(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let currently_member = services().rooms.state_cache.is_joined(user_id, room_id)?; let currently_member = services().rooms.state_cache.is_joined(user_id, room_id)?;
let history_visibility = self.room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "")?.map_or( let history_visibility = self
Ok(HistoryVisibility::Shared), .room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "")?
|s| { .map_or(Ok(HistoryVisibility::Shared), |s| {
serde_json::from_str(s.content.get()) serde_json::from_str(s.content.get())
.map(|c: RoomHistoryVisibilityEventContent| c.history_visibility) .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility)
.map_err(|e| { .map_err(|e| {
error!("Invalid history visibility event in database for room {}: {e}", &room_id); error!("Invalid history visibility event in database for room {}: {e}", &room_id);
Error::bad_database("Invalid history visibility event in database.") Error::bad_database("Invalid history visibility event in database.")
}) })
}, })?;
)?;
Ok(currently_member || history_visibility == HistoryVisibility::WorldReadable) Ok(currently_member || history_visibility == HistoryVisibility::WorldReadable)
} }
@ -232,28 +240,34 @@ impl Service {
} }
pub fn get_name(&self, room_id: &RoomId) -> Result<Option<String>> { pub fn get_name(&self, room_id: &RoomId) -> Result<Option<String>> {
services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomName, "")?.map_or(Ok(None), |s| { services()
.rooms
.state_accessor
.room_state_get(room_id, &StateEventType::RoomName, "")?
.map_or(Ok(None), |s| {
Ok(serde_json::from_str(s.content.get()).map_or_else(|_| None, |c: RoomNameEventContent| Some(c.name))) Ok(serde_json::from_str(s.content.get()).map_or_else(|_| None, |c: RoomNameEventContent| Some(c.name)))
}) })
} }
pub fn get_avatar(&self, room_id: &RoomId) -> Result<ruma::JsOption<RoomAvatarEventContent>> { pub fn get_avatar(&self, room_id: &RoomId) -> Result<ruma::JsOption<RoomAvatarEventContent>> {
services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomAvatar, "")?.map_or( services()
Ok(ruma::JsOption::Undefined), .rooms
|s| { .state_accessor
.room_state_get(room_id, &StateEventType::RoomAvatar, "")?
.map_or(Ok(ruma::JsOption::Undefined), |s| {
serde_json::from_str(s.content.get()) serde_json::from_str(s.content.get())
.map_err(|_| Error::bad_database("Invalid room avatar event in database.")) .map_err(|_| Error::bad_database("Invalid room avatar event in database."))
}, })
)
} }
pub fn get_member(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<RoomMemberEventContent>> { pub fn get_member(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<RoomMemberEventContent>> {
services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())?.map_or( services()
Ok(None), .rooms
|s| { .state_accessor
.room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())?
.map_or(Ok(None), |s| {
serde_json::from_str(s.content.get()) serde_json::from_str(s.content.get())
.map_err(|_| Error::bad_database("Invalid room member event in database.")) .map_err(|_| Error::bad_database("Invalid room member event in database."))
}, })
)
} }
} }

View file

@ -165,7 +165,9 @@ impl Service {
.get( .get(
None, // Ignored users are in global account data None, // Ignored users are in global account data
user_id, // Receiver user_id, // Receiver
GlobalAccountDataEventType::IgnoredUserList.to_string().into(), GlobalAccountDataEventType::IgnoredUserList
.to_string()
.into(),
)? )?
.map(|event| { .map(|event| {
serde_json::from_str::<IgnoredUserListEvent>(event.get()).map_err(|e| { serde_json::from_str::<IgnoredUserListEvent>(event.get()).map_err(|e| {
@ -175,7 +177,11 @@ impl Service {
}) })
.transpose()? .transpose()?
.map_or(false, |ignored| { .map_or(false, |ignored| {
ignored.content.ignored_users.iter().any(|(user, _details)| user == sender) ignored
.content
.ignored_users
.iter()
.any(|(user, _details)| user == sender)
}); });
if is_ignored { if is_ignored {

View file

@ -55,7 +55,12 @@ impl Service {
/// removed diff for the selected shortstatehash and each parent layer. /// removed diff for the selected shortstatehash and each parent layer.
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn load_shortstatehash_info(&self, shortstatehash: u64) -> ShortStateInfoResult { pub fn load_shortstatehash_info(&self, shortstatehash: u64) -> ShortStateInfoResult {
if let Some(r) = self.stateinfo_cache.lock().unwrap().get_mut(&shortstatehash) { if let Some(r) = self
.stateinfo_cache
.lock()
.unwrap()
.get_mut(&shortstatehash)
{
return Ok(r.clone()); return Ok(r.clone());
} }
@ -76,19 +81,31 @@ impl Service {
response.push((shortstatehash, Arc::new(state), added, Arc::new(removed))); response.push((shortstatehash, Arc::new(state), added, Arc::new(removed)));
self.stateinfo_cache.lock().unwrap().insert(shortstatehash, response.clone()); self.stateinfo_cache
.lock()
.unwrap()
.insert(shortstatehash, response.clone());
Ok(response) Ok(response)
} else { } else {
let response = vec![(shortstatehash, added.clone(), added, removed)]; let response = vec![(shortstatehash, added.clone(), added, removed)];
self.stateinfo_cache.lock().unwrap().insert(shortstatehash, response.clone()); self.stateinfo_cache
.lock()
.unwrap()
.insert(shortstatehash, response.clone());
Ok(response) Ok(response)
} }
} }
pub fn compress_state_event(&self, shortstatekey: u64, event_id: &EventId) -> Result<CompressedStateEvent> { pub fn compress_state_event(&self, shortstatekey: u64, event_id: &EventId) -> Result<CompressedStateEvent> {
let mut v = shortstatekey.to_be_bytes().to_vec(); let mut v = shortstatekey.to_be_bytes().to_vec();
v.extend_from_slice(&services().rooms.short.get_or_create_shorteventid(event_id)?.to_be_bytes()); v.extend_from_slice(
&services()
.rooms
.short
.get_or_create_shorteventid(event_id)?
.to_be_bytes(),
);
Ok(v.try_into().expect("we checked the size above")) Ok(v.try_into().expect("we checked the size above"))
} }
@ -238,10 +255,17 @@ impl Service {
) -> HashSetCompressStateEvent { ) -> HashSetCompressStateEvent {
let previous_shortstatehash = services().rooms.state.get_room_shortstatehash(room_id)?; let previous_shortstatehash = services().rooms.state.get_room_shortstatehash(room_id)?;
let state_hash = let state_hash = utils::calculate_hash(
utils::calculate_hash(&new_state_ids_compressed.iter().map(|bytes| &bytes[..]).collect::<Vec<_>>()); &new_state_ids_compressed
.iter()
.map(|bytes| &bytes[..])
.collect::<Vec<_>>(),
);
let (new_shortstatehash, already_existed) = services().rooms.short.get_or_create_shortstatehash(&state_hash)?; let (new_shortstatehash, already_existed) = services()
.rooms
.short
.get_or_create_shortstatehash(&state_hash)?;
if Some(new_shortstatehash) == previous_shortstatehash { if Some(new_shortstatehash) == previous_shortstatehash {
return Ok((new_shortstatehash, Arc::new(HashSet::new()), Arc::new(HashSet::new()))); return Ok((new_shortstatehash, Arc::new(HashSet::new()), Arc::new(HashSet::new())));
@ -251,10 +275,16 @@ impl Service {
previous_shortstatehash.map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?; previous_shortstatehash.map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?;
let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() {
let statediffnew: HashSet<_> = new_state_ids_compressed.difference(&parent_stateinfo.1).copied().collect(); let statediffnew: HashSet<_> = new_state_ids_compressed
.difference(&parent_stateinfo.1)
.copied()
.collect();
let statediffremoved: HashSet<_> = let statediffremoved: HashSet<_> = parent_stateinfo
parent_stateinfo.1.difference(&new_state_ids_compressed).copied().collect(); .1
.difference(&new_state_ids_compressed)
.copied()
.collect();
(Arc::new(statediffnew), Arc::new(statediffremoved)) (Arc::new(statediffnew), Arc::new(statediffremoved))
} else { } else {

View file

@ -60,7 +60,9 @@ impl Service {
unsigned.insert( unsigned.insert(
"m.relations".to_owned(), "m.relations".to_owned(),
json!({ "m.thread": content }).try_into().expect("thread is valid json"), json!({ "m.thread": content })
.try_into()
.expect("thread is valid json"),
); );
} else { } else {
// New thread // New thread
@ -74,11 +76,16 @@ impl Service {
unsigned.insert( unsigned.insert(
"m.relations".to_owned(), "m.relations".to_owned(),
json!({ "m.thread": content }).try_into().expect("thread is valid json"), json!({ "m.thread": content })
.try_into()
.expect("thread is valid json"),
); );
} }
services().rooms.timeline.replace_pdu(root_id, &root_pdu_json, &root_pdu)?; services()
.rooms
.timeline
.replace_pdu(root_id, &root_pdu_json, &root_pdu)?;
} }
let mut users = Vec::new(); let mut users = Vec::new();

View file

@ -152,7 +152,10 @@ impl Service {
/// Returns the version of a room, if known /// Returns the version of a room, if known
pub fn get_room_version(&self, room_id: &RoomId) -> Result<Option<RoomVersionId>> { pub fn get_room_version(&self, room_id: &RoomId) -> Result<Option<RoomVersionId>> {
let create_event = services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomCreate, "")?; let create_event = services()
.rooms
.state_accessor
.room_state_get(room_id, &StateEventType::RoomCreate, "")?;
let create_event_content: Option<RoomCreateEventContent> = create_event let create_event_content: Option<RoomCreateEventContent> = create_event
.as_ref() .as_ref()
@ -225,16 +228,25 @@ impl Service {
// Coalesce database writes for the remainder of this scope. // Coalesce database writes for the remainder of this scope.
let _cork = services().globals.db.cork_and_flush()?; let _cork = services().globals.db.cork_and_flush()?;
let shortroomid = services().rooms.short.get_shortroomid(&pdu.room_id)?.expect("room exists"); let shortroomid = services()
.rooms
.short
.get_shortroomid(&pdu.room_id)?
.expect("room exists");
// Make unsigned fields correct. This is not properly documented in the spec, // Make unsigned fields correct. This is not properly documented in the spec,
// but state events need to have previous content in the unsigned field, so // but state events need to have previous content in the unsigned field, so
// clients can easily interpret things like membership changes // clients can easily interpret things like membership changes
if let Some(state_key) = &pdu.state_key { if let Some(state_key) = &pdu.state_key {
if let CanonicalJsonValue::Object(unsigned) = if let CanonicalJsonValue::Object(unsigned) = pdu_json
pdu_json.entry("unsigned".to_owned()).or_insert_with(|| CanonicalJsonValue::Object(BTreeMap::default())) .entry("unsigned".to_owned())
.or_insert_with(|| CanonicalJsonValue::Object(BTreeMap::default()))
{ {
if let Some(shortstatehash) = services().rooms.state_accessor.pdu_shortstatehash(&pdu.event_id).unwrap() if let Some(shortstatehash) = services()
.rooms
.state_accessor
.pdu_shortstatehash(&pdu.event_id)
.unwrap()
{ {
if let Some(prev_state) = services() if let Some(prev_state) = services()
.rooms .rooms
@ -259,18 +271,38 @@ impl Service {
} }
// We must keep track of all events that have been referenced. // We must keep track of all events that have been referenced.
services().rooms.pdu_metadata.mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; services()
services().rooms.state.set_forward_extremities(&pdu.room_id, leaves, state_lock)?; .rooms
.pdu_metadata
.mark_as_referenced(&pdu.room_id, &pdu.prev_events)?;
services()
.rooms
.state
.set_forward_extremities(&pdu.room_id, leaves, state_lock)?;
let mutex_insert = let mutex_insert = Arc::clone(
Arc::clone(services().globals.roomid_mutex_insert.write().await.entry(pdu.room_id.clone()).or_default()); services()
.globals
.roomid_mutex_insert
.write()
.await
.entry(pdu.room_id.clone())
.or_default(),
);
let insert_lock = mutex_insert.lock().await; let insert_lock = mutex_insert.lock().await;
let count1 = services().globals.next_count()?; let count1 = services().globals.next_count()?;
// Mark as read first so the sending client doesn't get a notification even if // Mark as read first so the sending client doesn't get a notification even if
// appending fails // appending fails
services().rooms.edus.read_receipt.private_read_set(&pdu.room_id, &pdu.sender, count1)?; services()
services().rooms.user.reset_notification_counts(&pdu.sender, &pdu.room_id)?; .rooms
.edus
.read_receipt
.private_read_set(&pdu.room_id, &pdu.sender, count1)?;
services()
.rooms
.user
.reset_notification_counts(&pdu.sender, &pdu.room_id)?;
let count2 = services().globals.next_count()?; let count2 = services().globals.next_count()?;
let mut pdu_id = shortroomid.to_be_bytes().to_vec(); let mut pdu_id = shortroomid.to_be_bytes().to_vec();
@ -318,7 +350,10 @@ impl Service {
let mut notifies = Vec::new(); let mut notifies = Vec::new();
let mut highlights = Vec::new(); let mut highlights = Vec::new();
let mut push_target = services().rooms.state_cache.get_our_real_users(&pdu.room_id)?; let mut push_target = services()
.rooms
.state_cache
.get_our_real_users(&pdu.room_id)?;
if pdu.kind == TimelineEventType::RoomMember { if pdu.kind == TimelineEventType::RoomMember {
if let Some(state_key) = &pdu.state_key { if let Some(state_key) = &pdu.state_key {
@ -354,7 +389,9 @@ impl Service {
let mut notify = false; let mut notify = false;
for action in for action in
services().pusher.get_actions(user, &rules_for_user, &power_levels, &sync_pdu, &pdu.room_id)? services()
.pusher
.get_actions(user, &rules_for_user, &power_levels, &sync_pdu, &pdu.room_id)?
{ {
match action { match action {
Action::Notify => notify = true, Action::Notify => notify = true,
@ -378,7 +415,8 @@ impl Service {
} }
} }
self.db.increment_notification_counts(&pdu.room_id, notifies, highlights)?; self.db
.increment_notification_counts(&pdu.room_id, notifies, highlights)?;
match pdu.kind { match pdu.kind {
TimelineEventType::RoomRedaction => { TimelineEventType::RoomRedaction => {
@ -422,7 +460,13 @@ impl Service {
}, },
TimelineEventType::SpaceChild => { TimelineEventType::SpaceChild => {
if let Some(_state_key) = &pdu.state_key { if let Some(_state_key) = &pdu.state_key {
services().rooms.spaces.roomid_spacehierarchy_cache.lock().await.remove(&pdu.room_id); services()
.rooms
.spaces
.roomid_spacehierarchy_cache
.lock()
.await
.remove(&pdu.room_id);
} }
}, },
TimelineEventType::RoomMember => { TimelineEventType::RoomMember => {
@ -463,7 +507,10 @@ impl Service {
.map_err(|_| Error::bad_database("Invalid content in pdu."))?; .map_err(|_| Error::bad_database("Invalid content in pdu."))?;
if let Some(body) = content.body { if let Some(body) = content.body {
services().rooms.search.index_pdu(shortroomid, &pdu_id, &body)?; services()
.rooms
.search
.index_pdu(shortroomid, &pdu_id, &body)?;
let server_user = format!("@conduit:{}", services().globals.server_name()); let server_user = format!("@conduit:{}", services().globals.server_name());
@ -488,8 +535,15 @@ impl Service {
} }
if let Ok(content) = serde_json::from_str::<ExtractRelatesToEventId>(pdu.content.get()) { if let Ok(content) = serde_json::from_str::<ExtractRelatesToEventId>(pdu.content.get()) {
if let Some(related_pducount) = services().rooms.timeline.get_pdu_count(&content.relates_to.event_id)? { if let Some(related_pducount) = services()
services().rooms.pdu_metadata.add_relation(PduCount::Normal(count2), related_pducount)?; .rooms
.timeline
.get_pdu_count(&content.relates_to.event_id)?
{
services()
.rooms
.pdu_metadata
.add_relation(PduCount::Normal(count2), related_pducount)?;
} }
} }
@ -500,32 +554,52 @@ impl Service {
} => { } => {
// We need to do it again here, because replies don't have // We need to do it again here, because replies don't have
// event_id as a top level field // event_id as a top level field
if let Some(related_pducount) = services().rooms.timeline.get_pdu_count(&in_reply_to.event_id)? { if let Some(related_pducount) = services()
services().rooms.pdu_metadata.add_relation(PduCount::Normal(count2), related_pducount)?; .rooms
.timeline
.get_pdu_count(&in_reply_to.event_id)?
{
services()
.rooms
.pdu_metadata
.add_relation(PduCount::Normal(count2), related_pducount)?;
} }
}, },
Relation::Thread(thread) => { Relation::Thread(thread) => {
services().rooms.threads.add_to_thread(&thread.event_id, pdu)?; services()
.rooms
.threads
.add_to_thread(&thread.event_id, pdu)?;
}, },
_ => {}, // TODO: Aggregate other types _ => {}, // TODO: Aggregate other types
} }
} }
for appservice in services().appservice.read().await.values() { for appservice in services().appservice.read().await.values() {
if services().rooms.state_cache.appservice_in_room(&pdu.room_id, appservice)? { if services()
services().sending.send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; .rooms
.state_cache
.appservice_in_room(&pdu.room_id, appservice)?
{
services()
.sending
.send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?;
continue; continue;
} }
// If the RoomMember event has a non-empty state_key, it is targeted at someone. // If the RoomMember event has a non-empty state_key, it is targeted at someone.
// If it is our appservice user, we send this PDU to it. // If it is our appservice user, we send this PDU to it.
if pdu.kind == TimelineEventType::RoomMember { if pdu.kind == TimelineEventType::RoomMember {
if let Some(state_key_uid) = if let Some(state_key_uid) = &pdu
&pdu.state_key.as_ref().and_then(|state_key| UserId::parse(state_key.as_str()).ok()) .state_key
.as_ref()
.and_then(|state_key| UserId::parse(state_key.as_str()).ok())
{ {
let appservice_uid = appservice.registration.sender_localpart.as_str(); let appservice_uid = appservice.registration.sender_localpart.as_str();
if state_key_uid == appservice_uid { if state_key_uid == appservice_uid {
services().sending.send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; services()
.sending
.send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?;
continue; continue;
} }
} }
@ -534,7 +608,10 @@ impl Service {
let matching_users = |users: &NamespaceRegex| { let matching_users = |users: &NamespaceRegex| {
appservice.users.is_match(pdu.sender.as_str()) appservice.users.is_match(pdu.sender.as_str())
|| pdu.kind == TimelineEventType::RoomMember || pdu.kind == TimelineEventType::RoomMember
&& pdu.state_key.as_ref().map_or(false, |state_key| users.is_match(state_key)) && pdu
.state_key
.as_ref()
.map_or(false, |state_key| users.is_match(state_key))
}; };
let matching_aliases = |aliases: &NamespaceRegex| { let matching_aliases = |aliases: &NamespaceRegex| {
services() services()
@ -549,7 +626,9 @@ impl Service {
|| appservice.rooms.is_match(pdu.room_id.as_str()) || appservice.rooms.is_match(pdu.room_id.as_str())
|| matching_users(&appservice.users) || matching_users(&appservice.users)
{ {
services().sending.send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; services()
.sending
.send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?;
} }
} }
@ -571,11 +650,20 @@ impl Service {
redacts, redacts,
} = pdu_builder; } = pdu_builder;
let prev_events: Vec<_> = let prev_events: Vec<_> = services()
services().rooms.state.get_forward_extremities(room_id)?.into_iter().take(20).collect(); .rooms
.state
.get_forward_extremities(room_id)?
.into_iter()
.take(20)
.collect();
// If there was no create event yet, assume we are creating a room // If there was no create event yet, assume we are creating a room
let room_version_id = services().rooms.state.get_room_version(room_id).or_else(|_| { let room_version_id = services()
.rooms
.state
.get_room_version(room_id)
.or_else(|_| {
if event_type == TimelineEventType::RoomCreate { if event_type == TimelineEventType::RoomCreate {
#[derive(Deserialize)] #[derive(Deserialize)]
struct RoomCreate { struct RoomCreate {
@ -595,7 +683,10 @@ impl Service {
let room_version = RoomVersion::new(&room_version_id).expect("room version is supported"); let room_version = RoomVersion::new(&room_version_id).expect("room version is supported");
let auth_events = let auth_events =
services().rooms.state.get_auth_events(room_id, &event_type, sender, state_key.as_deref(), &content)?; services()
.rooms
.state
.get_auth_events(room_id, &event_type, sender, state_key.as_deref(), &content)?;
// Our depth is the maximum depth of prev_events + 1 // Our depth is the maximum depth of prev_events + 1
let depth = prev_events let depth = prev_events
@ -609,7 +700,10 @@ impl Service {
if let Some(state_key) = &state_key { if let Some(state_key) = &state_key {
if let Some(prev_pdu) = if let Some(prev_pdu) =
services().rooms.state_accessor.room_state_get(room_id, &event_type.to_string().into(), state_key)? services()
.rooms
.state_accessor
.room_state_get(room_id, &event_type.to_string().into(), state_key)?
{ {
unsigned.insert( unsigned.insert(
"prev_content".to_owned(), "prev_content".to_owned(),
@ -626,13 +720,18 @@ impl Service {
event_id: ruma::event_id!("$thiswillbefilledinlater").into(), event_id: ruma::event_id!("$thiswillbefilledinlater").into(),
room_id: room_id.to_owned(), room_id: room_id.to_owned(),
sender: sender.to_owned(), sender: sender.to_owned(),
origin_server_ts: utils::millis_since_unix_epoch().try_into().expect("time is valid"), origin_server_ts: utils::millis_since_unix_epoch()
.try_into()
.expect("time is valid"),
kind: event_type, kind: event_type,
content, content,
state_key, state_key,
prev_events, prev_events,
depth, depth,
auth_events: auth_events.values().map(|pdu| pdu.event_id.clone()).collect(), auth_events: auth_events
.values()
.map(|pdu| pdu.event_id.clone())
.collect(),
redacts, redacts,
unsigned: if unsigned.is_empty() { unsigned: if unsigned.is_empty() {
None None
@ -710,7 +809,10 @@ impl Service {
); );
// Generate short event id // Generate short event id
let _shorteventid = services().rooms.short.get_or_create_shorteventid(&pdu.event_id)?; let _shorteventid = services()
.rooms
.short
.get_or_create_shorteventid(&pdu.event_id)?;
Ok((pdu, pdu_json)) Ok((pdu, pdu_json))
} }
@ -744,7 +846,10 @@ impl Service {
membership: MembershipState, membership: MembershipState,
} }
let target = pdu.state_key().filter(|v| v.starts_with('@')).unwrap_or(sender.as_str()); let target = pdu
.state_key()
.filter(|v| v.starts_with('@'))
.unwrap_or(sender.as_str());
let server_name = services().globals.server_name(); let server_name = services().globals.server_name();
let server_user = format!("@conduit:{server_name}"); let server_user = format!("@conduit:{server_name}");
let content = serde_json::from_str::<ExtractMembership>(pdu.content.get()) let content = serde_json::from_str::<ExtractMembership>(pdu.content.get())
@ -825,16 +930,25 @@ impl Service {
// We set the room state after inserting the pdu, so that we never have a moment // We set the room state after inserting the pdu, so that we never have a moment
// in time where events in the current room state do not exist // in time where events in the current room state do not exist
services().rooms.state.set_room_state(room_id, statehashid, state_lock)?; services()
.rooms
.state
.set_room_state(room_id, statehashid, state_lock)?;
let mut servers: HashSet<OwnedServerName> = let mut servers: HashSet<OwnedServerName> = services()
services().rooms.state_cache.room_servers(room_id).filter_map(Result::ok).collect(); .rooms
.state_cache
.room_servers(room_id)
.filter_map(Result::ok)
.collect();
// In case we are kicking or banning a user, we need to inform their server of // In case we are kicking or banning a user, we need to inform their server of
// the change // the change
if pdu.kind == TimelineEventType::RoomMember { if pdu.kind == TimelineEventType::RoomMember {
if let Some(state_key_uid) = if let Some(state_key_uid) = &pdu
&pdu.state_key.as_ref().and_then(|state_key| UserId::parse(state_key.as_str()).ok()) .state_key
.as_ref()
.and_then(|state_key| UserId::parse(state_key.as_str()).ok())
{ {
servers.insert(state_key_uid.server_name().to_owned()); servers.insert(state_key_uid.server_name().to_owned());
} }
@ -864,15 +978,28 @@ impl Service {
// We append to state before appending the pdu, so we don't have a moment in // We append to state before appending the pdu, so we don't have a moment in
// time with the pdu without it's state. This is okay because append_pdu can't // time with the pdu without it's state. This is okay because append_pdu can't
// fail. // fail.
services().rooms.state.set_event_state(&pdu.event_id, &pdu.room_id, state_ids_compressed)?; services()
.rooms
.state
.set_event_state(&pdu.event_id, &pdu.room_id, state_ids_compressed)?;
if soft_fail { if soft_fail {
services().rooms.pdu_metadata.mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; services()
services().rooms.state.set_forward_extremities(&pdu.room_id, new_room_leaves, state_lock)?; .rooms
.pdu_metadata
.mark_as_referenced(&pdu.room_id, &pdu.prev_events)?;
services()
.rooms
.state
.set_forward_extremities(&pdu.room_id, new_room_leaves, state_lock)?;
return Ok(None); return Ok(None);
} }
let pdu_id = services().rooms.timeline.append_pdu(pdu, pdu_json, new_room_leaves, state_lock).await?; let pdu_id = services()
.rooms
.timeline
.append_pdu(pdu, pdu_json, new_room_leaves, state_lock)
.await?;
Ok(Some(pdu_id)) Ok(Some(pdu_id))
} }
@ -908,8 +1035,9 @@ impl Service {
pub fn redact_pdu(&self, event_id: &EventId, reason: &PduEvent) -> Result<()> { pub fn redact_pdu(&self, event_id: &EventId, reason: &PduEvent) -> Result<()> {
// TODO: Don't reserialize, keep original json // TODO: Don't reserialize, keep original json
if let Some(pdu_id) = self.get_pdu_id(event_id)? { if let Some(pdu_id) = self.get_pdu_id(event_id)? {
let mut pdu = let mut pdu = self
self.get_pdu_from_id(&pdu_id)?.ok_or_else(|| Error::bad_database("PDU ID points to invalid PDU."))?; .get_pdu_from_id(&pdu_id)?
.ok_or_else(|| Error::bad_database("PDU ID points to invalid PDU."))?;
let room_version_id = services().rooms.state.get_room_version(&pdu.room_id)?; let room_version_id = services().rooms.state.get_room_version(&pdu.room_id)?;
pdu.redact(room_version_id, reason)?; pdu.redact(room_version_id, reason)?;
self.replace_pdu( self.replace_pdu(
@ -927,8 +1055,10 @@ impl Service {
#[tracing::instrument(skip(self, room_id))] #[tracing::instrument(skip(self, room_id))]
pub async fn backfill_if_required(&self, room_id: &RoomId, from: PduCount) -> Result<()> { pub async fn backfill_if_required(&self, room_id: &RoomId, from: PduCount) -> Result<()> {
let first_pdu = let first_pdu = self
self.all_pdus(user_id!("@doesntmatter:conduit.rs"), room_id)?.next().expect("Room is not empty")?; .all_pdus(user_id!("@doesntmatter:conduit.rs"), room_id)?
.next()
.expect("Room is not empty")?;
if first_pdu.0 < from { if first_pdu.0 < from {
// No backfill required, there are still events between them // No backfill required, there are still events between them
@ -938,7 +1068,11 @@ impl Service {
let mut servers: Vec<&ServerName> = vec![]; let mut servers: Vec<&ServerName> = vec![];
// add server names from room aliases on the room ID // add server names from room aliases on the room ID
let room_aliases = services().rooms.alias.local_aliases_for_room(room_id).collect::<Result<Vec<_>, _>>(); let room_aliases = services()
.rooms
.alias
.local_aliases_for_room(room_id)
.collect::<Result<Vec<_>, _>>();
if let Ok(aliases) = &room_aliases { if let Ok(aliases) = &room_aliases {
for alias in aliases { for alias in aliases {
if alias.server_name() != services().globals.server_name() { if alias.server_name() != services().globals.server_name() {
@ -979,8 +1113,10 @@ impl Service {
} }
// don't backfill from ourselves (might be noop if we checked it above already) // don't backfill from ourselves (might be noop if we checked it above already)
if let Some(server_index) = if let Some(server_index) = servers
servers.clone().into_iter().position(|server| server == services().globals.server_name()) .clone()
.into_iter()
.position(|server| server == services().globals.server_name())
{ {
servers.remove(server_index); servers.remove(server_index);
} }
@ -1030,8 +1166,15 @@ impl Service {
let (event_id, value, room_id) = server_server::parse_incoming_pdu(&pdu)?; let (event_id, value, room_id) = server_server::parse_incoming_pdu(&pdu)?;
// Lock so we cannot backfill the same pdu twice at the same time // Lock so we cannot backfill the same pdu twice at the same time
let mutex = let mutex = Arc::clone(
Arc::clone(services().globals.roomid_mutex_federation.write().await.entry(room_id.clone()).or_default()); services()
.globals
.roomid_mutex_federation
.write()
.await
.entry(room_id.clone())
.or_default(),
);
let mutex_lock = mutex.lock().await; let mutex_lock = mutex.lock().await;
// Skip the PDU if we already have it as a timeline event // Skip the PDU if we already have it as a timeline event
@ -1040,7 +1183,11 @@ impl Service {
return Ok(()); return Ok(());
} }
services().rooms.event_handler.fetch_required_signing_keys([&value], pub_key_map).await?; services()
.rooms
.event_handler
.fetch_required_signing_keys([&value], pub_key_map)
.await?;
services() services()
.rooms .rooms
@ -1051,10 +1198,21 @@ impl Service {
let value = self.get_pdu_json(&event_id)?.expect("We just created it"); let value = self.get_pdu_json(&event_id)?.expect("We just created it");
let pdu = self.get_pdu(&event_id)?.expect("We just created it"); let pdu = self.get_pdu(&event_id)?.expect("We just created it");
let shortroomid = services().rooms.short.get_shortroomid(&room_id)?.expect("room exists"); let shortroomid = services()
.rooms
.short
.get_shortroomid(&room_id)?
.expect("room exists");
let mutex_insert = let mutex_insert = Arc::clone(
Arc::clone(services().globals.roomid_mutex_insert.write().await.entry(room_id.clone()).or_default()); services()
.globals
.roomid_mutex_insert
.write()
.await
.entry(room_id.clone())
.or_default(),
);
let insert_lock = mutex_insert.lock().await; let insert_lock = mutex_insert.lock().await;
let count = services().globals.next_count()?; let count = services().globals.next_count()?;
@ -1077,7 +1235,10 @@ impl Service {
.map_err(|_| Error::bad_database("Invalid content in pdu."))?; .map_err(|_| Error::bad_database("Invalid content in pdu."))?;
if let Some(body) = content.body { if let Some(body) = content.body {
services().rooms.search.index_pdu(shortroomid, &pdu_id, &body)?; services()
.rooms
.search
.index_pdu(shortroomid, &pdu_id, &body)?;
} }
} }
drop(mutex_lock); drop(mutex_lock);

View file

@ -27,7 +27,8 @@ impl Service {
} }
pub fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) -> Result<()> { pub fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) -> Result<()> {
self.db.associate_token_shortstatehash(room_id, token, shortstatehash) self.db
.associate_token_shortstatehash(room_id, token, shortstatehash)
} }
pub fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> { pub fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> {

View file

@ -127,7 +127,9 @@ impl Service {
let mut initial_transactions = HashMap::<OutgoingKind, Vec<SendingEventType>>::new(); let mut initial_transactions = HashMap::<OutgoingKind, Vec<SendingEventType>>::new();
for (key, outgoing_kind, event) in self.db.active_requests().filter_map(Result::ok) { for (key, outgoing_kind, event) in self.db.active_requests().filter_map(Result::ok) {
let entry = initial_transactions.entry(outgoing_kind.clone()).or_default(); let entry = initial_transactions
.entry(outgoing_kind.clone())
.or_default();
if entry.len() > 30 { if entry.len() > 30 {
warn!("Dropping some current events: {:?} {:?} {:?}", key, outgoing_kind, event); warn!("Dropping some current events: {:?} {:?} {:?}", key, outgoing_kind, event);
@ -236,7 +238,11 @@ impl Service {
if retry { if retry {
// We retry the previous transaction // We retry the previous transaction
for (_, e) in self.db.active_requests_for(outgoing_kind).filter_map(Result::ok) { for (_, e) in self
.db
.active_requests_for(outgoing_kind)
.filter_map(Result::ok)
{
events.push(e); events.push(e);
} }
} else { } else {
@ -282,7 +288,12 @@ impl Service {
// Look for presence updates in this room // Look for presence updates in this room
let mut presence_updates = Vec::new(); let mut presence_updates = Vec::new();
for (user_id, count, presence_event) in services().rooms.edus.presence.presence_since(&room_id, since) { for (user_id, count, presence_event) in services()
.rooms
.edus
.presence
.presence_since(&room_id, since)
{
if count > max_edu_count { if count > max_edu_count {
max_edu_count = count; max_edu_count = count;
} }
@ -295,7 +306,10 @@ impl Service {
user_id, user_id,
presence: presence_event.content.presence, presence: presence_event.content.presence,
currently_active: presence_event.content.currently_active.unwrap_or(false), currently_active: presence_event.content.currently_active.unwrap_or(false),
last_active_ago: presence_event.content.last_active_ago.unwrap_or_else(|| uint!(0)), last_active_ago: presence_event
.content
.last_active_ago
.unwrap_or_else(|| uint!(0)),
status_msg: presence_event.content.status_msg, status_msg: presence_event.content.status_msg,
}); });
} }
@ -305,7 +319,12 @@ impl Service {
} }
// Look for read receipts in this room // Look for read receipts in this room
for r in services().rooms.edus.read_receipt.readreceipts_since(&room_id, since) { for r in services()
.rooms
.edus
.read_receipt
.readreceipts_since(&room_id, since)
{
let (user_id, count, read_receipt) = r?; let (user_id, count, read_receipt) = r?;
if count > max_edu_count { if count > max_edu_count {
@ -321,8 +340,12 @@ impl Service {
let federation_event = if let AnySyncEphemeralRoomEvent::Receipt(r) = event { let federation_event = if let AnySyncEphemeralRoomEvent::Receipt(r) = event {
let mut read = BTreeMap::new(); let mut read = BTreeMap::new();
let (event_id, mut receipt) = let (event_id, mut receipt) = r
r.content.0.into_iter().next().expect("we only use one event per read receipt"); .content
.0
.into_iter()
.next()
.expect("we only use one event per read receipt");
let receipt = receipt let receipt = receipt
.remove(&ReceiptType::Read) .remove(&ReceiptType::Read)
.expect("our read receipts always set this") .expect("our read receipts always set this")
@ -385,7 +408,9 @@ impl Service {
let event = SendingEventType::Pdu(pdu_id.to_owned()); let event = SendingEventType::Pdu(pdu_id.to_owned());
let _cork = services().globals.db.cork()?; let _cork = services().globals.db.cork()?;
let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?;
self.sender.send((outgoing_kind, event, keys.into_iter().next().unwrap())).unwrap(); self.sender
.send((outgoing_kind, event, keys.into_iter().next().unwrap()))
.unwrap();
Ok(()) Ok(())
} }
@ -397,9 +422,16 @@ impl Service {
.map(|server| (OutgoingKind::Normal(server), SendingEventType::Pdu(pdu_id.to_owned()))) .map(|server| (OutgoingKind::Normal(server), SendingEventType::Pdu(pdu_id.to_owned())))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let _cork = services().globals.db.cork()?; let _cork = services().globals.db.cork()?;
let keys = self.db.queue_requests(&requests.iter().map(|(o, e)| (o, e.clone())).collect::<Vec<_>>())?; let keys = self.db.queue_requests(
&requests
.iter()
.map(|(o, e)| (o, e.clone()))
.collect::<Vec<_>>(),
)?;
for ((outgoing_kind, event), key) in requests.into_iter().zip(keys) { for ((outgoing_kind, event), key) in requests.into_iter().zip(keys) {
self.sender.send((outgoing_kind.clone(), event, key)).unwrap(); self.sender
.send((outgoing_kind.clone(), event, key))
.unwrap();
} }
Ok(()) Ok(())
@ -411,7 +443,9 @@ impl Service {
let event = SendingEventType::Edu(serialized); let event = SendingEventType::Edu(serialized);
let _cork = services().globals.db.cork()?; let _cork = services().globals.db.cork()?;
let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?;
self.sender.send((outgoing_kind, event, keys.into_iter().next().unwrap())).unwrap(); self.sender
.send((outgoing_kind, event, keys.into_iter().next().unwrap()))
.unwrap();
Ok(()) Ok(())
} }
@ -422,15 +456,21 @@ impl Service {
let event = SendingEventType::Pdu(pdu_id); let event = SendingEventType::Pdu(pdu_id);
let _cork = services().globals.db.cork()?; let _cork = services().globals.db.cork()?;
let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?;
self.sender.send((outgoing_kind, event, keys.into_iter().next().unwrap())).unwrap(); self.sender
.send((outgoing_kind, event, keys.into_iter().next().unwrap()))
.unwrap();
Ok(()) Ok(())
} }
#[tracing::instrument(skip(self, room_id))] #[tracing::instrument(skip(self, room_id))]
pub fn flush_room(&self, room_id: &RoomId) -> Result<()> { pub fn flush_room(&self, room_id: &RoomId) -> Result<()> {
let servers: HashSet<OwnedServerName> = let servers: HashSet<OwnedServerName> = services()
services().rooms.state_cache.room_servers(room_id).filter_map(Result::ok).collect(); .rooms
.state_cache
.room_servers(room_id)
.filter_map(Result::ok)
.collect();
self.flush_servers(servers.into_iter()) self.flush_servers(servers.into_iter())
} }
@ -444,7 +484,9 @@ impl Service {
.collect::<Vec<_>>(); .collect::<Vec<_>>();
for outgoing_kind in requests { for outgoing_kind in requests {
self.sender.send((outgoing_kind, SendingEventType::Flush, Vec::<u8>::new())).unwrap(); self.sender
.send((outgoing_kind, SendingEventType::Flush, Vec::<u8>::new()))
.unwrap();
} }
Ok(()) Ok(())
@ -454,7 +496,8 @@ impl Service {
/// Used for instance after we remove an appservice registration /// Used for instance after we remove an appservice registration
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn cleanup_events(&self, appservice_id: String) -> Result<()> { pub fn cleanup_events(&self, appservice_id: String) -> Result<()> {
self.db.delete_all_requests_for(&OutgoingKind::Appservice(appservice_id))?; self.db
.delete_all_requests_for(&OutgoingKind::Appservice(appservice_id))?;
Ok(()) Ok(())
} }
@ -497,7 +540,11 @@ impl Service {
let permit = services().sending.maximum_requests.acquire().await; let permit = services().sending.maximum_requests.acquire().await;
let response = match appservice_server::send_request( let response = match appservice_server::send_request(
services().appservice.get_registration(id).await.ok_or_else(|| { services()
.appservice
.get_registration(id)
.await
.ok_or_else(|| {
( (
kind.clone(), kind.clone(),
Error::bad_database("[Appservice] Could not load registration from db."), Error::bad_database("[Appservice] Could not load registration from db."),
@ -520,7 +567,9 @@ impl Service {
.await .await
{ {
None => Ok(kind.clone()), None => Ok(kind.clone()),
Some(op_resp) => op_resp.map(|_response| kind.clone()).map_err(|e| (kind.clone(), e)), Some(op_resp) => op_resp
.map(|_response| kind.clone())
.map_err(|e| (kind.clone(), e)),
}; };
drop(permit); drop(permit);

View file

@ -44,7 +44,10 @@ impl Service {
pub fn exists(&self, user_id: &UserId) -> Result<bool> { self.db.exists(user_id) } pub fn exists(&self, user_id: &UserId) -> Result<bool> { self.db.exists(user_id) }
pub fn forget_sync_request_connection(&self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String) { pub fn forget_sync_request_connection(&self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String) {
self.connections.lock().unwrap().remove(&(user_id, device_id, conn_id)); self.connections
.lock()
.unwrap()
.remove(&(user_id, device_id, conn_id));
} }
pub fn update_sync_request_with_cache( pub fn update_sync_request_with_cache(
@ -55,14 +58,18 @@ impl Service {
}; };
let mut cache = self.connections.lock().unwrap(); let mut cache = self.connections.lock().unwrap();
let cached = Arc::clone(cache.entry((user_id, device_id, conn_id)).or_insert_with(|| { let cached = Arc::clone(
cache
.entry((user_id, device_id, conn_id))
.or_insert_with(|| {
Arc::new(Mutex::new(SlidingSyncCache { Arc::new(Mutex::new(SlidingSyncCache {
lists: BTreeMap::new(), lists: BTreeMap::new(),
subscriptions: BTreeMap::new(), subscriptions: BTreeMap::new(),
known_rooms: BTreeMap::new(), known_rooms: BTreeMap::new(),
extensions: ExtensionsConfig::default(), extensions: ExtensionsConfig::default(),
})) }))
})); }),
);
let cached = &mut cached.lock().unwrap(); let cached = &mut cached.lock().unwrap();
drop(cache); drop(cache);
@ -72,12 +79,18 @@ impl Service {
list.sort.clone_from(&cached_list.sort); list.sort.clone_from(&cached_list.sort);
}; };
if list.room_details.required_state.is_empty() { if list.room_details.required_state.is_empty() {
list.room_details.required_state.clone_from(&cached_list.room_details.required_state); list.room_details
.required_state
.clone_from(&cached_list.room_details.required_state);
}; };
list.room_details.timeline_limit = list.room_details.timeline_limit = list
list.room_details.timeline_limit.or(cached_list.room_details.timeline_limit); .room_details
list.include_old_rooms = .timeline_limit
list.include_old_rooms.clone().or_else(|| cached_list.include_old_rooms.clone()); .or(cached_list.room_details.timeline_limit);
list.include_old_rooms = list
.include_old_rooms
.clone()
.or_else(|| cached_list.include_old_rooms.clone());
match (&mut list.filters, cached_list.filters.clone()) { match (&mut list.filters, cached_list.filters.clone()) {
(Some(list_filters), Some(cached_filters)) => { (Some(list_filters), Some(cached_filters)) => {
list_filters.is_dm = list_filters.is_dm.or(cached_filters.is_dm); list_filters.is_dm = list_filters.is_dm.or(cached_filters.is_dm);
@ -92,8 +105,10 @@ impl Service {
if list_filters.not_room_types.is_empty() { if list_filters.not_room_types.is_empty() {
list_filters.not_room_types = cached_filters.not_room_types; list_filters.not_room_types = cached_filters.not_room_types;
} }
list_filters.room_name_like = list_filters.room_name_like = list_filters
list_filters.room_name_like.clone().or(cached_filters.room_name_like); .room_name_like
.clone()
.or(cached_filters.room_name_like);
if list_filters.tags.is_empty() { if list_filters.tags.is_empty() {
list_filters.tags = cached_filters.tags; list_filters.tags = cached_filters.tags;
} }
@ -106,26 +121,49 @@ impl Service {
(..) => {}, (..) => {},
} }
if list.bump_event_types.is_empty() { if list.bump_event_types.is_empty() {
list.bump_event_types.clone_from(&cached_list.bump_event_types); list.bump_event_types
.clone_from(&cached_list.bump_event_types);
}; };
} }
cached.lists.insert(list_id.clone(), list.clone()); cached.lists.insert(list_id.clone(), list.clone());
} }
cached.subscriptions.extend(request.room_subscriptions.clone()); cached
request.room_subscriptions.extend(cached.subscriptions.clone()); .subscriptions
.extend(request.room_subscriptions.clone());
request
.room_subscriptions
.extend(cached.subscriptions.clone());
request.extensions.e2ee.enabled = request.extensions.e2ee.enabled.or(cached.extensions.e2ee.enabled); request.extensions.e2ee.enabled = request
.extensions
.e2ee
.enabled
.or(cached.extensions.e2ee.enabled);
request.extensions.to_device.enabled = request.extensions.to_device.enabled = request
request.extensions.to_device.enabled.or(cached.extensions.to_device.enabled); .extensions
.to_device
.enabled
.or(cached.extensions.to_device.enabled);
request.extensions.account_data.enabled = request.extensions.account_data.enabled = request
request.extensions.account_data.enabled.or(cached.extensions.account_data.enabled); .extensions
request.extensions.account_data.lists = .account_data
request.extensions.account_data.lists.clone().or_else(|| cached.extensions.account_data.lists.clone()); .enabled
request.extensions.account_data.rooms = .or(cached.extensions.account_data.enabled);
request.extensions.account_data.rooms.clone().or_else(|| cached.extensions.account_data.rooms.clone()); request.extensions.account_data.lists = request
.extensions
.account_data
.lists
.clone()
.or_else(|| cached.extensions.account_data.lists.clone());
request.extensions.account_data.rooms = request
.extensions
.account_data
.rooms
.clone()
.or_else(|| cached.extensions.account_data.rooms.clone());
cached.extensions = request.extensions.clone(); cached.extensions = request.extensions.clone();
@ -137,14 +175,18 @@ impl Service {
subscriptions: BTreeMap<OwnedRoomId, sync_events::v4::RoomSubscription>, subscriptions: BTreeMap<OwnedRoomId, sync_events::v4::RoomSubscription>,
) { ) {
let mut cache = self.connections.lock().unwrap(); let mut cache = self.connections.lock().unwrap();
let cached = Arc::clone(cache.entry((user_id, device_id, conn_id)).or_insert_with(|| { let cached = Arc::clone(
cache
.entry((user_id, device_id, conn_id))
.or_insert_with(|| {
Arc::new(Mutex::new(SlidingSyncCache { Arc::new(Mutex::new(SlidingSyncCache {
lists: BTreeMap::new(), lists: BTreeMap::new(),
subscriptions: BTreeMap::new(), subscriptions: BTreeMap::new(),
known_rooms: BTreeMap::new(), known_rooms: BTreeMap::new(),
extensions: ExtensionsConfig::default(), extensions: ExtensionsConfig::default(),
})) }))
})); }),
);
let cached = &mut cached.lock().unwrap(); let cached = &mut cached.lock().unwrap();
drop(cache); drop(cache);
@ -156,18 +198,27 @@ impl Service {
new_cached_rooms: BTreeSet<OwnedRoomId>, globalsince: u64, new_cached_rooms: BTreeSet<OwnedRoomId>, globalsince: u64,
) { ) {
let mut cache = self.connections.lock().unwrap(); let mut cache = self.connections.lock().unwrap();
let cached = Arc::clone(cache.entry((user_id, device_id, conn_id)).or_insert_with(|| { let cached = Arc::clone(
cache
.entry((user_id, device_id, conn_id))
.or_insert_with(|| {
Arc::new(Mutex::new(SlidingSyncCache { Arc::new(Mutex::new(SlidingSyncCache {
lists: BTreeMap::new(), lists: BTreeMap::new(),
subscriptions: BTreeMap::new(), subscriptions: BTreeMap::new(),
known_rooms: BTreeMap::new(), known_rooms: BTreeMap::new(),
extensions: ExtensionsConfig::default(), extensions: ExtensionsConfig::default(),
})) }))
})); }),
);
let cached = &mut cached.lock().unwrap(); let cached = &mut cached.lock().unwrap();
drop(cache); drop(cache);
for (roomid, lastsince) in cached.known_rooms.entry(list_id.clone()).or_default().iter_mut() { for (roomid, lastsince) in cached
.known_rooms
.entry(list_id.clone())
.or_default()
.iter_mut()
{
if !new_cached_rooms.contains(roomid) { if !new_cached_rooms.contains(roomid) {
*lastsince = 0; *lastsince = 0;
} }
@ -185,9 +236,16 @@ impl Service {
pub fn is_admin(&self, user_id: &UserId) -> Result<bool> { pub fn is_admin(&self, user_id: &UserId) -> Result<bool> {
let admin_room_alias_id = RoomAliasId::parse(format!("#admins:{}", services().globals.server_name())) let admin_room_alias_id = RoomAliasId::parse(format!("#admins:{}", services().globals.server_name()))
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid alias."))?; .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid alias."))?;
let admin_room_id = services().rooms.alias.resolve_local_alias(&admin_room_alias_id)?.unwrap(); let admin_room_id = services()
.rooms
.alias
.resolve_local_alias(&admin_room_alias_id)?
.unwrap();
services().rooms.state_cache.is_joined(user_id, &admin_room_id) services()
.rooms
.state_cache
.is_joined(user_id, &admin_room_id)
} }
/// Create a new user account on this homeserver. /// Create a new user account on this homeserver.
@ -250,7 +308,8 @@ impl Service {
pub fn create_device( pub fn create_device(
&self, user_id: &UserId, device_id: &DeviceId, token: &str, initial_device_display_name: Option<String>, &self, user_id: &UserId, device_id: &DeviceId, token: &str, initial_device_display_name: Option<String>,
) -> Result<()> { ) -> Result<()> {
self.db.create_device(user_id, device_id, token, initial_device_display_name) self.db
.create_device(user_id, device_id, token, initial_device_display_name)
} }
/// Removes a device from a user. /// Removes a device from a user.
@ -272,7 +331,8 @@ impl Service {
&self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &DeviceKeyId, &self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &DeviceKeyId,
one_time_key_value: &Raw<OneTimeKey>, one_time_key_value: &Raw<OneTimeKey>,
) -> Result<()> { ) -> Result<()> {
self.db.add_one_time_key(user_id, device_id, one_time_key_key, one_time_key_value) self.db
.add_one_time_key(user_id, device_id, one_time_key_key, one_time_key_value)
} }
pub fn last_one_time_keys_update(&self, user_id: &UserId) -> Result<u64> { pub fn last_one_time_keys_update(&self, user_id: &UserId) -> Result<u64> {
@ -299,7 +359,8 @@ impl Service {
&self, user_id: &UserId, master_key: &Raw<CrossSigningKey>, self_signing_key: &Option<Raw<CrossSigningKey>>, &self, user_id: &UserId, master_key: &Raw<CrossSigningKey>, self_signing_key: &Option<Raw<CrossSigningKey>>,
user_signing_key: &Option<Raw<CrossSigningKey>>, notify: bool, user_signing_key: &Option<Raw<CrossSigningKey>>, notify: bool,
) -> Result<()> { ) -> Result<()> {
self.db.add_cross_signing_keys(user_id, master_key, self_signing_key, user_signing_key, notify) self.db
.add_cross_signing_keys(user_id, master_key, self_signing_key, user_signing_key, notify)
} }
pub fn sign_key( pub fn sign_key(
@ -329,19 +390,22 @@ impl Service {
pub fn get_key( pub fn get_key(
&self, key: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, &self, key: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool,
) -> Result<Option<Raw<CrossSigningKey>>> { ) -> Result<Option<Raw<CrossSigningKey>>> {
self.db.get_key(key, sender_user, user_id, allowed_signatures) self.db
.get_key(key, sender_user, user_id, allowed_signatures)
} }
pub fn get_master_key( pub fn get_master_key(
&self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool,
) -> Result<Option<Raw<CrossSigningKey>>> { ) -> Result<Option<Raw<CrossSigningKey>>> {
self.db.get_master_key(sender_user, user_id, allowed_signatures) self.db
.get_master_key(sender_user, user_id, allowed_signatures)
} }
pub fn get_self_signing_key( pub fn get_self_signing_key(
&self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool,
) -> Result<Option<Raw<CrossSigningKey>>> { ) -> Result<Option<Raw<CrossSigningKey>>> {
self.db.get_self_signing_key(sender_user, user_id, allowed_signatures) self.db
.get_self_signing_key(sender_user, user_id, allowed_signatures)
} }
pub fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<Raw<CrossSigningKey>>> { pub fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<Raw<CrossSigningKey>>> {
@ -352,7 +416,8 @@ impl Service {
&self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str, &self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str,
content: serde_json::Value, content: serde_json::Value,
) -> Result<()> { ) -> Result<()> {
self.db.add_to_device_event(sender, target_user_id, target_device_id, event_type, content) self.db
.add_to_device_event(sender, target_user_id, target_device_id, event_type, content)
} }
pub fn get_to_device_events(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Vec<Raw<AnyToDeviceEvent>>> { pub fn get_to_device_events(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Vec<Raw<AnyToDeviceEvent>>> {
@ -411,7 +476,10 @@ impl Service {
pub fn clean_signatures<F: Fn(&UserId) -> bool>( pub fn clean_signatures<F: Fn(&UserId) -> bool>(
cross_signing_key: &mut serde_json::Value, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: F, cross_signing_key: &mut serde_json::Value, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: F,
) -> Result<(), Error> { ) -> Result<(), Error> {
if let Some(signatures) = cross_signing_key.get_mut("signatures").and_then(|v| v.as_object_mut()) { if let Some(signatures) = cross_signing_key
.get_mut("signatures")
.and_then(|v| v.as_object_mut())
{
// Don't allocate for the full size of the current signatures, but require // Don't allocate for the full size of the current signatures, but require
// at most one resize if nothing is dropped // at most one resize if nothing is dropped
let new_capacity = signatures.len() / 2; let new_capacity = signatures.len() / 2;

View file

@ -15,7 +15,10 @@ use ruma::{canonical_json::try_from_json_map, CanonicalJsonError, CanonicalJsonO
use crate::{services, Error, Result}; use crate::{services, Error, Result};
pub(crate) fn millis_since_unix_epoch() -> u64 { pub(crate) fn millis_since_unix_epoch() -> u64 {
SystemTime::now().duration_since(UNIX_EPOCH).expect("time is valid").as_millis() as u64 SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time is valid")
.as_millis() as u64
} }
pub(crate) fn increment(old: Option<&[u8]>) -> Vec<u8> { pub(crate) fn increment(old: Option<&[u8]>) -> Vec<u8> {
@ -59,13 +62,21 @@ pub fn user_id_from_bytes(bytes: &[u8]) -> Result<OwnedUserId> {
} }
pub fn random_string(length: usize) -> String { pub fn random_string(length: usize) -> String {
thread_rng().sample_iter(&rand::distributions::Alphanumeric).take(length).map(char::from).collect() thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
.take(length)
.map(char::from)
.collect()
} }
/// Calculate a new hash for the given password /// Calculate a new hash for the given password
pub fn calculate_password_hash(password: &str) -> Result<String, argon2::password_hash::Error> { pub fn calculate_password_hash(password: &str) -> Result<String, argon2::password_hash::Error> {
let salt = SaltString::generate(thread_rng()); let salt = SaltString::generate(thread_rng());
services().globals.argon.hash_password(password.as_bytes(), &salt).map(|it| it.to_string()) services()
.globals
.argon
.hash_password(password.as_bytes(), &salt)
.map(|it| it.to_string())
} }
#[tracing::instrument(skip(keys))] #[tracing::instrument(skip(keys))]