From 868976a14976602e967e9289050c77a76016462e Mon Sep 17 00:00:00 2001 From: strawberry Date: Mon, 25 Mar 2024 17:05:11 -0400 Subject: [PATCH] use chain_width 60 Signed-off-by: strawberry --- rustfmt.toml | 3 +- src/api/appservice_server.rs | 26 +- src/api/client_server/account.rs | 100 +- src/api/client_server/alias.rs | 60 +- src/api/client_server/backup.rs | 135 ++- src/api/client_server/context.rs | 51 +- src/api/client_server/device.rs | 24 +- src/api/client_server/directory.rs | 18 +- src/api/client_server/keys.rs | 111 ++- src/api/client_server/media.rs | 90 +- src/api/client_server/membership.rs | 285 +++++- src/api/client_server/message.rs | 47 +- src/api/client_server/presence.rs | 18 +- src/api/client_server/profile.rs | 106 +- src/api/client_server/push.rs | 34 +- src/api/client_server/read_marker.rs | 22 +- src/api/client_server/redact.rs | 11 +- src/api/client_server/relations.rs | 93 +- src/api/client_server/report.rs | 28 +- src/api/client_server/room.rs | 230 +++-- src/api/client_server/search.rs | 36 +- src/api/client_server/session.rs | 26 +- src/api/client_server/space.rs | 15 +- src/api/client_server/state.rs | 63 +- src/api/client_server/sync.rs | 938 +++++++++++------- src/api/client_server/tag.rs | 17 +- src/api/client_server/threads.rs | 14 +- src/api/client_server/to_device.rs | 10 +- src/api/client_server/typing.rs | 13 +- src/api/client_server/user_directory.rs | 39 +- src/api/ruma_wrapper/axum.rs | 31 +- src/api/server_server.rs | 438 ++++++-- src/config/mod.rs | 12 +- src/database/abstraction/rocksdb.rs | 21 +- src/database/abstraction/sqlite.rs | 27 +- src/database/key_value/account_data.rs | 27 +- src/database/key_value/appservice.rs | 9 +- src/database/key_value/globals.rs | 61 +- src/database/key_value/key_backups.rs | 87 +- src/database/key_value/media.rs | 70 +- src/database/key_value/pusher.rs | 4 +- src/database/key_value/rooms/alias.rs | 25 +- src/database/key_value/rooms/auth_chain.rs | 30 +- src/database/key_value/rooms/edus/presence.rs | 26 +- .../key_value/rooms/edus/read_receipt.rs | 23 +- src/database/key_value/rooms/metadata.rs | 7 +- src/database/key_value/rooms/outlier.rs | 16 +- src/database/key_value/rooms/pdu_metadata.rs | 13 +- src/database/key_value/rooms/search.rs | 8 +- src/database/key_value/rooms/short.rs | 76 +- src/database/key_value/rooms/state.rs | 18 +- .../key_value/rooms/state_accessor.rs | 53 +- src/database/key_value/rooms/state_cache.rs | 195 ++-- .../key_value/rooms/state_compressor.rs | 3 +- src/database/key_value/rooms/threads.rs | 23 +- src/database/key_value/rooms/timeline.rs | 75 +- src/database/key_value/rooms/user.rs | 56 +- src/database/key_value/sending.rs | 30 +- src/database/key_value/uiaa.rs | 14 +- src/database/key_value/users.rs | 251 +++-- src/database/mod.rs | 126 ++- src/lib.rs | 5 +- src/main.rs | 58 +- src/service/admin/appservice.rs | 12 +- src/service/admin/debug.rs | 18 +- src/service/admin/federation.rs | 6 +- src/service/admin/media.rs | 5 +- src/service/admin/mod.rs | 70 +- src/service/admin/room.rs | 30 +- src/service/admin/room_alias.rs | 46 +- src/service/admin/room_directory.rs | 30 +- src/service/admin/room_moderation.rs | 6 +- src/service/admin/user.rs | 44 +- src/service/appservice/mod.rs | 41 +- src/service/globals/mod.rs | 11 +- src/service/globals/resolver.rs | 6 +- src/service/key_backups/mod.rs | 6 +- src/service/media/mod.rs | 38 +- src/service/mod.rs | 82 +- src/service/pdu.rs | 10 +- src/service/pusher/mod.rs | 42 +- src/service/rooms/auth_chain/mod.rs | 31 +- src/service/rooms/edus/presence/mod.rs | 9 +- src/service/rooms/edus/typing/mod.rs | 37 +- src/service/rooms/event_handler/mod.rs | 408 ++++++-- src/service/rooms/lazy_loading/mod.rs | 6 +- src/service/rooms/pdu_metadata/mod.rs | 6 +- src/service/rooms/spaces/mod.rs | 124 ++- src/service/rooms/state/mod.rs | 129 ++- src/service/rooms/state_accessor/mod.rs | 94 +- src/service/rooms/state_cache/mod.rs | 10 +- src/service/rooms/state_compressor/mod.rs | 50 +- src/service/rooms/threads/mod.rs | 13 +- src/service/rooms/timeline/mod.rs | 309 ++++-- src/service/rooms/user/mod.rs | 3 +- src/service/sending/mod.rs | 95 +- src/service/users/mod.rs | 178 +++- src/utils/mod.rs | 17 +- 98 files changed, 4836 insertions(+), 1767 deletions(-) diff --git a/rustfmt.toml b/rustfmt.toml index 3216a9aa..114677d4 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -24,4 +24,5 @@ group_imports = "StdExternalCrate" newline_style = "Unix" use_field_init_shorthand = true use_small_heuristics = "Off" -use_try_shorthand = true \ No newline at end of file +use_try_shorthand = true +chain_width = 60 diff --git a/src/api/appservice_server.rs b/src/api/appservice_server.rs index 927d0612..ddc755b0 100644 --- a/src/api/appservice_server.rs +++ b/src/api/appservice_server.rs @@ -36,7 +36,11 @@ where "?" }; - parts.path_and_query = Some((old_path_and_query + symbol + "access_token=" + hs_token).parse().unwrap()); + parts.path_and_query = Some( + (old_path_and_query + symbol + "access_token=" + hs_token) + .parse() + .unwrap(), + ); *http_request.uri_mut() = parts.try_into().expect("our manipulation is always valid"); let mut reqwest_request = @@ -64,7 +68,13 @@ where } } - let mut response = match services().globals.client.appservice.execute(reqwest_request).await { + let mut response = match services() + .globals + .client + .appservice + .execute(reqwest_request) + .await + { Ok(r) => r, Err(e) => { warn!( @@ -77,10 +87,14 @@ where // reqwest::Response -> http::Response conversion let status = response.status(); - let mut http_response_builder = http::Response::builder().status(status).version(response.version()); + let mut http_response_builder = http::Response::builder() + .status(status) + .version(response.version()); mem::swap( response.headers_mut(), - http_response_builder.headers_mut().expect("http::response::Builder is usable"), + http_response_builder + .headers_mut() + .expect("http::response::Builder is usable"), ); let body = response.bytes().await.unwrap_or_else(|e| { @@ -99,7 +113,9 @@ where } let response = T::IncomingResponse::try_from_http_response( - http_response_builder.body(body).expect("reqwest body is valid http body"), + http_response_builder + .body(body) + .expect("reqwest body is valid http body"), ); Some(response.map_err(|_| { diff --git a/src/api/client_server/account.rs b/src/api/client_server/account.rs index 8792522f..4b172fd6 100644 --- a/src/api/client_server/account.rs +++ b/src/api/client_server/account.rs @@ -47,7 +47,11 @@ pub async fn get_register_available_route( return Err(Error::BadRequest(ErrorKind::UserInUse, "Desired user ID is already taken.")); } - if services().globals.forbidden_usernames().is_match(user_id.localpart()) { + if services() + .globals + .forbidden_usernames() + .is_match(user_id.localpart()) + { return Err(Error::BadRequest(ErrorKind::Unknown, "Username is forbidden.")); } @@ -129,7 +133,11 @@ pub async fn register_route(body: Ruma) -> Result) -> Result) -> Result) -> Result) -> }; if let Some(auth) = &body.auth { - let (worked, uiaainfo) = services().uiaa.try_auth(sender_user, sender_device, auth, &uiaainfo)?; + let (worked, uiaainfo) = services() + .uiaa + .try_auth(sender_user, sender_device, auth, &uiaainfo)?; if !worked { return Err(Error::Uiaa(uiaainfo)); } // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?; + services() + .uiaa + .create(sender_user, sender_device, &uiaainfo, &json)?; return Err(Error::Uiaa(uiaainfo)); } else { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); } - services().users.set_password(sender_user, Some(&body.new_password))?; + services() + .users + .set_password(sender_user, Some(&body.new_password))?; if body.logout_devices { // Logout all devices except the current one - for id in services().users.all_device_ids(sender_user).filter_map(Result::ok).filter(|id| id != sender_device) { + for id in services() + .users + .all_device_ids(sender_user) + .filter_map(Result::ok) + .filter(|id| id != sender_device) + { services().users.remove_device(sender_user, &id)?; } } info!("User {} changed their password.", sender_user); - services().admin.send_message(RoomMessageEventContent::notice_plain(format!( - "User {sender_user} changed their password." - ))); + services() + .admin + .send_message(RoomMessageEventContent::notice_plain(format!( + "User {sender_user} changed their password." + ))); Ok(change_password::v3::Response {}) } @@ -431,14 +473,18 @@ pub async fn deactivate_route(body: Ruma) -> Result) -> Result) -> Result return Err(Error::BadRequest(ErrorKind::InvalidParam, "Alias is from another server.")); } - if services().globals.forbidden_alias_names().is_match(body.room_alias.alias()) { + if services() + .globals + .forbidden_alias_names() + .is_match(body.room_alias.alias()) + { return Err(Error::BadRequest(ErrorKind::Unknown, "Room alias is forbidden.")); } - if services().rooms.alias.resolve_local_alias(&body.room_alias)?.is_some() { + if services() + .rooms + .alias + .resolve_local_alias(&body.room_alias)? + .is_some() + { return Err(Error::Conflict("Alias already exists.")); } - if services().rooms.alias.set_alias(&body.room_alias, &body.room_id).is_err() { + if services() + .rooms + .alias + .set_alias(&body.room_alias, &body.room_id) + .is_err() + { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Invalid room alias. Alias must be in the form of '#localpart:server_name'", @@ -50,11 +64,21 @@ pub async fn delete_alias_route(body: Ruma) -> Result return Err(Error::BadRequest(ErrorKind::InvalidParam, "Alias is from another server.")); } - if services().rooms.alias.resolve_local_alias(&body.room_alias)?.is_none() { + if services() + .rooms + .alias + .resolve_local_alias(&body.room_alias)? + .is_none() + { return Err(Error::BadRequest(ErrorKind::NotFound, "Alias does not exist.")); } - if services().rooms.alias.remove_alias(&body.room_alias).is_err() { + if services() + .rooms + .alias + .remove_alias(&body.room_alias) + .is_err() + { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Invalid room alias. Alias must be in the form of '#localpart:server_name'", @@ -90,13 +114,20 @@ pub(crate) async fn get_alias_helper(room_alias: OwnedRoomAliasId) -> Result Result = Vec::new(); // find active servers in room state cache to suggest - for extra_servers in services().rooms.state_cache.room_servers(&room_id).filter_map(Result::ok) { + for extra_servers in services() + .rooms + .state_cache + .room_servers(&room_id) + .filter_map(Result::ok) + { servers.push(extra_servers); } // insert our server as the very first choice if in list - if let Some(server_index) = - servers.clone().into_iter().position(|server| server == services().globals.server_name()) + if let Some(server_index) = servers + .clone() + .into_iter() + .position(|server| server == services().globals.server_name()) { servers.remove(server_index); servers.insert(0, services().globals.server_name().to_owned()); diff --git a/src/api/client_server/backup.rs b/src/api/client_server/backup.rs index 3e35da4f..517a2f60 100644 --- a/src/api/client_server/backup.rs +++ b/src/api/client_server/backup.rs @@ -17,7 +17,9 @@ pub async fn create_backup_version_route( body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let version = services().key_backups.create_backup(sender_user, &body.algorithm)?; + let version = services() + .key_backups + .create_backup(sender_user, &body.algorithm)?; Ok(create_backup_version::v3::Response { version, @@ -32,7 +34,9 @@ pub async fn update_backup_version_route( body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services().key_backups.update_backup(sender_user, &body.version, &body.algorithm)?; + services() + .key_backups + .update_backup(sender_user, &body.version, &body.algorithm)?; Ok(update_backup_version::v3::Response {}) } @@ -70,8 +74,13 @@ pub async fn get_backup_info_route(body: Ruma) -> Ok(get_backup_info::v3::Response { algorithm, - count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(), - etag: services().key_backups.get_etag(sender_user, &body.version)?, + count: (services() + .key_backups + .count_keys(sender_user, &body.version)? as u32) + .into(), + etag: services() + .key_backups + .get_etag(sender_user, &body.version)?, version: body.version.clone(), }) } @@ -87,7 +96,9 @@ pub async fn delete_backup_version_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services().key_backups.delete_backup(sender_user, &body.version)?; + services() + .key_backups + .delete_backup(sender_user, &body.version)?; Ok(delete_backup_version::v3::Response {}) } @@ -103,7 +114,12 @@ pub async fn delete_backup_version_route( pub async fn add_backup_keys_route(body: Ruma) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if Some(&body.version) != services().key_backups.get_latest_backup_version(sender_user)?.as_ref() { + if Some(&body.version) + != services() + .key_backups + .get_latest_backup_version(sender_user)? + .as_ref() + { return Err(Error::BadRequest( ErrorKind::InvalidParam, "You may only manipulate the most recently created version of the backup.", @@ -112,13 +128,20 @@ pub async fn add_backup_keys_route(body: Ruma) -> for (room_id, room) in &body.rooms { for (session_id, key_data) in &room.sessions { - services().key_backups.add_key(sender_user, &body.version, room_id, session_id, key_data)?; + services() + .key_backups + .add_key(sender_user, &body.version, room_id, session_id, key_data)?; } } Ok(add_backup_keys::v3::Response { - count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(), - etag: services().key_backups.get_etag(sender_user, &body.version)?, + count: (services() + .key_backups + .count_keys(sender_user, &body.version)? as u32) + .into(), + etag: services() + .key_backups + .get_etag(sender_user, &body.version)?, }) } @@ -135,7 +158,12 @@ pub async fn add_backup_keys_for_room_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if Some(&body.version) != services().key_backups.get_latest_backup_version(sender_user)?.as_ref() { + if Some(&body.version) + != services() + .key_backups + .get_latest_backup_version(sender_user)? + .as_ref() + { return Err(Error::BadRequest( ErrorKind::InvalidParam, "You may only manipulate the most recently created version of the backup.", @@ -143,12 +171,19 @@ pub async fn add_backup_keys_for_room_route( } for (session_id, key_data) in &body.sessions { - services().key_backups.add_key(sender_user, &body.version, &body.room_id, session_id, key_data)?; + services() + .key_backups + .add_key(sender_user, &body.version, &body.room_id, session_id, key_data)?; } Ok(add_backup_keys_for_room::v3::Response { - count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(), - etag: services().key_backups.get_etag(sender_user, &body.version)?, + count: (services() + .key_backups + .count_keys(sender_user, &body.version)? as u32) + .into(), + etag: services() + .key_backups + .get_etag(sender_user, &body.version)?, }) } @@ -165,18 +200,30 @@ pub async fn add_backup_keys_for_session_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if Some(&body.version) != services().key_backups.get_latest_backup_version(sender_user)?.as_ref() { + if Some(&body.version) + != services() + .key_backups + .get_latest_backup_version(sender_user)? + .as_ref() + { return Err(Error::BadRequest( ErrorKind::InvalidParam, "You may only manipulate the most recently created version of the backup.", )); } - services().key_backups.add_key(sender_user, &body.version, &body.room_id, &body.session_id, &body.session_data)?; + services() + .key_backups + .add_key(sender_user, &body.version, &body.room_id, &body.session_id, &body.session_data)?; Ok(add_backup_keys_for_session::v3::Response { - count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(), - etag: services().key_backups.get_etag(sender_user, &body.version)?, + count: (services() + .key_backups + .count_keys(sender_user, &body.version)? as u32) + .into(), + etag: services() + .key_backups + .get_etag(sender_user, &body.version)?, }) } @@ -201,7 +248,9 @@ pub async fn get_backup_keys_for_room_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sessions = services().key_backups.get_room(sender_user, &body.version, &body.room_id)?; + let sessions = services() + .key_backups + .get_room(sender_user, &body.version, &body.room_id)?; Ok(get_backup_keys_for_room::v3::Response { sessions, @@ -216,10 +265,13 @@ pub async fn get_backup_keys_for_session_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let key_data = - services().key_backups.get_session(sender_user, &body.version, &body.room_id, &body.session_id)?.ok_or( - Error::BadRequest(ErrorKind::NotFound, "Backup key not found for this user's session."), - )?; + let key_data = services() + .key_backups + .get_session(sender_user, &body.version, &body.room_id, &body.session_id)? + .ok_or(Error::BadRequest( + ErrorKind::NotFound, + "Backup key not found for this user's session.", + ))?; Ok(get_backup_keys_for_session::v3::Response { key_data, @@ -234,11 +286,18 @@ pub async fn delete_backup_keys_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services().key_backups.delete_all_keys(sender_user, &body.version)?; + services() + .key_backups + .delete_all_keys(sender_user, &body.version)?; Ok(delete_backup_keys::v3::Response { - count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(), - etag: services().key_backups.get_etag(sender_user, &body.version)?, + count: (services() + .key_backups + .count_keys(sender_user, &body.version)? as u32) + .into(), + etag: services() + .key_backups + .get_etag(sender_user, &body.version)?, }) } @@ -250,11 +309,18 @@ pub async fn delete_backup_keys_for_room_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services().key_backups.delete_room_keys(sender_user, &body.version, &body.room_id)?; + services() + .key_backups + .delete_room_keys(sender_user, &body.version, &body.room_id)?; Ok(delete_backup_keys_for_room::v3::Response { - count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(), - etag: services().key_backups.get_etag(sender_user, &body.version)?, + count: (services() + .key_backups + .count_keys(sender_user, &body.version)? as u32) + .into(), + etag: services() + .key_backups + .get_etag(sender_user, &body.version)?, }) } @@ -266,10 +332,17 @@ pub async fn delete_backup_keys_for_session_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services().key_backups.delete_room_key(sender_user, &body.version, &body.room_id, &body.session_id)?; + services() + .key_backups + .delete_room_key(sender_user, &body.version, &body.room_id, &body.session_id)?; Ok(delete_backup_keys_for_session::v3::Response { - count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(), - etag: services().key_backups.get_etag(sender_user, &body.version)?, + count: (services() + .key_backups + .count_keys(sender_user, &body.version)? as u32) + .into(), + etag: services() + .key_backups + .get_etag(sender_user, &body.version)?, }) } diff --git a/src/api/client_server/context.rs b/src/api/client_server/context.rs index 8e3429fd..f39f009b 100644 --- a/src/api/client_server/context.rs +++ b/src/api/client_server/context.rs @@ -42,7 +42,11 @@ pub async fn get_context_route(body: Ruma) -> Result) -> Result = events_before.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect(); + let events_before: Vec<_> = events_before + .into_iter() + .map(|(_, pdu)| pdu.to_room_event()) + .collect(); let events_after: Vec<_> = services() .rooms @@ -122,25 +131,41 @@ pub async fn get_context_route(body: Ruma) -> Result s, - None => services().rooms.state.get_room_shortstatehash(&room_id)?.expect("All rooms have state"), + None => services() + .rooms + .state + .get_room_shortstatehash(&room_id)? + .expect("All rooms have state"), }; - let state_ids = services().rooms.state_accessor.state_full_ids(shortstatehash).await?; + let state_ids = services() + .rooms + .state_accessor + .state_full_ids(shortstatehash) + .await?; - let end_token = events_after.last().map_or_else(|| base_token.stringify(), |(count, _)| count.stringify()); + let end_token = events_after + .last() + .map_or_else(|| base_token.stringify(), |(count, _)| count.stringify()); - let events_after: Vec<_> = events_after.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect(); + let events_after: Vec<_> = events_after + .into_iter() + .map(|(_, pdu)| pdu.to_room_event()) + .collect(); let mut state = Vec::new(); for (shortstatekey, id) in state_ids { - let (event_type, state_key) = services().rooms.short.get_statekey_from_short(shortstatekey)?; + let (event_type, state_key) = services() + .rooms + .short + .get_statekey_from_short(shortstatekey)?; if event_type != StateEventType::RoomMember { let Some(pdu) = services().rooms.timeline.get_pdu(&id)? else { diff --git a/src/api/client_server/device.rs b/src/api/client_server/device.rs index 8291b803..10a38a73 100644 --- a/src/api/client_server/device.rs +++ b/src/api/client_server/device.rs @@ -53,7 +53,9 @@ pub async fn update_device_route(body: Ruma) -> Resu device.display_name.clone_from(&body.display_name); - services().users.update_device_metadata(sender_user, &body.device_id, &device)?; + services() + .users + .update_device_metadata(sender_user, &body.device_id, &device)?; Ok(update_device::v3::Response {}) } @@ -84,20 +86,26 @@ pub async fn delete_device_route(body: Ruma) -> Resu }; if let Some(auth) = &body.auth { - let (worked, uiaainfo) = services().uiaa.try_auth(sender_user, sender_device, auth, &uiaainfo)?; + let (worked, uiaainfo) = services() + .uiaa + .try_auth(sender_user, sender_device, auth, &uiaainfo)?; if !worked { return Err(Error::Uiaa(uiaainfo)); } // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?; + services() + .uiaa + .create(sender_user, sender_device, &uiaainfo, &json)?; return Err(Error::Uiaa(uiaainfo)); } else { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); } - services().users.remove_device(sender_user, &body.device_id)?; + services() + .users + .remove_device(sender_user, &body.device_id)?; Ok(delete_device::v3::Response {}) } @@ -130,14 +138,18 @@ pub async fn delete_devices_route(body: Ruma) -> Re }; if let Some(auth) = &body.auth { - let (worked, uiaainfo) = services().uiaa.try_auth(sender_user, sender_device, auth, &uiaainfo)?; + let (worked, uiaainfo) = services() + .uiaa + .try_auth(sender_user, sender_device, auth, &uiaainfo)?; if !worked { return Err(Error::Uiaa(uiaainfo)); } // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?; + services() + .uiaa + .create(sender_user, sender_device, &uiaainfo, &json)?; return Err(Error::Uiaa(uiaainfo)); } else { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); diff --git a/src/api/client_server/directory.rs b/src/api/client_server/directory.rs index c2921d12..010693cc 100644 --- a/src/api/client_server/directory.rs +++ b/src/api/client_server/directory.rs @@ -34,7 +34,11 @@ use crate::{services, Error, Result, Ruma}; pub async fn get_public_rooms_filtered_route( body: Ruma, ) -> Result { - if !services().globals.config.allow_public_room_directory_without_auth { + if !services() + .globals + .config + .allow_public_room_directory_without_auth + { let _sender_user = body.sender_user.as_ref().expect("user is authenticated"); } @@ -56,7 +60,11 @@ pub async fn get_public_rooms_filtered_route( pub async fn get_public_rooms_route( body: Ruma, ) -> Result { - if !services().globals.config.allow_public_room_directory_without_auth { + if !services() + .globals + .config + .allow_public_room_directory_without_auth + { let _sender_user = body.sender_user.as_ref().expect("user is authenticated"); } @@ -325,7 +333,11 @@ pub(crate) async fn get_public_rooms_filtered_helper( let total_room_count_estimate = (all_rooms.len() as u32).into(); - let chunk: Vec<_> = all_rooms.into_iter().skip(num_since as usize).take(limit as usize).collect(); + let chunk: Vec<_> = all_rooms + .into_iter() + .skip(num_since as usize) + .take(limit as usize) + .collect(); let prev_batch = if num_since == 0 { None diff --git a/src/api/client_server/keys.rs b/src/api/client_server/keys.rs index 97fe13f5..203b675a 100644 --- a/src/api/client_server/keys.rs +++ b/src/api/client_server/keys.rs @@ -34,19 +34,29 @@ pub async fn upload_keys_route(body: Ruma) -> Result) -> .users .keys_changed( sender_user.as_str(), - body.from.parse().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?, - Some(body.to.parse().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?), + body.from + .parse() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?, + Some( + body.to + .parse() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?, + ), ) .filter_map(Result::ok), ); - for room_id in services().rooms.state_cache.rooms_joined(sender_user).filter_map(Result::ok) { + for room_id in services() + .rooms + .state_cache + .rooms_joined(sender_user) + .filter_map(Result::ok) + { device_list_updates.extend( services() .users .keys_changed( room_id.as_ref(), - body.from.parse().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?, - Some(body.to.parse().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?), + body.from + .parse() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?, + Some( + body.to + .parse() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?, + ), ) .filter_map(Result::ok), ); @@ -226,7 +259,10 @@ pub(crate) async fn get_keys_helper bool>( let user_id: &UserId = user_id; if user_id.server_name() != services().globals.server_name() { - get_over_federation.entry(user_id.server_name()).or_insert_with(Vec::new).push((user_id, device_ids)); + get_over_federation + .entry(user_id.server_name()) + .or_insert_with(Vec::new) + .push((user_id, device_ids)); continue; } @@ -251,9 +287,13 @@ pub(crate) async fn get_keys_helper bool>( for device_id in device_ids { let mut container = BTreeMap::new(); if let Some(mut keys) = services().users.get_device_keys(user_id, device_id)? { - let metadata = services().users.get_device_metadata(user_id, device_id)?.ok_or( - Error::BadRequest(ErrorKind::InvalidParam, "Tried to get keys for nonexistent device."), - )?; + let metadata = services() + .users + .get_device_metadata(user_id, device_id)? + .ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "Tried to get keys for nonexistent device.", + ))?; add_unsigned_device_display_name(&mut keys, metadata, include_display_names) .map_err(|_| Error::bad_database("invalid device keys in database"))?; @@ -263,11 +303,16 @@ pub(crate) async fn get_keys_helper bool>( } } - if let Some(master_key) = services().users.get_master_key(sender_user, user_id, &allowed_signatures)? { + if let Some(master_key) = services() + .users + .get_master_key(sender_user, user_id, &allowed_signatures)? + { master_keys.insert(user_id.to_owned(), master_key); } if let Some(self_signing_key) = - services().users.get_self_signing_key(sender_user, user_id, &allowed_signatures)? + services() + .users + .get_self_signing_key(sender_user, user_id, &allowed_signatures)? { self_signing_keys.insert(user_id.to_owned(), self_signing_key); } @@ -281,7 +326,13 @@ pub(crate) async fn get_keys_helper bool>( let mut failures = BTreeMap::new(); let back_off = |id| async { - match services().globals.bad_query_ratelimiter.write().await.entry(id) { + match services() + .globals + .bad_query_ratelimiter + .write() + .await + .entry(id) + { hash_map::Entry::Vacant(e) => { e.insert((Instant::now(), 1)); }, @@ -292,7 +343,13 @@ pub(crate) async fn get_keys_helper bool>( let mut futures: FuturesUnordered<_> = get_over_federation .into_iter() .map(|(server, vec)| async move { - if let Some((time, tries)) = services().globals.bad_query_ratelimiter.read().await.get(server) { + if let Some((time, tries)) = services() + .globals + .bad_query_ratelimiter + .read() + .await + .get(server) + { // Exponential backoff let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { @@ -336,7 +393,9 @@ pub(crate) async fn get_keys_helper bool>( let (master_key_id, mut master_key) = services().users.parse_master_key(&user, &masterkey)?; if let Some(our_master_key) = - services().users.get_key(&master_key_id, sender_user, &user, &allowed_signatures)? + services() + .users + .get_key(&master_key_id, sender_user, &user, &allowed_signatures)? { let (_, our_master_key) = services().users.parse_master_key(&user, &our_master_key)?; master_key.signatures.extend(our_master_key.signatures); @@ -404,12 +463,18 @@ pub(crate) async fn claim_keys_helper( for (user_id, map) in one_time_keys_input { if user_id.server_name() != services().globals.server_name() { - get_over_federation.entry(user_id.server_name()).or_insert_with(Vec::new).push((user_id, map)); + get_over_federation + .entry(user_id.server_name()) + .or_insert_with(Vec::new) + .push((user_id, map)); } let mut container = BTreeMap::new(); for (device_id, key_algorithm) in map { - if let Some(one_time_keys) = services().users.take_one_time_key(user_id, device_id, key_algorithm)? { + if let Some(one_time_keys) = services() + .users + .take_one_time_key(user_id, device_id, key_algorithm)? + { let mut c = BTreeMap::new(); c.insert(one_time_keys.0, one_time_keys.1); container.insert(device_id.clone(), c); diff --git a/src/api/client_server/media.rs b/src/api/client_server/media.rs index f6bbc69b..3b7ec3c4 100644 --- a/src/api/client_server/media.rs +++ b/src/api/client_server/media.rs @@ -152,7 +152,10 @@ pub async fn create_content_route(body: Ruma) -> Re .create( Some(sender_user.clone()), mxc.clone(), - body.filename.as_ref().map(|filename| "inline; filename=".to_owned() + filename).as_deref(), + body.filename + .as_ref() + .map(|filename| "inline; filename=".to_owned() + filename) + .as_deref(), body.content_type.as_deref(), &body.file, ) @@ -192,7 +195,10 @@ pub async fn create_content_v1_route( .create( Some(sender_user.clone()), mxc.clone(), - body.filename.as_ref().map(|filename| "inline; filename=".to_owned() + filename).as_deref(), + body.filename + .as_ref() + .map(|filename| "inline; filename=".to_owned() + filename) + .as_deref(), body.content_type.as_deref(), &body.file, ) @@ -213,7 +219,11 @@ pub async fn get_remote_content( ) -> Result { // we'll lie to the client and say the blocked server's media was not found and // log. the client has no way of telling anyways so this is a security bonus. - if services().globals.prevent_media_downloads_from().contains(&server_name.to_owned()) { + if services() + .globals + .prevent_media_downloads_from() + .contains(&server_name.to_owned()) + { info!( "Received request for remote media `{}` but server is in our media server blocklist. Returning 404.", mxc @@ -451,8 +461,12 @@ pub async fn get_content_thumbnail_route( .media .get_thumbnail( mxc.clone(), - body.width.try_into().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?, - body.height.try_into().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Height is invalid."))?, + body.width + .try_into() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?, + body.height + .try_into() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Height is invalid."))?, ) .await? { @@ -464,7 +478,11 @@ pub async fn get_content_thumbnail_route( } else if &*body.server_name != services().globals.server_name() && body.allow_remote { // we'll lie to the client and say the blocked server's media was not found and // log. the client has no way of telling anyways so this is a security bonus. - if services().globals.prevent_media_downloads_from().contains(&body.server_name.clone()) { + if services() + .globals + .prevent_media_downloads_from() + .contains(&body.server_name.clone()) + { info!( "Received request for remote media `{}` but server is in our media server blocklist. Returning 404.", mxc @@ -533,8 +551,12 @@ pub async fn get_content_thumbnail_v1_route( .media .get_thumbnail( mxc.clone(), - body.width.try_into().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?, - body.height.try_into().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Height is invalid."))?, + body.width + .try_into() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?, + body.height + .try_into() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Height is invalid."))?, ) .await? { @@ -547,7 +569,11 @@ pub async fn get_content_thumbnail_v1_route( } else if &*body.server_name != services().globals.server_name() && body.allow_remote { // we'll lie to the client and say the blocked server's media was not found and // log. the client has no way of telling anyways so this is a security bonus. - if services().globals.prevent_media_downloads_from().contains(&body.server_name.clone()) { + if services() + .globals + .prevent_media_downloads_from() + .contains(&body.server_name.clone()) + { info!( "Received request for remote media `{}` but server is in our media server blocklist. Returning 404.", mxc @@ -599,7 +625,10 @@ async fn download_image(client: &reqwest::Client, url: &str) -> Result (None, None), @@ -749,14 +778,21 @@ async fn request_url_preview(url: &str) -> Result { } } - if !response.remote_addr().map_or(false, |a| url_request_allowed(&a.ip())) { + if !response + .remote_addr() + .map_or(false, |a| url_request_allowed(&a.ip())) + { return Err(Error::BadRequest( ErrorKind::Forbidden, "Requesting from this address is forbidden", )); } - let content_type = match response.headers().get(reqwest::header::CONTENT_TYPE).and_then(|x| x.to_str().ok()) { + let content_type = match response + .headers() + .get(reqwest::header::CONTENT_TYPE) + .and_then(|x| x.to_str().ok()) + { Some(ct) => ct, None => return Err(Error::BadRequest(ErrorKind::Unknown, "Unknown Content-Type")), }; @@ -777,7 +813,15 @@ async fn get_url_preview(url: &str) -> Result { } // ensure that only one request is made per URL - let mutex_request = Arc::clone(services().media.url_preview_mutex.write().await.entry(url.to_owned()).or_default()); + let mutex_request = Arc::clone( + services() + .media + .url_preview_mutex + .write() + .await + .entry(url.to_owned()) + .or_default(), + ); let _request_lock = mutex_request.lock().await; match services().media.get_url_preview(url).await { @@ -795,7 +839,10 @@ fn url_preview_allowed(url_str: &str) -> bool { }, }; - if ["http", "https"].iter().all(|&scheme| scheme != url.scheme().to_lowercase()) { + if ["http", "https"] + .iter() + .all(|&scheme| scheme != url.scheme().to_lowercase()) + { debug!("Ignoring non-HTTP/HTTPS URL to preview: {}", url); return false; } @@ -826,12 +873,18 @@ fn url_preview_allowed(url_str: &str) -> bool { return true; } - if allowlist_domain_contains.iter().any(|domain_s| domain_s.contains(&host.clone())) { + if allowlist_domain_contains + .iter() + .any(|domain_s| domain_s.contains(&host.clone())) + { debug!("Host {} is allowed by url_preview_domain_contains_allowlist (check 2/3)", &host); return true; } - if allowlist_url_contains.iter().any(|url_s| url.to_string().contains(&url_s.to_string())) { + if allowlist_url_contains + .iter() + .any(|url_s| url.to_string().contains(&url_s.to_string())) + { debug!("URL {} is allowed by url_preview_url_contains_allowlist (check 3/3)", &host); return true; } @@ -850,7 +903,10 @@ fn url_preview_allowed(url_str: &str) -> bool { return true; } - if allowlist_domain_contains.iter().any(|domain_s| domain_s.contains(&root_domain.to_owned())) { + if allowlist_domain_contains + .iter() + .any(|domain_s| domain_s.contains(&root_domain.to_owned())) + { debug!( "Root domain {} is allowed by url_preview_domain_contains_allowlist (check 2/3)", &root_domain diff --git a/src/api/client_server/membership.rs b/src/api/client_server/membership.rs index 6bc57759..35060abe 100644 --- a/src/api/client_server/membership.rs +++ b/src/api/client_server/membership.rs @@ -242,8 +242,15 @@ pub async fn kick_user_route(body: Ruma) -> Result) -> Result) -> Result) -> Result) -> Result) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services().rooms.state_cache.forget(&body.room_id, sender_user)?; + services() + .rooms + .state_cache + .forget(&body.room_id, sender_user)?; Ok(forget_room::v3::Response::new()) } @@ -399,7 +429,12 @@ pub async fn joined_rooms_route(body: Ruma) -> Result let sender_user = body.sender_user.as_ref().expect("user is authenticated"); Ok(joined_rooms::v3::Response { - joined_rooms: services().rooms.state_cache.rooms_joined(sender_user).filter_map(Result::ok).collect(), + joined_rooms: services() + .rooms + .state_cache + .rooms_joined(sender_user) + .filter_map(Result::ok) + .collect(), }) } @@ -414,7 +449,11 @@ pub async fn get_member_events_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services().rooms.state_accessor.user_can_see_state_events(sender_user, &body.room_id)? { + if !services() + .rooms + .state_accessor + .user_can_see_state_events(sender_user, &body.room_id)? + { return Err(Error::BadRequest( ErrorKind::Forbidden, "You don't have permission to view this room.", @@ -443,7 +482,11 @@ pub async fn get_member_events_route( pub async fn joined_members_route(body: Ruma) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services().rooms.state_accessor.user_can_see_state_events(sender_user, &body.room_id)? { + if !services() + .rooms + .state_accessor + .user_can_see_state_events(sender_user, &body.room_id)? + { return Err(Error::BadRequest( ErrorKind::Forbidden, "You don't have permission to view this room.", @@ -451,7 +494,12 @@ pub async fn joined_members_route(body: Ruma) -> Re } let mut joined = BTreeMap::new(); - for user_id in services().rooms.state_cache.room_members(&body.room_id).filter_map(Result::ok) { + for user_id in services() + .rooms + .state_cache + .room_members(&body.room_id) + .filter_map(Result::ok) + { let display_name = services().users.displayname(&user_id)?; let avatar_url = services().users.avatar_url(&user_id)?; @@ -475,12 +523,23 @@ pub(crate) async fn join_room_by_id_helper( ) -> Result { let sender_user = sender_user.expect("user is authenticated"); - let mutex_state = - Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.to_owned()).or_default()); + let mutex_state = Arc::clone( + services() + .globals + .roomid_mutex_state + .write() + .await + .entry(room_id.to_owned()) + .or_default(), + ); let state_lock = mutex_state.lock().await; // Ask a remote server if we are not participating in this room - if !services().rooms.state_cache.server_in_room(services().globals.server_name(), room_id)? { + if !services() + .rooms + .state_cache + .server_in_room(services().globals.server_name(), room_id)? + { info!("Joining {room_id} over federation."); let (make_join_response, remote_server) = make_join_request(sender_user, room_id, servers).await?; @@ -488,7 +547,14 @@ pub(crate) async fn join_room_by_id_helper( info!("make_join finished"); let room_version_id = match make_join_response.room_version { - Some(room_version) if services().globals.supported_room_versions().contains(&room_version) => room_version, + Some(room_version) + if services() + .globals + .supported_room_versions() + .contains(&room_version) => + { + room_version + }, _ => return Err(Error::BadServerResponse("Room version is not supported")), }; @@ -497,7 +563,11 @@ pub(crate) async fn join_room_by_id_helper( let join_authorized_via_users_server = join_event_stub .get("content") - .map(|s| s.as_object()?.get("join_authorised_via_users_server")?.as_str()) + .map(|s| { + s.as_object()? + .get("join_authorised_via_users_server")? + .as_str() + }) .and_then(|s| OwnedUserId::try_from(s.unwrap_or_default()).ok()); // TODO: Is origin needed? @@ -508,7 +578,9 @@ pub(crate) async fn join_room_by_id_helper( join_event_stub.insert( "origin_server_ts".to_owned(), CanonicalJsonValue::Integer( - utils::millis_since_unix_epoch().try_into().expect("Timestamp is valid js_int value"), + utils::millis_since_unix_epoch() + .try_into() + .expect("Timestamp is valid js_int value"), ), ); join_event_stub.insert( @@ -691,10 +763,15 @@ pub(crate) async fn join_room_by_id_helper( Error::BadServerResponse("Invalid PDU in send_join response.") })?; - services().rooms.outlier.add_pdu_outlier(&event_id, &value)?; + services() + .rooms + .outlier + .add_pdu_outlier(&event_id, &value)?; if let Some(state_key) = &pdu.state_key { - let shortstatekey = - services().rooms.short.get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key)?; + let shortstatekey = services() + .rooms + .short + .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key)?; state.insert(shortstatekey, pdu.event_id.clone()); } } @@ -711,7 +788,10 @@ pub(crate) async fn join_room_by_id_helper( Err(_) => continue, }; - services().rooms.outlier.add_pdu_outlier(&event_id, &value)?; + services() + .rooms + .outlier + .add_pdu_outlier(&event_id, &value)?; } info!("Running send_join auth check"); @@ -725,8 +805,13 @@ pub(crate) async fn join_room_by_id_helper( .rooms .timeline .get_pdu( - state - .get(&services().rooms.short.get_or_create_shortstatekey(&k.to_string().into(), s).ok()?)?, + state.get( + &services() + .rooms + .short + .get_or_create_shortstatekey(&k.to_string().into(), s) + .ok()?, + )?, ) .ok()? }, @@ -746,12 +831,21 @@ pub(crate) async fn join_room_by_id_helper( Arc::new( state .into_iter() - .map(|(k, id)| services().rooms.state_compressor.compress_state_event(k, &id)) + .map(|(k, id)| { + services() + .rooms + .state_compressor + .compress_state_event(k, &id) + }) .collect::>()?, ), )?; - services().rooms.state.force_state(room_id, statehash_before_join, new, removed, &state_lock).await?; + services() + .rooms + .state + .force_state(room_id, statehash_before_join, new, removed, &state_lock) + .await?; info!("Updating joined counts for new room"); services().rooms.state_cache.update_joined_count(room_id)?; @@ -776,14 +870,23 @@ pub(crate) async fn join_room_by_id_helper( info!("Setting final room state for new room"); // We set the room state after inserting the pdu, so that we never have a moment // in time where events in the current room state do not exist - services().rooms.state.set_room_state(room_id, statehash_after_join, &state_lock)?; + services() + .rooms + .state + .set_room_state(room_id, statehash_after_join, &state_lock)?; } else { info!("We can join locally"); let join_rules_event = - services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomJoinRules, "")?; + services() + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomJoinRules, "")?; let power_levels_event = - services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomPowerLevels, "")?; + services() + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomPowerLevels, "")?; let join_rules_event_content: Option = join_rules_event .as_ref() @@ -821,7 +924,12 @@ pub(crate) async fn join_room_by_id_helper( let authorized_user = restriction_rooms .iter() .find_map(|restriction_room_id| { - if !services().rooms.state_cache.is_joined(sender_user, restriction_room_id).ok()? { + if !services() + .rooms + .state_cache + .is_joined(sender_user, restriction_room_id) + .ok()? + { return None; } let authorized_user = power_levels_event_content @@ -888,7 +996,10 @@ pub(crate) async fn join_room_by_id_helper( }; if !restriction_rooms.is_empty() - && servers.iter().filter(|s| *s != services().globals.server_name()).count() > 0 + && servers + .iter() + .filter(|s| *s != services().globals.server_name()) + .count() > 0 { info!( "We couldn't do the join locally, maybe federation can help to satisfy the restricted join \ @@ -897,7 +1008,12 @@ pub(crate) async fn join_room_by_id_helper( let (make_join_response, remote_server) = make_join_request(sender_user, room_id, servers).await?; let room_version_id = match make_join_response.room_version { - Some(room_version_id) if services().globals.supported_room_versions().contains(&room_version_id) => { + Some(room_version_id) + if services() + .globals + .supported_room_versions() + .contains(&room_version_id) => + { room_version_id }, _ => return Err(Error::BadServerResponse("Room version is not supported")), @@ -906,7 +1022,11 @@ pub(crate) async fn join_room_by_id_helper( .map_err(|_| Error::BadServerResponse("Invalid make_join event json received from server."))?; let join_authorized_via_users_server = join_event_stub .get("content") - .map(|s| s.as_object()?.get("join_authorised_via_users_server")?.as_str()) + .map(|s| { + s.as_object()? + .get("join_authorised_via_users_server")? + .as_str() + }) .and_then(|s| OwnedUserId::try_from(s.unwrap_or_default()).ok()); // TODO: Is origin needed? join_event_stub.insert( @@ -916,7 +1036,9 @@ pub(crate) async fn join_room_by_id_helper( join_event_stub.insert( "origin_server_ts".to_owned(), CanonicalJsonValue::Integer( - utils::millis_since_unix_epoch().try_into().expect("Timestamp is valid js_int value"), + utils::millis_since_unix_epoch() + .try_into() + .expect("Timestamp is valid js_int value"), ), ); join_event_stub.insert( @@ -1002,7 +1124,11 @@ pub(crate) async fn join_room_by_id_helper( drop(state_lock); let pub_key_map = RwLock::new(BTreeMap::new()); - services().rooms.event_handler.fetch_required_signing_keys([&signed_value], &pub_key_map).await?; + services() + .rooms + .event_handler + .fetch_required_signing_keys([&signed_value], &pub_key_map) + .await?; services() .rooms .event_handler @@ -1065,7 +1191,13 @@ async fn validate_and_add_event_id( .expect("ruma's reference hashes are valid event ids"); let back_off = |id| async { - match services().globals.bad_event_ratelimiter.write().await.entry(id) { + match services() + .globals + .bad_event_ratelimiter + .write() + .await + .entry(id) + { Entry::Vacant(e) => { e.insert((Instant::now(), 1)); }, @@ -1073,7 +1205,13 @@ async fn validate_and_add_event_id( } }; - if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().await.get(&event_id) { + if let Some((time, tries)) = services() + .globals + .bad_event_ratelimiter + .read() + .await + .get(&event_id) + { // Exponential backoff let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { @@ -1110,8 +1248,15 @@ pub(crate) async fn invite_helper( if user_id.server_name() != services().globals.server_name() { let (pdu, pdu_json, invite_room_state) = { - let mutex_state = - Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.to_owned()).or_default()); + let mutex_state = Arc::clone( + services() + .globals + .roomid_mutex_state + .write() + .await + .entry(room_id.to_owned()) + .or_default(), + ); let state_lock = mutex_state.lock().await; let content = to_raw_value(&RoomMemberEventContent { @@ -1196,7 +1341,11 @@ pub(crate) async fn invite_helper( ) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid."))?; - services().rooms.event_handler.fetch_required_signing_keys([&value], &pub_key_map).await?; + services() + .rooms + .event_handler + .fetch_required_signing_keys([&value], &pub_key_map) + .await?; let pdu_id: Vec = services() .rooms @@ -1221,15 +1370,26 @@ pub(crate) async fn invite_helper( return Ok(()); } - if !services().rooms.state_cache.is_joined(sender_user, room_id)? { + if !services() + .rooms + .state_cache + .is_joined(sender_user, room_id)? + { return Err(Error::BadRequest( ErrorKind::Forbidden, "You don't have permission to view this room.", )); } - let mutex_state = - Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.to_owned()).or_default()); + let mutex_state = Arc::clone( + services() + .globals + .roomid_mutex_state + .write() + .await + .entry(room_id.to_owned()) + .or_default(), + ); let state_lock = mutex_state.lock().await; services() @@ -1270,7 +1430,13 @@ pub async fn leave_all_rooms(user_id: &UserId) -> Result<()> { .rooms .state_cache .rooms_joined(user_id) - .chain(services().rooms.state_cache.rooms_invited(user_id).map(|t| t.map(|(r, _)| r))) + .chain( + services() + .rooms + .state_cache + .rooms_invited(user_id) + .map(|t| t.map(|(r, _)| r)), + ) .collect::>(); for room_id in all_rooms { @@ -1313,12 +1479,22 @@ pub async fn leave_room(user_id: &UserId, room_id: &RoomId, reason: Option Result<()> { let (make_leave_response, remote_server) = make_leave_response_and_server?; let room_version_id = match make_leave_response.room_version { - Some(version) if services().globals.supported_room_versions().contains(&version) => version, + Some(version) + if services() + .globals + .supported_room_versions() + .contains(&version) => + { + version + }, _ => return Err(Error::BadServerResponse("Room version is not supported")), }; @@ -1426,7 +1609,9 @@ async fn remote_leave_room(user_id: &UserId, room_id: &RoomId) -> Result<()> { leave_event_stub.insert( "origin_server_ts".to_owned(), CanonicalJsonValue::Integer( - utils::millis_since_unix_epoch().try_into().expect("Timestamp is valid js_int value"), + utils::millis_since_unix_epoch() + .try_into() + .expect("Timestamp is valid js_int value"), ), ); diff --git a/src/api/client_server/message.rs b/src/api/client_server/message.rs index d724e46a..b5d7bb3e 100644 --- a/src/api/client_server/message.rs +++ b/src/api/client_server/message.rs @@ -32,8 +32,15 @@ pub async fn send_message_event_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_deref(); - let mutex_state = - Arc::clone(services().globals.roomid_mutex_state.write().await.entry(body.room_id.clone()).or_default()); + let mutex_state = Arc::clone( + services() + .globals + .roomid_mutex_state + .write() + .await + .entry(body.room_id.clone()) + .or_default(), + ); let state_lock = mutex_state.lock().await; // Forbid m.room.encrypted if encryption is disabled @@ -90,7 +97,10 @@ pub async fn send_message_event_route( }; // Check if this is a new transaction id - if let Some(response) = services().transaction_ids.existing_txnid(sender_user, sender_device, &body.txn_id)? { + if let Some(response) = services() + .transaction_ids + .existing_txnid(sender_user, sender_device, &body.txn_id)? + { // The client might have sent a txnid of the /sendToDevice endpoint // This txnid has no response associated with it if response.is_empty() { @@ -130,7 +140,9 @@ pub async fn send_message_event_route( ) .await?; - services().transaction_ids.add_txnid(sender_user, sender_device, &body.txn_id, event_id.as_bytes())?; + services() + .transaction_ids + .add_txnid(sender_user, sender_device, &body.txn_id, event_id.as_bytes())?; drop(state_lock); @@ -158,9 +170,16 @@ pub async fn get_message_events_route( }, }; - let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok()); + let to = body + .to + .as_ref() + .and_then(|t| PduCount::try_from_string(t).ok()); - services().rooms.lazy_loading.lazy_load_confirm_delivery(sender_user, sender_device, &body.room_id, from).await?; + services() + .rooms + .lazy_loading + .lazy_load_confirm_delivery(sender_user, sender_device, &body.room_id, from) + .await?; let limit = u64::from(body.limit).min(100) as usize; @@ -208,14 +227,21 @@ pub async fn get_message_events_route( next_token = events_after.last().map(|(count, _)| count).copied(); - let events_after: Vec<_> = events_after.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect(); + let events_after: Vec<_> = events_after + .into_iter() + .map(|(_, pdu)| pdu.to_room_event()) + .collect(); resp.start = from.stringify(); resp.end = next_token.map(|count| count.stringify()); resp.chunk = events_after; }, ruma::api::Direction::Backward => { - services().rooms.timeline.backfill_if_required(&body.room_id, from).await?; + services() + .rooms + .timeline + .backfill_if_required(&body.room_id, from) + .await?; let events_before: Vec<_> = services() .rooms .timeline @@ -252,7 +278,10 @@ pub async fn get_message_events_route( next_token = events_before.last().map(|(count, _)| count).copied(); - let events_before: Vec<_> = events_before.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect(); + let events_before: Vec<_> = events_before + .into_iter() + .map(|(_, pdu)| pdu.to_room_event()) + .collect(); resp.start = from.stringify(); resp.end = next_token.map(|count| count.stringify()); diff --git a/src/api/client_server/presence.rs b/src/api/client_server/presence.rs index 201b8ad8..bde8d2bd 100644 --- a/src/api/client_server/presence.rs +++ b/src/api/client_server/presence.rs @@ -46,10 +46,19 @@ pub async fn get_presence_route(body: Ruma) -> Result let mut presence_event = None; - for room_id in services().rooms.user.get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])? { + for room_id in services() + .rooms + .user + .get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])? + { let room_id = room_id?; - if let Some(presence) = services().rooms.edus.presence.get_presence(&room_id, sender_user)? { + if let Some(presence) = services() + .rooms + .edus + .presence + .get_presence(&room_id, sender_user)? + { presence_event = Some(presence); break; } @@ -60,7 +69,10 @@ pub async fn get_presence_route(body: Ruma) -> Result // TODO: Should ruma just use the presenceeventcontent type here? status_msg: presence.content.status_msg, currently_active: presence.content.currently_active, - last_active_ago: presence.content.last_active_ago.map(|millis| Duration::from_millis(millis.into())), + last_active_ago: presence + .content + .last_active_ago + .map(|millis| Duration::from_millis(millis.into())), presence: presence.content.presence, }) } else { diff --git a/src/api/client_server/profile.rs b/src/api/client_server/profile.rs index 8b7f5a37..4a51c520 100644 --- a/src/api/client_server/profile.rs +++ b/src/api/client_server/profile.rs @@ -25,7 +25,10 @@ pub async fn set_displayname_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services().users.set_displayname(sender_user, body.displayname.clone()).await?; + services() + .users + .set_displayname(sender_user, body.displayname.clone()) + .await?; // Send a new membership event and presence update into all joined rooms let all_rooms_joined: Vec<_> = services() @@ -64,16 +67,31 @@ pub async fn set_displayname_route( .collect(); for (pdu_builder, room_id) in all_rooms_joined { - let mutex_state = - Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.clone()).or_default()); + let mutex_state = Arc::clone( + services() + .globals + .roomid_mutex_state + .write() + .await + .entry(room_id.clone()) + .or_default(), + ); let state_lock = mutex_state.lock().await; - _ = services().rooms.timeline.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock).await; + _ = services() + .rooms + .timeline + .build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock) + .await; } if services().globals.allow_local_presence() { // Presence update - services().rooms.edus.presence.ping_presence(sender_user, PresenceState::Online)?; + services() + .rooms + .edus + .presence + .ping_presence(sender_user, PresenceState::Online)?; } Ok(set_display_name::v3::Response {}) @@ -105,9 +123,18 @@ pub async fn get_displayname_route( services().users.create(&body.user_id, None)?; } - services().users.set_displayname(&body.user_id, response.displayname.clone()).await?; - services().users.set_avatar_url(&body.user_id, response.avatar_url.clone()).await?; - services().users.set_blurhash(&body.user_id, response.blurhash.clone()).await?; + services() + .users + .set_displayname(&body.user_id, response.displayname.clone()) + .await?; + services() + .users + .set_avatar_url(&body.user_id, response.avatar_url.clone()) + .await?; + services() + .users + .set_blurhash(&body.user_id, response.blurhash.clone()) + .await?; return Ok(get_display_name::v3::Response { displayname: response.displayname, @@ -134,9 +161,15 @@ pub async fn get_displayname_route( pub async fn set_avatar_url_route(body: Ruma) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services().users.set_avatar_url(sender_user, body.avatar_url.clone()).await?; + services() + .users + .set_avatar_url(sender_user, body.avatar_url.clone()) + .await?; - services().users.set_blurhash(sender_user, body.blurhash.clone()).await?; + services() + .users + .set_blurhash(sender_user, body.blurhash.clone()) + .await?; // Send a new membership event and presence update into all joined rooms let all_joined_rooms: Vec<_> = services() @@ -175,16 +208,31 @@ pub async fn set_avatar_url_route(body: Ruma) -> Re .collect(); for (pdu_builder, room_id) in all_joined_rooms { - let mutex_state = - Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.clone()).or_default()); + let mutex_state = Arc::clone( + services() + .globals + .roomid_mutex_state + .write() + .await + .entry(room_id.clone()) + .or_default(), + ); let state_lock = mutex_state.lock().await; - _ = services().rooms.timeline.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock).await; + _ = services() + .rooms + .timeline + .build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock) + .await; } if services().globals.allow_local_presence() { // Presence update - services().rooms.edus.presence.ping_presence(sender_user, PresenceState::Online)?; + services() + .rooms + .edus + .presence + .ping_presence(sender_user, PresenceState::Online)?; } Ok(set_avatar_url::v3::Response {}) @@ -214,9 +262,18 @@ pub async fn get_avatar_url_route(body: Ruma) -> Re services().users.create(&body.user_id, None)?; } - services().users.set_displayname(&body.user_id, response.displayname.clone()).await?; - services().users.set_avatar_url(&body.user_id, response.avatar_url.clone()).await?; - services().users.set_blurhash(&body.user_id, response.blurhash.clone()).await?; + services() + .users + .set_displayname(&body.user_id, response.displayname.clone()) + .await?; + services() + .users + .set_avatar_url(&body.user_id, response.avatar_url.clone()) + .await?; + services() + .users + .set_blurhash(&body.user_id, response.blurhash.clone()) + .await?; return Ok(get_avatar_url::v3::Response { avatar_url: response.avatar_url, @@ -261,9 +318,18 @@ pub async fn get_profile_route(body: Ruma) -> Result) -> Result .map_err(|_| Error::bad_database("Invalid account data event in db."))? .content; - let rule = account_data.global.get(body.kind.clone(), &body.rule_id).map(Into::into); + let rule = account_data + .global + .get(body.kind.clone(), &body.rule_id) + .map(Into::into); if let Some(rule) = rule { Ok(get_pushrule::v3::Response { @@ -83,7 +86,10 @@ pub async fn set_pushrule_route(body: Ruma) -> Result .map_err(|_| Error::bad_database("Invalid account data event in db."))?; if let Err(error) = - account_data.content.global.insert(body.rule.clone(), body.after.as_deref(), body.before.as_deref()) + account_data + .content + .global + .insert(body.rule.clone(), body.after.as_deref(), body.before.as_deref()) { let err = match error { InsertPushRuleError::ServerDefaultRuleId => Error::BadRequest( @@ -178,7 +184,12 @@ pub async fn set_pushrule_actions_route( let mut account_data = serde_json::from_str::(event.get()) .map_err(|_| Error::bad_database("Invalid account data event in db."))?; - if account_data.content.global.set_actions(body.kind.clone(), &body.rule_id, body.actions.clone()).is_err() { + if account_data + .content + .global + .set_actions(body.kind.clone(), &body.rule_id, body.actions.clone()) + .is_err() + { return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found.")); } @@ -249,7 +260,12 @@ pub async fn set_pushrule_enabled_route( let mut account_data = serde_json::from_str::(event.get()) .map_err(|_| Error::bad_database("Invalid account data event in db."))?; - if account_data.content.global.set_enabled(body.kind.clone(), &body.rule_id, body.enabled).is_err() { + if account_data + .content + .global + .set_enabled(body.kind.clone(), &body.rule_id, body.enabled) + .is_err() + { return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found.")); } @@ -284,7 +300,11 @@ pub async fn delete_pushrule_route(body: Ruma) -> let mut account_data = serde_json::from_str::(event.get()) .map_err(|_| Error::bad_database("Invalid account data event in db."))?; - if let Err(error) = account_data.content.global.remove(body.kind.clone(), &body.rule_id) { + if let Err(error) = account_data + .content + .global + .remove(body.kind.clone(), &body.rule_id) + { let err = match error { RemovePushRuleError::ServerDefault => { Error::BadRequest(ErrorKind::InvalidParam, "Cannot delete a server-default pushrule.") @@ -325,7 +345,9 @@ pub async fn get_pushers_route(body: Ruma) -> Result) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services().pusher.set_pusher(sender_user, body.action.clone())?; + services() + .pusher + .set_pusher(sender_user, body.action.clone())?; Ok(set_pusher::v3::Response::default()) } diff --git a/src/api/client_server/read_marker.rs b/src/api/client_server/read_marker.rs index a6097c17..210a6c55 100644 --- a/src/api/client_server/read_marker.rs +++ b/src/api/client_server/read_marker.rs @@ -36,7 +36,10 @@ pub async fn set_read_marker_route(body: Ruma) -> } if body.private_read_receipt.is_some() || body.read_receipt.is_some() { - services().rooms.user.reset_notification_counts(sender_user, &body.room_id)?; + services() + .rooms + .user + .reset_notification_counts(sender_user, &body.room_id)?; } if let Some(event) = &body.private_read_receipt { @@ -54,7 +57,11 @@ pub async fn set_read_marker_route(body: Ruma) -> }, PduCount::Normal(c) => c, }; - services().rooms.edus.read_receipt.private_read_set(&body.room_id, sender_user, count)?; + services() + .rooms + .edus + .read_receipt + .private_read_set(&body.room_id, sender_user, count)?; } if let Some(event) = &body.read_receipt { @@ -98,7 +105,10 @@ pub async fn create_receipt_route(body: Ruma) -> Re &body.receipt_type, create_receipt::v3::ReceiptType::Read | create_receipt::v3::ReceiptType::ReadPrivate ) { - services().rooms.user.reset_notification_counts(sender_user, &body.room_id)?; + services() + .rooms + .user + .reset_notification_counts(sender_user, &body.room_id)?; } match body.receipt_type { @@ -156,7 +166,11 @@ pub async fn create_receipt_route(body: Ruma) -> Re }, PduCount::Normal(c) => c, }; - services().rooms.edus.read_receipt.private_read_set(&body.room_id, sender_user, count)?; + services() + .rooms + .edus + .read_receipt + .private_read_set(&body.room_id, sender_user, count)?; }, _ => return Err(Error::bad_database("Unsupported receipt type")), } diff --git a/src/api/client_server/redact.rs b/src/api/client_server/redact.rs index 44e42f08..8e71bd3b 100644 --- a/src/api/client_server/redact.rs +++ b/src/api/client_server/redact.rs @@ -17,8 +17,15 @@ pub async fn redact_event_route(body: Ruma) -> Result let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let body = body.body; - let mutex_state = - Arc::clone(services().globals.roomid_mutex_state.write().await.entry(body.room_id.clone()).or_default()); + let mutex_state = Arc::clone( + services() + .globals + .roomid_mutex_state + .write() + .await + .entry(body.room_id.clone()) + .or_default(), + ); let state_lock = mutex_state.lock().await; let event_id = services() diff --git a/src/api/client_server/relations.rs b/src/api/client_server/relations.rs index c65b544e..2b40833a 100644 --- a/src/api/client_server/relations.rs +++ b/src/api/client_server/relations.rs @@ -19,21 +19,31 @@ pub async fn get_relating_events_with_rel_type_and_event_type_route( }, }; - let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok()); + let to = body + .to + .as_ref() + .and_then(|t| PduCount::try_from_string(t).ok()); // Use limit or else 10, with maximum 100 - let limit = body.limit.and_then(|u| u32::try_from(u).ok()).map_or(10_usize, |u| u as usize).min(100); + let limit = body + .limit + .and_then(|u| u32::try_from(u).ok()) + .map_or(10_usize, |u| u as usize) + .min(100); - let res = services().rooms.pdu_metadata.paginate_relations_with_filter( - sender_user, - &body.room_id, - &body.event_id, - &Some(body.event_type.clone()), - &Some(body.rel_type.clone()), - from, - to, - limit, - )?; + let res = services() + .rooms + .pdu_metadata + .paginate_relations_with_filter( + sender_user, + &body.room_id, + &body.event_id, + &Some(body.event_type.clone()), + &Some(body.rel_type.clone()), + from, + to, + limit, + )?; Ok(get_relating_events_with_rel_type_and_event_type::v1::Response { chunk: res.chunk, @@ -57,21 +67,31 @@ pub async fn get_relating_events_with_rel_type_route( }, }; - let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok()); + let to = body + .to + .as_ref() + .and_then(|t| PduCount::try_from_string(t).ok()); // Use limit or else 10, with maximum 100 - let limit = body.limit.and_then(|u| u32::try_from(u).ok()).map_or(10_usize, |u| u as usize).min(100); + let limit = body + .limit + .and_then(|u| u32::try_from(u).ok()) + .map_or(10_usize, |u| u as usize) + .min(100); - let res = services().rooms.pdu_metadata.paginate_relations_with_filter( - sender_user, - &body.room_id, - &body.event_id, - &None, - &Some(body.rel_type.clone()), - from, - to, - limit, - )?; + let res = services() + .rooms + .pdu_metadata + .paginate_relations_with_filter( + sender_user, + &body.room_id, + &body.event_id, + &None, + &Some(body.rel_type.clone()), + from, + to, + limit, + )?; Ok(get_relating_events_with_rel_type::v1::Response { chunk: res.chunk, @@ -95,19 +115,20 @@ pub async fn get_relating_events_route( }, }; - let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok()); + let to = body + .to + .as_ref() + .and_then(|t| PduCount::try_from_string(t).ok()); // Use limit or else 10, with maximum 100 - let limit = body.limit.and_then(|u| u32::try_from(u).ok()).map_or(10_usize, |u| u as usize).min(100); + let limit = body + .limit + .and_then(|u| u32::try_from(u).ok()) + .map_or(10_usize, |u| u as usize) + .min(100); - services().rooms.pdu_metadata.paginate_relations_with_filter( - sender_user, - &body.room_id, - &body.event_id, - &None, - &None, - from, - to, - limit, - ) + services() + .rooms + .pdu_metadata + .paginate_relations_with_filter(sender_user, &body.room_id, &body.event_id, &None, &None, from, to, limit) } diff --git a/src/api/client_server/report.rs b/src/api/client_server/report.rs index 9580524e..08ee6157 100644 --- a/src/api/client_server/report.rs +++ b/src/api/client_server/report.rs @@ -71,18 +71,20 @@ pub async fn report_event_route(body: Ruma) -> Resu // send admin room message that we received the report with an @room ping for // urgency - services().admin.send_message(message::RoomMessageEventContent::text_html( - format!( - "@room Report received from: {}\n\nEvent ID: {}\nRoom ID: {}\nSent By: {}\n\nReport Score: {}\nReport \ - Reason: {}", - sender_user.to_owned(), - pdu.event_id, - pdu.room_id, - pdu.sender.clone(), - body.score.unwrap_or_else(|| ruma::Int::from(0)), - body.reason.as_deref().unwrap_or("") - ), - format!( + services() + .admin + .send_message(message::RoomMessageEventContent::text_html( + format!( + "@room Report received from: {}\n\nEvent ID: {}\nRoom ID: {}\nSent By: {}\n\nReport Score: {}\nReport \ + Reason: {}", + sender_user.to_owned(), + pdu.event_id, + pdu.room_id, + pdu.sender.clone(), + body.score.unwrap_or_else(|| ruma::Int::from(0)), + body.reason.as_deref().unwrap_or("") + ), + format!( "
@room Report received from: {0}\
  • Event Info
    • Event ID: {1}\ 🔗
    • Room ID: {2}\ @@ -96,7 +98,7 @@ pub async fn report_event_route(body: Ruma) -> Resu body.score.unwrap_or_else(|| ruma::Int::from(0)), HtmlEscape(body.reason.as_deref().unwrap_or("")) ), - )); + )); // even though this is kinda security by obscurity, let's still make a small // random delay sending a successful response per spec suggestion regarding diff --git a/src/api/client_server/room.rs b/src/api/client_server/room.rs index 0d551dca..232cf209 100644 --- a/src/api/client_server/room.rs +++ b/src/api/client_server/room.rs @@ -80,7 +80,11 @@ pub async fn create_room_route(body: Ruma) -> Result) -> Result = body.room_alias_name.as_ref().map_or(Ok(None), |localpart| { - // Basic checks on the room alias validity - if localpart.contains(':') { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Room alias contained `:` which is not allowed. Please note that this expects a localpart, not the \ - full room alias.", - )); - } else if localpart.contains(char::is_whitespace) { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Room alias contained spaces which is not a valid room alias.", - )); - } else if localpart.len() > 255 { - // there is nothing spec-wise saying to check the limit of this, - // however absurdly long room aliases are guaranteed to be unreadable or done - // maliciously. there is no reason a room alias should even exceed 100 - // characters as is. generally in spec, 255 is matrix's fav number - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Room alias is excessively long, clients may not be able to handle this. Please shorten it.", - )); - } else if localpart.contains('"') { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Room alias contained `\"` which is not allowed.", - )); - } + let alias: Option = body + .room_alias_name + .as_ref() + .map_or(Ok(None), |localpart| { + // Basic checks on the room alias validity + if localpart.contains(':') { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Room alias contained `:` which is not allowed. Please note that this expects a localpart, not \ + the full room alias.", + )); + } else if localpart.contains(char::is_whitespace) { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Room alias contained spaces which is not a valid room alias.", + )); + } else if localpart.len() > 255 { + // there is nothing spec-wise saying to check the limit of this, + // however absurdly long room aliases are guaranteed to be unreadable or done + // maliciously. there is no reason a room alias should even exceed 100 + // characters as is. generally in spec, 255 is matrix's fav number + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Room alias is excessively long, clients may not be able to handle this. Please shorten it.", + )); + } else if localpart.contains('"') { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Room alias contained `\"` which is not allowed.", + )); + } - // check if room alias is forbidden - if services().globals.forbidden_alias_names().is_match(localpart) { - return Err(Error::BadRequest(ErrorKind::Unknown, "Room alias name is forbidden.")); - } + // check if room alias is forbidden + if services() + .globals + .forbidden_alias_names() + .is_match(localpart) + { + return Err(Error::BadRequest(ErrorKind::Unknown, "Room alias name is forbidden.")); + } - let alias = - RoomAliasId::parse(format!("#{}:{}", localpart, services().globals.server_name())).map_err(|e| { - warn!("Failed to parse room alias for room ID {}: {e}", room_id); - Error::BadRequest(ErrorKind::InvalidParam, "Invalid room alias specified.") - })?; + let alias = + RoomAliasId::parse(format!("#{}:{}", localpart, services().globals.server_name())).map_err(|e| { + warn!("Failed to parse room alias for room ID {}: {e}", room_id); + Error::BadRequest(ErrorKind::InvalidParam, "Invalid room alias specified.") + })?; - if services().rooms.alias.resolve_local_alias(&alias)?.is_some() { - Err(Error::BadRequest(ErrorKind::RoomInUse, "Room alias already exists.")) - } else { - Ok(Some(alias)) - } - })?; + if services() + .rooms + .alias + .resolve_local_alias(&alias)? + .is_some() + { + Err(Error::BadRequest(ErrorKind::RoomInUse, "Room alias already exists.")) + } else { + Ok(Some(alias)) + } + })?; let room_version = match body.room_version.clone() { Some(room_version) => { - if services().globals.supported_room_versions().contains(&room_version) { + if services() + .globals + .supported_room_versions() + .contains(&room_version) + { room_version } else { return Err(Error::BadRequest( @@ -177,10 +204,12 @@ pub async fn create_room_route(body: Ruma) -> Result { - let mut content = content.deserialize_as::().map_err(|e| { - error!("Failed to deserialise content as canonical JSON: {}", e); - Error::bad_database("Failed to deserialise content as canonical JSON.") - })?; + let mut content = content + .deserialize_as::() + .map_err(|e| { + error!("Failed to deserialise content as canonical JSON: {}", e); + Error::bad_database("Failed to deserialise content as canonical JSON.") + })?; match room_version { RoomVersionId::V1 | RoomVersionId::V2 @@ -256,8 +285,11 @@ pub async fn create_room_route(body: Ruma) -> Result(to_raw_value(&content).expect("Invalid creation content").get()); + let de_result = serde_json::from_str::( + to_raw_value(&content) + .expect("Invalid creation content") + .get(), + ); if de_result.is_err() { return Err(Error::BadRequest(ErrorKind::BadJson, "Invalid creation content")); @@ -463,7 +495,11 @@ pub async fn create_room_route(body: Ruma) -> Result) -> Result) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event = services().rooms.timeline.get_pdu(&body.event_id)?.ok_or_else(|| { - warn!("Event not found, event ID: {:?}", &body.event_id); - Error::BadRequest(ErrorKind::NotFound, "Event not found.") - })?; + let event = services() + .rooms + .timeline + .get_pdu(&body.event_id)? + .ok_or_else(|| { + warn!("Event not found, event ID: {:?}", &body.event_id); + Error::BadRequest(ErrorKind::NotFound, "Event not found.") + })?; - if !services().rooms.state_accessor.user_can_see_event(sender_user, &event.room_id, &body.event_id)? { + if !services() + .rooms + .state_accessor + .user_can_see_event(sender_user, &event.room_id, &body.event_id)? + { return Err(Error::BadRequest( ErrorKind::Forbidden, "You don't have permission to view this event.", @@ -567,7 +611,11 @@ pub async fn get_room_event_route(body: Ruma) -> Re pub async fn get_room_aliases_route(body: Ruma) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services().rooms.state_accessor.user_can_see_state_events(sender_user, &body.room_id)? { + if !services() + .rooms + .state_accessor + .user_can_see_state_events(sender_user, &body.room_id)? + { return Err(Error::BadRequest( ErrorKind::Forbidden, "You don't have permission to view this room.", @@ -575,7 +623,12 @@ pub async fn get_room_aliases_route(body: Ruma) -> Result< } Ok(aliases::v3::Response { - aliases: services().rooms.alias.local_aliases_for_room(&body.room_id).filter_map(Result::ok).collect(), + aliases: services() + .rooms + .alias + .local_aliases_for_room(&body.room_id) + .filter_map(Result::ok) + .collect(), }) } @@ -592,7 +645,11 @@ pub async fn get_room_aliases_route(body: Ruma) -> Result< pub async fn upgrade_room_route(body: Ruma) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services().globals.supported_room_versions().contains(&body.new_version) { + if !services() + .globals + .supported_room_versions() + .contains(&body.new_version) + { return Err(Error::BadRequest( ErrorKind::UnsupportedRoomVersion, "This server does not support that room version.", @@ -601,10 +658,20 @@ pub async fn upgrade_room_route(body: Ruma) -> Result // Create a replacement room let replacement_room = RoomId::new(services().globals.server_name()); - services().rooms.short.get_or_create_shortroomid(&replacement_room)?; + services() + .rooms + .short + .get_or_create_shortroomid(&replacement_room)?; - let mutex_state = - Arc::clone(services().globals.roomid_mutex_state.write().await.entry(body.room_id.clone()).or_default()); + let mutex_state = Arc::clone( + services() + .globals + .roomid_mutex_state + .write() + .await + .entry(body.room_id.clone()) + .or_default(), + ); let state_lock = mutex_state.lock().await; // Send a m.room.tombstone event to the old room to indicate that it is not @@ -633,8 +700,15 @@ pub async fn upgrade_room_route(body: Ruma) -> Result // Change lock to replacement room drop(state_lock); - let mutex_state = - Arc::clone(services().globals.roomid_mutex_state.write().await.entry(replacement_room.clone()).or_default()); + let mutex_state = Arc::clone( + services() + .globals + .roomid_mutex_state + .write() + .await + .entry(replacement_room.clone()) + .or_default(), + ); let state_lock = mutex_state.lock().await; // Get the old room creation event @@ -704,7 +778,9 @@ pub async fn upgrade_room_route(body: Ruma) -> Result // Validate creation event content let de_result = serde_json::from_str::( - to_raw_value(&create_event_content).expect("Error forming creation event").get(), + to_raw_value(&create_event_content) + .expect("Error forming creation event") + .get(), ); if de_result.is_err() { @@ -771,7 +847,11 @@ pub async fn upgrade_room_route(body: Ruma) -> Result // Replicate transferable state events to the new room for event_type in transferable_state_events { - let event_content = match services().rooms.state_accessor.room_state_get(&body.room_id, &event_type, "")? { + let event_content = match services() + .rooms + .state_accessor + .room_state_get(&body.room_id, &event_type, "")? + { Some(v) => v.content.clone(), None => continue, // Skipping missing events. }; @@ -795,8 +875,16 @@ pub async fn upgrade_room_route(body: Ruma) -> Result } // Moves any local aliases to the new room - for alias in services().rooms.alias.local_aliases_for_room(&body.room_id).filter_map(Result::ok) { - services().rooms.alias.set_alias(&alias, &replacement_room)?; + for alias in services() + .rooms + .alias + .local_aliases_for_room(&body.room_id) + .filter_map(Result::ok) + { + services() + .rooms + .alias + .set_alias(&alias, &replacement_room)?; } // Get the old room power levels diff --git a/src/api/client_server/search.rs b/src/api/client_server/search.rs index e8c324b9..e96a1712 100644 --- a/src/api/client_server/search.rs +++ b/src/api/client_server/search.rs @@ -29,10 +29,14 @@ pub async fn search_events_route(body: Ruma) -> Resu let filter = &search_criteria.filter; let include_state = &search_criteria.include_state; - let room_ids = filter - .rooms - .clone() - .unwrap_or_else(|| services().rooms.state_cache.rooms_joined(sender_user).filter_map(Result::ok).collect()); + let room_ids = filter.rooms.clone().unwrap_or_else(|| { + services() + .rooms + .state_cache + .rooms_joined(sender_user) + .filter_map(Result::ok) + .collect() + }); // Use limit or else 10, with maximum 100 let limit = filter.limit.map_or(10, u64::from).min(100) as usize; @@ -41,7 +45,11 @@ pub async fn search_events_route(body: Ruma) -> Resu if include_state.is_some_and(|include_state| include_state) { for room_id in &room_ids { - if !services().rooms.state_cache.is_joined(sender_user, room_id)? { + if !services() + .rooms + .state_cache + .is_joined(sender_user, room_id)? + { return Err(Error::BadRequest( ErrorKind::Forbidden, "You don't have permission to view this room.", @@ -49,7 +57,11 @@ pub async fn search_events_route(body: Ruma) -> Resu } // check if sender_user can see state events - if services().rooms.state_accessor.user_can_see_state_events(sender_user, room_id)? { + if services() + .rooms + .state_accessor + .user_can_see_state_events(sender_user, room_id)? + { let room_state = services() .rooms .state_accessor @@ -74,14 +86,22 @@ pub async fn search_events_route(body: Ruma) -> Resu let mut searches = Vec::new(); for room_id in &room_ids { - if !services().rooms.state_cache.is_joined(sender_user, room_id)? { + if !services() + .rooms + .state_cache + .is_joined(sender_user, room_id)? + { return Err(Error::BadRequest( ErrorKind::Forbidden, "You don't have permission to view this room.", )); } - if let Some(search) = services().rooms.search.search_pdus(room_id, &search_criteria.search_term)? { + if let Some(search) = services() + .rooms + .search + .search_pdus(room_id, &search_criteria.search_term)? + { searches.push(search.0.peekable()); } } diff --git a/src/api/client_server/session.rs b/src/api/client_server/session.rs index 381f08ea..ad135d0e 100644 --- a/src/api/client_server/session.rs +++ b/src/api/client_server/session.rs @@ -101,7 +101,11 @@ pub async fn login_route(body: Ruma) -> Result) -> Result) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let limit = body.limit.unwrap_or_else(|| UInt::from(10_u32)).min(UInt::from(100_u32)); + let limit = body + .limit + .unwrap_or_else(|| UInt::from(10_u32)) + .min(UInt::from(100_u32)); - let max_depth = body.max_depth.unwrap_or_else(|| UInt::from(3_u32)).min(UInt::from(10_u32)); + let max_depth = body + .max_depth + .unwrap_or_else(|| UInt::from(3_u32)) + .min(UInt::from(10_u32)); - let key = body.from.as_ref().and_then(|s| PagnationToken::from_str(s).ok()); + let key = body + .from + .as_ref() + .and_then(|s| PagnationToken::from_str(s).ok()); // Should prevent unexpeded behaviour in (bad) clients if let Some(ref token) = key { diff --git a/src/api/client_server/state.rs b/src/api/client_server/state.rs index 1720295a..f436cce6 100644 --- a/src/api/client_server/state.rs +++ b/src/api/client_server/state.rs @@ -86,7 +86,11 @@ pub async fn get_state_events_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services().rooms.state_accessor.user_can_see_state_events(sender_user, &body.room_id)? { + if !services() + .rooms + .state_accessor + .user_can_see_state_events(sender_user, &body.room_id)? + { return Err(Error::BadRequest( ErrorKind::Forbidden, "You don't have permission to view the room state.", @@ -118,21 +122,30 @@ pub async fn get_state_events_for_key_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services().rooms.state_accessor.user_can_see_state_events(sender_user, &body.room_id)? { + if !services() + .rooms + .state_accessor + .user_can_see_state_events(sender_user, &body.room_id)? + { return Err(Error::BadRequest( ErrorKind::Forbidden, "You don't have permission to view the room state.", )); } - let event = - services().rooms.state_accessor.room_state_get(&body.room_id, &body.event_type, &body.state_key)?.ok_or_else( - || { - warn!("State event {:?} not found in room {:?}", &body.event_type, &body.room_id); - Error::BadRequest(ErrorKind::NotFound, "State event not found.") - }, - )?; - if body.format.as_ref().is_some_and(|f| f.to_lowercase().eq("event")) { + let event = services() + .rooms + .state_accessor + .room_state_get(&body.room_id, &body.event_type, &body.state_key)? + .ok_or_else(|| { + warn!("State event {:?} not found in room {:?}", &body.event_type, &body.room_id); + Error::BadRequest(ErrorKind::NotFound, "State event not found.") + })?; + if body + .format + .as_ref() + .is_some_and(|f| f.to_lowercase().eq("event")) + { Ok(get_state_events_for_key::v3::Response { content: None, event: serde_json::from_str(event.to_state_event().json().get()).map_err(|e| { @@ -164,20 +177,31 @@ pub async fn get_state_events_for_empty_key_route( ) -> Result> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services().rooms.state_accessor.user_can_see_state_events(sender_user, &body.room_id)? { + if !services() + .rooms + .state_accessor + .user_can_see_state_events(sender_user, &body.room_id)? + { return Err(Error::BadRequest( ErrorKind::Forbidden, "You don't have permission to view the room state.", )); } - let event = - services().rooms.state_accessor.room_state_get(&body.room_id, &body.event_type, "")?.ok_or_else(|| { + let event = services() + .rooms + .state_accessor + .room_state_get(&body.room_id, &body.event_type, "")? + .ok_or_else(|| { warn!("State event {:?} not found in room {:?}", &body.event_type, &body.room_id); Error::BadRequest(ErrorKind::NotFound, "State event not found.") })?; - if body.format.as_ref().is_some_and(|f| f.to_lowercase().eq("event")) { + if body + .format + .as_ref() + .is_some_and(|f| f.to_lowercase().eq("event")) + { Ok(get_state_events_for_key::v3::Response { content: None, event: serde_json::from_str(event.to_state_event().json().get()).map_err(|e| { @@ -229,8 +253,15 @@ async fn send_state_event_for_key_helper( } } - let mutex_state = - Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.to_owned()).or_default()); + let mutex_state = Arc::clone( + services() + .globals + .roomid_mutex_state + .write() + .await + .entry(room_id.to_owned()) + .or_default(), + ); let state_lock = mutex_state.lock().await; let event_id = services() diff --git a/src/api/client_server/sync.rs b/src/api/client_server/sync.rs index f72da248..8d1f39f2 100644 --- a/src/api/client_server/sync.rs +++ b/src/api/client_server/sync.rs @@ -81,33 +81,38 @@ pub async fn sync_events_route( let sender_device = body.sender_device.expect("user is authenticated"); let body = body.body; - let mut rx = - match services().globals.sync_receivers.write().await.entry((sender_user.clone(), sender_device.clone())) { - Entry::Vacant(v) => { + let mut rx = match services() + .globals + .sync_receivers + .write() + .await + .entry((sender_user.clone(), sender_device.clone())) + { + Entry::Vacant(v) => { + let (tx, rx) = tokio::sync::watch::channel(None); + + v.insert((body.since.clone(), rx.clone())); + + tokio::spawn(sync_helper_wrapper(sender_user.clone(), sender_device.clone(), body, tx)); + + rx + }, + Entry::Occupied(mut o) => { + if o.get().0 != body.since { let (tx, rx) = tokio::sync::watch::channel(None); - v.insert((body.since.clone(), rx.clone())); + o.insert((body.since.clone(), rx.clone())); + + debug!("Sync started for {sender_user}"); tokio::spawn(sync_helper_wrapper(sender_user.clone(), sender_device.clone(), body, tx)); rx - }, - Entry::Occupied(mut o) => { - if o.get().0 != body.since { - let (tx, rx) = tokio::sync::watch::channel(None); - - o.insert((body.since.clone(), rx.clone())); - - debug!("Sync started for {sender_user}"); - - tokio::spawn(sync_helper_wrapper(sender_user.clone(), sender_device.clone(), body, tx)); - - rx - } else { - o.get().1.clone() - } - }, - }; + } else { + o.get().1.clone() + } + }, + }; let we_have_to_wait = rx.borrow().is_none(); if we_have_to_wait { @@ -116,7 +121,11 @@ pub async fn sync_events_route( } } - let result = match rx.borrow().as_ref().expect("When sync channel changes it's always set to some") { + let result = match rx + .borrow() + .as_ref() + .expect("When sync channel changes it's always set to some") + { Ok(response) => Ok(response.clone()), Err(error) => Err(error.to_response()), }; @@ -134,7 +143,13 @@ async fn sync_helper_wrapper( if let Ok((_, caching_allowed)) = r { if !caching_allowed { - match services().globals.sync_receivers.write().await.entry((sender_user, sender_device)) { + match services() + .globals + .sync_receivers + .write() + .await + .entry((sender_user, sender_device)) + { Entry::Occupied(o) => { // Only remove if the device didn't start a different /sync already if o.get().0 == since { @@ -157,7 +172,11 @@ async fn sync_helper( ) -> Result<(sync_events::v3::Response, bool), Error> { // Presence update if services().globals.allow_local_presence() { - services().rooms.edus.presence.ping_presence(&sender_user, body.set_presence)?; + services() + .rooms + .edus + .presence + .ping_presence(&sender_user, body.set_presence)?; } // Setup watchers, so if there's no response, we can wait for them @@ -171,7 +190,10 @@ async fn sync_helper( let filter = match body.filter { None => FilterDefinition::default(), Some(Filter::FilterDefinition(filter)) => filter, - Some(Filter::FilterId(filter_id)) => services().users.get_filter(&sender_user, &filter_id)?.unwrap_or_default(), + Some(Filter::FilterId(filter_id)) => services() + .users + .get_filter(&sender_user, &filter_id)? + .unwrap_or_default(), }; let (lazy_load_enabled, lazy_load_send_redundant) = match filter.room.state.lazy_load_options { @@ -184,7 +206,11 @@ async fn sync_helper( let full_state = body.full_state; let mut joined_rooms = BTreeMap::new(); - let since = body.since.as_ref().and_then(|string| string.parse().ok()).unwrap_or(0); + let since = body + .since + .as_ref() + .and_then(|string| string.parse().ok()) + .unwrap_or(0); let sincecount = PduCount::Normal(since); let mut presence_updates = HashMap::new(); @@ -193,9 +219,18 @@ async fn sync_helper( let mut device_list_left = HashSet::new(); // Look for device list updates of this account - device_list_updates.extend(services().users.keys_changed(sender_user.as_ref(), since, None).filter_map(Result::ok)); + device_list_updates.extend( + services() + .users + .keys_changed(sender_user.as_ref(), since, None) + .filter_map(Result::ok), + ); - let all_joined_rooms = services().rooms.state_cache.rooms_joined(&sender_user).collect::>(); + let all_joined_rooms = services() + .rooms + .state_cache + .rooms_joined(&sender_user) + .collect::>(); // Coalesce database writes for the remainder of this scope. let _cork = services().globals.db.cork_and_flush()?; @@ -229,7 +264,11 @@ async fn sync_helper( } let mut left_rooms = BTreeMap::new(); - let all_left_rooms: Vec<_> = services().rooms.state_cache.rooms_left(&sender_user).collect(); + let all_left_rooms: Vec<_> = services() + .rooms + .state_cache + .rooms_left(&sender_user) + .collect(); for result in all_left_rooms { let (room_id, _) = result?; @@ -237,13 +276,23 @@ async fn sync_helper( { // Get and drop the lock to wait for remaining operations to finish - let mutex_insert = - Arc::clone(services().globals.roomid_mutex_insert.write().await.entry(room_id.clone()).or_default()); + let mutex_insert = Arc::clone( + services() + .globals + .roomid_mutex_insert + .write() + .await + .entry(room_id.clone()) + .or_default(), + ); let insert_lock = mutex_insert.lock().await; drop(insert_lock); }; - let left_count = services().rooms.state_cache.get_left_count(&room_id, &sender_user)?; + let left_count = services() + .rooms + .state_cache + .get_left_count(&room_id, &sender_user)?; // Left before last sync if Some(since) >= left_count { @@ -255,7 +304,10 @@ async fn sync_helper( continue; } - let since_shortstatehash = services().rooms.user.get_token_shortstatehash(&room_id, since)?; + let since_shortstatehash = services() + .rooms + .user + .get_token_shortstatehash(&room_id, since)?; let since_state_ids = match since_shortstatehash { Some(s) => services().rooms.state_accessor.state_full_ids(s).await?, @@ -274,7 +326,11 @@ async fn sync_helper( }, }; - let left_shortstatehash = match services().rooms.state_accessor.pdu_shortstatehash(&left_event_id)? { + let left_shortstatehash = match services() + .rooms + .state_accessor + .pdu_shortstatehash(&left_event_id)? + { Some(s) => s, None => { error!("Leave event has no state"); @@ -282,10 +338,16 @@ async fn sync_helper( }, }; - let mut left_state_ids = services().rooms.state_accessor.state_full_ids(left_shortstatehash).await?; + let mut left_state_ids = services() + .rooms + .state_accessor + .state_full_ids(left_shortstatehash) + .await?; - let leave_shortstatekey = - services().rooms.short.get_or_create_shortstatekey(&StateEventType::RoomMember, sender_user.as_str())?; + let leave_shortstatekey = services() + .rooms + .short + .get_or_create_shortstatekey(&StateEventType::RoomMember, sender_user.as_str())?; left_state_ids.insert(leave_shortstatekey, left_event_id); @@ -334,19 +396,33 @@ async fn sync_helper( } let mut invited_rooms = BTreeMap::new(); - let all_invited_rooms: Vec<_> = services().rooms.state_cache.rooms_invited(&sender_user).collect(); + let all_invited_rooms: Vec<_> = services() + .rooms + .state_cache + .rooms_invited(&sender_user) + .collect(); for result in all_invited_rooms { let (room_id, invite_state_events) = result?; { // Get and drop the lock to wait for remaining operations to finish - let mutex_insert = - Arc::clone(services().globals.roomid_mutex_insert.write().await.entry(room_id.clone()).or_default()); + let mutex_insert = Arc::clone( + services() + .globals + .roomid_mutex_insert + .write() + .await + .entry(room_id.clone()) + .or_default(), + ); let insert_lock = mutex_insert.lock().await; drop(insert_lock); }; - let invite_count = services().rooms.state_cache.get_invite_count(&room_id, &sender_user)?; + let invite_count = services() + .rooms + .state_cache + .get_invite_count(&room_id, &sender_user)?; // Invited before last sync if Some(since) >= invite_count { @@ -388,7 +464,9 @@ async fn sync_helper( } // Remove all to-device events the device received *last time* - services().users.remove_to_device_events(&sender_user, &sender_device, since)?; + services() + .users + .remove_to_device_events(&sender_user, &sender_device, since)?; let response = sync_events::v3::Response { next_batch: next_batch_string, @@ -420,9 +498,13 @@ async fn sync_helper( changed: device_list_updates.into_iter().collect(), left: device_list_left.into_iter().collect(), }, - device_one_time_keys_count: services().users.count_one_time_keys(&sender_user, &sender_device)?, + device_one_time_keys_count: services() + .users + .count_one_time_keys(&sender_user, &sender_device)?, to_device: ToDevice { - events: services().users.get_to_device_events(&sender_user, &sender_device)?, + events: services() + .users + .get_to_device_events(&sender_user, &sender_device)?, }, // Fallback keys are not yet supported device_unused_fallback_key_types: None, @@ -453,7 +535,12 @@ async fn process_room_presence_updates( presence_updates: &mut HashMap, room_id: &RoomId, since: u64, ) -> Result<()> { // Take presence updates from this room - for (user_id, _, presence_event) in services().rooms.edus.presence.presence_since(room_id, since) { + for (user_id, _, presence_event) in services() + .rooms + .edus + .presence + .presence_since(room_id, since) + { match presence_updates.entry(user_id) { Entry::Vacant(slot) => { slot.insert(presence_event); @@ -465,11 +552,19 @@ async fn process_room_presence_updates( // Update existing presence event with more info curr_content.presence = new_content.presence; - curr_content.status_msg = new_content.status_msg.or_else(|| curr_content.status_msg.take()); + curr_content.status_msg = new_content + .status_msg + .or_else(|| curr_content.status_msg.take()); curr_content.last_active_ago = new_content.last_active_ago.or(curr_content.last_active_ago); - curr_content.displayname = new_content.displayname.or_else(|| curr_content.displayname.take()); - curr_content.avatar_url = new_content.avatar_url.or_else(|| curr_content.avatar_url.take()); - curr_content.currently_active = new_content.currently_active.or(curr_content.currently_active); + curr_content.displayname = new_content + .displayname + .or_else(|| curr_content.displayname.take()); + curr_content.avatar_url = new_content + .avatar_url + .or_else(|| curr_content.avatar_url.take()); + curr_content.currently_active = new_content + .currently_active + .or(curr_content.currently_active); }, } } @@ -486,23 +581,38 @@ async fn load_joined_room( { // Get and drop the lock to wait for remaining operations to finish // This will make sure the we have all events until next_batch - let mutex_insert = - Arc::clone(services().globals.roomid_mutex_insert.write().await.entry(room_id.to_owned()).or_default()); + let mutex_insert = Arc::clone( + services() + .globals + .roomid_mutex_insert + .write() + .await + .entry(room_id.to_owned()) + .or_default(), + ); let insert_lock = mutex_insert.lock().await; drop(insert_lock); }; let (timeline_pdus, limited) = load_timeline(sender_user, room_id, sincecount, 10)?; - let send_notification_counts = - !timeline_pdus.is_empty() || services().rooms.user.last_notification_read(sender_user, room_id)? > since; + let send_notification_counts = !timeline_pdus.is_empty() + || services() + .rooms + .user + .last_notification_read(sender_user, room_id)? + > since; let mut timeline_users = HashSet::new(); for (_, event) in &timeline_pdus { timeline_users.insert(event.sender.as_str().to_owned()); } - services().rooms.lazy_loading.lazy_load_confirm_delivery(sender_user, sender_device, room_id, sincecount).await?; + services() + .rooms + .lazy_loading + .lazy_load_confirm_delivery(sender_user, sender_device, room_id, sincecount) + .await?; // Database queries: @@ -513,28 +623,37 @@ async fn load_joined_room( return Err(Error::BadDatabase("Room has no state")); }; - let since_shortstatehash = services().rooms.user.get_token_shortstatehash(room_id, since)?; + let since_shortstatehash = services() + .rooms + .user + .get_token_shortstatehash(room_id, since)?; - let (heroes, joined_member_count, invited_member_count, joined_since_last_sync, state_events) = if timeline_pdus - .is_empty() - && since_shortstatehash == Some(current_shortstatehash) - { - // No state changes - (Vec::new(), None, None, false, Vec::new()) - } else { - // Calculates joined_member_count, invited_member_count and heroes - let calculate_counts = || { - let joined_member_count = services().rooms.state_cache.room_joined_count(room_id)?.unwrap_or(0); - let invited_member_count = services().rooms.state_cache.room_invited_count(room_id)?.unwrap_or(0); + let (heroes, joined_member_count, invited_member_count, joined_since_last_sync, state_events) = + if timeline_pdus.is_empty() && since_shortstatehash == Some(current_shortstatehash) { + // No state changes + (Vec::new(), None, None, false, Vec::new()) + } else { + // Calculates joined_member_count, invited_member_count and heroes + let calculate_counts = || { + let joined_member_count = services() + .rooms + .state_cache + .room_joined_count(room_id)? + .unwrap_or(0); + let invited_member_count = services() + .rooms + .state_cache + .room_invited_count(room_id)? + .unwrap_or(0); - // Recalculate heroes (first 5 members) - let mut heroes = Vec::new(); + // Recalculate heroes (first 5 members) + let mut heroes = Vec::new(); - if joined_member_count + invited_member_count <= 5 { - // Go through all PDUs and for each member event, check if the user is still - // joined or invited until we have 5 or we reach the end + if joined_member_count + invited_member_count <= 5 { + // Go through all PDUs and for each member event, check if the user is still + // joined or invited until we have 5 or we reach the end - for hero in services() + for hero in services() .rooms .timeline .all_pdus(sender_user, room_id)? @@ -565,116 +684,77 @@ async fn load_joined_room( .filter_map(Result::ok) // Filter for possible heroes .flatten() - { - if heroes.contains(&hero) || hero == sender_user.as_str() { - continue; + { + if heroes.contains(&hero) || hero == sender_user.as_str() { + continue; + } + + heroes.push(hero); } - - heroes.push(hero); } - } - Ok::<_, Error>((Some(joined_member_count), Some(invited_member_count), heroes)) - }; + Ok::<_, Error>((Some(joined_member_count), Some(invited_member_count), heroes)) + }; - let since_sender_member: Option = since_shortstatehash - .and_then(|shortstatehash| { - services() + let since_sender_member: Option = since_shortstatehash + .and_then(|shortstatehash| { + services() + .rooms + .state_accessor + .state_get(shortstatehash, &StateEventType::RoomMember, sender_user.as_str()) + .transpose() + }) + .transpose()? + .and_then(|pdu| { + serde_json::from_str(pdu.content.get()) + .map_err(|_| Error::bad_database("Invalid PDU in database.")) + .ok() + }); + + let joined_since_last_sync = + since_sender_member.map_or(true, |member| member.membership != MembershipState::Join); + + if since_shortstatehash.is_none() || joined_since_last_sync { + // Probably since = 0, we will do an initial sync + + let (joined_member_count, invited_member_count, heroes) = calculate_counts()?; + + let current_state_ids = services() .rooms .state_accessor - .state_get(shortstatehash, &StateEventType::RoomMember, sender_user.as_str()) - .transpose() - }) - .transpose()? - .and_then(|pdu| { - serde_json::from_str(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid PDU in database.")) - .ok() - }); + .state_full_ids(current_shortstatehash) + .await?; - let joined_since_last_sync = - since_sender_member.map_or(true, |member| member.membership != MembershipState::Join); + let mut state_events = Vec::new(); + let mut lazy_loaded = HashSet::new(); - if since_shortstatehash.is_none() || joined_since_last_sync { - // Probably since = 0, we will do an initial sync + let mut i = 0; + for (shortstatekey, id) in current_state_ids { + let (event_type, state_key) = services() + .rooms + .short + .get_statekey_from_short(shortstatekey)?; - let (joined_member_count, invited_member_count, heroes) = calculate_counts()?; + if event_type != StateEventType::RoomMember { + let pdu = match services().rooms.timeline.get_pdu(&id)? { + Some(pdu) => pdu, + None => { + error!("Pdu in state not found: {}", id); + continue; + }, + }; + state_events.push(pdu); - let current_state_ids = services().rooms.state_accessor.state_full_ids(current_shortstatehash).await?; - - let mut state_events = Vec::new(); - let mut lazy_loaded = HashSet::new(); - - let mut i = 0; - for (shortstatekey, id) in current_state_ids { - let (event_type, state_key) = services().rooms.short.get_statekey_from_short(shortstatekey)?; - - if event_type != StateEventType::RoomMember { - let pdu = match services().rooms.timeline.get_pdu(&id)? { - Some(pdu) => pdu, - None => { - error!("Pdu in state not found: {}", id); - continue; - }, - }; - state_events.push(pdu); - - i += 1; - if i % 100 == 0 { - tokio::task::yield_now().await; - } - } else if !lazy_load_enabled + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } + } else if !lazy_load_enabled || full_state || timeline_users.contains(&state_key) // TODO: Delete the following line when this is resolved: https://github.com/vector-im/element-web/issues/22565 || (cfg!(feature = "element_hacks") && *sender_user == state_key) - { - let pdu = match services().rooms.timeline.get_pdu(&id)? { - Some(pdu) => pdu, - None => { - error!("Pdu in state not found: {}", id); - continue; - }, - }; - - // This check is in case a bad user ID made it into the database - if let Ok(uid) = UserId::parse(&state_key) { - lazy_loaded.insert(uid); - } - state_events.push(pdu); - - i += 1; - if i % 100 == 0 { - tokio::task::yield_now().await; - } - } - } - - // Reset lazy loading because this is an initial sync - services().rooms.lazy_loading.lazy_load_reset(sender_user, sender_device, room_id)?; - - // The state_events above should contain all timeline_users, let's mark them as - // lazy loaded. - services() - .rooms - .lazy_loading - .lazy_load_mark_sent(sender_user, sender_device, room_id, lazy_loaded, next_batchcount) - .await; - - (heroes, joined_member_count, invited_member_count, true, state_events) - } else { - // Incremental /sync - let since_shortstatehash = since_shortstatehash.unwrap(); - - let mut state_events = Vec::new(); - let mut lazy_loaded = HashSet::new(); - - if since_shortstatehash != current_shortstatehash { - let current_state_ids = services().rooms.state_accessor.state_full_ids(current_shortstatehash).await?; - let since_state_ids = services().rooms.state_accessor.state_full_ids(since_shortstatehash).await?; - - for (key, id) in current_state_ids { - if full_state || since_state_ids.get(&key) != Some(&id) { + { let pdu = match services().rooms.timeline.get_pdu(&id)? { Some(pdu) => pdu, None => { @@ -683,137 +763,210 @@ async fn load_joined_room( }, }; - if pdu.kind == TimelineEventType::RoomMember { - match UserId::parse(pdu.state_key.as_ref().expect("State event has state key").clone()) { - Ok(state_key_userid) => { - lazy_loaded.insert(state_key_userid); - }, - Err(e) => error!("Invalid state key for member event: {}", e), - } + // This check is in case a bad user ID made it into the database + if let Ok(uid) = UserId::parse(&state_key) { + lazy_loaded.insert(uid); } - state_events.push(pdu); - tokio::task::yield_now().await; + + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } } } - } - for (_, event) in &timeline_pdus { - if lazy_loaded.contains(&event.sender) { - continue; - } + // Reset lazy loading because this is an initial sync + services() + .rooms + .lazy_loading + .lazy_load_reset(sender_user, sender_device, room_id)?; - if !services().rooms.lazy_loading.lazy_load_was_sent_before( - sender_user, - sender_device, - room_id, - &event.sender, - )? || lazy_load_send_redundant - { - if let Some(member_event) = services().rooms.state_accessor.room_state_get( - room_id, - &StateEventType::RoomMember, - event.sender.as_str(), - )? { - lazy_loaded.insert(event.sender.clone()); - state_events.push(member_event); + // The state_events above should contain all timeline_users, let's mark them as + // lazy loaded. + services() + .rooms + .lazy_loading + .lazy_load_mark_sent(sender_user, sender_device, room_id, lazy_loaded, next_batchcount) + .await; + + (heroes, joined_member_count, invited_member_count, true, state_events) + } else { + // Incremental /sync + let since_shortstatehash = since_shortstatehash.unwrap(); + + let mut state_events = Vec::new(); + let mut lazy_loaded = HashSet::new(); + + if since_shortstatehash != current_shortstatehash { + let current_state_ids = services() + .rooms + .state_accessor + .state_full_ids(current_shortstatehash) + .await?; + let since_state_ids = services() + .rooms + .state_accessor + .state_full_ids(since_shortstatehash) + .await?; + + for (key, id) in current_state_ids { + if full_state || since_state_ids.get(&key) != Some(&id) { + let pdu = match services().rooms.timeline.get_pdu(&id)? { + Some(pdu) => pdu, + None => { + error!("Pdu in state not found: {}", id); + continue; + }, + }; + + if pdu.kind == TimelineEventType::RoomMember { + match UserId::parse( + pdu.state_key + .as_ref() + .expect("State event has state key") + .clone(), + ) { + Ok(state_key_userid) => { + lazy_loaded.insert(state_key_userid); + }, + Err(e) => error!("Invalid state key for member event: {}", e), + } + } + + state_events.push(pdu); + tokio::task::yield_now().await; + } } } - } - services() - .rooms - .lazy_loading - .lazy_load_mark_sent(sender_user, sender_device, room_id, lazy_loaded, next_batchcount) - .await; - - let encrypted_room = services() - .rooms - .state_accessor - .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "")? - .is_some(); - - let since_encryption = - services().rooms.state_accessor.state_get(since_shortstatehash, &StateEventType::RoomEncryption, "")?; - - // Calculations: - let new_encrypted_room = encrypted_room && since_encryption.is_none(); - - let send_member_count = state_events.iter().any(|event| event.kind == TimelineEventType::RoomMember); - - if encrypted_room { - for state_event in &state_events { - if state_event.kind != TimelineEventType::RoomMember { + for (_, event) in &timeline_pdus { + if lazy_loaded.contains(&event.sender) { continue; } - if let Some(state_key) = &state_event.state_key { - let user_id = UserId::parse(state_key.clone()) - .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; - - if user_id == sender_user { - continue; - } - - let new_membership = serde_json::from_str::(state_event.content.get()) - .map_err(|_| Error::bad_database("Invalid PDU in database."))? - .membership; - - match new_membership { - MembershipState::Join => { - // A new user joined an encrypted room - if !share_encrypted_room(sender_user, &user_id, room_id)? { - device_list_updates.insert(user_id); - } - }, - MembershipState::Leave => { - // Write down users that have left encrypted rooms we are in - left_encrypted_users.insert(user_id); - }, - _ => {}, + if !services().rooms.lazy_loading.lazy_load_was_sent_before( + sender_user, + sender_device, + room_id, + &event.sender, + )? || lazy_load_send_redundant + { + if let Some(member_event) = services().rooms.state_accessor.room_state_get( + room_id, + &StateEventType::RoomMember, + event.sender.as_str(), + )? { + lazy_loaded.insert(event.sender.clone()); + state_events.push(member_event); } } } + + services() + .rooms + .lazy_loading + .lazy_load_mark_sent(sender_user, sender_device, room_id, lazy_loaded, next_batchcount) + .await; + + let encrypted_room = services() + .rooms + .state_accessor + .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "")? + .is_some(); + + let since_encryption = services().rooms.state_accessor.state_get( + since_shortstatehash, + &StateEventType::RoomEncryption, + "", + )?; + + // Calculations: + let new_encrypted_room = encrypted_room && since_encryption.is_none(); + + let send_member_count = state_events + .iter() + .any(|event| event.kind == TimelineEventType::RoomMember); + + if encrypted_room { + for state_event in &state_events { + if state_event.kind != TimelineEventType::RoomMember { + continue; + } + + if let Some(state_key) = &state_event.state_key { + let user_id = UserId::parse(state_key.clone()) + .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; + + if user_id == sender_user { + continue; + } + + let new_membership = + serde_json::from_str::(state_event.content.get()) + .map_err(|_| Error::bad_database("Invalid PDU in database."))? + .membership; + + match new_membership { + MembershipState::Join => { + // A new user joined an encrypted room + if !share_encrypted_room(sender_user, &user_id, room_id)? { + device_list_updates.insert(user_id); + } + }, + MembershipState::Leave => { + // Write down users that have left encrypted rooms we are in + left_encrypted_users.insert(user_id); + }, + _ => {}, + } + } + } + } + + if joined_since_last_sync && encrypted_room || new_encrypted_room { + // If the user is in a new encrypted room, give them all joined users + device_list_updates.extend( + services() + .rooms + .state_cache + .room_members(room_id) + .flatten() + .filter(|user_id| { + // Don't send key updates from the sender to the sender + sender_user != user_id + }) + .filter(|user_id| { + // Only send keys if the sender doesn't share an encrypted room with the target + // already + !share_encrypted_room(sender_user, user_id, room_id).unwrap_or(false) + }), + ); + } + + let (joined_member_count, invited_member_count, heroes) = if send_member_count { + calculate_counts()? + } else { + (None, None, Vec::new()) + }; + + ( + heroes, + joined_member_count, + invited_member_count, + joined_since_last_sync, + state_events, + ) } - - if joined_since_last_sync && encrypted_room || new_encrypted_room { - // If the user is in a new encrypted room, give them all joined users - device_list_updates.extend( - services() - .rooms - .state_cache - .room_members(room_id) - .flatten() - .filter(|user_id| { - // Don't send key updates from the sender to the sender - sender_user != user_id - }) - .filter(|user_id| { - // Only send keys if the sender doesn't share an encrypted room with the target - // already - !share_encrypted_room(sender_user, user_id, room_id).unwrap_or(false) - }), - ); - } - - let (joined_member_count, invited_member_count, heroes) = if send_member_count { - calculate_counts()? - } else { - (None, None, Vec::new()) - }; - - ( - heroes, - joined_member_count, - invited_member_count, - joined_since_last_sync, - state_events, - ) - } - }; + }; // Look for device list updates in this room - device_list_updates.extend(services().users.keys_changed(room_id.as_ref(), since, None).filter_map(Result::ok)); + device_list_updates.extend( + services() + .users + .keys_changed(room_id.as_ref(), since, None) + .filter_map(Result::ok), + ); let notification_count = if send_notification_counts { Some( @@ -841,17 +994,22 @@ async fn load_joined_room( None }; - let prev_batch = timeline_pdus.first().map_or(Ok::<_, Error>(None), |(pdu_count, _)| { - Ok(Some(match pdu_count { - PduCount::Backfilled(_) => { - error!("timeline in backfill state?!"); - "0".to_owned() - }, - PduCount::Normal(c) => c.to_string(), - })) - })?; + let prev_batch = timeline_pdus + .first() + .map_or(Ok::<_, Error>(None), |(pdu_count, _)| { + Ok(Some(match pdu_count { + PduCount::Backfilled(_) => { + error!("timeline in backfill state?!"); + "0".to_owned() + }, + PduCount::Normal(c) => c.to_string(), + })) + })?; - let room_events: Vec<_> = timeline_pdus.iter().map(|(_, pdu)| pdu.to_sync_room_event()).collect(); + let room_events: Vec<_> = timeline_pdus + .iter() + .map(|(_, pdu)| pdu.to_sync_room_event()) + .collect(); let mut edus: Vec<_> = services() .rooms @@ -862,7 +1020,13 @@ async fn load_joined_room( .map(|(_, _, v)| v) .collect(); - if services().rooms.edus.typing.last_typing_update(room_id).await? > since { + if services() + .rooms + .edus + .typing + .last_typing_update(room_id) + .await? > since + { edus.push( serde_json::from_str( &serde_json::to_string(&services().rooms.edus.typing.typings_all(room_id).await?) @@ -874,7 +1038,10 @@ async fn load_joined_room( // Save the state after this sync so we can send the correct state diff next // sync - services().rooms.user.associate_token_shortstatehash(room_id, next_batch, current_shortstatehash)?; + services() + .rooms + .user + .associate_token_shortstatehash(room_id, next_batch, current_shortstatehash)?; Ok(JoinedRoom { account_data: RoomAccountData { @@ -904,7 +1071,10 @@ async fn load_joined_room( events: room_events, }, state: State { - events: state_events.iter().map(|pdu| pdu.to_sync_state_event()).collect(), + events: state_events + .iter() + .map(|pdu| pdu.to_sync_state_event()) + .collect(), }, ephemeral: Ephemeral { events: edus, @@ -918,7 +1088,12 @@ fn load_timeline( ) -> Result<(Vec<(PduCount, PduEvent)>, bool), Error> { let timeline_pdus; let limited; - if services().rooms.timeline.last_timeline_count(sender_user, room_id)? > roomsincecount { + if services() + .rooms + .timeline + .last_timeline_count(sender_user, room_id)? + > roomsincecount + { let mut non_timeline_pdus = services() .rooms .timeline @@ -933,8 +1108,13 @@ fn load_timeline( .take_while(|(pducount, _)| pducount > &roomsincecount); // Take the last events for the timeline - timeline_pdus = - non_timeline_pdus.by_ref().take(limit as usize).collect::>().into_iter().rev().collect::>(); + timeline_pdus = non_timeline_pdus + .by_ref() + .take(limit as usize) + .collect::>() + .into_iter() + .rev() + .collect::>(); // They /sync response doesn't always return all messages, so we say the output // is limited unless there are events in non_timeline_pdus @@ -980,7 +1160,11 @@ pub async fn sync_events_v4_route( let next_batch = services().globals.next_count()?; - let globalsince = body.pos.as_ref().and_then(|string| string.parse().ok()).unwrap_or(0); + let globalsince = body + .pos + .as_ref() + .and_then(|string| string.parse().ok()) + .unwrap_or(0); if globalsince == 0 { if let Some(conn_id) = &body.conn_id { @@ -994,13 +1178,21 @@ pub async fn sync_events_v4_route( // Get sticky parameters from cache let known_rooms = - services().users.update_sync_request_with_cache(sender_user.clone(), sender_device.clone(), &mut body); + services() + .users + .update_sync_request_with_cache(sender_user.clone(), sender_device.clone(), &mut body); - let all_joined_rooms = - services().rooms.state_cache.rooms_joined(&sender_user).filter_map(Result::ok).collect::>(); + let all_joined_rooms = services() + .rooms + .state_cache + .rooms_joined(&sender_user) + .filter_map(Result::ok) + .collect::>(); if body.extensions.to_device.enabled.unwrap_or(false) { - services().users.remove_to_device_events(&sender_user, &sender_device, globalsince)?; + services() + .users + .remove_to_device_events(&sender_user, &sender_device, globalsince)?; } let mut left_encrypted_users = HashSet::new(); // Users that have left any encrypted rooms the sender was in @@ -1009,8 +1201,12 @@ pub async fn sync_events_v4_route( if body.extensions.e2ee.enabled.unwrap_or(false) { // Look for device list updates of this account - device_list_changes - .extend(services().users.keys_changed(sender_user.as_ref(), globalsince, None).filter_map(Result::ok)); + device_list_changes.extend( + services() + .users + .keys_changed(sender_user.as_ref(), globalsince, None) + .filter_map(Result::ok), + ); for room_id in &all_joined_rooms { let current_shortstatehash = if let Some(s) = services().rooms.state.get_room_shortstatehash(room_id)? { @@ -1020,7 +1216,10 @@ pub async fn sync_events_v4_route( continue; }; - let since_shortstatehash = services().rooms.user.get_token_shortstatehash(room_id, globalsince)?; + let since_shortstatehash = services() + .rooms + .user + .get_token_shortstatehash(room_id, globalsince)?; let since_sender_member: Option = since_shortstatehash .and_then(|shortstatehash| { @@ -1060,9 +1259,16 @@ pub async fn sync_events_v4_route( let new_encrypted_room = encrypted_room && since_encryption.is_none(); if encrypted_room { - let current_state_ids = - services().rooms.state_accessor.state_full_ids(current_shortstatehash).await?; - let since_state_ids = services().rooms.state_accessor.state_full_ids(since_shortstatehash).await?; + let current_state_ids = services() + .rooms + .state_accessor + .state_full_ids(current_shortstatehash) + .await?; + let since_state_ids = services() + .rooms + .state_accessor + .state_full_ids(since_shortstatehash) + .await?; for (key, id) in current_state_ids { if since_state_ids.get(&key) != Some(&id) { @@ -1126,8 +1332,12 @@ pub async fn sync_events_v4_route( } } // Look for device list updates in this room - device_list_changes - .extend(services().users.keys_changed(room_id.as_ref(), globalsince, None).filter_map(Result::ok)); + device_list_changes.extend( + services() + .users + .keys_changed(room_id.as_ref(), globalsince, None) + .filter_map(Result::ok), + ); } for user_id in left_encrypted_users { let dont_share_encrypted_room = services() @@ -1171,19 +1381,33 @@ pub async fn sync_events_v4_route( .ranges .into_iter() .map(|mut r| { - r.0 = r.0.clamp(uint!(0), UInt::from(all_joined_rooms.len() as u32 - 1)); - r.1 = r.1.clamp(r.0, UInt::from(all_joined_rooms.len() as u32 - 1)); + r.0 = + r.0.clamp(uint!(0), UInt::from(all_joined_rooms.len() as u32 - 1)); + r.1 = + r.1.clamp(r.0, UInt::from(all_joined_rooms.len() as u32 - 1)); let room_ids = all_joined_rooms[(u64::from(r.0) as usize)..=(u64::from(r.1) as usize)].to_vec(); new_known_rooms.extend(room_ids.iter().cloned()); for room_id in &room_ids { - let todo_room = todo_rooms.entry(room_id.clone()).or_insert((BTreeSet::new(), 0, u64::MAX)); - let limit = list.room_details.timeline_limit.map_or(10, u64::from).min(100); - todo_room.0.extend(list.room_details.required_state.iter().cloned()); + let todo_room = todo_rooms + .entry(room_id.clone()) + .or_insert((BTreeSet::new(), 0, u64::MAX)); + let limit = list + .room_details + .timeline_limit + .map_or(10, u64::from) + .min(100); + todo_room + .0 + .extend(list.room_details.required_state.iter().cloned()); todo_room.1 = todo_room.1.max(limit); // 0 means unknown because it got out of date - todo_room.2 = todo_room - .2 - .min(known_rooms.get(&list_id).and_then(|k| k.get(room_id)).copied().unwrap_or(0)); + todo_room.2 = todo_room.2.min( + known_rooms + .get(&list_id) + .and_then(|k| k.get(room_id)) + .copied() + .unwrap_or(0), + ); } sync_events::v4::SyncOp { op: SlidingOp::Sync, @@ -1215,13 +1439,20 @@ pub async fn sync_events_v4_route( if !services().rooms.metadata.exists(room_id)? { continue; } - let todo_room = todo_rooms.entry(room_id.clone()).or_insert((BTreeSet::new(), 0, u64::MAX)); + let todo_room = todo_rooms + .entry(room_id.clone()) + .or_insert((BTreeSet::new(), 0, u64::MAX)); let limit = room.timeline_limit.map_or(10, u64::from).min(100); todo_room.0.extend(room.required_state.iter().cloned()); todo_room.1 = todo_room.1.max(limit); // 0 means unknown because it got out of date - todo_room.2 = - todo_room.2.min(known_rooms.get("subscriptions").and_then(|k| k.get(room_id)).copied().unwrap_or(0)); + todo_room.2 = todo_room.2.min( + known_rooms + .get("subscriptions") + .and_then(|k| k.get(room_id)) + .copied() + .unwrap_or(0), + ); known_subscription_rooms.insert(room_id.clone()); } @@ -1279,11 +1510,19 @@ pub async fn sync_events_v4_route( } }); - let room_events: Vec<_> = timeline_pdus.iter().map(|(_, pdu)| pdu.to_sync_room_event()).collect(); + let room_events: Vec<_> = timeline_pdus + .iter() + .map(|(_, pdu)| pdu.to_sync_room_event()) + .collect(); let required_state = required_state_request .iter() - .map(|state| services().rooms.state_accessor.room_state_get(room_id, &state.0, &state.1)) + .map(|state| { + services() + .rooms + .state_accessor + .room_state_get(room_id, &state.0, &state.1) + }) .filter_map(Result::ok) .flatten() .map(|state| state.to_sync_state_event()) @@ -1298,12 +1537,18 @@ pub async fn sync_events_v4_route( .filter(|member| member != &sender_user) .map(|member| { Ok::<_, Error>( - services().rooms.state_accessor.get_member(room_id, &member)?.map(|memberevent| { - ( - memberevent.displayname.unwrap_or_else(|| member.to_string()), - memberevent.avatar_url, - ) - }), + services() + .rooms + .state_accessor + .get_member(room_id, &member)? + .map(|memberevent| { + ( + memberevent + .displayname + .unwrap_or_else(|| member.to_string()), + memberevent.avatar_url, + ) + }), ) }) .filter_map(Result::ok) @@ -1313,7 +1558,13 @@ pub async fn sync_events_v4_route( let name = match heroes.len().cmp(&(1_usize)) { Ordering::Greater => { let last = heroes[0].0.clone(); - Some(heroes[1..].iter().map(|h| h.0.clone()).collect::>().join(", ") + " and " + &last) + Some( + heroes[1..] + .iter() + .map(|h| h.0.clone()) + .collect::>() + .join(", ") + " and " + &last, + ) }, Ordering::Equal => Some(heroes[0].0.clone()), Ordering::Less => None, @@ -1364,10 +1615,20 @@ pub async fn sync_events_v4_route( prev_batch, limited, joined_count: Some( - (services().rooms.state_cache.room_joined_count(room_id)?.unwrap_or(0) as u32).into(), + (services() + .rooms + .state_cache + .room_joined_count(room_id)? + .unwrap_or(0) as u32) + .into(), ), invited_count: Some( - (services().rooms.state_cache.room_invited_count(room_id)?.unwrap_or(0) as u32).into(), + (services() + .rooms + .state_cache + .room_invited_count(room_id)? + .unwrap_or(0) as u32) + .into(), ), num_live: None, // Count events in timeline greater than global sync counter timestamp: None, @@ -1375,7 +1636,10 @@ pub async fn sync_events_v4_route( ); } - if rooms.iter().all(|(_, r)| r.timeline.is_empty() && r.required_state.is_empty()) { + if rooms + .iter() + .all(|(_, r)| r.timeline.is_empty() && r.required_state.is_empty()) + { // Hang a few seconds so requests are not spammed // Stop hanging if new info arrives let mut duration = body.timeout.unwrap_or(Duration::from_secs(30)); @@ -1394,7 +1658,9 @@ pub async fn sync_events_v4_route( extensions: sync_events::v4::Extensions { to_device: if body.extensions.to_device.enabled.unwrap_or(false) { Some(sync_events::v4::ToDevice { - events: services().users.get_to_device_events(&sender_user, &sender_device)?, + events: services() + .users + .get_to_device_events(&sender_user, &sender_device)?, next_batch: next_batch.to_string(), }) } else { @@ -1405,7 +1671,9 @@ pub async fn sync_events_v4_route( changed: device_list_changes.into_iter().collect(), left: device_list_left.into_iter().collect(), }, - device_one_time_keys_count: services().users.count_one_time_keys(&sender_user, &sender_device)?, + device_one_time_keys_count: services() + .users + .count_one_time_keys(&sender_user, &sender_device)?, // Fallback keys are not yet supported device_unused_fallback_key_types: None, }, diff --git a/src/api/client_server/tag.rs b/src/api/client_server/tag.rs index 406dcaa7..eb00e9fd 100644 --- a/src/api/client_server/tag.rs +++ b/src/api/client_server/tag.rs @@ -18,7 +18,9 @@ use crate::{services, Error, Result, Ruma}; pub async fn update_tag_route(body: Ruma) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event = services().account_data.get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?; + let event = services() + .account_data + .get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?; let mut tags_event = event.map_or_else( || { @@ -31,7 +33,10 @@ pub async fn update_tag_route(body: Ruma) -> Result) -> Result) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event = services().account_data.get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?; + let event = services() + .account_data + .get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?; let mut tags_event = event.map_or_else( || { @@ -84,7 +91,9 @@ pub async fn delete_tag_route(body: Ruma) -> Result) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event = services().account_data.get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?; + let event = services() + .account_data + .get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?; let tags_event = event.map_or_else( || { diff --git a/src/api/client_server/threads.rs b/src/api/client_server/threads.rs index 72deebed..c752d782 100644 --- a/src/api/client_server/threads.rs +++ b/src/api/client_server/threads.rs @@ -7,10 +7,15 @@ pub async fn get_threads_route(body: Ruma) -> Result) -> Result) -> Result avatar_url: services().users.avatar_url(&user_id).ok()?, }; - let user_id_matches = user.user_id.to_string().to_lowercase().contains(&body.search_term.to_lowercase()); + let user_id_matches = user + .user_id + .to_string() + .to_lowercase() + .contains(&body.search_term.to_lowercase()); let user_displayname_matches = user .display_name .as_ref() - .filter(|name| name.to_lowercase().contains(&body.search_term.to_lowercase())) + .filter(|name| { + name.to_lowercase() + .contains(&body.search_term.to_lowercase()) + }) .is_some(); if !user_id_matches && !user_displayname_matches { @@ -44,24 +51,34 @@ pub async fn search_users_route(body: Ruma) -> Result // It's a matching user, but is the sender allowed to see them? let mut user_visible = false; - let user_is_in_public_rooms = - services().rooms.state_cache.rooms_joined(&user_id).filter_map(Result::ok).any(|room| { - services().rooms.state_accessor.room_state_get(&room, &StateEventType::RoomJoinRules, "").map_or( - false, - |event| { + let user_is_in_public_rooms = services() + .rooms + .state_cache + .rooms_joined(&user_id) + .filter_map(Result::ok) + .any(|room| { + services() + .rooms + .state_accessor + .room_state_get(&room, &StateEventType::RoomJoinRules, "") + .map_or(false, |event| { event.map_or(false, |event| { serde_json::from_str(event.content.get()) .map_or(false, |r: RoomJoinRulesEventContent| r.join_rule == JoinRule::Public) }) - }, - ) + }) }); if user_is_in_public_rooms { user_visible = true; } else { - let user_is_in_shared_rooms = - services().rooms.user.get_shared_rooms(vec![sender_user.clone(), user_id]).ok()?.next().is_some(); + let user_is_in_shared_rooms = services() + .rooms + .user + .get_shared_rooms(vec![sender_user.clone(), user_id]) + .ok()? + .next() + .is_some(); if user_is_in_shared_rooms { user_visible = true; diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs index bd0110d5..36f184fc 100644 --- a/src/api/ruma_wrapper/axum.rs +++ b/src/api/ruma_wrapper/axum.rs @@ -44,14 +44,16 @@ where let (mut parts, mut body) = match req.with_limited_body() { Ok(limited_req) => { let (parts, body) = limited_req.into_parts(); - let body = - to_bytes(body).await.map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?; + let body = to_bytes(body) + .await + .map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?; (parts, body) }, Err(original_req) => { let (parts, body) = original_req.into_parts(); - let body = - to_bytes(body).await.map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?; + let body = to_bytes(body) + .await + .map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?; (parts, body) }, }; @@ -180,8 +182,10 @@ where return Err(Error::bad_config("Federation is disabled.")); } - let TypedHeader(Authorization(x_matrix)) = - parts.extract::>>().await.map_err(|e| { + let TypedHeader(Authorization(x_matrix)) = parts + .extract::>>() + .await + .map_err(|e| { warn!("Missing or invalid Authorization header: {}", e); let msg = match e.reason() { @@ -265,7 +269,11 @@ where AuthScheme::None => match parts.uri.path() { // allow_public_room_directory_without_auth "/_matrix/client/v3/publicRooms" | "/_matrix/client/r0/publicRooms" => { - if !services().globals.config.allow_public_room_directory_without_auth { + if !services() + .globals + .config + .allow_public_room_directory_without_auth + { let token = match token { Some(token) => token, _ => return Err(Error::BadRequest(ErrorKind::MissingToken, "Missing access token.")), @@ -362,7 +370,9 @@ impl Credentials for XMatrix { "HeaderValue to decode should start with \"X-Matrix ..\", received = {value:?}", ); - let parameters = str::from_utf8(&value.as_bytes()["X-Matrix ".len()..]).ok()?.trim_start(); + let parameters = str::from_utf8(&value.as_bytes()["X-Matrix ".len()..]) + .ok()? + .trim_start(); let mut origin = None; let mut destination = None; @@ -374,7 +384,10 @@ impl Credentials for XMatrix { // It's not at all clear why some fields are quoted and others not in the spec, // let's simply accept either form for every field. - let value = value.strip_prefix('"').and_then(|rest| rest.strip_suffix('"')).unwrap_or(value); + let value = value + .strip_prefix('"') + .and_then(|rest| rest.strip_suffix('"')) + .unwrap_or(value); // FIXME: Catch multiple fields of the same name match name { diff --git a/src/api/server_server.rs b/src/api/server_server.rs index 2fe5d538..36187c89 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -159,7 +159,13 @@ where let mut write_destination_to_cache = false; - let cached_result = services().globals.actual_destinations().read().await.get(destination).cloned(); + let cached_result = services() + .globals + .actual_destinations() + .read() + .await + .get(destination) + .cloned(); let (actual_destination, host) = if let Some(result) = cached_result { result @@ -196,7 +202,12 @@ where request_map.insert("method".to_owned(), T::METADATA.method.to_string().into()); request_map.insert( "uri".to_owned(), - http_request.uri().path_and_query().expect("all requests have a path").to_string().into(), + http_request + .uri() + .path_and_query() + .expect("all requests have a path") + .to_string() + .into(), ); request_map.insert("origin".to_owned(), services().globals.server_name().as_str().into()); request_map.insert("destination".to_owned(), destination.as_str().into()); @@ -217,7 +228,12 @@ where .as_object() .unwrap() .values() - .map(|v| v.as_object().unwrap().iter().map(|(k, v)| (k, v.as_str().unwrap()))); + .map(|v| { + v.as_object() + .unwrap() + .iter() + .map(|(k, v)| (k, v.as_str().unwrap())) + }); for signature_server in signatures { for s in signature_server { @@ -257,7 +273,12 @@ where } debug!("Sending request to {destination} at {url}"); - let response = services().globals.client.federation.execute(reqwest_request).await; + let response = services() + .globals + .client + .federation + .execute(reqwest_request) + .await; debug!("Received response from {destination} at {url}"); match response { @@ -283,10 +304,14 @@ where } let status = response.status(); - let mut http_response_builder = http::Response::builder().status(status).version(response.version()); + let mut http_response_builder = http::Response::builder() + .status(status) + .version(response.version()); mem::swap( response.headers_mut(), - http_response_builder.headers_mut().expect("http::response::Builder is usable"), + http_response_builder + .headers_mut() + .expect("http::response::Builder is usable"), ); debug!("Getting response bytes from {destination}"); @@ -301,11 +326,16 @@ where "Response not successful\n{} {}: {}", url, status, - String::from_utf8_lossy(&body).lines().collect::>().join(" ") + String::from_utf8_lossy(&body) + .lines() + .collect::>() + .join(" ") ); } - let http_response = http_response_builder.body(body).expect("reqwest body is valid http body"); + let http_response = http_response_builder + .body(body) + .expect("reqwest body is valid http body"); if status.is_success() { debug!("Parsing response bytes from {destination}"); @@ -490,7 +520,12 @@ async fn find_actual_destination(destination: &'_ ServerName) -> (FedDest, FedDe } async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u16) { - match services().globals.dns_resolver().lookup_ip(hostname.to_owned()).await { + match services() + .globals + .dns_resolver() + .lookup_ip(hostname.to_owned()) + .await + { Ok(override_ip) => { debug!("Caching result of {:?} overriding {:?}", hostname, overname); @@ -521,7 +556,11 @@ async fn query_srv_record(hostname: &'_ str) -> Option { async fn lookup_srv(hostname: &str) -> Result { debug!("querying SRV for {:?}", hostname); let hostname = hostname.trim_end_matches('.'); - services().globals.dns_resolver().srv_lookup(hostname.to_owned()).await + services() + .globals + .dns_resolver() + .srv_lookup(hostname.to_owned()) + .await } let first_hostname = format!("_matrix-fed._tcp.{hostname}."); @@ -539,7 +578,14 @@ async fn query_srv_record(hostname: &'_ str) -> Option { } async fn request_well_known(destination: &str) -> Option { - if !services().globals.resolver.overrides.read().unwrap().contains_key(destination) { + if !services() + .globals + .resolver + .overrides + .read() + .unwrap() + .contains_key(destination) + { query_and_cache_override(destination, destination, 8448).await; } @@ -663,7 +709,10 @@ pub async fn get_server_keys_deprecated_route() -> impl IntoResponse { get_serve pub async fn get_public_rooms_filtered_route( body: Ruma, ) -> Result { - if !services().globals.allow_public_room_directory_over_federation() { + if !services() + .globals + .allow_public_room_directory_over_federation() + { return Err(Error::bad_config("Room directory is not public.")); } @@ -690,7 +739,10 @@ pub async fn get_public_rooms_filtered_route( pub async fn get_public_rooms_route( body: Ruma, ) -> Result { - if !services().globals.allow_public_room_directory_over_federation() { + if !services() + .globals + .allow_public_room_directory_over_federation() + { return Err(Error::bad_config("Room directory is not public.")); } @@ -743,7 +795,10 @@ pub fn parse_incoming_pdu(pdu: &RawJsonValue) -> Result<(OwnedEventId, Canonical pub async fn send_transaction_message_route( body: Ruma, ) -> Result { - let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); + let sender_servername = body + .sender_servername + .as_ref() + .expect("server is authenticated"); let mut resolved_map = BTreeMap::new(); @@ -799,8 +854,15 @@ pub async fn send_transaction_message_route( }); for (event_id, value, room_id) in parsed_pdus { - let mutex = - Arc::clone(services().globals.roomid_mutex_federation.write().await.entry(room_id.clone()).or_default()); + let mutex = Arc::clone( + services() + .globals + .roomid_mutex_federation + .write() + .await + .entry(room_id.clone()) + .or_default(), + ); let mutex_lock = mutex.lock().await; let start_time = Instant::now(); resolved_map.insert( @@ -831,7 +893,11 @@ pub async fn send_transaction_message_route( } } - for edu in body.edus.iter().filter_map(|edu| serde_json::from_str::(edu.json().get()).ok()) { + for edu in body + .edus + .iter() + .filter_map(|edu| serde_json::from_str::(edu.json().get()).ok()) + { match edu { Edu::Presence(presence) => { if !services().globals.allow_incoming_presence() { @@ -862,7 +928,13 @@ pub async fn send_transaction_message_route( .event_ids .iter() .filter_map(|id| { - services().rooms.timeline.get_pdu_count(id).ok().flatten().map(|r| (id, r)) + services() + .rooms + .timeline + .get_pdu_count(id) + .ok() + .flatten() + .map(|r| (id, r)) }) .max_by_key(|(_, count)| *count) { @@ -879,7 +951,11 @@ pub async fn send_transaction_message_route( content: ReceiptEventContent(receipt_content), room_id: room_id.clone(), }; - services().rooms.edus.read_receipt.readreceipt_update(&user_id, &room_id, event)?; + services() + .rooms + .edus + .read_receipt + .readreceipt_update(&user_id, &room_id, event)?; } else { // TODO fetch missing events debug!("No known event ids in read receipt: {:?}", user_updates); @@ -888,7 +964,11 @@ pub async fn send_transaction_message_route( } }, Edu::Typing(typing) => { - if services().rooms.state_cache.is_joined(&typing.user_id, &typing.room_id)? { + if services() + .rooms + .state_cache + .is_joined(&typing.user_id, &typing.room_id)? + { if typing.typing { services() .rooms @@ -897,7 +977,12 @@ pub async fn send_transaction_message_route( .typing_add(&typing.user_id, &typing.room_id, 3000 + utils::millis_since_unix_epoch()) .await?; } else { - services().rooms.edus.typing.typing_remove(&typing.user_id, &typing.room_id).await?; + services() + .rooms + .edus + .typing + .typing_remove(&typing.user_id, &typing.room_id) + .await?; } } }, @@ -914,7 +999,11 @@ pub async fn send_transaction_message_route( messages, }) => { // Check if this is a new transaction id - if services().transaction_ids.existing_txnid(&sender, None, &message_id)?.is_some() { + if services() + .transaction_ids + .existing_txnid(&sender, None, &message_id)? + .is_some() + { continue; } @@ -952,7 +1041,9 @@ pub async fn send_transaction_message_route( } // Save transaction id with empty data - services().transaction_ids.add_txnid(&sender, None, &message_id, &[])?; + services() + .transaction_ids + .add_txnid(&sender, None, &message_id, &[])?; }, Edu::SigningKeyUpdate(SigningKeyUpdateContent { user_id, @@ -963,7 +1054,9 @@ pub async fn send_transaction_message_route( continue; } if let Some(master_key) = master_key { - services().users.add_cross_signing_keys(&user_id, &master_key, &self_signing_key, &None, true)?; + services() + .users + .add_cross_signing_keys(&user_id, &master_key, &self_signing_key, &None, true)?; } }, Edu::_Custom(_) => {}, @@ -971,7 +1064,10 @@ pub async fn send_transaction_message_route( } Ok(send_transaction_message::v1::Response { - pdus: resolved_map.into_iter().map(|(e, r)| (e, r.map_err(|e| e.sanitized_error()))).collect(), + pdus: resolved_map + .into_iter() + .map(|(e, r)| (e, r.map_err(|e| e.sanitized_error()))) + .collect(), }) } @@ -982,12 +1078,19 @@ pub async fn send_transaction_message_route( /// - Only works if a user of this server is currently invited or joined the /// room pub async fn get_event_route(body: Ruma) -> Result { - let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); + let sender_servername = body + .sender_servername + .as_ref() + .expect("server is authenticated"); - let event = services().rooms.timeline.get_pdu_json(&body.event_id)?.ok_or_else(|| { - warn!("Event not found, event ID: {:?}", &body.event_id); - Error::BadRequest(ErrorKind::NotFound, "Event not found.") - })?; + let event = services() + .rooms + .timeline + .get_pdu_json(&body.event_id)? + .ok_or_else(|| { + warn!("Event not found, event ID: {:?}", &body.event_id); + Error::BadRequest(ErrorKind::NotFound, "Event not found.") + })?; let room_id_str = event .get("room_id") @@ -997,11 +1100,19 @@ pub async fn get_event_route(body: Ruma) -> Result::try_from(room_id_str) .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; - if !services().rooms.state_cache.server_in_room(sender_servername, room_id)? { + if !services() + .rooms + .state_cache + .server_in_room(sender_servername, room_id)? + { return Err(Error::BadRequest(ErrorKind::Forbidden, "Server is not in room")); } - if !services().rooms.state_accessor.server_can_see_event(sender_servername, room_id, &body.event_id)? { + if !services() + .rooms + .state_accessor + .server_can_see_event(sender_servername, room_id, &body.event_id)? + { return Err(Error::BadRequest(ErrorKind::Forbidden, "Server is not allowed to see event.")); } @@ -1017,15 +1128,25 @@ pub async fn get_event_route(body: Ruma) -> Result) -> Result { - let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); + let sender_servername = body + .sender_servername + .as_ref() + .expect("server is authenticated"); debug!("Got backfill request from: {}", sender_servername); - if !services().rooms.state_cache.server_in_room(sender_servername, &body.room_id)? { + if !services() + .rooms + .state_cache + .server_in_room(sender_servername, &body.room_id)? + { return Err(Error::BadRequest(ErrorKind::Forbidden, "Server is not in room.")); } - services().rooms.event_handler.acl_check(sender_servername, &body.room_id)?; + services() + .rooms + .event_handler + .acl_check(sender_servername, &body.room_id)?; let until = body .v @@ -1047,7 +1168,10 @@ pub async fn get_backfill_route(body: Ruma) -> Result .filter_map(Result::ok) .filter(|(_, e)| { matches!( - services().rooms.state_accessor.server_can_see_event(sender_servername, &e.room_id, &e.event_id,), + services() + .rooms + .state_accessor + .server_can_see_event(sender_servername, &e.room_id, &e.event_id,), Ok(true), ) }) @@ -1069,13 +1193,23 @@ pub async fn get_backfill_route(body: Ruma) -> Result pub async fn get_missing_events_route( body: Ruma, ) -> Result { - let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); + let sender_servername = body + .sender_servername + .as_ref() + .expect("server is authenticated"); - if !services().rooms.state_cache.server_in_room(sender_servername, &body.room_id)? { + if !services() + .rooms + .state_cache + .server_in_room(sender_servername, &body.room_id)? + { return Err(Error::BadRequest(ErrorKind::Forbidden, "Server is not in room")); } - services().rooms.event_handler.acl_check(sender_servername, &body.room_id)?; + services() + .rooms + .event_handler + .acl_check(sender_servername, &body.room_id)?; let mut queued_events = body.latest_events.clone(); let mut events = Vec::new(); @@ -1142,18 +1276,32 @@ pub async fn get_missing_events_route( pub async fn get_event_authorization_route( body: Ruma, ) -> Result { - let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); + let sender_servername = body + .sender_servername + .as_ref() + .expect("server is authenticated"); - if !services().rooms.state_cache.server_in_room(sender_servername, &body.room_id)? { + if !services() + .rooms + .state_cache + .server_in_room(sender_servername, &body.room_id)? + { return Err(Error::BadRequest(ErrorKind::Forbidden, "Server is not in room.")); } - services().rooms.event_handler.acl_check(sender_servername, &body.room_id)?; + services() + .rooms + .event_handler + .acl_check(sender_servername, &body.room_id)?; - let event = services().rooms.timeline.get_pdu_json(&body.event_id)?.ok_or_else(|| { - warn!("Event not found, event ID: {:?}", &body.event_id); - Error::BadRequest(ErrorKind::NotFound, "Event not found.") - })?; + let event = services() + .rooms + .timeline + .get_pdu_json(&body.event_id)? + .ok_or_else(|| { + warn!("Event not found, event ID: {:?}", &body.event_id); + Error::BadRequest(ErrorKind::NotFound, "Event not found.") + })?; let room_id_str = event .get("room_id") @@ -1163,7 +1311,11 @@ pub async fn get_event_authorization_route( let room_id = <&RoomId>::try_from(room_id_str) .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; - let auth_chain_ids = services().rooms.auth_chain.get_auth_chain(room_id, vec![Arc::from(&*body.event_id)]).await?; + let auth_chain_ids = services() + .rooms + .auth_chain + .get_auth_chain(room_id, vec![Arc::from(&*body.event_id)]) + .await?; Ok(get_event_authorization::v1::Response { auth_chain: auth_chain_ids @@ -1177,13 +1329,23 @@ pub async fn get_event_authorization_route( /// /// Retrieves the current state of the room. pub async fn get_room_state_route(body: Ruma) -> Result { - let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); + let sender_servername = body + .sender_servername + .as_ref() + .expect("server is authenticated"); - if !services().rooms.state_cache.server_in_room(sender_servername, &body.room_id)? { + if !services() + .rooms + .state_cache + .server_in_room(sender_servername, &body.room_id)? + { return Err(Error::BadRequest(ErrorKind::Forbidden, "Server is not in room.")); } - services().rooms.event_handler.acl_check(sender_servername, &body.room_id)?; + services() + .rooms + .event_handler + .acl_check(sender_servername, &body.room_id)?; let shortstatehash = services() .rooms @@ -1199,13 +1361,21 @@ pub async fn get_room_state_route(body: Ruma) -> Re .into_values() .map(|id| { PduEvent::convert_to_outgoing_federation_event( - services().rooms.timeline.get_pdu_json(&id).unwrap().unwrap(), + services() + .rooms + .timeline + .get_pdu_json(&id) + .unwrap() + .unwrap(), ) }) .collect(); - let auth_chain_ids = - services().rooms.auth_chain.get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)]).await?; + let auth_chain_ids = services() + .rooms + .auth_chain + .get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)]) + .await?; Ok(get_room_state::v1::Response { auth_chain: auth_chain_ids @@ -1227,13 +1397,23 @@ pub async fn get_room_state_route(body: Ruma) -> Re pub async fn get_room_state_ids_route( body: Ruma, ) -> Result { - let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); + let sender_servername = body + .sender_servername + .as_ref() + .expect("server is authenticated"); - if !services().rooms.state_cache.server_in_room(sender_servername, &body.room_id)? { + if !services() + .rooms + .state_cache + .server_in_room(sender_servername, &body.room_id)? + { return Err(Error::BadRequest(ErrorKind::Forbidden, "Server is not in room.")); } - services().rooms.event_handler.acl_check(sender_servername, &body.room_id)?; + services() + .rooms + .event_handler + .acl_check(sender_servername, &body.room_id)?; let shortstatehash = services() .rooms @@ -1250,8 +1430,11 @@ pub async fn get_room_state_ids_route( .map(|id| (*id).to_owned()) .collect(); - let auth_chain_ids = - services().rooms.auth_chain.get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)]).await?; + let auth_chain_ids = services() + .rooms + .auth_chain + .get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)]) + .await?; Ok(get_room_state_ids::v1::Response { auth_chain_ids: auth_chain_ids.map(|id| (*id).to_owned()).collect(), @@ -1269,17 +1452,33 @@ pub async fn create_join_event_template_route( return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); } - let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); + let sender_servername = body + .sender_servername + .as_ref() + .expect("server is authenticated"); - services().rooms.event_handler.acl_check(sender_servername, &body.room_id)?; + services() + .rooms + .event_handler + .acl_check(sender_servername, &body.room_id)?; - let mutex_state = - Arc::clone(services().globals.roomid_mutex_state.write().await.entry(body.room_id.clone()).or_default()); + let mutex_state = Arc::clone( + services() + .globals + .roomid_mutex_state + .write() + .await + .entry(body.room_id.clone()) + .or_default(), + ); let state_lock = mutex_state.lock().await; // TODO: Conduit does not implement restricted join rules yet, we always reject let join_rules_event = - services().rooms.state_accessor.room_state_get(&body.room_id, &StateEventType::RoomJoinRules, "")?; + services() + .rooms + .state_accessor + .room_state_get(&body.room_id, &StateEventType::RoomJoinRules, "")?; let join_rules_event_content: Option = join_rules_event .as_ref() @@ -1376,11 +1575,17 @@ async fn create_join_event( return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); } - services().rooms.event_handler.acl_check(sender_servername, room_id)?; + services() + .rooms + .event_handler + .acl_check(sender_servername, room_id)?; // TODO: Conduit does not implement restricted join rules yet, we always reject let join_rules_event = - services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomJoinRules, "")?; + services() + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomJoinRules, "")?; let join_rules_event_content: Option = join_rules_event .as_ref() @@ -1431,16 +1636,29 @@ async fn create_join_event( let origin: OwnedServerName = serde_json::from_value( serde_json::to_value( - value.get("origin").ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Event needs an origin field."))?, + value + .get("origin") + .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Event needs an origin field."))?, ) .expect("CanonicalJson is valid json value"), ) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid."))?; - services().rooms.event_handler.fetch_required_signing_keys([&value], &pub_key_map).await?; + services() + .rooms + .event_handler + .fetch_required_signing_keys([&value], &pub_key_map) + .await?; - let mutex = - Arc::clone(services().globals.roomid_mutex_federation.write().await.entry(room_id.to_owned()).or_default()); + let mutex = Arc::clone( + services() + .globals + .roomid_mutex_federation + .write() + .await + .entry(room_id.to_owned()) + .or_default(), + ); let mutex_lock = mutex.lock().await; let pdu_id: Vec = services() .rooms @@ -1453,9 +1671,16 @@ async fn create_join_event( ))?; drop(mutex_lock); - let state_ids = services().rooms.state_accessor.state_full_ids(shortstatehash).await?; - let auth_chain_ids = - services().rooms.auth_chain.get_auth_chain(room_id, state_ids.values().cloned().collect()).await?; + let state_ids = services() + .rooms + .state_accessor + .state_full_ids(shortstatehash) + .await?; + let auth_chain_ids = services() + .rooms + .auth_chain + .get_auth_chain(room_id, state_ids.values().cloned().collect()) + .await?; let servers = services() .rooms @@ -1486,7 +1711,10 @@ async fn create_join_event( pub async fn create_join_event_v1_route( body: Ruma, ) -> Result { - let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); + let sender_servername = body + .sender_servername + .as_ref() + .expect("server is authenticated"); let room_state = create_join_event(sender_servername, &body.room_id, &body.pdu).await?; @@ -1501,7 +1729,10 @@ pub async fn create_join_event_v1_route( pub async fn create_join_event_v2_route( body: Ruma, ) -> Result { - let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); + let sender_servername = body + .sender_servername + .as_ref() + .expect("server is authenticated"); let create_join_event::v1::RoomState { auth_chain, @@ -1525,11 +1756,21 @@ pub async fn create_join_event_v2_route( /// /// Invites a remote user to a room. pub async fn create_invite_route(body: Ruma) -> Result { - let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); + let sender_servername = body + .sender_servername + .as_ref() + .expect("server is authenticated"); - services().rooms.event_handler.acl_check(sender_servername, &body.room_id)?; + services() + .rooms + .event_handler + .acl_check(sender_servername, &body.room_id)?; - if !services().globals.supported_room_versions().contains(&body.room_version) { + if !services() + .globals + .supported_room_versions() + .contains(&body.room_version) + { return Err(Error::BadRequest( ErrorKind::IncompatibleRoomVersion { room_version: body.room_version.clone(), @@ -1619,7 +1860,11 @@ pub async fn create_invite_route(body: Ruma) -> Resu // If we are active in the room, the remote server will notify us about the join // via /send - if !services().rooms.state_cache.server_in_room(services().globals.server_name(), &body.room_id)? { + if !services() + .rooms + .state_cache + .server_in_room(services().globals.server_name(), &body.room_id)? + { services() .rooms .state_cache @@ -1650,7 +1895,10 @@ pub async fn get_devices_route(body: Ruma) -> Result) -> Result) -> Result { - if body.device_keys.iter().any(|(u, _)| u.server_name() != services().globals.server_name()) { + if body + .device_keys + .iter() + .any(|(u, _)| u.server_name() != services().globals.server_name()) + { return Err(Error::BadRequest( ErrorKind::InvalidParam, "User does not belong to this server.", @@ -1774,7 +2031,11 @@ pub async fn get_keys_route(body: Ruma) -> Result) -> Result { - if body.one_time_keys.iter().any(|(u, _)| u.server_name() != services().globals.server_name()) { + if body + .one_time_keys + .iter() + .any(|(u, _)| u.server_name() != services().globals.server_name()) + { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Tried to access user from other server.", @@ -1809,10 +2070,17 @@ pub async fn well_known_server_route() -> Result { /// Gets the space tree in a depth-first manner to locate child rooms of a given /// space. pub async fn get_hierarchy_route(body: Ruma) -> Result { - let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); + let sender_servername = body + .sender_servername + .as_ref() + .expect("server is authenticated"); if services().rooms.metadata.exists(&body.room_id)? { - services().rooms.spaces.get_federation_hierarchy(&body.room_id, sender_servername, body.suggested_only).await + services() + .rooms + .spaces + .get_federation_hierarchy(&body.room_id, sender_servername, body.suggested_only) + .await } else { Err(Error::BadRequest(ErrorKind::NotFound, "Room does not exist.")) } diff --git a/src/config/mod.rs b/src/config/mod.rs index e8310b2b..9c17a585 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -251,7 +251,11 @@ impl Config { pub fn warn_deprecated(&self) { debug!("Checking for deprecated config keys"); let mut was_deprecated = false; - for key in self.catchall.keys().filter(|key| DEPRECATED_KEYS.iter().any(|s| s == key)) { + for key in self + .catchall + .keys() + .filter(|key| DEPRECATED_KEYS.iter().any(|s| s == key)) + { warn!("Config parameter \"{}\" is deprecated, ignoring.", key); was_deprecated = true; } @@ -268,8 +272,10 @@ impl Config { /// if there are any. pub fn warn_unknown_key(&self) { debug!("Checking for unknown config keys"); - for key in - self.catchall.keys().filter(|key| "config".to_owned().ne(key.to_owned()) /* "config" is expected */) + for key in self + .catchall + .keys() + .filter(|key| "config".to_owned().ne(key.to_owned()) /* "config" is expected */) { warn!("Config parameter \"{}\" is unknown to conduwuit, ignoring.", key); } diff --git a/src/database/abstraction/rocksdb.rs b/src/database/abstraction/rocksdb.rs index 0c9b0e54..a2ca5fa7 100644 --- a/src/database/abstraction/rocksdb.rs +++ b/src/database/abstraction/rocksdb.rs @@ -155,7 +155,8 @@ impl KeyValueDatabaseEngine for Arc { let db = rust_rocksdb::DBWithThreadMode::::open_cf_descriptors( &db_opts, &config.database_path, - cfs.iter().map(|name| rust_rocksdb::ColumnFamilyDescriptor::new(name, db_opts.clone())), + cfs.iter() + .map(|name| rust_rocksdb::ColumnFamilyDescriptor::new(name, db_opts.clone())), )?; Ok(Arc::new(Engine { @@ -200,13 +201,15 @@ impl KeyValueDatabaseEngine for Arc { fn corked(&self) -> bool { self.corks.load(std::sync::atomic::Ordering::Relaxed) > 0 } fn cork(&self) -> Result<()> { - self.corks.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + self.corks + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); Ok(()) } fn uncork(&self) -> Result<()> { - self.corks.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + self.corks + .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); Ok(()) } @@ -293,7 +296,9 @@ impl KeyValueDatabaseEngine for Arc { format_args!( "#{} {}: {} bytes, {} files\n", info.backup_id, - DateTime::::from_timestamp(info.timestamp, 0).unwrap_or_default().to_rfc2822(), + DateTime::::from_timestamp(info.timestamp, 0) + .unwrap_or_default() + .to_rfc2822(), info.size, info.num_files, ), @@ -358,7 +363,9 @@ impl KvTree for RocksDbEngineTree<'_> { let writeoptions = rust_rocksdb::WriteOptions::default(); let lock = self.write_lock.read().unwrap(); - self.db.rocks.put_cf_opt(&self.cf(), key, value, &writeoptions)?; + self.db + .rocks + .put_cf_opt(&self.cf(), key, value, &writeoptions)?; drop(lock); @@ -465,7 +472,9 @@ impl KvTree for RocksDbEngineTree<'_> { let old = self.db.rocks.get_cf_opt(&self.cf(), key, &readoptions)?; let new = utils::increment(old.as_deref()); - self.db.rocks.put_cf_opt(&self.cf(), key, &new, &writeoptions)?; + self.db + .rocks + .put_cf_opt(&self.cf(), key, &new, &writeoptions)?; drop(lock); diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index 3594d5a9..45723254 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -68,15 +68,18 @@ impl Engine { fn write_lock(&self) -> MutexGuard<'_, Connection> { self.writer.lock() } fn read_lock(&self) -> &Connection { - self.read_conn_tls.get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap()) + self.read_conn_tls + .get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap()) } fn read_lock_iterator(&self) -> &Connection { - self.read_iterator_conn_tls.get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap()) + self.read_iterator_conn_tls + .get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap()) } pub fn flush_wal(self: &Arc) -> Result<()> { - self.write_lock().pragma_update(Some(Main), "wal_checkpoint", "RESTART")?; + self.write_lock() + .pragma_update(Some(Main), "wal_checkpoint", "RESTART")?; Ok(()) } } @@ -153,7 +156,9 @@ impl SqliteTable { pub fn iter_with_guard<'a>(&'a self, guard: &'a Connection) -> Box + 'a> { let statement = Box::leak(Box::new( - guard.prepare(&format!("SELECT key, value FROM {} ORDER BY key ASC", &self.name)).unwrap(), + guard + .prepare(&format!("SELECT key, value FROM {} ORDER BY key ASC", &self.name)) + .unwrap(), )); let statement_ref = NonAliasingBox(statement); @@ -161,7 +166,10 @@ impl SqliteTable { //let name = self.name.clone(); let iterator = Box::new( - statement.query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))).unwrap().map(Result::unwrap), + statement + .query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) + .unwrap() + .map(Result::unwrap), ); Box::new(PreparedStatementIterator { @@ -293,7 +301,10 @@ impl KvTree for SqliteTable { } fn scan_prefix<'a>(&'a self, prefix: Vec) -> Box + 'a> { - Box::new(self.iter_from(&prefix, false).take_while(move |(key, _)| key.starts_with(&prefix))) + Box::new( + self.iter_from(&prefix, false) + .take_while(move |(key, _)| key.starts_with(&prefix)), + ) } fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin + Send + 'a>> { @@ -302,7 +313,9 @@ impl KvTree for SqliteTable { fn clear(&self) -> Result<()> { debug!("clear: running"); - self.engine.write_lock().execute(format!("DELETE FROM {}", self.name).as_str(), [])?; + self.engine + .write_lock() + .execute(format!("DELETE FROM {}", self.name).as_str(), [])?; debug!("clear: ran"); Ok(()) } diff --git a/src/database/key_value/account_data.rs b/src/database/key_value/account_data.rs index 87455a66..54b61742 100644 --- a/src/database/key_value/account_data.rs +++ b/src/database/key_value/account_data.rs @@ -18,7 +18,11 @@ impl service::account_data::Data for KeyValueDatabase { &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, data: &serde_json::Value, ) -> Result<()> { - let mut prefix = room_id.map(ToString::to_string).unwrap_or_default().as_bytes().to_vec(); + let mut prefix = room_id + .map(ToString::to_string) + .unwrap_or_default() + .as_bytes() + .to_vec(); prefix.push(0xFF); prefix.extend_from_slice(user_id.as_bytes()); prefix.push(0xFF); @@ -45,7 +49,8 @@ impl service::account_data::Data for KeyValueDatabase { let prev = self.roomusertype_roomuserdataid.get(&key)?; - self.roomusertype_roomuserdataid.insert(&key, &roomuserdataid)?; + self.roomusertype_roomuserdataid + .insert(&key, &roomuserdataid)?; // Remove old entry if let Some(prev) = prev { @@ -60,7 +65,11 @@ impl service::account_data::Data for KeyValueDatabase { fn get( &self, room_id: Option<&RoomId>, user_id: &UserId, kind: RoomAccountDataEventType, ) -> Result>> { - let mut key = room_id.map(ToString::to_string).unwrap_or_default().as_bytes().to_vec(); + let mut key = room_id + .map(ToString::to_string) + .unwrap_or_default() + .as_bytes() + .to_vec(); key.push(0xFF); key.extend_from_slice(user_id.as_bytes()); key.push(0xFF); @@ -68,7 +77,11 @@ impl service::account_data::Data for KeyValueDatabase { self.roomusertype_roomuserdataid .get(&key)? - .and_then(|roomuserdataid| self.roomuserdataid_accountdata.get(&roomuserdataid).transpose()) + .and_then(|roomuserdataid| { + self.roomuserdataid_accountdata + .get(&roomuserdataid) + .transpose() + }) .transpose()? .map(|data| serde_json::from_slice(&data).map_err(|_| Error::bad_database("could not deserialize"))) .transpose() @@ -81,7 +94,11 @@ impl service::account_data::Data for KeyValueDatabase { ) -> Result>> { let mut userdata = HashMap::new(); - let mut prefix = room_id.map(ToString::to_string).unwrap_or_default().as_bytes().to_vec(); + let mut prefix = room_id + .map(ToString::to_string) + .unwrap_or_default() + .as_bytes() + .to_vec(); prefix.push(0xFF); prefix.extend_from_slice(user_id.as_bytes()); prefix.push(0xFF); diff --git a/src/database/key_value/appservice.rs b/src/database/key_value/appservice.rs index 2e6e3a71..ead37c2b 100644 --- a/src/database/key_value/appservice.rs +++ b/src/database/key_value/appservice.rs @@ -6,7 +6,8 @@ impl service::appservice::Data for KeyValueDatabase { /// Registers an appservice and returns the ID to the caller fn register_appservice(&self, yaml: Registration) -> Result { let id = yaml.id.as_str(); - self.id_appserviceregistrations.insert(id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes())?; + self.id_appserviceregistrations + .insert(id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes())?; Ok(id.to_owned()) } @@ -17,7 +18,8 @@ impl service::appservice::Data for KeyValueDatabase { /// /// * `service_name` - the name you send to register the service previously fn unregister_appservice(&self, service_name: &str) -> Result<()> { - self.id_appserviceregistrations.remove(service_name.as_bytes())?; + self.id_appserviceregistrations + .remove(service_name.as_bytes())?; Ok(()) } @@ -44,7 +46,8 @@ impl service::appservice::Data for KeyValueDatabase { .map(move |id| { Ok(( id.clone(), - self.get_registration(&id)?.expect("iter_ids only returns appservices that exist"), + self.get_registration(&id)? + .expect("iter_ids only returns appservices that exist"), )) }) .collect() diff --git a/src/database/key_value/globals.rs b/src/database/key_value/globals.rs index dc4988b6..246767bc 100644 --- a/src/database/key_value/globals.rs +++ b/src/database/key_value/globals.rs @@ -31,14 +31,17 @@ impl service::globals::Data for KeyValueDatabase { } fn last_check_for_updates_id(&self) -> Result { - self.global.get(LAST_CHECK_FOR_UPDATES_COUNT)?.map_or(Ok(0_u64), |bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("last check for updates count has invalid bytes.")) - }) + self.global + .get(LAST_CHECK_FOR_UPDATES_COUNT)? + .map_or(Ok(0_u64), |bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("last check for updates count has invalid bytes.")) + }) } fn update_check_for_updates_id(&self, id: u64) -> Result<()> { - self.global.insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?; + self.global + .insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?; Ok(()) } @@ -62,11 +65,19 @@ impl service::globals::Data for KeyValueDatabase { futures.push(self.userroomid_joined.watch_prefix(&userid_prefix)); futures.push(self.userroomid_invitestate.watch_prefix(&userid_prefix)); futures.push(self.userroomid_leftstate.watch_prefix(&userid_prefix)); - futures.push(self.userroomid_notificationcount.watch_prefix(&userid_prefix)); + futures.push( + self.userroomid_notificationcount + .watch_prefix(&userid_prefix), + ); futures.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix)); // Events for rooms we are in - for room_id in services().rooms.state_cache.rooms_joined(user_id).filter_map(Result::ok) { + for room_id in services() + .rooms + .state_cache + .rooms_joined(user_id) + .filter_map(Result::ok) + { let short_roomid = services() .rooms .short @@ -98,13 +109,19 @@ impl service::globals::Data for KeyValueDatabase { let mut roomuser_prefix = roomid_prefix.clone(); roomuser_prefix.extend_from_slice(&userid_prefix); - futures.push(self.roomusertype_roomuserdataid.watch_prefix(&roomuser_prefix)); + futures.push( + self.roomusertype_roomuserdataid + .watch_prefix(&roomuser_prefix), + ); } let mut globaluserdata_prefix = vec![0xFF]; globaluserdata_prefix.extend_from_slice(&userid_prefix); - futures.push(self.roomusertype_roomuserdataid.watch_prefix(&globaluserdata_prefix)); + futures.push( + self.roomusertype_roomuserdataid + .watch_prefix(&globaluserdata_prefix), + ); // More key changes (used when user is not joined to any rooms) futures.push(self.keychangeid_userid.watch_prefix(&userid_prefix)); @@ -207,7 +224,9 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n" utils::string_from_bytes( // 1. version - parts.next().expect("splitn always returns at least one element"), + parts + .next() + .expect("splitn always returns at least one element"), ) .map_err(|_| Error::bad_database("Invalid version bytes in keypair.")) .and_then(|version| { @@ -231,10 +250,12 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n" // Not atomic, but this is not critical let signingkeys = self.server_signingkeys.get(origin.as_bytes())?; - let mut keys = signingkeys.and_then(|keys| serde_json::from_slice(&keys).ok()).unwrap_or_else(|| { - // Just insert "now", it doesn't matter - ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now()) - }); + let mut keys = signingkeys + .and_then(|keys| serde_json::from_slice(&keys).ok()) + .unwrap_or_else(|| { + // Just insert "now", it doesn't matter + ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now()) + }); let ServerSigningKeys { verify_keys, @@ -251,7 +272,11 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n" )?; let mut tree = keys.verify_keys; - tree.extend(keys.old_verify_keys.into_iter().map(|old| (old.0, VerifyKey::new(old.1.key)))); + tree.extend( + keys.old_verify_keys + .into_iter() + .map(|old| (old.0, VerifyKey::new(old.1.key))), + ); Ok(tree) } @@ -265,7 +290,11 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n" .and_then(|bytes| serde_json::from_slice(&bytes).ok()) .map_or_else(BTreeMap::new, |keys: ServerSigningKeys| { let mut tree = keys.verify_keys; - tree.extend(keys.old_verify_keys.into_iter().map(|old| (old.0, VerifyKey::new(old.1.key)))); + tree.extend( + keys.old_verify_keys + .into_iter() + .map(|old| (old.0, VerifyKey::new(old.1.key))), + ); tree }); diff --git a/src/database/key_value/key_backups.rs b/src/database/key_value/key_backups.rs index 1eda4322..7ed1da4c 100644 --- a/src/database/key_value/key_backups.rs +++ b/src/database/key_value/key_backups.rs @@ -23,7 +23,8 @@ impl service::key_backups::Data for KeyValueDatabase { &key, &serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"), )?; - self.backupid_etag.insert(&key, &services().globals.next_count()?.to_be_bytes())?; + self.backupid_etag + .insert(&key, &services().globals.next_count()?.to_be_bytes())?; Ok(version) } @@ -53,8 +54,10 @@ impl service::key_backups::Data for KeyValueDatabase { return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup.")); } - self.backupid_algorithm.insert(&key, backup_metadata.json().get().as_bytes())?; - self.backupid_etag.insert(&key, &services().globals.next_count()?.to_be_bytes())?; + self.backupid_algorithm + .insert(&key, backup_metadata.json().get().as_bytes())?; + self.backupid_etag + .insert(&key, &services().globals.next_count()?.to_be_bytes())?; Ok(version.to_owned()) } @@ -69,8 +72,12 @@ impl service::key_backups::Data for KeyValueDatabase { .take_while(move |(k, _)| k.starts_with(&prefix)) .next() .map(|(key, _)| { - utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element")) - .map_err(|_| Error::bad_database("backupid_algorithm key is invalid.")) + utils::string_from_bytes( + key.rsplit(|&b| b == 0xFF) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| Error::bad_database("backupid_algorithm key is invalid.")) }) .transpose() } @@ -87,7 +94,9 @@ impl service::key_backups::Data for KeyValueDatabase { .next() .map(|(key, value)| { let version = utils::string_from_bytes( - key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element"), + key.rsplit(|&b| b == 0xFF) + .next() + .expect("rsplit always returns an element"), ) .map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))?; @@ -105,10 +114,12 @@ impl service::key_backups::Data for KeyValueDatabase { key.push(0xFF); key.extend_from_slice(version.as_bytes()); - self.backupid_algorithm.get(&key)?.map_or(Ok(None), |bytes| { - serde_json::from_slice(&bytes) - .map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid.")) - }) + self.backupid_algorithm + .get(&key)? + .map_or(Ok(None), |bytes| { + serde_json::from_slice(&bytes) + .map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid.")) + }) } fn add_key( @@ -122,14 +133,16 @@ impl service::key_backups::Data for KeyValueDatabase { return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup.")); } - self.backupid_etag.insert(&key, &services().globals.next_count()?.to_be_bytes())?; + self.backupid_etag + .insert(&key, &services().globals.next_count()?.to_be_bytes())?; key.push(0xFF); key.extend_from_slice(room_id.as_bytes()); key.push(0xFF); key.extend_from_slice(session_id.as_bytes()); - self.backupkeyid_backup.insert(&key, key_data.json().get().as_bytes())?; + self.backupkeyid_backup + .insert(&key, key_data.json().get().as_bytes())?; Ok(()) } @@ -148,7 +161,10 @@ impl service::key_backups::Data for KeyValueDatabase { key.extend_from_slice(version.as_bytes()); Ok(utils::u64_from_bytes( - &self.backupid_etag.get(&key)?.ok_or_else(|| Error::bad_database("Backup has no etag."))?, + &self + .backupid_etag + .get(&key)? + .ok_or_else(|| Error::bad_database("Backup has no etag."))?, ) .map_err(|_| Error::bad_database("etag in backupid_etag invalid."))? .to_string()) @@ -162,27 +178,34 @@ impl service::key_backups::Data for KeyValueDatabase { let mut rooms = BTreeMap::::new(); - for result in self.backupkeyid_backup.scan_prefix(prefix).map(|(key, value)| { - let mut parts = key.rsplit(|&b| b == 0xFF); + for result in self + .backupkeyid_backup + .scan_prefix(prefix) + .map(|(key, value)| { + let mut parts = key.rsplit(|&b| b == 0xFF); - let session_id = utils::string_from_bytes( - parts.next().ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?, - ) - .map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?; - - let room_id = RoomId::parse( - utils::string_from_bytes( - parts.next().ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?, + let session_id = utils::string_from_bytes( + parts + .next() + .ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?, ) - .map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?, - ) - .map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid room id."))?; + .map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?; - let key_data = serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))?; + let room_id = RoomId::parse( + utils::string_from_bytes( + parts + .next() + .ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?, + ) + .map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?, + ) + .map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid room id."))?; - Ok::<_, Error>((room_id, session_id, key_data)) - }) { + let key_data = serde_json::from_slice(&value) + .map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))?; + + Ok::<_, Error>((room_id, session_id, key_data)) + }) { let (room_id, session_id, key_data) = result?; rooms .entry(room_id) @@ -213,7 +236,9 @@ impl service::key_backups::Data for KeyValueDatabase { let mut parts = key.rsplit(|&b| b == 0xFF); let session_id = utils::string_from_bytes( - parts.next().ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?, + parts + .next() + .ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?, ) .map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?; diff --git a/src/database/key_value/media.rs b/src/database/key_value/media.rs index af7a883a..d86b50f1 100644 --- a/src/database/key_value/media.rs +++ b/src/database/key_value/media.rs @@ -18,9 +18,19 @@ impl service::media::Data for KeyValueDatabase { key.extend_from_slice(&width.to_be_bytes()); key.extend_from_slice(&height.to_be_bytes()); key.push(0xFF); - key.extend_from_slice(content_disposition.as_ref().map(|f| f.as_bytes()).unwrap_or_default()); + key.extend_from_slice( + content_disposition + .as_ref() + .map(|f| f.as_bytes()) + .unwrap_or_default(), + ); key.push(0xFF); - key.extend_from_slice(content_type.as_ref().map(|c| c.as_bytes()).unwrap_or_default()); + key.extend_from_slice( + content_type + .as_ref() + .map(|c| c.as_bytes()) + .unwrap_or_default(), + ); self.mediaid_file.insert(&key, &[])?; @@ -107,8 +117,9 @@ impl service::media::Data for KeyValueDatabase { }) .transpose()?; - let content_disposition_bytes = - parts.next().ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?; + let content_disposition_bytes = parts + .next() + .ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?; let content_disposition = if content_disposition_bytes.is_empty() { None @@ -139,11 +150,26 @@ impl service::media::Data for KeyValueDatabase { let mut value = Vec::::new(); value.extend_from_slice(×tamp.as_secs().to_be_bytes()); value.push(0xFF); - value.extend_from_slice(data.title.as_ref().map(String::as_bytes).unwrap_or_default()); + value.extend_from_slice( + data.title + .as_ref() + .map(String::as_bytes) + .unwrap_or_default(), + ); value.push(0xFF); - value.extend_from_slice(data.description.as_ref().map(String::as_bytes).unwrap_or_default()); + value.extend_from_slice( + data.description + .as_ref() + .map(String::as_bytes) + .unwrap_or_default(), + ); value.push(0xFF); - value.extend_from_slice(data.image.as_ref().map(String::as_bytes).unwrap_or_default()); + value.extend_from_slice( + data.image + .as_ref() + .map(String::as_bytes) + .unwrap_or_default(), + ); value.push(0xFF); value.extend_from_slice(&data.image_size.unwrap_or(0).to_be_bytes()); value.push(0xFF); @@ -166,27 +192,45 @@ impl service::media::Data for KeyValueDatabase { x => x, };*/ - let title = match values.next().and_then(|b| String::from_utf8(b.to_vec()).ok()) { + let title = match values + .next() + .and_then(|b| String::from_utf8(b.to_vec()).ok()) + { Some(s) if s.is_empty() => None, x => x, }; - let description = match values.next().and_then(|b| String::from_utf8(b.to_vec()).ok()) { + let description = match values + .next() + .and_then(|b| String::from_utf8(b.to_vec()).ok()) + { Some(s) if s.is_empty() => None, x => x, }; - let image = match values.next().and_then(|b| String::from_utf8(b.to_vec()).ok()) { + let image = match values + .next() + .and_then(|b| String::from_utf8(b.to_vec()).ok()) + { Some(s) if s.is_empty() => None, x => x, }; - let image_size = match values.next().map(|b| usize::from_be_bytes(b.try_into().unwrap_or_default())) { + let image_size = match values + .next() + .map(|b| usize::from_be_bytes(b.try_into().unwrap_or_default())) + { Some(0) => None, x => x, }; - let image_width = match values.next().map(|b| u32::from_be_bytes(b.try_into().unwrap_or_default())) { + let image_width = match values + .next() + .map(|b| u32::from_be_bytes(b.try_into().unwrap_or_default())) + { Some(0) => None, x => x, }; - let image_height = match values.next().map(|b| u32::from_be_bytes(b.try_into().unwrap_or_default())) { + let image_height = match values + .next() + .map(|b| u32::from_be_bytes(b.try_into().unwrap_or_default())) + { Some(0) => None, x => x, }; diff --git a/src/database/key_value/pusher.rs b/src/database/key_value/pusher.rs index 537c7343..851831ec 100644 --- a/src/database/key_value/pusher.rs +++ b/src/database/key_value/pusher.rs @@ -53,7 +53,9 @@ impl service::pusher::Data for KeyValueDatabase { Box::new(self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| { let mut parts = k.splitn(2, |&b| b == 0xFF); let _senderkey = parts.next(); - let push_key = parts.next().ok_or_else(|| Error::bad_database("Invalid senderkey_pusher in db"))?; + let push_key = parts + .next() + .ok_or_else(|| Error::bad_database("Invalid senderkey_pusher in db"))?; let push_key_string = utils::string_from_bytes(push_key) .map_err(|_| Error::bad_database("Invalid pusher bytes in senderkey_pusher"))?; diff --git a/src/database/key_value/rooms/alias.rs b/src/database/key_value/rooms/alias.rs index e035784a..b5a976c7 100644 --- a/src/database/key_value/rooms/alias.rs +++ b/src/database/key_value/rooms/alias.rs @@ -4,7 +4,8 @@ use crate::{database::KeyValueDatabase, service, services, utils, Error, Result} impl service::rooms::alias::Data for KeyValueDatabase { fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> { - self.alias_roomid.insert(alias.alias().as_bytes(), room_id.as_bytes())?; + self.alias_roomid + .insert(alias.alias().as_bytes(), room_id.as_bytes())?; let mut aliasid = room_id.as_bytes().to_vec(); aliasid.push(0xFF); aliasid.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); @@ -55,16 +56,20 @@ impl service::rooms::alias::Data for KeyValueDatabase { } fn all_local_aliases<'a>(&'a self) -> Box> + 'a> { - Box::new(self.alias_roomid.iter().map(|(room_alias_bytes, room_id_bytes)| { - let room_alias_localpart = utils::string_from_bytes(&room_alias_bytes) - .map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))?; + Box::new( + self.alias_roomid + .iter() + .map(|(room_alias_bytes, room_id_bytes)| { + let room_alias_localpart = utils::string_from_bytes(&room_alias_bytes) + .map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))?; - let room_id = utils::string_from_bytes(&room_id_bytes) - .map_err(|_| Error::bad_database("Invalid room_id bytes in aliasid_alias."))? - .try_into() - .map_err(|_| Error::bad_database("Invalid room_id in aliasid_alias."))?; + let room_id = utils::string_from_bytes(&room_id_bytes) + .map_err(|_| Error::bad_database("Invalid room_id bytes in aliasid_alias."))? + .try_into() + .map_err(|_| Error::bad_database("Invalid room_id in aliasid_alias."))?; - Ok((room_id, room_alias_localpart)) - })) + Ok((room_id, room_alias_localpart)) + }), + ) } } diff --git a/src/database/key_value/rooms/auth_chain.rs b/src/database/key_value/rooms/auth_chain.rs index 89e8502f..1dacb470 100644 --- a/src/database/key_value/rooms/auth_chain.rs +++ b/src/database/key_value/rooms/auth_chain.rs @@ -12,18 +12,24 @@ impl service::rooms::auth_chain::Data for KeyValueDatabase { // We only save auth chains for single events in the db if key.len() == 1 { // Check DB cache - let chain = self.shorteventid_authchain.get(&key[0].to_be_bytes())?.map(|chain| { - chain - .chunks_exact(size_of::()) - .map(|chunk| utils::u64_from_bytes(chunk).expect("byte length is correct")) - .collect() - }); + let chain = self + .shorteventid_authchain + .get(&key[0].to_be_bytes())? + .map(|chain| { + chain + .chunks_exact(size_of::()) + .map(|chunk| utils::u64_from_bytes(chunk).expect("byte length is correct")) + .collect() + }); if let Some(chain) = chain { let chain = Arc::new(chain); // Cache in RAM - self.auth_chain_cache.lock().unwrap().insert(vec![key[0]], Arc::clone(&chain)); + self.auth_chain_cache + .lock() + .unwrap() + .insert(vec![key[0]], Arc::clone(&chain)); return Ok(Some(chain)); } @@ -37,12 +43,18 @@ impl service::rooms::auth_chain::Data for KeyValueDatabase { if key.len() == 1 { self.shorteventid_authchain.insert( &key[0].to_be_bytes(), - &auth_chain.iter().flat_map(|s| s.to_be_bytes().to_vec()).collect::>(), + &auth_chain + .iter() + .flat_map(|s| s.to_be_bytes().to_vec()) + .collect::>(), )?; } // Cache in RAM - self.auth_chain_cache.lock().unwrap().insert(key, auth_chain); + self.auth_chain_cache + .lock() + .unwrap() + .insert(key, auth_chain); Ok(()) } diff --git a/src/database/key_value/rooms/edus/presence.rs b/src/database/key_value/rooms/edus/presence.rs index 56feeb03..b3e458fc 100644 --- a/src/database/key_value/rooms/edus/presence.rs +++ b/src/database/key_value/rooms/edus/presence.rs @@ -65,7 +65,8 @@ impl service::rooms::edus::presence::Data for KeyValueDatabase { None => Presence::new(new_state.clone(), new_state == PresenceState::Online, now, count, None), }; - self.roomuserid_presence.insert(&key, &new_presence.to_json_bytes()?)?; + self.roomuserid_presence + .insert(&key, &new_presence.to_json_bytes()?)?; } let timeout = match new_state { @@ -73,10 +74,12 @@ impl service::rooms::edus::presence::Data for KeyValueDatabase { _ => services().globals.config.presence_offline_timeout_s, }; - self.presence_timer_sender.send((user_id.to_owned(), Duration::from_secs(timeout))).map_err(|e| { - error!("Failed to add presence timer: {}", e); - Error::bad_database("Failed to add presence timer") - }) + self.presence_timer_sender + .send((user_id.to_owned(), Duration::from_secs(timeout))) + .map_err(|e| { + error!("Failed to add presence timer: {}", e); + Error::bad_database("Failed to add presence timer") + }) } fn set_presence( @@ -104,12 +107,15 @@ impl service::rooms::edus::presence::Data for KeyValueDatabase { _ => services().globals.config.presence_offline_timeout_s, }; - self.presence_timer_sender.send((user_id.to_owned(), Duration::from_secs(timeout))).map_err(|e| { - error!("Failed to add presence timer: {}", e); - Error::bad_database("Failed to add presence timer") - })?; + self.presence_timer_sender + .send((user_id.to_owned(), Duration::from_secs(timeout))) + .map_err(|e| { + error!("Failed to add presence timer: {}", e); + Error::bad_database("Failed to add presence timer") + })?; - self.roomuserid_presence.insert(&key, &presence.to_json_bytes()?)?; + self.roomuserid_presence + .insert(&key, &presence.to_json_bytes()?)?; Ok(()) } diff --git a/src/database/key_value/rooms/edus/read_receipt.rs b/src/database/key_value/rooms/edus/read_receipt.rs index ab191b39..4f6ad3af 100644 --- a/src/database/key_value/rooms/edus/read_receipt.rs +++ b/src/database/key_value/rooms/edus/read_receipt.rs @@ -18,7 +18,10 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase { .iter_from(&last_possible_key, true) .take_while(|(key, _)| key.starts_with(&prefix)) .find(|(key, _)| { - key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element") == user_id.as_bytes() + key.rsplit(|&b| b == 0xFF) + .next() + .expect("rsplit always returns an element") + == user_id.as_bytes() }) { // This is the old room_latest self.readreceiptid_readreceipt.remove(&old)?; @@ -78,9 +81,11 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase { key.push(0xFF); key.extend_from_slice(user_id.as_bytes()); - self.roomuserid_privateread.insert(&key, &count.to_be_bytes())?; + self.roomuserid_privateread + .insert(&key, &count.to_be_bytes())?; - self.roomuserid_lastprivatereadupdate.insert(&key, &services().globals.next_count()?.to_be_bytes()) + self.roomuserid_lastprivatereadupdate + .insert(&key, &services().globals.next_count()?.to_be_bytes()) } fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result> { @@ -88,11 +93,13 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase { key.push(0xFF); key.extend_from_slice(user_id.as_bytes()); - self.roomuserid_privateread.get(&key)?.map_or(Ok(None), |v| { - Ok(Some( - utils::u64_from_bytes(&v).map_err(|_| Error::bad_database("Invalid private read marker bytes"))?, - )) - }) + self.roomuserid_privateread + .get(&key)? + .map_or(Ok(None), |v| { + Ok(Some( + utils::u64_from_bytes(&v).map_err(|_| Error::bad_database("Invalid private read marker bytes"))?, + )) + }) } fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result { diff --git a/src/database/key_value/rooms/metadata.rs b/src/database/key_value/rooms/metadata.rs index 78fba9d3..9528da1e 100644 --- a/src/database/key_value/rooms/metadata.rs +++ b/src/database/key_value/rooms/metadata.rs @@ -11,7 +11,12 @@ impl service::rooms::metadata::Data for KeyValueDatabase { }; // Look for PDUs in that room. - Ok(self.pduid_pdu.iter_from(&prefix, false).next().filter(|(k, _)| k.starts_with(&prefix)).is_some()) + Ok(self + .pduid_pdu + .iter_from(&prefix, false) + .next() + .filter(|(k, _)| k.starts_with(&prefix)) + .is_some()) } fn iter_ids<'a>(&'a self) -> Box> + 'a> { diff --git a/src/database/key_value/rooms/outlier.rs b/src/database/key_value/rooms/outlier.rs index f45e78fc..933660e8 100644 --- a/src/database/key_value/rooms/outlier.rs +++ b/src/database/key_value/rooms/outlier.rs @@ -4,15 +4,19 @@ use crate::{database::KeyValueDatabase, service, Error, PduEvent, Result}; impl service::rooms::outlier::Data for KeyValueDatabase { fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result> { - self.eventid_outlierpdu.get(event_id.as_bytes())?.map_or(Ok(None), |pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) - }) + self.eventid_outlierpdu + .get(event_id.as_bytes())? + .map_or(Ok(None), |pdu| { + serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) + }) } fn get_outlier_pdu(&self, event_id: &EventId) -> Result> { - self.eventid_outlierpdu.get(event_id.as_bytes())?.map_or(Ok(None), |pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) - }) + self.eventid_outlierpdu + .get(event_id.as_bytes())? + .map_or(Ok(None), |pdu| { + serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) + }) } fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> { diff --git a/src/database/key_value/rooms/pdu_metadata.rs b/src/database/key_value/rooms/pdu_metadata.rs index 44c621c7..b5c81f62 100644 --- a/src/database/key_value/rooms/pdu_metadata.rs +++ b/src/database/key_value/rooms/pdu_metadata.rs @@ -32,8 +32,10 @@ impl service::rooms::pdu_metadata::Data for KeyValueDatabase { current.extend_from_slice(&count_raw.to_be_bytes()); Ok(Box::new( - self.tofrom_relation.iter_from(¤t, true).take_while(move |(k, _)| k.starts_with(&prefix)).map( - move |(tofrom, _data)| { + self.tofrom_relation + .iter_from(¤t, true) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .map(move |(tofrom, _data)| { let from = utils::u64_from_bytes(&tofrom[(mem::size_of::())..]) .map_err(|_| Error::bad_database("Invalid count in tofrom_relation."))?; @@ -49,8 +51,7 @@ impl service::rooms::pdu_metadata::Data for KeyValueDatabase { pdu.remove_transaction_id()?; } Ok((PduCount::Normal(from), pdu)) - }, - ), + }), )) } @@ -75,6 +76,8 @@ impl service::rooms::pdu_metadata::Data for KeyValueDatabase { } fn is_event_soft_failed(&self, event_id: &EventId) -> Result { - self.softfailedeventids.get(event_id.as_bytes()).map(|o| o.is_some()) + self.softfailedeventids + .get(event_id.as_bytes()) + .map(|o| o.is_some()) } } diff --git a/src/database/key_value/rooms/search.rs b/src/database/key_value/rooms/search.rs index 05f6f749..f06598ae 100644 --- a/src/database/key_value/rooms/search.rs +++ b/src/database/key_value/rooms/search.rs @@ -23,7 +23,13 @@ impl service::rooms::search::Data for KeyValueDatabase { } fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> { - let prefix = services().rooms.short.get_shortroomid(room_id)?.expect("room exists").to_be_bytes().to_vec(); + let prefix = services() + .rooms + .short + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); let words: Vec<_> = search_string .split_terminator(|c: char| !c.is_alphanumeric()) diff --git a/src/database/key_value/rooms/short.rs b/src/database/key_value/rooms/short.rs index 1ae774c1..6fb2e99f 100644 --- a/src/database/key_value/rooms/short.rs +++ b/src/database/key_value/rooms/short.rs @@ -17,20 +17,28 @@ impl service::rooms::short::Data for KeyValueDatabase { }, None => { let shorteventid = services().globals.next_count()?; - self.eventid_shorteventid.insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?; - self.shorteventid_eventid.insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?; + self.eventid_shorteventid + .insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?; + self.shorteventid_eventid + .insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?; shorteventid }, }; - self.eventidshort_cache.lock().unwrap().insert(event_id.to_owned(), short); + self.eventidshort_cache + .lock() + .unwrap() + .insert(event_id.to_owned(), short); Ok(short) } fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result> { - if let Some(short) = - self.statekeyshort_cache.lock().unwrap().get_mut(&(event_type.clone(), state_key.to_owned())) + if let Some(short) = self + .statekeyshort_cache + .lock() + .unwrap() + .get_mut(&(event_type.clone(), state_key.to_owned())) { return Ok(Some(*short)); } @@ -48,15 +56,21 @@ impl service::rooms::short::Data for KeyValueDatabase { .transpose()?; if let Some(s) = short { - self.statekeyshort_cache.lock().unwrap().insert((event_type.clone(), state_key.to_owned()), s); + self.statekeyshort_cache + .lock() + .unwrap() + .insert((event_type.clone(), state_key.to_owned()), s); } Ok(short) } fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result { - if let Some(short) = - self.statekeyshort_cache.lock().unwrap().get_mut(&(event_type.clone(), state_key.to_owned())) + if let Some(short) = self + .statekeyshort_cache + .lock() + .unwrap() + .get_mut(&(event_type.clone(), state_key.to_owned())) { return Ok(*short); } @@ -70,19 +84,29 @@ impl service::rooms::short::Data for KeyValueDatabase { .map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?, None => { let shortstatekey = services().globals.next_count()?; - self.statekey_shortstatekey.insert(&statekey_vec, &shortstatekey.to_be_bytes())?; - self.shortstatekey_statekey.insert(&shortstatekey.to_be_bytes(), &statekey_vec)?; + self.statekey_shortstatekey + .insert(&statekey_vec, &shortstatekey.to_be_bytes())?; + self.shortstatekey_statekey + .insert(&shortstatekey.to_be_bytes(), &statekey_vec)?; shortstatekey }, }; - self.statekeyshort_cache.lock().unwrap().insert((event_type.clone(), state_key.to_owned()), short); + self.statekeyshort_cache + .lock() + .unwrap() + .insert((event_type.clone(), state_key.to_owned()), short); Ok(short) } fn get_eventid_from_short(&self, shorteventid: u64) -> Result> { - if let Some(id) = self.shorteventid_cache.lock().unwrap().get_mut(&shorteventid) { + if let Some(id) = self + .shorteventid_cache + .lock() + .unwrap() + .get_mut(&shorteventid) + { return Ok(Arc::clone(id)); } @@ -97,13 +121,21 @@ impl service::rooms::short::Data for KeyValueDatabase { ) .map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?; - self.shorteventid_cache.lock().unwrap().insert(shorteventid, Arc::clone(&event_id)); + self.shorteventid_cache + .lock() + .unwrap() + .insert(shorteventid, Arc::clone(&event_id)); Ok(event_id) } fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { - if let Some(id) = self.shortstatekey_cache.lock().unwrap().get_mut(&shortstatekey) { + if let Some(id) = self + .shortstatekey_cache + .lock() + .unwrap() + .get_mut(&shortstatekey) + { return Ok(id.clone()); } @@ -114,8 +146,9 @@ impl service::rooms::short::Data for KeyValueDatabase { let mut parts = bytes.splitn(2, |&b| b == 0xFF); let eventtype_bytes = parts.next().expect("split always returns one entry"); - let statekey_bytes = - parts.next().ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?; + let statekey_bytes = parts + .next() + .ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?; let event_type = StateEventType::from(utils::string_from_bytes(eventtype_bytes).map_err(|e| { warn!("Event type in shortstatekey_statekey is invalid: {}", e); @@ -127,7 +160,10 @@ impl service::rooms::short::Data for KeyValueDatabase { let result = (event_type, state_key); - self.shortstatekey_cache.lock().unwrap().insert(shortstatekey, result.clone()); + self.shortstatekey_cache + .lock() + .unwrap() + .insert(shortstatekey, result.clone()); Ok(result) } @@ -142,7 +178,8 @@ impl service::rooms::short::Data for KeyValueDatabase { ), None => { let shortstatehash = services().globals.next_count()?; - self.statehash_shortstatehash.insert(state_hash, &shortstatehash.to_be_bytes())?; + self.statehash_shortstatehash + .insert(state_hash, &shortstatehash.to_be_bytes())?; (shortstatehash, false) }, }) @@ -162,7 +199,8 @@ impl service::rooms::short::Data for KeyValueDatabase { }, None => { let short = services().globals.next_count()?; - self.roomid_shortroomid.insert(room_id.as_bytes(), &short.to_be_bytes())?; + self.roomid_shortroomid + .insert(room_id.as_bytes(), &short.to_be_bytes())?; short }, }) diff --git a/src/database/key_value/rooms/state.rs b/src/database/key_value/rooms/state.rs index f11d2df3..79ef202f 100644 --- a/src/database/key_value/rooms/state.rs +++ b/src/database/key_value/rooms/state.rs @@ -7,11 +7,13 @@ use crate::{database::KeyValueDatabase, service, utils, Error, Result}; impl service::rooms::state::Data for KeyValueDatabase { fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result> { - self.roomid_shortstatehash.get(room_id.as_bytes())?.map_or(Ok(None), |bytes| { - Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Invalid shortstatehash in roomid_shortstatehash") - })?)) - }) + self.roomid_shortstatehash + .get(room_id.as_bytes())? + .map_or(Ok(None), |bytes| { + Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Invalid shortstatehash in roomid_shortstatehash") + })?)) + }) } fn set_room_state( @@ -20,12 +22,14 @@ impl service::rooms::state::Data for KeyValueDatabase { new_shortstatehash: u64, _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex ) -> Result<()> { - self.roomid_shortstatehash.insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?; + self.roomid_shortstatehash + .insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?; Ok(()) } fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()> { - self.shorteventid_shortstatehash.insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?; + self.shorteventid_shortstatehash + .insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?; Ok(()) } diff --git a/src/database/key_value/rooms/state_accessor.rs b/src/database/key_value/rooms/state_accessor.rs index 24bf6247..3ea29ac1 100644 --- a/src/database/key_value/rooms/state_accessor.rs +++ b/src/database/key_value/rooms/state_accessor.rs @@ -19,7 +19,10 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase { let mut result = HashMap::new(); let mut i = 0; for compressed in full_state.iter() { - let parsed = services().rooms.state_compressor.parse_compressed_state_event(compressed)?; + let parsed = services() + .rooms + .state_compressor + .parse_compressed_state_event(compressed)?; result.insert(parsed.0, parsed.1); i += 1; @@ -43,7 +46,10 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase { let mut result = HashMap::new(); let mut i = 0; for compressed in full_state.iter() { - let (_, eventid) = services().rooms.state_compressor.parse_compressed_state_event(compressed)?; + let (_, eventid) = services() + .rooms + .state_compressor + .parse_compressed_state_event(compressed)?; if let Some(pdu) = services().rooms.timeline.get_pdu(&eventid)? { result.insert( ( @@ -71,7 +77,11 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase { fn state_get_id( &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, ) -> Result>> { - let shortstatekey = match services().rooms.short.get_shortstatekey(event_type, state_key)? { + let shortstatekey = match services() + .rooms + .short + .get_shortstatekey(event_type, state_key)? + { Some(s) => s, None => return Ok(None), }; @@ -82,11 +92,17 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase { .pop() .expect("there is always one layer") .1; - Ok( - full_state.iter().find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())).and_then(|compressed| { - services().rooms.state_compressor.parse_compressed_state_event(compressed).ok().map(|(_, id)| id) - }), - ) + Ok(full_state + .iter() + .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) + .and_then(|compressed| { + services() + .rooms + .state_compressor + .parse_compressed_state_event(compressed) + .ok() + .map(|(_, id)| id) + })) } /// Returns a single PDU from `room_id` with key (`event_type`, @@ -100,15 +116,18 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase { /// Returns the state hash for this pdu. fn pdu_shortstatehash(&self, event_id: &EventId) -> Result> { - self.eventid_shorteventid.get(event_id.as_bytes())?.map_or(Ok(None), |shorteventid| { - self.shorteventid_shortstatehash - .get(&shorteventid)? - .map(|bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid shortstatehash bytes in shorteventid_shortstatehash")) - }) - .transpose() - }) + self.eventid_shorteventid + .get(event_id.as_bytes())? + .map_or(Ok(None), |shorteventid| { + self.shorteventid_shortstatehash + .get(&shorteventid)? + .map(|bytes| { + utils::u64_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Invalid shortstatehash bytes in shorteventid_shortstatehash") + }) + }) + .transpose() + }) } /// Returns the full room state. diff --git a/src/database/key_value/rooms/state_cache.rs b/src/database/key_value/rooms/state_cache.rs index 4736e225..d37a6b9a 100644 --- a/src/database/key_value/rooms/state_cache.rs +++ b/src/database/key_value/rooms/state_cache.rs @@ -58,7 +58,8 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { &userroom_id, &serde_json::to_vec(&last_state.unwrap_or_default()).expect("state to bytes always works"), )?; - self.roomuserid_invitecount.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?; + self.roomuserid_invitecount + .insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?; self.userroomid_joined.remove(&userroom_id)?; self.roomuserid_joined.remove(&roomuser_id)?; self.userroomid_leftstate.remove(&userroom_id)?; @@ -80,7 +81,8 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { &userroom_id, &serde_json::to_vec(&Vec::>::new()).unwrap(), )?; // TODO - self.roomuserid_leftcount.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?; + self.roomuserid_leftcount + .insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?; self.userroomid_joined.remove(&userroom_id)?; self.roomuserid_joined.remove(&roomuser_id)?; self.userroomid_invitestate.remove(&userroom_id)?; @@ -109,11 +111,16 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { invitedcount += 1; } - self.roomid_joinedcount.insert(room_id.as_bytes(), &joinedcount.to_be_bytes())?; + self.roomid_joinedcount + .insert(room_id.as_bytes(), &joinedcount.to_be_bytes())?; - self.roomid_invitedcount.insert(room_id.as_bytes(), &invitedcount.to_be_bytes())?; + self.roomid_invitedcount + .insert(room_id.as_bytes(), &invitedcount.to_be_bytes())?; - self.our_real_users_cache.write().unwrap().insert(room_id.to_owned(), Arc::new(real_users)); + self.our_real_users_cache + .write() + .unwrap() + .insert(room_id.to_owned(), Arc::new(real_users)); for old_joined_server in self.room_servers(room_id).filter_map(Result::ok) { if !joined_servers.remove(&old_joined_server) { @@ -145,19 +152,33 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { self.serverroomids.insert(&serverroom_id, &[])?; } - self.appservice_in_room_cache.write().unwrap().remove(room_id); + self.appservice_in_room_cache + .write() + .unwrap() + .remove(room_id); Ok(()) } #[tracing::instrument(skip(self, room_id))] fn get_our_real_users(&self, room_id: &RoomId) -> Result>> { - let maybe = self.our_real_users_cache.read().unwrap().get(room_id).cloned(); + let maybe = self + .our_real_users_cache + .read() + .unwrap() + .get(room_id) + .cloned(); if let Some(users) = maybe { Ok(users) } else { self.update_joined_count(room_id)?; - Ok(Arc::clone(self.our_real_users_cache.read().unwrap().get(room_id).unwrap())) + Ok(Arc::clone( + self.our_real_users_cache + .read() + .unwrap() + .get(room_id) + .unwrap(), + )) } } @@ -221,8 +242,12 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { Box::new(self.roomserverids.scan_prefix(prefix).map(|(key, _)| { ServerName::parse( - utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element")) - .map_err(|_| Error::bad_database("Server name in roomserverids is invalid unicode."))?, + utils::string_from_bytes( + key.rsplit(|&b| b == 0xFF) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| Error::bad_database("Server name in roomserverids is invalid unicode."))?, ) .map_err(|_| Error::bad_database("Server name in roomserverids is invalid.")) })) @@ -246,8 +271,12 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { Box::new(self.serverroomids.scan_prefix(prefix).map(|(key, _)| { RoomId::parse( - utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element")) - .map_err(|_| Error::bad_database("RoomId in serverroomids is invalid unicode."))?, + utils::string_from_bytes( + key.rsplit(|&b| b == 0xFF) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| Error::bad_database("RoomId in serverroomids is invalid unicode."))?, ) .map_err(|_| Error::bad_database("RoomId in serverroomids is invalid.")) })) @@ -261,8 +290,12 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { Box::new(self.roomuserid_joined.scan_prefix(prefix).map(|(key, _)| { UserId::parse( - utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element")) - .map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid unicode."))?, + utils::string_from_bytes( + key.rsplit(|&b| b == 0xFF) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid unicode."))?, ) .map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid.")) })) @@ -292,13 +325,21 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xFF); - Box::new(self.roomuseroncejoinedids.scan_prefix(prefix).map(|(key, _)| { - UserId::parse( - utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element")) - .map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid.")) - })) + Box::new( + self.roomuseroncejoinedids + .scan_prefix(prefix) + .map(|(key, _)| { + UserId::parse( + utils::string_from_bytes( + key.rsplit(|&b| b == 0xFF) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid.")) + }), + ) } /// Returns an iterator over all invited members of a room. @@ -307,13 +348,21 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xFF); - Box::new(self.roomuserid_invitecount.scan_prefix(prefix).map(|(key, _)| { - UserId::parse( - utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element")) - .map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid.")) - })) + Box::new( + self.roomuserid_invitecount + .scan_prefix(prefix) + .map(|(key, _)| { + UserId::parse( + utils::string_from_bytes( + key.rsplit(|&b| b == 0xFF) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid.")) + }), + ) } #[tracing::instrument(skip(self))] @@ -322,11 +371,13 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { key.push(0xFF); key.extend_from_slice(user_id.as_bytes()); - self.roomuserid_invitecount.get(&key)?.map_or(Ok(None), |bytes| { - Ok(Some( - utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid invitecount in db."))?, - )) - }) + self.roomuserid_invitecount + .get(&key)? + .map_or(Ok(None), |bytes| { + Ok(Some( + utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid invitecount in db."))?, + )) + }) } #[tracing::instrument(skip(self))] @@ -344,13 +395,21 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { /// Returns an iterator over all rooms this user joined. #[tracing::instrument(skip(self))] fn rooms_joined<'a>(&'a self, user_id: &UserId) -> Box> + 'a> { - Box::new(self.userroomid_joined.scan_prefix(user_id.as_bytes().to_vec()).map(|(key, _)| { - RoomId::parse( - utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element")) - .map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid.")) - })) + Box::new( + self.userroomid_joined + .scan_prefix(user_id.as_bytes().to_vec()) + .map(|(key, _)| { + RoomId::parse( + utils::string_from_bytes( + key.rsplit(|&b| b == 0xFF) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid.")) + }), + ) } /// Returns an iterator over all rooms a user was invited to. @@ -359,18 +418,26 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xFF); - Box::new(self.userroomid_invitestate.scan_prefix(prefix).map(|(key, state)| { - let room_id = RoomId::parse( - utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element")) - .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?; + Box::new( + self.userroomid_invitestate + .scan_prefix(prefix) + .map(|(key, state)| { + let room_id = RoomId::parse( + utils::string_from_bytes( + key.rsplit(|&b| b == 0xFF) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?; - let state = serde_json::from_slice(&state) - .map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?; + let state = serde_json::from_slice(&state) + .map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?; - Ok((room_id, state)) - })) + Ok((room_id, state)) + }), + ) } #[tracing::instrument(skip(self))] @@ -413,18 +480,26 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xFF); - Box::new(self.userroomid_leftstate.scan_prefix(prefix).map(|(key, state)| { - let room_id = RoomId::parse( - utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element")) - .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?; + Box::new( + self.userroomid_leftstate + .scan_prefix(prefix) + .map(|(key, state)| { + let room_id = RoomId::parse( + utils::string_from_bytes( + key.rsplit(|&b| b == 0xFF) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?; - let state = serde_json::from_slice(&state) - .map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?; + let state = serde_json::from_slice(&state) + .map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?; - Ok((room_id, state)) - })) + Ok((room_id, state)) + }), + ) } #[tracing::instrument(skip(self))] diff --git a/src/database/key_value/rooms/state_compressor.rs b/src/database/key_value/rooms/state_compressor.rs index 9be3a196..0043a1ba 100644 --- a/src/database/key_value/rooms/state_compressor.rs +++ b/src/database/key_value/rooms/state_compressor.rs @@ -58,6 +58,7 @@ impl service::rooms::state_compressor::Data for KeyValueDatabase { } } - self.shortstatehash_statediff.insert(&shortstatehash.to_be_bytes(), &value) + self.shortstatehash_statediff + .insert(&shortstatehash.to_be_bytes(), &value) } } diff --git a/src/database/key_value/rooms/threads.rs b/src/database/key_value/rooms/threads.rs index 108c11e4..4cb2591b 100644 --- a/src/database/key_value/rooms/threads.rs +++ b/src/database/key_value/rooms/threads.rs @@ -10,14 +10,22 @@ impl service::rooms::threads::Data for KeyValueDatabase { fn threads_until<'a>( &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, _include: &'a IncludeThreads, ) -> PduEventIterResult<'a> { - let prefix = services().rooms.short.get_shortroomid(room_id)?.expect("room exists").to_be_bytes().to_vec(); + let prefix = services() + .rooms + .short + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); let mut current = prefix.clone(); current.extend_from_slice(&(until - 1).to_be_bytes()); Ok(Box::new( - self.threadid_userids.iter_from(¤t, true).take_while(move |(k, _)| k.starts_with(&prefix)).map( - move |(pduid, _users)| { + self.threadid_userids + .iter_from(¤t, true) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .map(move |(pduid, _users)| { let count = utils::u64_from_bytes(&pduid[(mem::size_of::())..]) .map_err(|_| Error::bad_database("Invalid pduid in threadid_userids."))?; let mut pdu = services() @@ -29,13 +37,16 @@ impl service::rooms::threads::Data for KeyValueDatabase { pdu.remove_transaction_id()?; } Ok((count, pdu)) - }, - ), + }), )) } fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()> { - let users = participants.iter().map(|user| user.as_bytes()).collect::>().join(&[0xFF][..]); + let users = participants + .iter() + .map(|user| user.as_bytes()) + .collect::>() + .join(&[0xFF][..]); self.threadid_userids.insert(root_id, &users)?; diff --git a/src/database/key_value/rooms/timeline.rs b/src/database/key_value/rooms/timeline.rs index 63b57f9b..0b1573b4 100644 --- a/src/database/key_value/rooms/timeline.rs +++ b/src/database/key_value/rooms/timeline.rs @@ -8,15 +8,22 @@ use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEven impl service::rooms::timeline::Data for KeyValueDatabase { fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { - match self.lasttimelinecount_cache.lock().unwrap().entry(room_id.to_owned()) { + match self + .lasttimelinecount_cache + .lock() + .unwrap() + .entry(room_id.to_owned()) + { hash_map::Entry::Vacant(v) => { - if let Some(last_count) = self.pdus_until(sender_user, room_id, PduCount::max())?.find_map(|r| { - // Filter out buggy events - if r.is_err() { - error!("Bad pdu in pdus_since: {:?}", r); - } - r.ok() - }) { + if let Some(last_count) = self + .pdus_until(sender_user, room_id, PduCount::max())? + .find_map(|r| { + // Filter out buggy events + if r.is_err() { + error!("Bad pdu in pdus_since: {:?}", r); + } + r.ok() + }) { Ok(*v.insert(last_count.0)) } else { Ok(PduCount::Normal(0)) @@ -28,7 +35,10 @@ impl service::rooms::timeline::Data for KeyValueDatabase { /// Returns the `count` of this pdu's id. fn get_pdu_count(&self, event_id: &EventId) -> Result> { - self.eventid_pduid.get(event_id.as_bytes())?.map(|pdu_id| pdu_count(&pdu_id)).transpose() + self.eventid_pduid + .get(event_id.as_bytes())? + .map(|pdu_id| pdu_count(&pdu_id)) + .transpose() } /// Returns the json of a pdu. @@ -49,7 +59,9 @@ impl service::rooms::timeline::Data for KeyValueDatabase { self.eventid_pduid .get(event_id.as_bytes())? .map(|pduid| { - self.pduid_pdu.get(&pduid)?.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) + self.pduid_pdu + .get(&pduid)? + .ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) }) .transpose()? .map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))) @@ -64,7 +76,9 @@ impl service::rooms::timeline::Data for KeyValueDatabase { self.eventid_pduid .get(event_id.as_bytes())? .map(|pduid| { - self.pduid_pdu.get(&pduid)?.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) + self.pduid_pdu + .get(&pduid)? + .ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) }) .transpose()? .map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))) @@ -92,7 +106,10 @@ impl service::rooms::timeline::Data for KeyValueDatabase { )? .map(Arc::new) { - self.pdu_cache.lock().unwrap().insert(event_id.to_owned(), Arc::clone(&pdu)); + self.pdu_cache + .lock() + .unwrap() + .insert(event_id.to_owned(), Arc::clone(&pdu)); Ok(Some(pdu)) } else { Ok(None) @@ -125,7 +142,10 @@ impl service::rooms::timeline::Data for KeyValueDatabase { &serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"), )?; - self.lasttimelinecount_cache.lock().unwrap().insert(pdu.room_id.clone(), PduCount::Normal(count)); + self.lasttimelinecount_cache + .lock() + .unwrap() + .insert(pdu.room_id.clone(), PduCount::Normal(count)); self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id)?; self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?; @@ -156,7 +176,10 @@ impl service::rooms::timeline::Data for KeyValueDatabase { return Err(Error::BadRequest(ErrorKind::NotFound, "PDU does not exist.")); } - self.pdu_cache.lock().unwrap().remove(&(*pdu.event_id).to_owned()); + self.pdu_cache + .lock() + .unwrap() + .remove(&(*pdu.event_id).to_owned()); Ok(()) } @@ -172,8 +195,10 @@ impl service::rooms::timeline::Data for KeyValueDatabase { let user_id = user_id.to_owned(); Ok(Box::new( - self.pduid_pdu.iter_from(¤t, true).take_while(move |(k, _)| k.starts_with(&prefix)).map( - move |(pdu_id, v)| { + self.pduid_pdu + .iter_from(¤t, true) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .map(move |(pdu_id, v)| { let mut pdu = serde_json::from_slice::(&v) .map_err(|_| Error::bad_database("PDU in db is invalid."))?; if pdu.sender != user_id { @@ -182,8 +207,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase { pdu.add_age()?; let count = pdu_count(&pdu_id)?; Ok((count, pdu)) - }, - ), + }), )) } @@ -195,8 +219,10 @@ impl service::rooms::timeline::Data for KeyValueDatabase { let user_id = user_id.to_owned(); Ok(Box::new( - self.pduid_pdu.iter_from(¤t, false).take_while(move |(k, _)| k.starts_with(&prefix)).map( - move |(pdu_id, v)| { + self.pduid_pdu + .iter_from(¤t, false) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .map(move |(pdu_id, v)| { let mut pdu = serde_json::from_slice::(&v) .map_err(|_| Error::bad_database("PDU in db is invalid."))?; if pdu.sender != user_id { @@ -205,8 +231,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase { pdu.add_age()?; let count = pdu_count(&pdu_id)?; Ok((count, pdu)) - }, - ), + }), )) } @@ -228,8 +253,10 @@ impl service::rooms::timeline::Data for KeyValueDatabase { highlights_batch.push(userroom_id); } - self.userroomid_notificationcount.increment_batch(&mut notifies_batch.into_iter())?; - self.userroomid_highlightcount.increment_batch(&mut highlights_batch.into_iter())?; + self.userroomid_notificationcount + .increment_batch(&mut notifies_batch.into_iter())?; + self.userroomid_highlightcount + .increment_batch(&mut highlights_batch.into_iter())?; Ok(()) } } diff --git a/src/database/key_value/rooms/user.rs b/src/database/key_value/rooms/user.rs index 4fa4dc14..d773a577 100644 --- a/src/database/key_value/rooms/user.rs +++ b/src/database/key_value/rooms/user.rs @@ -11,10 +11,13 @@ impl service::rooms::user::Data for KeyValueDatabase { roomuser_id.push(0xFF); roomuser_id.extend_from_slice(user_id.as_bytes()); - self.userroomid_notificationcount.insert(&userroom_id, &0_u64.to_be_bytes())?; - self.userroomid_highlightcount.insert(&userroom_id, &0_u64.to_be_bytes())?; + self.userroomid_notificationcount + .insert(&userroom_id, &0_u64.to_be_bytes())?; + self.userroomid_highlightcount + .insert(&userroom_id, &0_u64.to_be_bytes())?; - self.roomuserid_lastnotificationread.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?; + self.roomuserid_lastnotificationread + .insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?; Ok(()) } @@ -24,9 +27,11 @@ impl service::rooms::user::Data for KeyValueDatabase { userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); - self.userroomid_notificationcount.get(&userroom_id)?.map_or(Ok(0), |bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid notification count in db.")) - }) + self.userroomid_notificationcount + .get(&userroom_id)? + .map_or(Ok(0), |bytes| { + utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid notification count in db.")) + }) } fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { @@ -34,9 +39,11 @@ impl service::rooms::user::Data for KeyValueDatabase { userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); - self.userroomid_highlightcount.get(&userroom_id)?.map_or(Ok(0), |bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid highlight count in db.")) - }) + self.userroomid_highlightcount + .get(&userroom_id)? + .map_or(Ok(0), |bytes| { + utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid highlight count in db.")) + }) } fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result { @@ -56,16 +63,25 @@ impl service::rooms::user::Data for KeyValueDatabase { } fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) -> Result<()> { - let shortroomid = services().rooms.short.get_shortroomid(room_id)?.expect("room exists"); + let shortroomid = services() + .rooms + .short + .get_shortroomid(room_id)? + .expect("room exists"); let mut key = shortroomid.to_be_bytes().to_vec(); key.extend_from_slice(&token.to_be_bytes()); - self.roomsynctoken_shortstatehash.insert(&key, &shortstatehash.to_be_bytes()) + self.roomsynctoken_shortstatehash + .insert(&key, &shortstatehash.to_be_bytes()) } fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result> { - let shortroomid = services().rooms.short.get_shortroomid(room_id)?.expect("room exists"); + let shortroomid = services() + .rooms + .short + .get_shortroomid(room_id)? + .expect("room exists"); let mut key = shortroomid.to_be_bytes().to_vec(); key.extend_from_slice(&token.to_be_bytes()); @@ -106,13 +122,15 @@ impl service::rooms::user::Data for KeyValueDatabase { // We use the default compare function because keys are sorted correctly (not // reversed) Ok(Box::new( - utils::common_elements(iterators, Ord::cmp).expect("users is not empty").map(|bytes| { - RoomId::parse( - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid RoomId bytes in userroomid_joined"))?, - ) - .map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined.")) - }), + utils::common_elements(iterators, Ord::cmp) + .expect("users is not empty") + .map(|bytes| { + RoomId::parse( + utils::string_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Invalid RoomId bytes in userroomid_joined"))?, + ) + .map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined.")) + }), )) } } diff --git a/src/database/key_value/sending.rs b/src/database/key_value/sending.rs index 5087cbe7..9f871ee7 100644 --- a/src/database/key_value/sending.rs +++ b/src/database/key_value/sending.rs @@ -73,7 +73,8 @@ impl service::sending::Data for KeyValueDatabase { batch.push((key.clone(), value.to_owned())); keys.push(key); } - self.servernameevent_data.insert_batch(&mut batch.into_iter())?; + self.servernameevent_data + .insert_batch(&mut batch.into_iter())?; Ok(keys) } @@ -107,13 +108,16 @@ impl service::sending::Data for KeyValueDatabase { } fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) -> Result<()> { - self.servername_educount.insert(server_name.as_bytes(), &last_count.to_be_bytes()) + self.servername_educount + .insert(server_name.as_bytes(), &last_count.to_be_bytes()) } fn get_latest_educount(&self, server_name: &ServerName) -> Result { - self.servername_educount.get(server_name.as_bytes())?.map_or(Ok(0), |bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid u64 in servername_educount.")) - }) + self.servername_educount + .get(server_name.as_bytes())? + .map_or(Ok(0), |bytes| { + utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid u64 in servername_educount.")) + }) } } @@ -124,7 +128,9 @@ fn parse_servercurrentevent(key: &[u8], value: Vec) -> Result<(OutgoingKind, let mut parts = key[1..].splitn(2, |&b| b == 0xFF); let server = parts.next().expect("splitn always returns one element"); - let event = parts.next().ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; + let event = parts + .next() + .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; let server = utils::string_from_bytes(server) .map_err(|_| Error::bad_database("Invalid server bytes in server_currenttransaction"))?; @@ -146,11 +152,15 @@ fn parse_servercurrentevent(key: &[u8], value: Vec) -> Result<(OutgoingKind, let user_id = UserId::parse(user_string).map_err(|_| Error::bad_database("Invalid user id in servercurrentevent"))?; - let pushkey = parts.next().ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; + let pushkey = parts + .next() + .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; let pushkey_string = utils::string_from_bytes(pushkey) .map_err(|_| Error::bad_database("Invalid pushkey in servercurrentevent"))?; - let event = parts.next().ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; + let event = parts + .next() + .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; ( OutgoingKind::Push(user_id, pushkey_string), @@ -165,7 +175,9 @@ fn parse_servercurrentevent(key: &[u8], value: Vec) -> Result<(OutgoingKind, let mut parts = key.splitn(2, |&b| b == 0xFF); let server = parts.next().expect("splitn always returns one element"); - let event = parts.next().ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; + let event = parts + .next() + .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; let server = utils::string_from_bytes(server) .map_err(|_| Error::bad_database("Invalid server bytes in server_currenttransaction"))?; diff --git a/src/database/key_value/uiaa.rs b/src/database/key_value/uiaa.rs index 9ec364dc..c8b7611e 100644 --- a/src/database/key_value/uiaa.rs +++ b/src/database/key_value/uiaa.rs @@ -9,10 +9,13 @@ impl service::uiaa::Data for KeyValueDatabase { fn set_uiaa_request( &self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue, ) -> Result<()> { - self.userdevicesessionid_uiaarequest.write().unwrap().insert( - (user_id.to_owned(), device_id.to_owned(), session.to_owned()), - request.to_owned(), - ); + self.userdevicesessionid_uiaarequest + .write() + .unwrap() + .insert( + (user_id.to_owned(), device_id.to_owned(), session.to_owned()), + request.to_owned(), + ); Ok(()) } @@ -40,7 +43,8 @@ impl service::uiaa::Data for KeyValueDatabase { &serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"), )?; } else { - self.userdevicesessionid_uiaainfo.remove(&userdevicesessionid)?; + self.userdevicesessionid_uiaainfo + .remove(&userdevicesessionid)?; } Ok(()) diff --git a/src/database/key_value/users.rs b/src/database/key_value/users.rs index 550d46c9..b83e456f 100644 --- a/src/database/key_value/users.rs +++ b/src/database/key_value/users.rs @@ -34,23 +34,27 @@ impl service::users::Data for KeyValueDatabase { /// Find out which user an access token belongs to. fn find_from_token(&self, token: &str) -> Result> { - self.token_userdeviceid.get(token.as_bytes())?.map_or(Ok(None), |bytes| { - let mut parts = bytes.split(|&b| b == 0xFF); - let user_bytes = - parts.next().ok_or_else(|| Error::bad_database("User ID in token_userdeviceid is invalid."))?; - let device_bytes = - parts.next().ok_or_else(|| Error::bad_database("Device ID in token_userdeviceid is invalid."))?; + self.token_userdeviceid + .get(token.as_bytes())? + .map_or(Ok(None), |bytes| { + let mut parts = bytes.split(|&b| b == 0xFF); + let user_bytes = parts + .next() + .ok_or_else(|| Error::bad_database("User ID in token_userdeviceid is invalid."))?; + let device_bytes = parts + .next() + .ok_or_else(|| Error::bad_database("Device ID in token_userdeviceid is invalid."))?; - Ok(Some(( - UserId::parse( - utils::string_from_bytes(user_bytes) - .map_err(|_| Error::bad_database("User ID in token_userdeviceid is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("User ID in token_userdeviceid is invalid."))?, - utils::string_from_bytes(device_bytes) - .map_err(|_| Error::bad_database("Device ID in token_userdeviceid is invalid."))?, - ))) - }) + Ok(Some(( + UserId::parse( + utils::string_from_bytes(user_bytes) + .map_err(|_| Error::bad_database("User ID in token_userdeviceid is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("User ID in token_userdeviceid is invalid."))?, + utils::string_from_bytes(device_bytes) + .map_err(|_| Error::bad_database("Device ID in token_userdeviceid is invalid."))?, + ))) + }) } /// Returns an iterator over all users on this homeserver. @@ -79,18 +83,21 @@ impl service::users::Data for KeyValueDatabase { /// Returns the password hash for the given user. fn password_hash(&self, user_id: &UserId) -> Result> { - self.userid_password.get(user_id.as_bytes())?.map_or(Ok(None), |bytes| { - Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Password hash in db is not valid string.") - })?)) - }) + self.userid_password + .get(user_id.as_bytes())? + .map_or(Ok(None), |bytes| { + Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Password hash in db is not valid string.") + })?)) + }) } /// Hash and set the user's password to the Argon2 hash fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { if let Some(password) = password { if let Ok(hash) = utils::calculate_password_hash(password) { - self.userid_password.insert(user_id.as_bytes(), hash.as_bytes())?; + self.userid_password + .insert(user_id.as_bytes(), hash.as_bytes())?; Ok(()) } else { Err(Error::BadRequest( @@ -106,18 +113,22 @@ impl service::users::Data for KeyValueDatabase { /// Returns the displayname of a user on this homeserver. fn displayname(&self, user_id: &UserId) -> Result> { - self.userid_displayname.get(user_id.as_bytes())?.map_or(Ok(None), |bytes| { - Ok(Some( - utils::string_from_bytes(&bytes).map_err(|_| Error::bad_database("Displayname in db is invalid."))?, - )) - }) + self.userid_displayname + .get(user_id.as_bytes())? + .map_or(Ok(None), |bytes| { + Ok(Some( + utils::string_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Displayname in db is invalid."))?, + )) + }) } /// Sets a new displayname or removes it if displayname is None. You still /// need to nofify all rooms of this change. fn set_displayname(&self, user_id: &UserId, displayname: Option) -> Result<()> { if let Some(displayname) = displayname { - self.userid_displayname.insert(user_id.as_bytes(), displayname.as_bytes())?; + self.userid_displayname + .insert(user_id.as_bytes(), displayname.as_bytes())?; } else { self.userid_displayname.remove(user_id.as_bytes())?; } @@ -143,7 +154,8 @@ impl service::users::Data for KeyValueDatabase { /// Sets a new avatar_url or removes it if avatar_url is None. fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option) -> Result<()> { if let Some(avatar_url) = avatar_url { - self.userid_avatarurl.insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?; + self.userid_avatarurl + .insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?; } else { self.userid_avatarurl.remove(user_id.as_bytes())?; } @@ -167,7 +179,8 @@ impl service::users::Data for KeyValueDatabase { /// Sets a new avatar_url or removes it if avatar_url is None. fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()> { if let Some(blurhash) = blurhash { - self.userid_blurhash.insert(user_id.as_bytes(), blurhash.as_bytes())?; + self.userid_blurhash + .insert(user_id.as_bytes(), blurhash.as_bytes())?; } else { self.userid_blurhash.remove(user_id.as_bytes())?; } @@ -190,7 +203,8 @@ impl service::users::Data for KeyValueDatabase { userdeviceid.push(0xFF); userdeviceid.extend_from_slice(device_id.as_bytes()); - self.userid_devicelistversion.increment(user_id.as_bytes())?; + self.userid_devicelistversion + .increment(user_id.as_bytes())?; self.userdeviceid_metadata.insert( &userdeviceid, @@ -230,7 +244,8 @@ impl service::users::Data for KeyValueDatabase { // TODO: Remove onetimekeys - self.userid_devicelistversion.increment(user_id.as_bytes())?; + self.userid_devicelistversion + .increment(user_id.as_bytes())?; self.userdeviceid_metadata.remove(&userdeviceid)?; @@ -242,16 +257,20 @@ impl service::users::Data for KeyValueDatabase { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xFF); // All devices have metadata - Box::new(self.userdeviceid_metadata.scan_prefix(prefix).map(|(bytes, _)| { - Ok(utils::string_from_bytes( - bytes - .rsplit(|&b| b == 0xFF) - .next() - .ok_or_else(|| Error::bad_database("UserDevice ID in db is invalid."))?, - ) - .map_err(|_| Error::bad_database("Device ID in userdeviceid_metadata is invalid."))? - .into()) - })) + Box::new( + self.userdeviceid_metadata + .scan_prefix(prefix) + .map(|(bytes, _)| { + Ok(utils::string_from_bytes( + bytes + .rsplit(|&b| b == 0xFF) + .next() + .ok_or_else(|| Error::bad_database("UserDevice ID in db is invalid."))?, + ) + .map_err(|_| Error::bad_database("Device ID in userdeviceid_metadata is invalid."))? + .into()) + }), + ) } /// Replaces the access token of one device. @@ -278,8 +297,10 @@ impl service::users::Data for KeyValueDatabase { } // Assign token to user device combination - self.userdeviceid_token.insert(&userdeviceid, token.as_bytes())?; - self.token_userdeviceid.insert(token.as_bytes(), &userdeviceid)?; + self.userdeviceid_token + .insert(&userdeviceid, token.as_bytes())?; + self.token_userdeviceid + .insert(token.as_bytes(), &userdeviceid)?; Ok(()) } @@ -310,7 +331,9 @@ impl service::users::Data for KeyValueDatabase { // TODO: Use DeviceKeyId::to_string when it's available (and update everything, // because there are no wrapping quotation marks anymore) key.extend_from_slice( - serde_json::to_string(one_time_key_key).expect("DeviceKeyId::to_string always works").as_bytes(), + serde_json::to_string(one_time_key_key) + .expect("DeviceKeyId::to_string always works") + .as_bytes(), ); self.onetimekeyid_onetimekeys.insert( @@ -318,16 +341,19 @@ impl service::users::Data for KeyValueDatabase { &serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"), )?; - self.userid_lastonetimekeyupdate.insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?; + self.userid_lastonetimekeyupdate + .insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?; Ok(()) } fn last_one_time_keys_update(&self, user_id: &UserId) -> Result { - self.userid_lastonetimekeyupdate.get(user_id.as_bytes())?.map_or(Ok(0), |bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Count in roomid_lastroomactiveupdate is invalid.")) - }) + self.userid_lastonetimekeyupdate + .get(user_id.as_bytes())? + .map_or(Ok(0), |bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Count in roomid_lastroomactiveupdate is invalid.")) + }) } fn take_one_time_key( @@ -341,7 +367,8 @@ impl service::users::Data for KeyValueDatabase { prefix.extend_from_slice(key_algorithm.as_ref().as_bytes()); prefix.push(b':'); - self.userid_lastonetimekeyupdate.insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?; + self.userid_lastonetimekeyupdate + .insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?; self.onetimekeyid_onetimekeys .scan_prefix(prefix) @@ -372,18 +399,21 @@ impl service::users::Data for KeyValueDatabase { let mut counts = BTreeMap::new(); - for algorithm in self.onetimekeyid_onetimekeys.scan_prefix(userdeviceid).map(|(bytes, _)| { - Ok::<_, Error>( - serde_json::from_slice::( - bytes - .rsplit(|&b| b == 0xFF) - .next() - .ok_or_else(|| Error::bad_database("OneTimeKey ID in db is invalid."))?, + for algorithm in self + .onetimekeyid_onetimekeys + .scan_prefix(userdeviceid) + .map(|(bytes, _)| { + Ok::<_, Error>( + serde_json::from_slice::( + bytes + .rsplit(|&b| b == 0xFF) + .next() + .ok_or_else(|| Error::bad_database("OneTimeKey ID in db is invalid."))?, + ) + .map_err(|_| Error::bad_database("DeviceKeyId in db is invalid."))? + .algorithm(), ) - .map_err(|_| Error::bad_database("DeviceKeyId in db is invalid."))? - .algorithm(), - ) - }) { + }) { *counts.entry(algorithm?).or_default() += UInt::from(1_u32); } @@ -415,9 +445,11 @@ impl service::users::Data for KeyValueDatabase { let (master_key_key, _) = self.parse_master_key(user_id, master_key)?; - self.keyid_key.insert(&master_key_key, master_key.json().get().as_bytes())?; + self.keyid_key + .insert(&master_key_key, master_key.json().get().as_bytes())?; - self.userid_masterkeyid.insert(user_id.as_bytes(), &master_key_key)?; + self.userid_masterkeyid + .insert(user_id.as_bytes(), &master_key_key)?; // Self-signing key if let Some(self_signing_key) = self_signing_key { @@ -441,9 +473,11 @@ impl service::users::Data for KeyValueDatabase { let mut self_signing_key_key = prefix.clone(); self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes()); - self.keyid_key.insert(&self_signing_key_key, self_signing_key.json().get().as_bytes())?; + self.keyid_key + .insert(&self_signing_key_key, self_signing_key.json().get().as_bytes())?; - self.userid_selfsigningkeyid.insert(user_id.as_bytes(), &self_signing_key_key)?; + self.userid_selfsigningkeyid + .insert(user_id.as_bytes(), &self_signing_key_key)?; } // User-signing key @@ -468,9 +502,11 @@ impl service::users::Data for KeyValueDatabase { let mut user_signing_key_key = prefix; user_signing_key_key.extend_from_slice(user_signing_key_id.as_bytes()); - self.keyid_key.insert(&user_signing_key_key, user_signing_key.json().get().as_bytes())?; + self.keyid_key + .insert(&user_signing_key_key, user_signing_key.json().get().as_bytes())?; - self.userid_usersigningkeyid.insert(user_id.as_bytes(), &user_signing_key_key)?; + self.userid_usersigningkeyid + .insert(user_id.as_bytes(), &user_signing_key_key)?; } if notify { @@ -559,9 +595,18 @@ impl service::users::Data for KeyValueDatabase { fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { let count = services().globals.next_count()?.to_be_bytes(); - for room_id in services().rooms.state_cache.rooms_joined(user_id).filter_map(Result::ok) { + for room_id in services() + .rooms + .state_cache + .rooms_joined(user_id) + .filter_map(Result::ok) + { // Don't send key updates to unencrypted rooms - if services().rooms.state_accessor.room_state_get(&room_id, &StateEventType::RoomEncryption, "")?.is_none() + if services() + .rooms + .state_accessor + .room_state_get(&room_id, &StateEventType::RoomEncryption, "")? + .is_none() { continue; } @@ -599,11 +644,13 @@ impl service::users::Data for KeyValueDatabase { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xFF); - let master_key = - master_key.deserialize().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))?; + let master_key = master_key + .deserialize() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))?; let mut master_key_ids = master_key.keys.values(); - let master_key_id = - master_key_ids.next().ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Master key contained no key."))?; + let master_key_id = master_key_ids + .next() + .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Master key contained no key."))?; if master_key_ids.next().is_some() { return Err(Error::BadRequest( ErrorKind::InvalidParam, @@ -646,14 +693,16 @@ impl service::users::Data for KeyValueDatabase { } fn get_user_signing_key(&self, user_id: &UserId) -> Result>> { - self.userid_usersigningkeyid.get(user_id.as_bytes())?.map_or(Ok(None), |key| { - self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { - Ok(Some( - serde_json::from_slice(&bytes) - .map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?, - )) + self.userid_usersigningkeyid + .get(user_id.as_bytes())? + .map_or(Ok(None), |key| { + self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { + Ok(Some( + serde_json::from_slice(&bytes) + .map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?, + )) + }) }) - }) } fn add_to_device_event( @@ -743,7 +792,8 @@ impl service::users::Data for KeyValueDatabase { )); } - self.userid_devicelistversion.increment(user_id.as_bytes())?; + self.userid_devicelistversion + .increment(user_id.as_bytes())?; self.userdeviceid_metadata.insert( &userdeviceid, @@ -759,27 +809,37 @@ impl service::users::Data for KeyValueDatabase { userdeviceid.push(0xFF); userdeviceid.extend_from_slice(device_id.as_bytes()); - self.userdeviceid_metadata.get(&userdeviceid)?.map_or(Ok(None), |bytes| { - Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { - Error::bad_database("Metadata in userdeviceid_metadata is invalid.") - })?)) - }) + self.userdeviceid_metadata + .get(&userdeviceid)? + .map_or(Ok(None), |bytes| { + Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { + Error::bad_database("Metadata in userdeviceid_metadata is invalid.") + })?)) + }) } fn get_devicelist_version(&self, user_id: &UserId) -> Result> { - self.userid_devicelistversion.get(user_id.as_bytes())?.map_or(Ok(None), |bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid devicelistversion in db.")).map(Some) - }) + self.userid_devicelistversion + .get(user_id.as_bytes())? + .map_or(Ok(None), |bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Invalid devicelistversion in db.")) + .map(Some) + }) } fn all_devices_metadata<'a>(&'a self, user_id: &UserId) -> Box> + 'a> { let mut key = user_id.as_bytes().to_vec(); key.push(0xFF); - Box::new(self.userdeviceid_metadata.scan_prefix(key).map(|(_, bytes)| { - serde_json::from_slice::(&bytes) - .map_err(|_| Error::bad_database("Device in userdeviceid_metadata is invalid.")) - })) + Box::new( + self.userdeviceid_metadata + .scan_prefix(key) + .map(|(_, bytes)| { + serde_json::from_slice::(&bytes) + .map_err(|_| Error::bad_database("Device in userdeviceid_metadata is invalid.")) + }), + ) } /// Creates a new sync filter. Returns the filter id. @@ -790,7 +850,8 @@ impl service::users::Data for KeyValueDatabase { key.push(0xFF); key.extend_from_slice(filter_id.as_bytes()); - self.userfilterid_filter.insert(&key, &serde_json::to_vec(&filter).expect("filter is valid json"))?; + self.userfilterid_filter + .insert(&key, &serde_json::to_vec(&filter).expect("filter is valid json"))?; Ok(filter_id) } diff --git a/src/database/mod.rs b/src/database/mod.rs index 55b4a7af..7d884fac 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -375,7 +375,10 @@ impl KeyValueDatabase { server_signingkeys: builder.open_tree("server_signingkeys")?, pdu_cache: Mutex::new(LruCache::new( - config.pdu_cache_capacity.try_into().expect("pdu cache capacity fits into usize"), + config + .pdu_cache_capacity + .try_into() + .expect("pdu cache capacity fits into usize"), )), auth_chain_cache: Mutex::new(LruCache::new((100_000.0 * config.conduit_cache_capacity_modifier) as usize)), shorteventid_cache: Mutex::new(LruCache::new( @@ -457,8 +460,11 @@ impl KeyValueDatabase { .argon .hash_password(b"", &salt) .expect("our own password to be properly hashed"); - let empty_hashed_password = - services().globals.argon.verify_password(&password, &empty_pass).is_ok(); + let empty_hashed_password = services() + .globals + .argon + .verify_password(&password, &empty_pass) + .is_ok(); if empty_hashed_password { db.userid_password.insert(&userid, b"")?; @@ -526,7 +532,8 @@ impl KeyValueDatabase { key.push(0xFF); key.extend_from_slice(event_type); - db.roomusertype_roomuserdataid.insert(&key, &roomuserdataid)?; + db.roomusertype_roomuserdataid + .insert(&key, &roomuserdataid)?; } services().globals.bump_database_version(5)?; @@ -565,16 +572,24 @@ impl KeyValueDatabase { let states_parents = last_roomsstatehash.map_or_else( || Ok(Vec::new()), |&last_roomsstatehash| { - services().rooms.state_compressor.load_shortstatehash_info(last_roomsstatehash) + services() + .rooms + .state_compressor + .load_shortstatehash_info(last_roomsstatehash) }, )?; let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { - let statediffnew = - current_state.difference(&parent_stateinfo.1).copied().collect::>(); + let statediffnew = current_state + .difference(&parent_stateinfo.1) + .copied() + .collect::>(); - let statediffremoved = - parent_stateinfo.1.difference(¤t_state).copied().collect::>(); + let statediffremoved = parent_stateinfo + .1 + .difference(¤t_state) + .copied() + .collect::>(); (statediffnew, statediffremoved) } else { @@ -634,7 +649,12 @@ impl KeyValueDatabase { let event_id = db.shorteventid_eventid.get(&seventid).unwrap().unwrap(); let string = utils::string_from_bytes(&event_id).unwrap(); let event_id = <&EventId>::try_from(string.as_str()).unwrap(); - let pdu = services().rooms.timeline.get_pdu(event_id).unwrap().unwrap(); + let pdu = services() + .rooms + .timeline + .get_pdu(event_id) + .unwrap() + .unwrap(); if Some(&pdu.room_id) != current_room.as_ref() { current_room = Some(pdu.room_id.clone()); @@ -676,7 +696,11 @@ impl KeyValueDatabase { let room_id = parts.next().unwrap(); let count = parts.next().unwrap(); - let short_room_id = db.roomid_shortroomid.get(room_id).unwrap().expect("shortroomid should exist"); + let short_room_id = db + .roomid_shortroomid + .get(room_id) + .unwrap() + .expect("shortroomid should exist"); let mut new_key = short_room_id; new_key.extend_from_slice(count); @@ -694,7 +718,11 @@ impl KeyValueDatabase { let room_id = parts.next().unwrap(); let count = parts.next().unwrap(); - let short_room_id = db.roomid_shortroomid.get(room_id).unwrap().expect("shortroomid should exist"); + let short_room_id = db + .roomid_shortroomid + .get(room_id) + .unwrap() + .expect("shortroomid should exist"); let mut new_value = short_room_id; new_value.extend_from_slice(count); @@ -724,8 +752,11 @@ impl KeyValueDatabase { let _pdu_id_room = parts.next().unwrap(); let pdu_id_count = parts.next().unwrap(); - let short_room_id = - db.roomid_shortroomid.get(room_id).unwrap().expect("shortroomid should exist"); + let short_room_id = db + .roomid_shortroomid + .get(room_id) + .unwrap() + .expect("shortroomid should exist"); let mut new_key = short_room_id; new_key.extend_from_slice(word); new_key.push(0xFF); @@ -765,7 +796,8 @@ impl KeyValueDatabase { if services().globals.database_version()? < 10 { // Add other direction for shortstatekeys for (statekey, shortstatekey) in db.statekey_shortstatekey.iter() { - db.shortstatekey_statekey.insert(&shortstatekey, &statekey)?; + db.shortstatekey_statekey + .insert(&shortstatekey, &statekey)?; } // Force E2EE device list updates so we can send them over federation @@ -779,7 +811,9 @@ impl KeyValueDatabase { } if services().globals.database_version()? < 11 { - db.db.open_tree("userdevicesessionid_uiaarequest")?.clear()?; + db.db + .open_tree("userdevicesessionid_uiaarequest")? + .clear()?; services().globals.bump_database_version(11)?; warn!("Migration: 10 -> 11 finished"); @@ -813,7 +847,9 @@ impl KeyValueDatabase { if rule.is_some() { let mut rule = rule.unwrap().clone(); content_rule_transformation[1].clone_into(&mut rule.rule_id); - rules_list.content.shift_remove(content_rule_transformation[0]); + rules_list + .content + .shift_remove(content_rule_transformation[0]); rules_list.content.insert(rule); } } @@ -874,7 +910,10 @@ impl KeyValueDatabase { let mut account_data = serde_json::from_str::(raw_rules_list.get()).unwrap(); let user_default_rules = Ruleset::server_default(&user); - account_data.content.global.update_with_server_default(user_default_rules); + account_data + .content + .global + .update_with_server_default(user_default_rules); services().account_data.update( None, @@ -936,7 +975,10 @@ impl KeyValueDatabase { warn!( "User {} matches the following forbidden username patterns: {}", user_id.to_string(), - matches.into_iter().map(|x| &patterns.patterns()[x]).join(", ") + matches + .into_iter() + .map(|x| &patterns.patterns()[x]) + .join(", ") ); } } @@ -957,7 +999,10 @@ impl KeyValueDatabase { "Room with alias {} ({}) matches the following forbidden room name patterns: {}", room_alias, &room_id, - matches.into_iter().map(|x| &patterns.patterns()[x]).join(", ") + matches + .into_iter() + .map(|x| &patterns.patterns()[x]) + .join(", ") ); } } @@ -971,7 +1016,9 @@ impl KeyValueDatabase { latest_database_version ); } else { - services().globals.bump_database_version(latest_database_version)?; + services() + .globals + .bump_database_version(latest_database_version)?; // Create the admin room and server user on first run services().admin.create_admin_room().await?; @@ -993,10 +1040,12 @@ impl KeyValueDatabase { "The Conduit account emergency password is set! Please unset it as soon as you finish admin \ account recovery!" ); - services().admin.send_message(RoomMessageEventContent::text_plain( - "The Conduit account emergency password is set! Please unset it as soon as you finish admin \ - account recovery!", - )); + services() + .admin + .send_message(RoomMessageEventContent::text_plain( + "The Conduit account emergency password is set! Please unset it as soon as you finish \ + admin account recovery!", + )); } }, Err(e) => { @@ -1047,8 +1096,13 @@ impl KeyValueDatabase { } async fn try_handle_updates() -> Result<()> { - let response = - services().globals.client.default.get("https://pupbrain.dev/check-for-updates/stable").send().await?; + let response = services() + .globals + .client + .default + .get("https://pupbrain.dev/check-for-updates/stable") + .send() + .await?; let response = serde_json::from_str::(&response.text().await?).map_err(|e| { error!("Bad check for updates response: {e}"); @@ -1060,13 +1114,17 @@ impl KeyValueDatabase { last_update_id = last_update_id.max(update.id); if update.id > services().globals.last_check_for_updates_id()? { error!("{}", update.message); - services().admin.send_message(RoomMessageEventContent::text_plain(format!( - "@room: the following is a message from the conduwuit puppy. it was sent on '{}':\n\n{}", - update.date, update.message - ))); + services() + .admin + .send_message(RoomMessageEventContent::text_plain(format!( + "@room: the following is a message from the conduwuit puppy. it was sent on '{}':\n\n{}", + update.date, update.message + ))); } } - services().globals.update_check_for_updates_id(last_update_id)?; + services() + .globals + .update_check_for_updates_id(last_update_id)?; Ok(()) } @@ -1138,7 +1196,9 @@ fn set_emergency_access() -> Result { let conduit_user = UserId::parse_with_server_name("conduit", services().globals.server_name()) .expect("@conduit:server_name is a valid UserId"); - services().users.set_password(&conduit_user, services().globals.emergency_password().as_deref())?; + services() + .users + .set_password(&conduit_user, services().globals.emergency_password().as_deref())?; let (ruleset, res) = match services().globals.emergency_password() { Some(_) => (Ruleset::server_default(&conduit_user), Ok(true)), diff --git a/src/lib.rs b/src/lib.rs index 4ff91f56..e287f1cb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,5 +20,8 @@ pub use utils::error::{Error, Result}; pub static SERVICES: RwLock>> = RwLock::new(None); pub fn services() -> &'static Services<'static> { - SERVICES.read().unwrap().expect("SERVICES should be initialized when this is called") + SERVICES + .read() + .unwrap() + .expect("SERVICES should be initialized when this is called") } diff --git a/src/main.rs b/src/main.rs index aa95ef53..48cf93f6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -111,7 +111,9 @@ async fn main() { }, }; - let subscriber = tracing_subscriber::Registry::default().with(filter_layer).with(telemetry); + let subscriber = tracing_subscriber::Registry::default() + .with(filter_layer) + .with(telemetry); tracing::subscriber::set_global_default(subscriber).unwrap(); } } else if config.tracing_flame { @@ -292,21 +294,30 @@ async fn main() { ); } - if config.url_preview_domain_contains_allowlist.contains(&"*".to_owned()) { + if config + .url_preview_domain_contains_allowlist + .contains(&"*".to_owned()) + { warn!( "All URLs are allowed for URL previews via setting \"url_preview_domain_contains_allowlist\" to \"*\". \ This opens up significant attack surface to your server. You are expected to be aware of the risks by \ doing this." ); } - if config.url_preview_domain_explicit_allowlist.contains(&"*".to_owned()) { + if config + .url_preview_domain_explicit_allowlist + .contains(&"*".to_owned()) + { warn!( "All URLs are allowed for URL previews via setting \"url_preview_domain_explicit_allowlist\" to \"*\". \ This opens up significant attack surface to your server. You are expected to be aware of the risks by \ doing this." ); } - if config.url_preview_url_contains_allowlist.contains(&"*".to_owned()) { + if config + .url_preview_url_contains_allowlist + .contains(&"*".to_owned()) + { warn!( "All URLs are allowed for URL previews via setting \"url_preview_url_contains_allowlist\" to \"*\". This \ opens up significant attack surface to your server. You are expected to be aware of the risks by doing \ @@ -338,9 +349,17 @@ async fn run_server() -> io::Result<()> { // Left is only 1 value, so make a vec with 1 value only let port_vec = [port]; - port_vec.iter().copied().map(|port| SocketAddr::from((config.address, *port))).collect::>() + port_vec + .iter() + .copied() + .map(|port| SocketAddr::from((config.address, *port))) + .collect::>() }, - Right(ports) => ports.iter().copied().map(|port| SocketAddr::from((config.address, port))).collect::>(), + Right(ports) => ports + .iter() + .copied() + .map(|port| SocketAddr::from((config.address, port))) + .collect::>(), }; let x_requested_with = HeaderName::from_static("x-requested-with"); @@ -385,7 +404,10 @@ async fn run_server() -> io::Result<()> { .max_age(Duration::from_secs(86400)), ) .layer(DefaultBodyLimit::max( - config.max_request_size.try_into().expect("failed to convert max request size"), + config + .max_request_size + .try_into() + .expect("failed to convert max request size"), )); let app; @@ -395,7 +417,9 @@ async fn run_server() -> io::Result<()> { { app = if cfg!(feature = "zstd_compression") && config.zstd_compression { debug!("zstd body compression is enabled"); - routes().layer(middlewares.compression()).into_make_service() + routes() + .layer(middlewares.compression()) + .into_make_service() } else { routes().layer(middlewares).into_make_service() } @@ -432,7 +456,9 @@ async fn run_server() -> io::Result<()> { let octal_perms = u32::from_str_radix(&socket_perms, 8).unwrap(); let listener = tokio::net::UnixListener::bind(path.clone())?; - tokio::fs::set_permissions(path, Permissions::from_mode(octal_perms)).await.unwrap(); + tokio::fs::set_permissions(path, Permissions::from_mode(octal_perms)) + .await + .unwrap(); let socket = SocketIncoming::from_listener(listener); #[allow(clippy::let_underscore_untyped)] // error[E0658]: attributes on expressions are experimental @@ -483,7 +509,11 @@ async fn run_server() -> io::Result<()> { } } else { for addr in &addrs { - join_set.spawn(bind_rustls(*addr, conf.clone()).handle(handle.clone()).serve(app.clone())); + join_set.spawn( + bind_rustls(*addr, conf.clone()) + .handle(handle.clone()) + .serve(app.clone()), + ); } } @@ -528,7 +558,9 @@ async fn spawn_task( if services().globals.shutdown.load(atomic::Ordering::Relaxed) { return Err(StatusCode::SERVICE_UNAVAILABLE); } - tokio::spawn(next.run(req)).await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) + tokio::spawn(next.run(req)) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) } async fn unrecognized_method( @@ -758,7 +790,9 @@ fn routes() -> Router { async fn shutdown_signal(handle: ServerHandle, tx: Sender<()>) -> Result<()> { let ctrl_c = async { - signal::ctrl_c().await.expect("failed to install Ctrl+C handler"); + signal::ctrl_c() + .await + .expect("failed to install Ctrl+C handler"); }; #[cfg(unix)] diff --git a/src/service/admin/appservice.rs b/src/service/admin/appservice.rs index db1cda70..ff0611e0 100644 --- a/src/service/admin/appservice.rs +++ b/src/service/admin/appservice.rs @@ -62,7 +62,11 @@ pub(crate) async fn process(command: AppserviceCommand, body: Vec<&str>) -> Resu }, AppserviceCommand::Unregister { appservice_identifier, - } => match services().appservice.unregister_appservice(&appservice_identifier).await { + } => match services() + .appservice + .unregister_appservice(&appservice_identifier) + .await + { Ok(()) => Ok(RoomMessageEventContent::text_plain("Appservice unregistered.")), Err(e) => Ok(RoomMessageEventContent::text_plain(format!( "Failed to unregister appservice: {e}" @@ -70,7 +74,11 @@ pub(crate) async fn process(command: AppserviceCommand, body: Vec<&str>) -> Resu }, AppserviceCommand::Show { appservice_identifier, - } => match services().appservice.get_registration(&appservice_identifier).await { + } => match services() + .appservice + .get_registration(&appservice_identifier) + .await + { Some(config) => { let config_str = serde_yaml::to_string(&config).expect("config should've been validated on register"); let output = format!("Config for {}:\n\n```yaml\n{}\n```", appservice_identifier, config_str,); diff --git a/src/service/admin/debug.rs b/src/service/admin/debug.rs index bd415ddc..309111af 100644 --- a/src/service/admin/debug.rs +++ b/src/service/admin/debug.rs @@ -81,7 +81,12 @@ pub(crate) async fn process(command: DebugCommand, body: Vec<&str>) -> Result::try_from(room_id_str) .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; let start = Instant::now(); - let count = services().rooms.auth_chain.get_auth_chain(room_id, vec![event_id]).await?.count(); + let count = services() + .rooms + .auth_chain + .get_auth_chain(room_id, vec![event_id]) + .await? + .count(); let elapsed = start.elapsed(); RoomMessageEventContent::text_plain(format!("Loaded auth chain with length {count} in {elapsed:?}")) } else { @@ -119,7 +124,10 @@ pub(crate) async fn process(command: DebugCommand, body: Vec<&str>) -> Result { let mut outlier = false; - let mut pdu_json = services().rooms.timeline.get_non_outlier_pdu_json(&event_id)?; + let mut pdu_json = services() + .rooms + .timeline + .get_non_outlier_pdu_json(&event_id)?; if pdu_json.is_none() { outlier = true; pdu_json = services().rooms.timeline.get_pdu_json(&event_id)?; @@ -224,7 +232,11 @@ pub(crate) async fn process(command: DebugCommand, body: Vec<&str>) -> Result) -> Resu Ok(value) => { let pub_key_map = RwLock::new(BTreeMap::new()); - services().rooms.event_handler.fetch_required_signing_keys([&value], &pub_key_map).await?; + services() + .rooms + .event_handler + .fetch_required_signing_keys([&value], &pub_key_map) + .await?; let pub_key_map = pub_key_map.read().await; match ruma::signatures::verify_json(&pub_key_map, &value) { diff --git a/src/service/admin/media.rs b/src/service/admin/media.rs index f467c0dd..ee86401e 100644 --- a/src/service/admin/media.rs +++ b/src/service/admin/media.rs @@ -202,7 +202,10 @@ pub(crate) async fn process(command: MediaCommand, body: Vec<&str>) -> Result { - let deleted_count = services().media.delete_all_remote_media_at_after_time(duration).await?; + let deleted_count = services() + .media + .delete_all_remote_media_at_after_time(duration) + .await?; Ok(RoomMessageEventContent::text_plain(format!( "Deleted {} total files.", diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index c09bbacd..23fba1f4 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -169,11 +169,15 @@ impl Service { } pub fn process_message(&self, room_message: String, event_id: Arc) { - self.sender.send(AdminRoomEvent::ProcessMessage(room_message, event_id)).unwrap(); + self.sender + .send(AdminRoomEvent::ProcessMessage(room_message, event_id)) + .unwrap(); } pub fn send_message(&self, message_content: RoomMessageEventContent) { - self.sender.send(AdminRoomEvent::SendMessage(message_content)).unwrap(); + self.sender + .send(AdminRoomEvent::SendMessage(message_content)) + .unwrap(); } // Parse and process a message from the admin room @@ -276,7 +280,10 @@ impl Service { if let Some(line_index) = text_lines.iter().position(|line| *line == "[commandbody]") { text_lines.remove(line_index); - while text_lines.get(line_index).is_some_and(|line| line.starts_with('#')) { + while text_lines + .get(line_index) + .is_some_and(|line| line.starts_with('#')) + { command_body += if text_lines[line_index].starts_with("# ") { &text_lines[line_index][2..] } else { @@ -304,7 +311,9 @@ impl Service { // Add HTML line-breaks - text.replace("\n\n\n", "\n\n").replace('\n', "
      \n").replace("[nobr]
      ", "") + text.replace("\n\n\n", "\n\n") + .replace('\n', "
      \n") + .replace("[nobr]
      ", "") } /// Create the admin room. @@ -316,8 +325,15 @@ impl Service { services().rooms.short.get_or_create_shortroomid(&room_id)?; - let mutex_state = - Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.clone()).or_default()); + let mutex_state = Arc::clone( + services() + .globals + .roomid_mutex_state + .write() + .await + .entry(room_id.clone()) + .or_default(), + ); let state_lock = mutex_state.lock().await; // Create a user for the server @@ -560,7 +576,10 @@ impl Service { .try_into() .expect("#admins:server_name is a valid alias name"); - services().rooms.alias.resolve_local_alias(&admin_room_alias) + services() + .rooms + .alias + .resolve_local_alias(&admin_room_alias) } /// Invite the user to the conduit admin room. @@ -568,8 +587,15 @@ impl Service { /// In conduit, this is equivalent to granting admin privileges. pub(crate) async fn make_user_admin(&self, user_id: &UserId, displayname: String) -> Result<()> { if let Some(room_id) = Self::get_admin_room()? { - let mutex_state = - Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.clone()).or_default()); + let mutex_state = Arc::clone( + services() + .globals + .roomid_mutex_state + .write() + .await + .entry(room_id.clone()) + .or_default(), + ); let state_lock = mutex_state.lock().await; // Use the server user to grant the new admin's power level @@ -681,13 +707,29 @@ impl Service { } } -fn escape_html(s: &str) -> String { s.replace('&', "&").replace('<', "<").replace('>', ">") } +fn escape_html(s: &str) -> String { + s.replace('&', "&") + .replace('<', "<") + .replace('>', ">") +} fn get_room_info(id: &OwnedRoomId) -> (OwnedRoomId, u64, String) { ( id.clone(), - services().rooms.state_cache.room_joined_count(id).ok().flatten().unwrap_or(0), - services().rooms.state_accessor.get_name(id).ok().flatten().unwrap_or_else(|| id.to_string()), + services() + .rooms + .state_cache + .room_joined_count(id) + .ok() + .flatten() + .unwrap_or(0), + services() + .rooms + .state_accessor + .get_name(id) + .ok() + .flatten() + .unwrap_or_else(|| id.to_string()), ) } @@ -705,7 +747,9 @@ mod test { fn get_help_subcommand() { get_help_inner("help"); } fn get_help_inner(input: &str) { - let error = AdminCommand::try_parse_from(["argv[0] doesn't matter", input]).unwrap_err().to_string(); + let error = AdminCommand::try_parse_from(["argv[0] doesn't matter", input]) + .unwrap_err() + .to_string(); // Search for a handful of keywords that suggest the help printed properly assert!(error.contains("Usage:")); diff --git a/src/service/admin/room.rs b/src/service/admin/room.rs index 040bc0f2..721191b1 100644 --- a/src/service/admin/room.rs +++ b/src/service/admin/room.rs @@ -55,7 +55,11 @@ pub(crate) async fn process(command: RoomCommand, body: Vec<&str>) -> Result>(); + let rooms = rooms + .into_iter() + .skip(page.saturating_sub(1) * PAGE_SIZE) + .take(PAGE_SIZE) + .collect::>(); if rooms.is_empty() { return Ok(RoomMessageEventContent::text_plain("No more rooms.")); @@ -72,17 +76,19 @@ pub(crate) async fn process(command: RoomCommand, body: Vec<&str>) -> ResultRoom list - page \ {page}\nid\tmembers\tname\n{}", - rooms.iter().fold(String::new(), |mut output, (id, members, name)| { - writeln!( - output, - "{}\t{}\t{}", - escape_html(id.as_ref()), - members, - escape_html(name) - ) - .unwrap(); - output - }) + rooms + .iter() + .fold(String::new(), |mut output, (id, members, name)| { + writeln!( + output, + "{}\t{}\t{}", + escape_html(id.as_ref()), + members, + escape_html(name) + ) + .unwrap(); + output + }) ); Ok(RoomMessageEventContent::text_html(output_plain, output_html)) }, diff --git a/src/service/admin/room_alias.rs b/src/service/admin/room_alias.rs index 14ca078d..f1621344 100644 --- a/src/service/admin/room_alias.rs +++ b/src/service/admin/room_alias.rs @@ -107,7 +107,11 @@ pub(crate) async fn process(command: RoomAliasCommand, _body: Vec<&str>) -> Resu room_id, } => { if let Some(room_id) = room_id { - let aliases = services().rooms.alias.local_aliases_for_room(&room_id).collect::, _>>(); + let aliases = services() + .rooms + .alias + .local_aliases_for_room(&room_id) + .collect::, _>>(); match aliases { Ok(aliases) => { let plain_list = aliases.iter().fold(String::new(), |mut output, alias| { @@ -127,26 +131,34 @@ pub(crate) async fn process(command: RoomAliasCommand, _body: Vec<&str>) -> Resu Err(err) => Ok(RoomMessageEventContent::text_plain(format!("Unable to list aliases: {}", err))), } } else { - let aliases = services().rooms.alias.all_local_aliases().collect::, _>>(); + let aliases = services() + .rooms + .alias + .all_local_aliases() + .collect::, _>>(); match aliases { Ok(aliases) => { let server_name = services().globals.server_name(); - let plain_list = aliases.iter().fold(String::new(), |mut output, (alias, id)| { - writeln!(output, "- `{alias}` -> #{id}:{server_name}").unwrap(); - output - }); + let plain_list = aliases + .iter() + .fold(String::new(), |mut output, (alias, id)| { + writeln!(output, "- `{alias}` -> #{id}:{server_name}").unwrap(); + output + }); - let html_list = aliases.iter().fold(String::new(), |mut output, (alias, id)| { - writeln!( - output, - "
    • {} -> #{}:{}
    • ", - escape_html(alias.as_ref()), - escape_html(id.as_ref()), - server_name - ) - .unwrap(); - output - }); + let html_list = aliases + .iter() + .fold(String::new(), |mut output, (alias, id)| { + writeln!( + output, + "
    • {} -> #{}:{}
    • ", + escape_html(alias.as_ref()), + escape_html(id.as_ref()), + server_name + ) + .unwrap(); + output + }); let plain = format!("Aliases:\n{plain_list}"); let html = format!("Aliases:\n
        {html_list}
      "); diff --git a/src/service/admin/room_directory.rs b/src/service/admin/room_directory.rs index d0fdcf6a..86dc03d6 100644 --- a/src/service/admin/room_directory.rs +++ b/src/service/admin/room_directory.rs @@ -58,7 +58,11 @@ pub(crate) async fn process(command: RoomDirectoryCommand, _body: Vec<&str>) -> rooms.sort_by_key(|r| r.1); rooms.reverse(); - let rooms = rooms.into_iter().skip(page.saturating_sub(1) * PAGE_SIZE).take(PAGE_SIZE).collect::>(); + let rooms = rooms + .into_iter() + .skip(page.saturating_sub(1) * PAGE_SIZE) + .take(PAGE_SIZE) + .collect::>(); if rooms.is_empty() { return Ok(RoomMessageEventContent::text_plain("No more rooms.")); @@ -75,17 +79,19 @@ pub(crate) async fn process(command: RoomDirectoryCommand, _body: Vec<&str>) -> let output_html = format!( "\n\t\t\n{}
      Room directory - page \ {page}
      idmembersname
      ", - rooms.iter().fold(String::new(), |mut output, (id, members, name)| { - writeln!( - output, - "{}\t{}\t{}", - escape_html(id.as_ref()), - members, - escape_html(name.as_ref()) - ) - .unwrap(); - output - }) + rooms + .iter() + .fold(String::new(), |mut output, (id, members, name)| { + writeln!( + output, + "{}\t{}\t{}", + escape_html(id.as_ref()), + members, + escape_html(name.as_ref()) + ) + .unwrap(); + output + }) ); Ok(RoomMessageEventContent::text_html(output_plain, output_html)) }, diff --git a/src/service/admin/room_moderation.rs b/src/service/admin/room_moderation.rs index ca20ca20..22d5ff31 100644 --- a/src/service/admin/room_moderation.rs +++ b/src/service/admin/room_moderation.rs @@ -446,7 +446,11 @@ pub(crate) async fn process(command: RoomModerationCommand, body: Vec<&str>) -> )) }, RoomModerationCommand::ListBannedRooms => { - let rooms = services().rooms.metadata.list_banned_rooms().collect::, _>>(); + let rooms = services() + .rooms + .metadata + .list_banned_rooms() + .collect::, _>>(); match rooms { Ok(room_ids) => { diff --git a/src/service/admin/user.rs b/src/service/admin/user.rs index 977e80e4..5f1bbce9 100644 --- a/src/service/admin/user.rs +++ b/src/service/admin/user.rs @@ -115,13 +115,18 @@ pub(crate) async fn process(command: UserCommand, body: Vec<&str>) -> Result) -> Result) -> Result Ok(RoomMessageEventContent::text_plain(format!( "Successfully reset the password for user {user_id}: `{new_password}`" ))), @@ -343,17 +355,19 @@ pub(crate) async fn process(command: UserCommand, body: Vec<&str>) -> ResultRooms {user_id} \ Joined\nid\tmembers\tname\n{}", - rooms.iter().fold(String::new(), |mut output, (id, members, name)| { - writeln!( - output, - "{}\t{}\t{}", - escape_html(id.as_ref()), - members, - escape_html(name) - ) - .unwrap(); - output - }) + rooms + .iter() + .fold(String::new(), |mut output, (id, members, name)| { + writeln!( + output, + "{}\t{}\t{}", + escape_html(id.as_ref()), + members, + escape_html(name) + ) + .unwrap(); + output + }) ); Ok(RoomMessageEventContent::text_html(output_plain, output_html)) }, diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs index d0a37961..4f25aade 100644 --- a/src/service/appservice/mod.rs +++ b/src/service/appservice/mod.rs @@ -107,7 +107,10 @@ impl Service { for appservice in db.all()? { registration_info.insert( appservice.0, - appservice.1.try_into().expect("Should be validated on registration"), + appservice + .1 + .try_into() + .expect("Should be validated on registration"), ); } @@ -119,7 +122,12 @@ impl Service { /// Registers an appservice and returns the ID to the caller pub async fn register_appservice(&self, yaml: Registration) -> Result { - services().appservice.registration_info.write().await.insert(yaml.id.clone(), yaml.clone().try_into()?); + services() + .appservice + .registration_info + .write() + .await + .insert(yaml.id.clone(), yaml.clone().try_into()?); self.db.register_appservice(yaml) } @@ -130,19 +138,40 @@ impl Service { /// /// * `service_name` - the name you send to register the service previously pub async fn unregister_appservice(&self, service_name: &str) -> Result<()> { - services().appservice.registration_info.write().await.remove(service_name); + services() + .appservice + .registration_info + .write() + .await + .remove(service_name); self.db.unregister_appservice(service_name) } pub async fn get_registration(&self, id: &str) -> Option { - self.registration_info.read().await.get(id).cloned().map(|info| info.registration) + self.registration_info + .read() + .await + .get(id) + .cloned() + .map(|info| info.registration) } - pub async fn iter_ids(&self) -> Vec { self.registration_info.read().await.keys().cloned().collect() } + pub async fn iter_ids(&self) -> Vec { + self.registration_info + .read() + .await + .keys() + .cloned() + .collect() + } pub async fn find_from_token(&self, token: &str) -> Option { - self.read().await.values().find(|info| info.registration.as_token == token).cloned() + self.read() + .await + .values() + .find(|info| info.registration.as_token == token) + .cloned() } pub fn read(&self) -> impl Future>> { diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 61a00e15..0422b857 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -106,8 +106,10 @@ impl Service<'_> { }, }; - let jwt_decoding_key = - config.jwt_secret.as_ref().map(|secret| jsonwebtoken::DecodingKey::from_secret(secret.as_bytes())); + let jwt_decoding_key = config + .jwt_secret + .as_ref() + .map(|secret| jsonwebtoken::DecodingKey::from_secret(secret.as_bytes())); let resolver = Arc::new(resolver::Resolver::new(config)); @@ -159,7 +161,10 @@ impl Service<'_> { fs::create_dir_all(s.get_media_folder())?; - if !s.supported_room_versions().contains(&s.config.default_room_version) { + if !s + .supported_room_versions() + .contains(&s.config.default_room_version) + { error!(config=?s.config.default_room_version, fallback=?crate::config::default_default_room_version(), "Room version in config isn't supported, falling back to default version"); s.config.default_room_version = crate::config::default_default_room_version(); }; diff --git a/src/service/globals/resolver.rs b/src/service/globals/resolver.rs index f2acc5f8..95cf10b6 100644 --- a/src/service/globals/resolver.rs +++ b/src/service/globals/resolver.rs @@ -65,9 +65,7 @@ impl Resolver { } impl Resolve for Resolver { - fn resolve(&self, name: Name) -> Resolving { - resolve_to_reqwest(self.resolver.clone(), name) - } + fn resolve(&self, name: Name) -> Resolving { resolve_to_reqwest(self.resolver.clone(), name) } } impl Resolve for Hooked { @@ -78,7 +76,7 @@ impl Resolve for Hooked { .get(name.as_str()) .map_or_else( || resolve_to_reqwest(self.resolver.clone(), name), - |(override_name, port)| cached_to_reqwest(override_name, *port) + |(override_name, port)| cached_to_reqwest(override_name, *port), ) } } diff --git a/src/service/key_backups/mod.rs b/src/service/key_backups/mod.rs index 202f72d2..f4bb5c3b 100644 --- a/src/service/key_backups/mod.rs +++ b/src/service/key_backups/mod.rs @@ -44,7 +44,8 @@ impl Service { pub fn add_key( &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw, ) -> Result<()> { - self.db.add_key(user_id, version, room_id, session_id, key_data) + self.db + .add_key(user_id, version, room_id, session_id, key_data) } pub fn count_keys(&self, user_id: &UserId, version: &str) -> Result { self.db.count_keys(user_id, version) } @@ -76,6 +77,7 @@ impl Service { } pub fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()> { - self.db.delete_room_key(user_id, version, room_id, session_id) + self.db + .delete_room_key(user_id, version, room_id, session_id) } } diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs index 363e6039..8114a6b7 100644 --- a/src/service/media/mod.rs +++ b/src/service/media/mod.rs @@ -50,9 +50,11 @@ impl Service { ) -> Result<()> { // Width, Height = 0 if it's not a thumbnail let key = if let Some(user) = sender_user { - self.db.create_file_metadata(Some(user.as_str()), mxc, 0, 0, content_disposition, content_type)? + self.db + .create_file_metadata(Some(user.as_str()), mxc, 0, 0, content_disposition, content_type)? } else { - self.db.create_file_metadata(None, mxc, 0, 0, content_disposition, content_type)? + self.db + .create_file_metadata(None, mxc, 0, 0, content_disposition, content_type)? }; let path; @@ -118,9 +120,11 @@ impl Service { content_type: Option<&str>, width: u32, height: u32, file: &[u8], ) -> Result<()> { let key = if let Some(user) = sender_user { - self.db.create_file_metadata(Some(user.as_str()), mxc, width, height, content_disposition, content_type)? + self.db + .create_file_metadata(Some(user.as_str()), mxc, width, height, content_disposition, content_type)? } else { - self.db.create_file_metadata(None, mxc, width, height, content_disposition, content_type)? + self.db + .create_file_metadata(None, mxc, width, height, content_disposition, content_type)? }; let path; @@ -161,7 +165,9 @@ impl Service { }; let mut file = Vec::new(); - BufReader::new(File::open(path).await?).read_to_end(&mut file).await?; + BufReader::new(File::open(path).await?) + .read_to_end(&mut file) + .await?; Ok(Some(FileMeta { content_disposition, @@ -308,7 +314,9 @@ impl Service { /// For width,height <= 96 the server uses another thumbnailing algorithm /// which crops the image afterwards. pub async fn get_thumbnail(&self, mxc: String, width: u32, height: u32) -> Result> { - let (width, height, crop) = self.thumbnail_properties(width, height).unwrap_or((0, 0, false)); // 0, 0 because that's the original file + let (width, height, crop) = self + .thumbnail_properties(width, height) + .unwrap_or((0, 0, false)); // 0, 0 because that's the original file if let Ok((content_disposition, content_type, key)) = self.db.search_file_metadata(mxc.clone(), width, height) { // Using saved thumbnail @@ -466,7 +474,9 @@ impl Service { } pub async fn set_url_preview(&self, url: &str, data: &UrlPreviewData) -> Result<()> { - let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).expect("valid system time"); + let now = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .expect("valid system time"); self.db.set_url_preview(url, data, now) } } @@ -492,9 +502,19 @@ mod tests { key.extend_from_slice(&width.to_be_bytes()); key.extend_from_slice(&height.to_be_bytes()); key.push(0xFF); - key.extend_from_slice(content_disposition.as_ref().map(|f| f.as_bytes()).unwrap_or_default()); + key.extend_from_slice( + content_disposition + .as_ref() + .map(|f| f.as_bytes()) + .unwrap_or_default(), + ); key.push(0xFF); - key.extend_from_slice(content_type.as_ref().map(|c| c.as_bytes()).unwrap_or_default()); + key.extend_from_slice( + content_type + .as_ref() + .map(|c| c.as_bytes()) + .unwrap_or_default(), + ); Ok(key) } diff --git a/src/service/mod.rs b/src/service/mod.rs index 8f19e029..e2c49de4 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -168,11 +168,41 @@ impl Services<'_> { async fn memory_usage(&self) -> String { let lazy_load_waiting = self.rooms.lazy_loading.lazy_load_waiting.lock().await.len(); - let server_visibility_cache = self.rooms.state_accessor.server_visibility_cache.lock().unwrap().len(); - let user_visibility_cache = self.rooms.state_accessor.user_visibility_cache.lock().unwrap().len(); - let stateinfo_cache = self.rooms.state_compressor.stateinfo_cache.lock().unwrap().len(); - let lasttimelinecount_cache = self.rooms.timeline.lasttimelinecount_cache.lock().await.len(); - let roomid_spacehierarchy_cache = self.rooms.spaces.roomid_spacehierarchy_cache.lock().await.len(); + let server_visibility_cache = self + .rooms + .state_accessor + .server_visibility_cache + .lock() + .unwrap() + .len(); + let user_visibility_cache = self + .rooms + .state_accessor + .user_visibility_cache + .lock() + .unwrap() + .len(); + let stateinfo_cache = self + .rooms + .state_compressor + .stateinfo_cache + .lock() + .unwrap() + .len(); + let lasttimelinecount_cache = self + .rooms + .timeline + .lasttimelinecount_cache + .lock() + .await + .len(); + let roomid_spacehierarchy_cache = self + .rooms + .spaces + .roomid_spacehierarchy_cache + .lock() + .await + .len(); format!( "\ @@ -187,22 +217,52 @@ roomid_spacehierarchy_cache: {roomid_spacehierarchy_cache}" async fn clear_caches(&self, amount: u32) { if amount > 0 { - self.rooms.lazy_loading.lazy_load_waiting.lock().await.clear(); + self.rooms + .lazy_loading + .lazy_load_waiting + .lock() + .await + .clear(); } if amount > 1 { - self.rooms.state_accessor.server_visibility_cache.lock().unwrap().clear(); + self.rooms + .state_accessor + .server_visibility_cache + .lock() + .unwrap() + .clear(); } if amount > 2 { - self.rooms.state_accessor.user_visibility_cache.lock().unwrap().clear(); + self.rooms + .state_accessor + .user_visibility_cache + .lock() + .unwrap() + .clear(); } if amount > 3 { - self.rooms.state_compressor.stateinfo_cache.lock().unwrap().clear(); + self.rooms + .state_compressor + .stateinfo_cache + .lock() + .unwrap() + .clear(); } if amount > 4 { - self.rooms.timeline.lasttimelinecount_cache.lock().await.clear(); + self.rooms + .timeline + .lasttimelinecount_cache + .lock() + .await + .clear(); } if amount > 5 { - self.rooms.spaces.roomid_spacehierarchy_cache.lock().await.clear(); + self.rooms + .spaces + .roomid_spacehierarchy_cache + .lock() + .await + .clear(); } if amount > 6 { self.globals.resolver.overrides.write().unwrap().clear(); diff --git a/src/service/pdu.rs b/src/service/pdu.rs index e38e0cdc..e4b8deae 100644 --- a/src/service/pdu.rs +++ b/src/service/pdu.rs @@ -277,11 +277,17 @@ impl PduEvent { /// This does not return a full `Pdu` it is only to satisfy ruma's types. #[tracing::instrument] pub fn convert_to_outgoing_federation_event(mut pdu_json: CanonicalJsonObject) -> Box { - if let Some(unsigned) = pdu_json.get_mut("unsigned").and_then(|val| val.as_object_mut()) { + if let Some(unsigned) = pdu_json + .get_mut("unsigned") + .and_then(|val| val.as_object_mut()) + { unsigned.remove("transaction_id"); } - if let Some(room_id) = pdu_json.get("room_id").and_then(|val| RoomId::parse(val.as_str()?).ok()) { + if let Some(room_id) = pdu_json + .get("room_id") + .and_then(|val| RoomId::parse(val.as_str()?).ok()) + { if let Ok(room_version_id) = services().rooms.state.get_room_version(&room_id) { // room v3 and above removed the "event_id" field from remote PDU format match room_version_id { diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs index 91170b97..70d303ca 100644 --- a/src/service/pusher/mod.rs +++ b/src/service/pusher/mod.rs @@ -83,7 +83,12 @@ impl Service { } } - let response = services().globals.client.pusher.execute(reqwest_request).await; + let response = services() + .globals + .client + .pusher + .execute(reqwest_request) + .await; match response { Ok(mut response) => { @@ -108,10 +113,14 @@ impl Service { } let status = response.status(); - let mut http_response_builder = http::Response::builder().status(status).version(response.version()); + let mut http_response_builder = http::Response::builder() + .status(status) + .version(response.version()); mem::swap( response.headers_mut(), - http_response_builder.headers_mut().expect("http::response::Builder is usable"), + http_response_builder + .headers_mut() + .expect("http::response::Builder is usable"), ); let body = response.bytes().await.unwrap_or_else(|e| { @@ -130,7 +139,9 @@ impl Service { } let response = T::IncomingResponse::try_from_http_response( - http_response_builder.body(body).expect("reqwest body is valid http body"), + http_response_builder + .body(body) + .expect("reqwest body is valid http body"), ); response.map_err(|_| { info!("Push gateway returned invalid response bytes {}\n{}", destination, url); @@ -202,9 +213,18 @@ impl Service { let ctx = PushConditionRoomCtx { room_id: room_id.to_owned(), - member_count: UInt::from(services().rooms.state_cache.room_joined_count(room_id)?.unwrap_or(1) as u32), + member_count: UInt::from( + services() + .rooms + .state_cache + .room_joined_count(room_id)? + .unwrap_or(1) as u32, + ), user_id: user.to_owned(), - user_display_name: services().users.displayname(user)?.unwrap_or_else(|| user.localpart().to_owned()), + user_display_name: services() + .users + .displayname(user)? + .unwrap_or_else(|| user.localpart().to_owned()), power_levels: Some(power_levels), }; @@ -242,13 +262,16 @@ impl Service { notifi.counts = NotificationCounts::new(unread, uint!(0)); if event.kind == TimelineEventType::RoomEncrypted - || tweaks.iter().any(|t| matches!(t, Tweak::Highlight(true) | Tweak::Sound(_))) + || tweaks + .iter() + .any(|t| matches!(t, Tweak::Highlight(true) | Tweak::Sound(_))) { notifi.prio = NotificationPriority::High; } if event_id_only { - self.send_request(&http.url, send_event_notification::v1::Request::new(notifi)).await?; + self.send_request(&http.url, send_event_notification::v1::Request::new(notifi)) + .await?; } else { notifi.sender = Some(event.sender.clone()); notifi.event_type = Some(event.kind.clone()); @@ -262,7 +285,8 @@ impl Service { notifi.room_name = services().rooms.state_accessor.get_name(&event.room_id)?; - self.send_request(&http.url, send_event_notification::v1::Request::new(notifi)).await?; + self.send_request(&http.url, send_event_notification::v1::Request::new(notifi)) + .await?; } Ok(()) diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index 2aa80442..2225b7a4 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -53,7 +53,11 @@ impl Service { } let chunk_key: Vec = chunk.iter().map(|(short, _)| short).copied().collect(); - if let Some(cached) = services().rooms.auth_chain.get_cached_eventid_authchain(&chunk_key)? { + if let Some(cached) = services() + .rooms + .auth_chain + .get_cached_eventid_authchain(&chunk_key)? + { hits += 1; full_auth_chain.extend(cached.iter().copied()); continue; @@ -65,13 +69,20 @@ impl Service { let mut misses2 = 0; let mut i = 0; for (sevent_id, event_id) in chunk { - if let Some(cached) = services().rooms.auth_chain.get_cached_eventid_authchain(&[sevent_id])? { + if let Some(cached) = services() + .rooms + .auth_chain + .get_cached_eventid_authchain(&[sevent_id])? + { hits2 += 1; chunk_cache.extend(cached.iter().copied()); } else { misses2 += 1; let auth_chain = Arc::new(self.get_auth_chain_inner(room_id, &event_id)?); - services().rooms.auth_chain.cache_auth_chain(vec![sevent_id], Arc::clone(&auth_chain))?; + services() + .rooms + .auth_chain + .cache_auth_chain(vec![sevent_id], Arc::clone(&auth_chain))?; debug!( event_id = ?event_id, chain_length = ?auth_chain.len(), @@ -92,7 +103,10 @@ impl Service { "Chunk missed", ); let chunk_cache = Arc::new(chunk_cache); - services().rooms.auth_chain.cache_auth_chain(chunk_key, Arc::clone(&chunk_cache))?; + services() + .rooms + .auth_chain + .cache_auth_chain(chunk_key, Arc::clone(&chunk_cache))?; full_auth_chain.extend(chunk_cache.iter()); } @@ -103,7 +117,9 @@ impl Service { "Auth chain stats", ); - Ok(full_auth_chain.into_iter().filter_map(move |sid| services().rooms.short.get_eventid_from_short(sid).ok())) + Ok(full_auth_chain + .into_iter() + .filter_map(move |sid| services().rooms.short.get_eventid_from_short(sid).ok())) } #[tracing::instrument(skip(self, event_id))] @@ -118,7 +134,10 @@ impl Service { return Err(Error::BadRequest(ErrorKind::Forbidden, "Evil event in db")); } for auth_event in &pdu.auth_events { - let sauthevent = services().rooms.short.get_or_create_shorteventid(auth_event)?; + let sauthevent = services() + .rooms + .short + .get_or_create_shorteventid(auth_event)?; if !found.contains(&sauthevent) { found.insert(sauthevent); diff --git a/src/service/rooms/edus/presence/mod.rs b/src/service/rooms/edus/presence/mod.rs index 077fc915..f066cc70 100644 --- a/src/service/rooms/edus/presence/mod.rs +++ b/src/service/rooms/edus/presence/mod.rs @@ -91,7 +91,8 @@ impl Service { &self, room_id: &RoomId, user_id: &UserId, presence_state: PresenceState, currently_active: Option, last_active_ago: Option, status_msg: Option, ) -> Result<()> { - self.db.set_presence(room_id, user_id, presence_state, currently_active, last_active_ago, status_msg) + self.db + .set_presence(room_id, user_id, presence_state, currently_active, last_active_ago, status_msg) } /// Removes the presence record for the given user from the database. @@ -142,7 +143,11 @@ fn process_presence_timer(user_id: &OwnedUserId) -> Result<()> { let mut status_msg = None; for room_id in services().rooms.state_cache.rooms_joined(user_id) { - let presence_event = services().rooms.edus.presence.get_presence(&room_id?, user_id)?; + let presence_event = services() + .rooms + .edus + .presence + .get_presence(&room_id?, user_id)?; if let Some(presence_event) = presence_event { presence_state = presence_event.content.presence; diff --git a/src/service/rooms/edus/typing/mod.rs b/src/service/rooms/edus/typing/mod.rs index 01a197ba..8437597f 100644 --- a/src/service/rooms/edus/typing/mod.rs +++ b/src/service/rooms/edus/typing/mod.rs @@ -16,16 +16,32 @@ impl Service { /// Sets a user as typing until the timeout timestamp is reached or /// roomtyping_remove is called. pub async fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> { - self.typing.write().await.entry(room_id.to_owned()).or_default().insert(user_id.to_owned(), timeout); - self.last_typing_update.write().await.insert(room_id.to_owned(), services().globals.next_count()?); + self.typing + .write() + .await + .entry(room_id.to_owned()) + .or_default() + .insert(user_id.to_owned(), timeout); + self.last_typing_update + .write() + .await + .insert(room_id.to_owned(), services().globals.next_count()?); _ = self.typing_update_sender.send(room_id.to_owned()); Ok(()) } /// Removes a user from typing before the timeout is reached. pub async fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - self.typing.write().await.entry(room_id.to_owned()).or_default().remove(user_id); - self.last_typing_update.write().await.insert(room_id.to_owned(), services().globals.next_count()?); + self.typing + .write() + .await + .entry(room_id.to_owned()) + .or_default() + .remove(user_id); + self.last_typing_update + .write() + .await + .insert(room_id.to_owned(), services().globals.next_count()?); _ = self.typing_update_sender.send(room_id.to_owned()); Ok(()) } @@ -67,7 +83,10 @@ impl Service { for user in removable { room.remove(&user); } - self.last_typing_update.write().await.insert(room_id.to_owned(), services().globals.next_count()?); + self.last_typing_update + .write() + .await + .insert(room_id.to_owned(), services().globals.next_count()?); _ = self.typing_update_sender.send(room_id.to_owned()); } @@ -77,7 +96,13 @@ impl Service { /// Returns the count of the last typing update in this room. pub async fn last_typing_update(&self, room_id: &RoomId) -> Result { self.typings_maintain(room_id).await?; - Ok(self.last_typing_update.read().await.get(room_id).copied().unwrap_or(0)) + Ok(self + .last_typing_update + .read() + .await + .get(room_id) + .copied() + .unwrap_or(0)) } /// Returns a new typing EDU. diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index 67cee839..7e798cbf 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -125,8 +125,9 @@ impl Service { .first_pdu_in_room(room_id)? .ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?; - let (incoming_pdu, val) = - self.handle_outlier_pdu(origin, &create_event, event_id, room_id, value, false, pub_key_map).await?; + let (incoming_pdu, val) = self + .handle_outlier_pdu(origin, &create_event, event_id, room_id, value, false, pub_key_map) + .await?; self.check_room_id(room_id, &incoming_pdu)?; // 8. if not timeline event: stop @@ -167,7 +168,13 @@ impl Service { )); } - if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().await.get(&*prev_id) { + if let Some((time, tries)) = services() + .globals + .bad_event_ratelimiter + .read() + .await + .get(&*prev_id) + { // Exponential backoff let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { @@ -182,7 +189,13 @@ impl Service { if errors >= 5 { // Timeout other events - match services().globals.bad_event_ratelimiter.write().await.entry((*prev_id).to_owned()) { + match services() + .globals + .bad_event_ratelimiter + .write() + .await + .entry((*prev_id).to_owned()) + { hash_map::Entry::Vacant(e) => { e.insert((Instant::now(), 1)); }, @@ -207,12 +220,19 @@ impl Service { .await .insert(room_id.to_owned(), ((*prev_id).to_owned(), start_time)); - if let Err(e) = - self.upgrade_outlier_to_timeline_pdu(pdu, json, &create_event, origin, room_id, pub_key_map).await + if let Err(e) = self + .upgrade_outlier_to_timeline_pdu(pdu, json, &create_event, origin, room_id, pub_key_map) + .await { errors += 1; warn!("Prev event {} failed: {}", prev_id, e); - match services().globals.bad_event_ratelimiter.write().await.entry((*prev_id).to_owned()) { + match services() + .globals + .bad_event_ratelimiter + .write() + .await + .entry((*prev_id).to_owned()) + { hash_map::Entry::Vacant(e) => { e.insert((Instant::now(), 1)); }, @@ -222,7 +242,12 @@ impl Service { } } let elapsed = start_time.elapsed(); - services().globals.roomid_federationhandletime.write().await.remove(&room_id.to_owned()); + services() + .globals + .roomid_federationhandletime + .write() + .await + .remove(&room_id.to_owned()); debug!( "Handling prev event {} took {}m{}s", prev_id, @@ -246,7 +271,12 @@ impl Service { .event_handler .upgrade_outlier_to_timeline_pdu(incoming_pdu, val, &create_event, origin, room_id, pub_key_map) .await; - services().globals.roomid_federationhandletime.write().await.remove(&room_id.to_owned()); + services() + .globals + .roomid_federationhandletime + .write() + .await + .remove(&room_id.to_owned()); r } @@ -323,7 +353,11 @@ impl Service { debug!(event_id = ?incoming_pdu.event_id, "Fetching auth events"); self.fetch_and_handle_outliers( origin, - &incoming_pdu.auth_events.iter().map(|x| Arc::from(&**x)).collect::>(), + &incoming_pdu + .auth_events + .iter() + .map(|x| Arc::from(&**x)) + .collect::>(), create_event, room_id, room_version_id, @@ -351,7 +385,10 @@ impl Service { match auth_events.entry(( auth_event.kind.to_string().into(), - auth_event.state_key.clone().expect("all auth events have state keys"), + auth_event + .state_key + .clone() + .expect("all auth events have state keys"), )) { hash_map::Entry::Vacant(v) => { v.insert(auth_event); @@ -367,7 +404,9 @@ impl Service { // The original create event must be in the auth events if !matches!( - auth_events.get(&(StateEventType::RoomCreate, String::new())).map(AsRef::as_ref), + auth_events + .get(&(StateEventType::RoomCreate, String::new())) + .map(AsRef::as_ref), Some(_) | None ) { return Err(Error::BadRequest( @@ -390,7 +429,10 @@ impl Service { debug!("Validation successful."); // 7. Persist the event as an outlier. - services().rooms.outlier.add_pdu_outlier(&incoming_pdu.event_id, &val)?; + services() + .rooms + .outlier + .add_pdu_outlier(&incoming_pdu.event_id, &val)?; debug!("Added pdu as outlier."); @@ -407,7 +449,11 @@ impl Service { return Ok(Some(pduid)); } - if services().rooms.pdu_metadata.is_event_soft_failed(&incoming_pdu.event_id)? { + if services() + .rooms + .pdu_metadata + .is_event_soft_failed(&incoming_pdu.event_id)? + { return Err(Error::BadRequest(ErrorKind::InvalidParam, "Event has been soft failed")); } @@ -434,10 +480,19 @@ impl Service { if incoming_pdu.prev_events.len() == 1 { let prev_event = &*incoming_pdu.prev_events[0]; - let prev_event_sstatehash = services().rooms.state_accessor.pdu_shortstatehash(prev_event)?; + let prev_event_sstatehash = services() + .rooms + .state_accessor + .pdu_shortstatehash(prev_event)?; let state = if let Some(shortstatehash) = prev_event_sstatehash { - Some(services().rooms.state_accessor.state_full_ids(shortstatehash).await) + Some( + services() + .rooms + .state_accessor + .state_full_ids(shortstatehash) + .await, + ) } else { None }; @@ -477,7 +532,11 @@ impl Service { break; }; - let sstatehash = if let Ok(Some(s)) = services().rooms.state_accessor.pdu_shortstatehash(prev_eventid) { + let sstatehash = if let Ok(Some(s)) = services() + .rooms + .state_accessor + .pdu_shortstatehash(prev_eventid) + { s } else { okay = false; @@ -492,8 +551,11 @@ impl Service { let mut auth_chain_sets = Vec::with_capacity(extremity_sstatehashes.len()); for (sstatehash, prev_event) in extremity_sstatehashes { - let mut leaf_state: HashMap<_, _> = - services().rooms.state_accessor.state_full_ids(sstatehash).await?; + let mut leaf_state: HashMap<_, _> = services() + .rooms + .state_accessor + .state_full_ids(sstatehash) + .await?; if let Some(state_key) = &prev_event.state_key { let shortstatekey = services() @@ -518,8 +580,14 @@ impl Service { starting_events.push(id); } - auth_chain_sets - .push(services().rooms.auth_chain.get_auth_chain(room_id, starting_events).await?.collect()); + auth_chain_sets.push( + services() + .rooms + .auth_chain + .get_auth_chain(room_id, starting_events) + .await? + .collect(), + ); fork_states.push(state); } @@ -579,7 +647,11 @@ impl Service { Ok(res) => { debug!("Fetching state events at event."); - let collect = res.pdu_ids.iter().map(|x| Arc::from(&**x)).collect::>(); + let collect = res + .pdu_ids + .iter() + .map(|x| Arc::from(&**x)) + .collect::>(); let state_vec = self .fetch_and_handle_outliers( @@ -679,8 +751,15 @@ impl Service { // 13. Use state resolution to find new room state // We start looking at current room state now, so lets lock the room - let mutex_state = - Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.to_owned()).or_default()); + let mutex_state = Arc::clone( + services() + .globals + .roomid_mutex_state + .write() + .await + .entry(room_id.to_owned()) + .or_default(), + ); let state_lock = mutex_state.lock().await; // Now we calculate the set of extremities this room has after the incoming @@ -698,13 +777,26 @@ impl Service { } // Only keep those extremities were not referenced yet - extremities.retain(|id| !matches!(services().rooms.pdu_metadata.is_event_referenced(room_id, id), Ok(true))); + extremities.retain(|id| { + !matches!( + services() + .rooms + .pdu_metadata + .is_event_referenced(room_id, id), + Ok(true) + ) + }); debug!("Compressing state at event"); let state_ids_compressed = Arc::new( state_at_incoming_event .iter() - .map(|(shortstatekey, id)| services().rooms.state_compressor.compress_state_event(*shortstatekey, id)) + .map(|(shortstatekey, id)| { + services() + .rooms + .state_compressor + .compress_state_event(*shortstatekey, id) + }) .collect::>()?, ); @@ -722,14 +814,23 @@ impl Service { state_after.insert(shortstatekey, Arc::from(&*incoming_pdu.event_id)); } - let new_room_state = self.resolve_state(room_id, room_version_id, state_after).await?; + let new_room_state = self + .resolve_state(room_id, room_version_id, state_after) + .await?; // Set the new room state to the resolved state debug!("Forcing new room state"); - let (sstatehash, new, removed) = services().rooms.state_compressor.save_state(room_id, new_room_state)?; + let (sstatehash, new, removed) = services() + .rooms + .state_compressor + .save_state(room_id, new_room_state)?; - services().rooms.state.force_state(room_id, sstatehash, new, removed, &state_lock).await?; + services() + .rooms + .state + .force_state(room_id, sstatehash, new, removed, &state_lock) + .await?; } // 14. Check if the event passes auth based on the "current state" of the room, @@ -752,7 +853,10 @@ impl Service { // Soft fail, we keep the event as an outlier but don't add it to the timeline warn!("Event was soft failed: {:?}", incoming_pdu); - services().rooms.pdu_metadata.mark_event_soft_failed(&incoming_pdu.event_id)?; + services() + .rooms + .pdu_metadata + .mark_event_soft_failed(&incoming_pdu.event_id)?; return Err(Error::BadRequest(ErrorKind::InvalidParam, "Event has been soft failed")); } @@ -787,10 +891,17 @@ impl Service { &self, room_id: &RoomId, room_version_id: &RoomVersionId, incoming_state: HashMap>, ) -> Result>> { debug!("Loading current room state ids"); - let current_sstatehash = - services().rooms.state.get_room_shortstatehash(room_id)?.expect("every room has state"); + let current_sstatehash = services() + .rooms + .state + .get_room_shortstatehash(room_id)? + .expect("every room has state"); - let current_state_ids = services().rooms.state_accessor.state_full_ids(current_sstatehash).await?; + let current_state_ids = services() + .rooms + .state_accessor + .state_full_ids(current_sstatehash) + .await?; let fork_states = [current_state_ids, incoming_state]; @@ -852,9 +963,14 @@ impl Service { let new_room_state = state .into_iter() .map(|((event_type, state_key), event_id)| { - let shortstatekey = - services().rooms.short.get_or_create_shortstatekey(&event_type.to_string().into(), &state_key)?; - services().rooms.state_compressor.compress_state_event(shortstatekey, &event_id) + let shortstatekey = services() + .rooms + .short + .get_or_create_shortstatekey(&event_type.to_string().into(), &state_key)?; + services() + .rooms + .state_compressor + .compress_state_event(shortstatekey, &event_id) }) .collect::>()?; @@ -877,7 +993,13 @@ impl Service { ) -> AsyncRecursiveCanonicalJsonVec<'a> { Box::pin(async move { let back_off = |id| async { - match services().globals.bad_event_ratelimiter.write().await.entry(id) { + match services() + .globals + .bad_event_ratelimiter + .write() + .await + .entry(id) + { hash_map::Entry::Vacant(e) => { e.insert((Instant::now(), 1)); }, @@ -904,7 +1026,13 @@ impl Service { let mut events_all = HashSet::new(); let mut i = 0; while let Some(next_id) = todo_auth_events.pop() { - if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().await.get(&*next_id) { + if let Some((time, tries)) = services() + .globals + .bad_event_ratelimiter + .read() + .await + .get(&*next_id) + { // Exponential backoff let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { @@ -1010,7 +1138,13 @@ impl Service { pdus.push((local_pdu, None)); } for (next_id, value) in events_in_reverse_order.iter().rev() { - if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().await.get(&**next_id) { + if let Some((time, tries)) = services() + .globals + .bad_event_ratelimiter + .read() + .await + .get(&**next_id) + { // Exponential backoff let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { @@ -1087,9 +1221,14 @@ impl Service { continue; } - if let Some(json) = - json_opt.or_else(|| services().rooms.outlier.get_outlier_pdu_json(&prev_event_id).ok().flatten()) - { + if let Some(json) = json_opt.or_else(|| { + services() + .rooms + .outlier + .get_outlier_pdu_json(&prev_event_id) + .ok() + .flatten() + }) { if pdu.origin_server_ts > first_pdu_in_room.origin_server_ts { amount += 1; for prev_prev in &pdu.prev_events { @@ -1122,7 +1261,9 @@ impl Service { Ok(( int!(0), MilliSecondsSinceUnixEpoch( - eventid_info.get(event_id).map_or_else(|| uint!(0), |info| info.0.origin_server_ts), + eventid_info + .get(event_id) + .map_or_else(|| uint!(0), |info| info.0.origin_server_ts), ), )) }) @@ -1172,7 +1313,11 @@ impl Service { info!( "Fetch keys for {}", - server_key_ids.keys().cloned().collect::>().join(", ") + server_key_ids + .keys() + .cloned() + .collect::>() + .join(", ") ); let mut server_keys: FuturesUnordered<_> = server_key_ids @@ -1204,7 +1349,10 @@ impl Service { while let Some(fetch_res) = server_keys.next().await { match fetch_res { Ok((signature_server, keys)) => { - pub_key_map.write().await.insert(signature_server.clone(), keys); + pub_key_map + .write() + .await + .insert(signature_server.clone(), keys); }, Err((signature_server, e)) => { warn!("Failed to fetch keys for {}: {:?}", signature_server, e); @@ -1235,7 +1383,13 @@ impl Service { ); let event_id = <&EventId>::try_from(event_id.as_str()).expect("ruma's reference hashes are valid event ids"); - if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().await.get(event_id) { + if let Some((time, tries)) = services() + .globals + .bad_event_ratelimiter + .read() + .await + .get(event_id) + { // Exponential backoff let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { @@ -1275,8 +1429,12 @@ impl Service { debug!("Loading signing keys for {}", origin); - let result: BTreeMap<_, _> = - services().globals.signing_keys_for(origin)?.into_iter().map(|(k, v)| (k.to_string(), v.key)).collect(); + let result: BTreeMap<_, _> = services() + .globals + .signing_keys_for(origin)? + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)) + .collect(); if !contains_all_ids(&result) { debug!("Signing key not loaded for {}", origin); @@ -1353,7 +1511,10 @@ impl Service { .into_keys() .map(|server| async move { ( - services().sending.send_federation_request(&server, get_server_keys::v2::Request::new()).await, + services() + .sending + .send_federation_request(&server, get_server_keys::v2::Request::new()) + .await, server, ) }) @@ -1391,10 +1552,14 @@ impl Service { // Try to fetch keys, failure is okay // Servers we couldn't find in the cache will be added to `servers` for pdu in &event.room_state.state { - _ = self.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm).await; + _ = self + .get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm) + .await; } for pdu in &event.room_state.auth_chain { - _ = self.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm).await; + _ = self + .get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm) + .await; } drop(pkm); @@ -1411,7 +1576,8 @@ impl Service { homeserver signing keys." ); - self.batch_request_signing_keys(servers.clone(), pub_key_map).await?; + self.batch_request_signing_keys(servers.clone(), pub_key_map) + .await?; if servers.is_empty() { info!("Trusted server supplied all signing keys, no more keys to fetch"); @@ -1420,11 +1586,13 @@ impl Service { info!("Remaining servers left that the notary/trusted servers did not provide: {servers:?}"); - self.request_signing_keys(servers.clone(), pub_key_map).await?; + self.request_signing_keys(servers.clone(), pub_key_map) + .await?; } else { info!("query_trusted_key_servers_first is set to false, querying individual homeservers first"); - self.request_signing_keys(servers.clone(), pub_key_map).await?; + self.request_signing_keys(servers.clone(), pub_key_map) + .await?; if servers.is_empty() { info!("Individual homeservers supplied all signing keys, no more keys to fetch"); @@ -1433,7 +1601,8 @@ impl Service { info!("Remaining servers left the individual homeservers did not provide: {servers:?}"); - self.batch_request_signing_keys(servers.clone(), pub_key_map).await?; + self.batch_request_signing_keys(servers.clone(), pub_key_map) + .await?; } info!("Search for signing keys done"); @@ -1448,7 +1617,11 @@ impl Service { /// Returns Ok if the acl allows the server pub fn acl_check(&self, server_name: &ServerName, room_id: &RoomId) -> Result<()> { let acl_event = - match services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomServerAcl, "")? { + match services() + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomServerAcl, "")? + { Some(acl) => { debug!("ACL event found: {acl:?}"); acl @@ -1493,21 +1666,36 @@ impl Service { ) -> Result> { let contains_all_ids = |keys: &BTreeMap| signature_ids.iter().all(|id| keys.contains_key(id)); - let permit = - services().globals.servername_ratelimiter.read().await.get(origin).map(|s| Arc::clone(s).acquire_owned()); + let permit = services() + .globals + .servername_ratelimiter + .read() + .await + .get(origin) + .map(|s| Arc::clone(s).acquire_owned()); let permit = if let Some(p) = permit { p } else { let mut write = services().globals.servername_ratelimiter.write().await; - let s = Arc::clone(write.entry(origin.to_owned()).or_insert_with(|| Arc::new(Semaphore::new(1)))); + let s = Arc::clone( + write + .entry(origin.to_owned()) + .or_insert_with(|| Arc::new(Semaphore::new(1))), + ); s.acquire_owned() } .await; let back_off = |id| async { - match services().globals.bad_signature_ratelimiter.write().await.entry(id) { + match services() + .globals + .bad_signature_ratelimiter + .write() + .await + .entry(id) + { hash_map::Entry::Vacant(e) => { e.insert((Instant::now(), 1)); }, @@ -1515,7 +1703,13 @@ impl Service { } }; - if let Some((time, tries)) = services().globals.bad_signature_ratelimiter.read().await.get(&signature_ids) { + if let Some((time, tries)) = services() + .globals + .bad_signature_ratelimiter + .read() + .await + .get(&signature_ids) + { // Exponential backoff let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { @@ -1530,8 +1724,12 @@ impl Service { debug!("Loading signing keys for {origin} from our database if available"); - let mut result: BTreeMap<_, _> = - services().globals.signing_keys_for(origin)?.into_iter().map(|(k, v)| (k.to_string(), v.key)).collect(); + let mut result: BTreeMap<_, _> = services() + .globals + .signing_keys_for(origin)? + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)) + .collect(); if contains_all_ids(&result) { debug!("We have all homeserver signing keys locally for {origin}, not fetching any remotely"); @@ -1554,20 +1752,34 @@ impl Service { get_remote_server_keys::v2::Request::new( origin.to_owned(), MilliSecondsSinceUnixEpoch::from_system_time( - SystemTime::now().checked_add(Duration::from_secs(3600)).expect("SystemTime too large"), + SystemTime::now() + .checked_add(Duration::from_secs(3600)) + .expect("SystemTime too large"), ) .expect("time is valid"), ), ) .await .ok() - .map(|resp| resp.server_keys.into_iter().filter_map(|e| e.deserialize().ok()).collect::>()) - { + .map(|resp| { + resp.server_keys + .into_iter() + .filter_map(|e| e.deserialize().ok()) + .collect::>() + }) { debug!("Got signing keys: {:?}", server_keys); for k in server_keys { services().globals.add_signing_key(origin, k.clone())?; - result.extend(k.verify_keys.into_iter().map(|(k, v)| (k.to_string(), v.key))); - result.extend(k.old_verify_keys.into_iter().map(|(k, v)| (k.to_string(), v.key))); + result.extend( + k.verify_keys + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)), + ); + result.extend( + k.old_verify_keys + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)), + ); } if contains_all_ids(&result) { @@ -1584,10 +1796,22 @@ impl Service { .ok() .and_then(|resp| resp.server_key.deserialize().ok()) { - services().globals.add_signing_key(origin, server_key.clone())?; + services() + .globals + .add_signing_key(origin, server_key.clone())?; - result.extend(server_key.verify_keys.into_iter().map(|(k, v)| (k.to_string(), v.key))); - result.extend(server_key.old_verify_keys.into_iter().map(|(k, v)| (k.to_string(), v.key))); + result.extend( + server_key + .verify_keys + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)), + ); + result.extend( + server_key + .old_verify_keys + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)), + ); if contains_all_ids(&result) { return Ok(result); @@ -1604,10 +1828,22 @@ impl Service { .ok() .and_then(|resp| resp.server_key.deserialize().ok()) { - services().globals.add_signing_key(origin, server_key.clone())?; + services() + .globals + .add_signing_key(origin, server_key.clone())?; - result.extend(server_key.verify_keys.into_iter().map(|(k, v)| (k.to_string(), v.key))); - result.extend(server_key.old_verify_keys.into_iter().map(|(k, v)| (k.to_string(), v.key))); + result.extend( + server_key + .verify_keys + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)), + ); + result.extend( + server_key + .old_verify_keys + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)), + ); if contains_all_ids(&result) { return Ok(result); @@ -1623,20 +1859,34 @@ impl Service { get_remote_server_keys::v2::Request::new( origin.to_owned(), MilliSecondsSinceUnixEpoch::from_system_time( - SystemTime::now().checked_add(Duration::from_secs(3600)).expect("SystemTime too large"), + SystemTime::now() + .checked_add(Duration::from_secs(3600)) + .expect("SystemTime too large"), ) .expect("time is valid"), ), ) .await .ok() - .map(|resp| resp.server_keys.into_iter().filter_map(|e| e.deserialize().ok()).collect::>()) - { + .map(|resp| { + resp.server_keys + .into_iter() + .filter_map(|e| e.deserialize().ok()) + .collect::>() + }) { debug!("Got signing keys: {:?}", server_keys); for k in server_keys { services().globals.add_signing_key(origin, k.clone())?; - result.extend(k.verify_keys.into_iter().map(|(k, v)| (k.to_string(), v.key))); - result.extend(k.old_verify_keys.into_iter().map(|(k, v)| (k.to_string(), v.key))); + result.extend( + k.verify_keys + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)), + ); + result.extend( + k.old_verify_keys + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)), + ); } if contains_all_ids(&result) { diff --git a/src/service/rooms/lazy_loading/mod.rs b/src/service/rooms/lazy_loading/mod.rs index c94eb92b..13a45987 100644 --- a/src/service/rooms/lazy_loading/mod.rs +++ b/src/service/rooms/lazy_loading/mod.rs @@ -20,7 +20,8 @@ impl Service { pub fn lazy_load_was_sent_before( &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId, ) -> Result { - self.db.lazy_load_was_sent_before(user_id, device_id, room_id, ll_user) + self.db + .lazy_load_was_sent_before(user_id, device_id, room_id, ll_user) } #[tracing::instrument(skip(self))] @@ -44,7 +45,8 @@ impl Service { room_id.to_owned(), since, )) { - self.db.lazy_load_confirm_delivery(user_id, device_id, room_id, &mut user_ids.iter().map(|u| &**u))?; + self.db + .lazy_load_confirm_delivery(user_id, device_id, room_id, &mut user_ids.iter().map(|u| &**u))?; } else { // Ignore } diff --git a/src/service/rooms/pdu_metadata/mod.rs b/src/service/rooms/pdu_metadata/mod.rs index 1b82146c..5feafbaf 100644 --- a/src/service/rooms/pdu_metadata/mod.rs +++ b/src/service/rooms/pdu_metadata/mod.rs @@ -126,8 +126,10 @@ impl Service { next_token = events_before.last().map(|(count, _)| count).copied(); - let events_before: Vec<_> = - events_before.into_iter().map(|(_, pdu)| pdu.to_message_like_event()).collect(); + let events_before: Vec<_> = events_before + .into_iter() + .map(|(_, pdu)| pdu.to_message_like_event()) + .collect(); Ok(get_relating_events::v1::Response { chunk: events_before, diff --git a/src/service/rooms/spaces/mod.rs b/src/service/rooms/spaces/mod.rs index ec8d5914..bea0aa5b 100644 --- a/src/service/rooms/spaces/mod.rs +++ b/src/service/rooms/spaces/mod.rs @@ -148,7 +148,13 @@ impl Arena { )]; while let Some(parent) = self.parent(parents.last().expect("Has at least one value, as above").0) { - parents.push((parent, self.get(parent).expect("It is some, as above").room_id.clone())); + parents.push(( + parent, + self.get(parent) + .expect("It is some, as above") + .room_id + .clone(), + )); } // If at max_depth, don't add new rooms @@ -178,7 +184,10 @@ impl Arena { } if self.first_untraversed.is_none() - || parent >= self.first_untraversed.expect("Should have already continued if none") + || parent + >= self + .first_untraversed + .expect("Should have already continued if none") { self.first_untraversed = next_id; } @@ -187,7 +196,9 @@ impl Arena { // This is done as if we use an if-let above, we cannot reference self.nodes // above as then we would have multiple mutable references - let node = self.get_mut(parent).expect("Must be some, as inside this block"); + let node = self + .get_mut(parent) + .expect("Must be some, as inside this block"); node.first_child = next_id; } @@ -324,7 +335,10 @@ impl Service { pub async fn get_federation_hierarchy( &self, room_id: &RoomId, server_name: &ServerName, suggested_only: bool, ) -> Result { - match self.get_summary_and_children(&room_id.to_owned(), suggested_only, Identifier::None).await? { + match self + .get_summary_and_children(&room_id.to_owned(), suggested_only, Identifier::None) + .await? + { Some(SummaryAccessibility::Accessible(room)) => { let mut children = Vec::new(); let mut inaccessible_children = Vec::new(); @@ -360,7 +374,13 @@ impl Service { async fn get_summary_and_children( &self, current_room: &OwnedRoomId, suggested_only: bool, identifier: Identifier<'_>, ) -> Result> { - if let Some(cached) = self.roomid_spacehierarchy_cache.lock().await.get_mut(¤t_room.to_owned()).as_ref() { + if let Some(cached) = self + .roomid_spacehierarchy_cache + .lock() + .await + .get_mut(¤t_room.to_owned()) + .as_ref() + { return Ok(if let Some(cached) = cached { if is_accessable_child( current_room, @@ -395,7 +415,9 @@ impl Service { // Federation requests should not request information from other // servers } else if let Identifier::UserId(_) = identifier { - let server = current_room.server_name().expect("Room IDs should always have a server name"); + let server = current_room + .server_name() + .expect("Room IDs should always have a server name"); if server == services().globals.server_name() { return Ok(None); } @@ -473,7 +495,10 @@ impl Service { Some(SummaryAccessibility::Inaccessible) } } else { - self.roomid_spacehierarchy_cache.lock().await.insert(current_room.clone(), None); + self.roomid_spacehierarchy_cache + .lock() + .await + .insert(current_room.clone(), None); None } @@ -494,10 +519,12 @@ impl Service { .state_accessor .room_state_get(room_id, &StateEventType::RoomJoinRules, "")? .map(|s| { - serde_json::from_str(s.content.get()).map(|c: RoomJoinRulesEventContent| c.join_rule).map_err(|e| { - error!("Invalid room join rule event in database: {}", e); - Error::BadDatabase("Invalid room join rule event in database.") - }) + serde_json::from_str(s.content.get()) + .map(|c: RoomJoinRulesEventContent| c.join_rule) + .map_err(|e| { + error!("Invalid room join rule event in database: {}", e); + Error::BadDatabase("Invalid room join rule event in database.") + }) }) .transpose()? .unwrap_or(JoinRule::Invite); @@ -534,15 +561,18 @@ impl Service { .try_into() .expect("user count should not be that big"), room_id: room_id.to_owned(), - topic: services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomTopic, "")?.map_or( - Ok(None), - |s| { - serde_json::from_str(s.content.get()).map(|c: RoomTopicEventContent| Some(c.topic)).map_err(|_| { - error!("Invalid room topic event in database for room {}", room_id); - Error::bad_database("Invalid room topic event in database.") - }) - }, - )?, + topic: services() + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomTopic, "")? + .map_or(Ok(None), |s| { + serde_json::from_str(s.content.get()) + .map(|c: RoomTopicEventContent| Some(c.topic)) + .map_err(|_| { + error!("Invalid room topic event in database for room {}", room_id); + Error::bad_database("Invalid room topic event in database.") + }) + })?, world_readable: world_readable(room_id)?, guest_can_join: guest_can_join(room_id)?, avatar_url: services() @@ -588,7 +618,9 @@ impl Service { let mut arena = Arena::new(summary.room_id.clone(), max_depth); let mut results = Vec::new(); - let root = arena.first_untraversed().expect("The node just added is not traversed"); + let root = arena + .first_untraversed() + .expect("The node just added is not traversed"); arena.push(root, get_parent_children(&summary.clone(), suggested_only)); results.push(summary_to_chunk(*summary.clone())); @@ -597,7 +629,10 @@ impl Service { if limit > results.len() { if let Some(SummaryAccessibility::Accessible(summary)) = self .get_summary_and_children( - &arena.get(current_room).expect("We added this node, it must exist").room_id, + &arena + .get(current_room) + .expect("We added this node, it must exist") + .room_id, suggested_only, Identifier::UserId(sender_user), ) @@ -651,7 +686,11 @@ async fn get_stripped_space_child_events( room_id: &RoomId, ) -> Result>>, Error> { if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? { - let state = services().rooms.state_accessor.state_full_ids(current_shortstatehash).await?; + let state = services() + .rooms + .state_accessor + .state_full_ids(current_shortstatehash) + .await?; let mut children_pdus = Vec::new(); for (key, id) in state { let (event_type, state_key) = services().rooms.short.get_statekey_from_short(key)?; @@ -702,13 +741,24 @@ fn is_accessable_child_recurse( let room_id: &RoomId = current_room; // Checks if ACLs allow for the server to participate - if services().rooms.event_handler.acl_check(server_name, room_id).is_err() { + if services() + .rooms + .event_handler + .acl_check(server_name, room_id) + .is_err() + { return Ok(false); } }, Identifier::UserId(user_id) => { - if services().rooms.state_cache.is_joined(user_id, current_room)? - || services().rooms.state_cache.is_invited(user_id, current_room)? + if services() + .rooms + .state_cache + .is_joined(user_id, current_room)? + || services() + .rooms + .state_cache + .is_invited(user_id, current_room)? { return Ok(true); } @@ -746,26 +796,28 @@ fn is_accessable_child_recurse( /// Checks if guests are able to join a given room fn guest_can_join(room_id: &RoomId) -> Result { - services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomGuestAccess, "")?.map_or( - Ok(false), - |s| { + services() + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomGuestAccess, "")? + .map_or(Ok(false), |s| { serde_json::from_str(s.content.get()) .map(|c: RoomGuestAccessEventContent| c.guest_access == GuestAccess::CanJoin) .map_err(|_| Error::bad_database("Invalid room guest access event in database.")) - }, - ) + }) } /// Checks if guests are able to view room content without joining fn world_readable(room_id: &RoomId) -> Result { - services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "")?.map_or( - Ok(false), - |s| { + services() + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "")? + .map_or(Ok(false), |s| { serde_json::from_str(s.content.get()) .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility == HistoryVisibility::WorldReadable) .map_err(|_| Error::bad_database("Invalid room history visibility event in database.")) - }, - ) + }) } /// Returns the join rule for a given room diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index db2c1921..580c5bad 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -36,7 +36,12 @@ impl Service { state_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex ) -> Result<()> { for event_id in statediffnew.iter().filter_map(|new| { - services().rooms.state_compressor.parse_compressed_state_event(new).ok().map(|(_, id)| id) + services() + .rooms + .state_compressor + .parse_compressed_state_event(new) + .ok() + .map(|(_, id)| id) }) { let pdu = match services().rooms.timeline.get_pdu_json(&event_id)? { Some(pdu) => pdu, @@ -74,7 +79,13 @@ impl Service { .await?; }, TimelineEventType::SpaceChild => { - services().rooms.spaces.roomid_spacehierarchy_cache.lock().await.remove(&pdu.room_id); + services() + .rooms + .spaces + .roomid_spacehierarchy_cache + .lock() + .await + .remove(&pdu.room_id); }, _ => continue, } @@ -82,7 +93,8 @@ impl Service { services().rooms.state_cache.update_joined_count(room_id)?; - self.db.set_room_state(room_id, shortstatehash, state_lock)?; + self.db + .set_room_state(room_id, shortstatehash, state_lock)?; Ok(()) } @@ -95,25 +107,47 @@ impl Service { pub fn set_event_state( &self, event_id: &EventId, room_id: &RoomId, state_ids_compressed: Arc>, ) -> Result { - let shorteventid = services().rooms.short.get_or_create_shorteventid(event_id)?; + let shorteventid = services() + .rooms + .short + .get_or_create_shorteventid(event_id)?; let previous_shortstatehash = self.db.get_room_shortstatehash(room_id)?; - let state_hash = calculate_hash(&state_ids_compressed.iter().map(|s| &s[..]).collect::>()); + let state_hash = calculate_hash( + &state_ids_compressed + .iter() + .map(|s| &s[..]) + .collect::>(), + ); - let (shortstatehash, already_existed) = services().rooms.short.get_or_create_shortstatehash(&state_hash)?; + let (shortstatehash, already_existed) = services() + .rooms + .short + .get_or_create_shortstatehash(&state_hash)?; if !already_existed { let states_parents = previous_shortstatehash.map_or_else( || Ok(Vec::new()), - |p| services().rooms.state_compressor.load_shortstatehash_info(p), + |p| { + services() + .rooms + .state_compressor + .load_shortstatehash_info(p) + }, )?; let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { - let statediffnew: HashSet<_> = state_ids_compressed.difference(&parent_stateinfo.1).copied().collect(); + let statediffnew: HashSet<_> = state_ids_compressed + .difference(&parent_stateinfo.1) + .copied() + .collect(); - let statediffremoved: HashSet<_> = - parent_stateinfo.1.difference(&state_ids_compressed).copied().collect(); + let statediffremoved: HashSet<_> = parent_stateinfo + .1 + .difference(&state_ids_compressed) + .copied() + .collect(); (Arc::new(statediffnew), Arc::new(statediffremoved)) } else { @@ -139,7 +173,10 @@ impl Service { /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. #[tracing::instrument(skip(self, new_pdu))] pub fn append_to_state(&self, new_pdu: &PduEvent) -> Result { - let shorteventid = services().rooms.short.get_or_create_shorteventid(&new_pdu.event_id)?; + let shorteventid = services() + .rooms + .short + .get_or_create_shorteventid(&new_pdu.event_id)?; let previous_shortstatehash = self.get_room_shortstatehash(&new_pdu.room_id)?; @@ -150,17 +187,31 @@ impl Service { if let Some(state_key) = &new_pdu.state_key { let states_parents = previous_shortstatehash.map_or_else( || Ok(Vec::new()), - |p| services().rooms.state_compressor.load_shortstatehash_info(p), + |p| { + services() + .rooms + .state_compressor + .load_shortstatehash_info(p) + }, )?; - let shortstatekey = - services().rooms.short.get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key)?; + let shortstatekey = services() + .rooms + .short + .get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key)?; - let new = services().rooms.state_compressor.compress_state_event(shortstatekey, &new_pdu.event_id)?; + let new = services() + .rooms + .state_compressor + .compress_state_event(shortstatekey, &new_pdu.event_id)?; let replaces = states_parents .last() - .map(|info| info.1.iter().find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))) + .map(|info| { + info.1 + .iter() + .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) + }) .unwrap_or_default(); if Some(&new) == replaces { @@ -197,12 +248,18 @@ impl Service { let mut state = Vec::new(); // Add recommended events if let Some(e) = - services().rooms.state_accessor.room_state_get(&invite_event.room_id, &StateEventType::RoomCreate, "")? + services() + .rooms + .state_accessor + .room_state_get(&invite_event.room_id, &StateEventType::RoomCreate, "")? { state.push(e.to_stripped_state_event()); } if let Some(e) = - services().rooms.state_accessor.room_state_get(&invite_event.room_id, &StateEventType::RoomJoinRules, "")? + services() + .rooms + .state_accessor + .room_state_get(&invite_event.room_id, &StateEventType::RoomJoinRules, "")? { state.push(e.to_stripped_state_event()); } @@ -214,12 +271,18 @@ impl Service { state.push(e.to_stripped_state_event()); } if let Some(e) = - services().rooms.state_accessor.room_state_get(&invite_event.room_id, &StateEventType::RoomAvatar, "")? + services() + .rooms + .state_accessor + .room_state_get(&invite_event.room_id, &StateEventType::RoomAvatar, "")? { state.push(e.to_stripped_state_event()); } if let Some(e) = - services().rooms.state_accessor.room_state_get(&invite_event.room_id, &StateEventType::RoomName, "")? + services() + .rooms + .state_accessor + .room_state_get(&invite_event.room_id, &StateEventType::RoomName, "")? { state.push(e.to_stripped_state_event()); } @@ -249,7 +312,10 @@ impl Service { /// Returns the room's version. #[tracing::instrument(skip(self))] pub fn get_room_version(&self, room_id: &RoomId) -> Result { - let create_event = services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomCreate, "")?; + let create_event = services() + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomCreate, "")?; let create_event_content: RoomCreateEventContent = create_event .as_ref() @@ -279,7 +345,8 @@ impl Service { event_ids: Vec, state_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex ) -> Result<()> { - self.db.set_forward_extremities(room_id, event_ids, state_lock) + self.db + .set_forward_extremities(room_id, event_ids, state_lock) } /// This fetches auth events from the current state. @@ -321,9 +388,23 @@ impl Service { Ok(full_state .iter() - .filter_map(|compressed| services().rooms.state_compressor.parse_compressed_state_event(compressed).ok()) + .filter_map(|compressed| { + services() + .rooms + .state_compressor + .parse_compressed_state_event(compressed) + .ok() + }) .filter_map(|(shortstatekey, event_id)| sauthevents.remove(&shortstatekey).map(|k| (k, event_id))) - .filter_map(|(k, event_id)| services().rooms.timeline.get_pdu(&event_id).ok().flatten().map(|pdu| (k, pdu))) + .filter_map(|(k, event_id)| { + services() + .rooms + .timeline + .get_pdu(&event_id) + .ok() + .flatten() + .map(|pdu| (k, pdu)) + }) .collect()) } } diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index af0ab238..719b943d 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -59,19 +59,18 @@ impl Service { /// Get membership for given user in state fn user_membership(&self, shortstatehash: u64, user_id: &UserId) -> Result { - self.state_get(shortstatehash, &StateEventType::RoomMember, user_id.as_str())?.map_or( - Ok(MembershipState::Leave), - |s| { + self.state_get(shortstatehash, &StateEventType::RoomMember, user_id.as_str())? + .map_or(Ok(MembershipState::Leave), |s| { serde_json::from_str(s.content.get()) .map(|c: RoomMemberEventContent| c.membership) .map_err(|_| Error::bad_database("Invalid room membership event in database.")) - }, - ) + }) } /// The user was a joined member at this state (potentially in the past) fn user_was_joined(&self, shortstatehash: u64, user_id: &UserId) -> bool { - self.user_membership(shortstatehash, user_id).is_ok_and(|s| s == MembershipState::Join) + self.user_membership(shortstatehash, user_id) + .is_ok_and(|s| s == MembershipState::Join) // Return sensible default, i.e. // false } @@ -92,20 +91,22 @@ impl Service { return Ok(true); }; - if let Some(visibility) = - self.server_visibility_cache.lock().unwrap().get_mut(&(origin.to_owned(), shortstatehash)) + if let Some(visibility) = self + .server_visibility_cache + .lock() + .unwrap() + .get_mut(&(origin.to_owned(), shortstatehash)) { return Ok(*visibility); } - let history_visibility = self.state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")?.map_or( - Ok(HistoryVisibility::Shared), - |s| { + let history_visibility = self + .state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")? + .map_or(Ok(HistoryVisibility::Shared), |s| { serde_json::from_str(s.content.get()) .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility) .map_err(|_| Error::bad_database("Invalid history visibility event in database.")) - }, - )?; + })?; let mut current_server_members = services() .rooms @@ -130,7 +131,10 @@ impl Service { }, }; - self.server_visibility_cache.lock().unwrap().insert((origin.to_owned(), shortstatehash), visibility); + self.server_visibility_cache + .lock() + .unwrap() + .insert((origin.to_owned(), shortstatehash), visibility); Ok(visibility) } @@ -144,22 +148,24 @@ impl Service { None => return Ok(true), }; - if let Some(visibility) = - self.user_visibility_cache.lock().unwrap().get_mut(&(user_id.to_owned(), shortstatehash)) + if let Some(visibility) = self + .user_visibility_cache + .lock() + .unwrap() + .get_mut(&(user_id.to_owned(), shortstatehash)) { return Ok(*visibility); } let currently_member = services().rooms.state_cache.is_joined(user_id, room_id)?; - let history_visibility = self.state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")?.map_or( - Ok(HistoryVisibility::Shared), - |s| { + let history_visibility = self + .state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")? + .map_or(Ok(HistoryVisibility::Shared), |s| { serde_json::from_str(s.content.get()) .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility) .map_err(|_| Error::bad_database("Invalid history visibility event in database.")) - }, - )?; + })?; let visibility = match history_visibility { HistoryVisibility::WorldReadable => true, @@ -178,7 +184,10 @@ impl Service { }, }; - self.user_visibility_cache.lock().unwrap().insert((user_id.to_owned(), shortstatehash), visibility); + self.user_visibility_cache + .lock() + .unwrap() + .insert((user_id.to_owned(), shortstatehash), visibility); Ok(visibility) } @@ -189,17 +198,16 @@ impl Service { pub fn user_can_see_state_events(&self, user_id: &UserId, room_id: &RoomId) -> Result { let currently_member = services().rooms.state_cache.is_joined(user_id, room_id)?; - let history_visibility = self.room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "")?.map_or( - Ok(HistoryVisibility::Shared), - |s| { + let history_visibility = self + .room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "")? + .map_or(Ok(HistoryVisibility::Shared), |s| { serde_json::from_str(s.content.get()) .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility) .map_err(|e| { error!("Invalid history visibility event in database for room {}: {e}", &room_id); Error::bad_database("Invalid history visibility event in database.") }) - }, - )?; + })?; Ok(currently_member || history_visibility == HistoryVisibility::WorldReadable) } @@ -232,28 +240,34 @@ impl Service { } pub fn get_name(&self, room_id: &RoomId) -> Result> { - services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomName, "")?.map_or(Ok(None), |s| { - Ok(serde_json::from_str(s.content.get()).map_or_else(|_| None, |c: RoomNameEventContent| Some(c.name))) - }) + services() + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomName, "")? + .map_or(Ok(None), |s| { + Ok(serde_json::from_str(s.content.get()).map_or_else(|_| None, |c: RoomNameEventContent| Some(c.name))) + }) } pub fn get_avatar(&self, room_id: &RoomId) -> Result> { - services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomAvatar, "")?.map_or( - Ok(ruma::JsOption::Undefined), - |s| { + services() + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomAvatar, "")? + .map_or(Ok(ruma::JsOption::Undefined), |s| { serde_json::from_str(s.content.get()) .map_err(|_| Error::bad_database("Invalid room avatar event in database.")) - }, - ) + }) } pub fn get_member(&self, room_id: &RoomId, user_id: &UserId) -> Result> { - services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())?.map_or( - Ok(None), - |s| { + services() + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())? + .map_or(Ok(None), |s| { serde_json::from_str(s.content.get()) .map_err(|_| Error::bad_database("Invalid room member event in database.")) - }, - ) + }) } } diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index 65e45926..2a0e492e 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -165,7 +165,9 @@ impl Service { .get( None, // Ignored users are in global account data user_id, // Receiver - GlobalAccountDataEventType::IgnoredUserList.to_string().into(), + GlobalAccountDataEventType::IgnoredUserList + .to_string() + .into(), )? .map(|event| { serde_json::from_str::(event.get()).map_err(|e| { @@ -175,7 +177,11 @@ impl Service { }) .transpose()? .map_or(false, |ignored| { - ignored.content.ignored_users.iter().any(|(user, _details)| user == sender) + ignored + .content + .ignored_users + .iter() + .any(|(user, _details)| user == sender) }); if is_ignored { diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index 2570d325..36252897 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -55,7 +55,12 @@ impl Service { /// removed diff for the selected shortstatehash and each parent layer. #[tracing::instrument(skip(self))] pub fn load_shortstatehash_info(&self, shortstatehash: u64) -> ShortStateInfoResult { - if let Some(r) = self.stateinfo_cache.lock().unwrap().get_mut(&shortstatehash) { + if let Some(r) = self + .stateinfo_cache + .lock() + .unwrap() + .get_mut(&shortstatehash) + { return Ok(r.clone()); } @@ -76,19 +81,31 @@ impl Service { response.push((shortstatehash, Arc::new(state), added, Arc::new(removed))); - self.stateinfo_cache.lock().unwrap().insert(shortstatehash, response.clone()); + self.stateinfo_cache + .lock() + .unwrap() + .insert(shortstatehash, response.clone()); Ok(response) } else { let response = vec![(shortstatehash, added.clone(), added, removed)]; - self.stateinfo_cache.lock().unwrap().insert(shortstatehash, response.clone()); + self.stateinfo_cache + .lock() + .unwrap() + .insert(shortstatehash, response.clone()); Ok(response) } } pub fn compress_state_event(&self, shortstatekey: u64, event_id: &EventId) -> Result { let mut v = shortstatekey.to_be_bytes().to_vec(); - v.extend_from_slice(&services().rooms.short.get_or_create_shorteventid(event_id)?.to_be_bytes()); + v.extend_from_slice( + &services() + .rooms + .short + .get_or_create_shorteventid(event_id)? + .to_be_bytes(), + ); Ok(v.try_into().expect("we checked the size above")) } @@ -238,10 +255,17 @@ impl Service { ) -> HashSetCompressStateEvent { let previous_shortstatehash = services().rooms.state.get_room_shortstatehash(room_id)?; - let state_hash = - utils::calculate_hash(&new_state_ids_compressed.iter().map(|bytes| &bytes[..]).collect::>()); + let state_hash = utils::calculate_hash( + &new_state_ids_compressed + .iter() + .map(|bytes| &bytes[..]) + .collect::>(), + ); - let (new_shortstatehash, already_existed) = services().rooms.short.get_or_create_shortstatehash(&state_hash)?; + let (new_shortstatehash, already_existed) = services() + .rooms + .short + .get_or_create_shortstatehash(&state_hash)?; if Some(new_shortstatehash) == previous_shortstatehash { return Ok((new_shortstatehash, Arc::new(HashSet::new()), Arc::new(HashSet::new()))); @@ -251,10 +275,16 @@ impl Service { previous_shortstatehash.map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?; let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { - let statediffnew: HashSet<_> = new_state_ids_compressed.difference(&parent_stateinfo.1).copied().collect(); + let statediffnew: HashSet<_> = new_state_ids_compressed + .difference(&parent_stateinfo.1) + .copied() + .collect(); - let statediffremoved: HashSet<_> = - parent_stateinfo.1.difference(&new_state_ids_compressed).copied().collect(); + let statediffremoved: HashSet<_> = parent_stateinfo + .1 + .difference(&new_state_ids_compressed) + .copied() + .collect(); (Arc::new(statediffnew), Arc::new(statediffremoved)) } else { diff --git a/src/service/rooms/threads/mod.rs b/src/service/rooms/threads/mod.rs index fd9e8b93..e4f7b5dd 100644 --- a/src/service/rooms/threads/mod.rs +++ b/src/service/rooms/threads/mod.rs @@ -60,7 +60,9 @@ impl Service { unsigned.insert( "m.relations".to_owned(), - json!({ "m.thread": content }).try_into().expect("thread is valid json"), + json!({ "m.thread": content }) + .try_into() + .expect("thread is valid json"), ); } else { // New thread @@ -74,11 +76,16 @@ impl Service { unsigned.insert( "m.relations".to_owned(), - json!({ "m.thread": content }).try_into().expect("thread is valid json"), + json!({ "m.thread": content }) + .try_into() + .expect("thread is valid json"), ); } - services().rooms.timeline.replace_pdu(root_id, &root_pdu_json, &root_pdu)?; + services() + .rooms + .timeline + .replace_pdu(root_id, &root_pdu_json, &root_pdu)?; } let mut users = Vec::new(); diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 50ee6500..efad5be2 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -152,7 +152,10 @@ impl Service { /// Returns the version of a room, if known pub fn get_room_version(&self, room_id: &RoomId) -> Result> { - let create_event = services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomCreate, "")?; + let create_event = services() + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomCreate, "")?; let create_event_content: Option = create_event .as_ref() @@ -225,16 +228,25 @@ impl Service { // Coalesce database writes for the remainder of this scope. let _cork = services().globals.db.cork_and_flush()?; - let shortroomid = services().rooms.short.get_shortroomid(&pdu.room_id)?.expect("room exists"); + let shortroomid = services() + .rooms + .short + .get_shortroomid(&pdu.room_id)? + .expect("room exists"); // Make unsigned fields correct. This is not properly documented in the spec, // but state events need to have previous content in the unsigned field, so // clients can easily interpret things like membership changes if let Some(state_key) = &pdu.state_key { - if let CanonicalJsonValue::Object(unsigned) = - pdu_json.entry("unsigned".to_owned()).or_insert_with(|| CanonicalJsonValue::Object(BTreeMap::default())) + if let CanonicalJsonValue::Object(unsigned) = pdu_json + .entry("unsigned".to_owned()) + .or_insert_with(|| CanonicalJsonValue::Object(BTreeMap::default())) { - if let Some(shortstatehash) = services().rooms.state_accessor.pdu_shortstatehash(&pdu.event_id).unwrap() + if let Some(shortstatehash) = services() + .rooms + .state_accessor + .pdu_shortstatehash(&pdu.event_id) + .unwrap() { if let Some(prev_state) = services() .rooms @@ -259,18 +271,38 @@ impl Service { } // We must keep track of all events that have been referenced. - services().rooms.pdu_metadata.mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; - services().rooms.state.set_forward_extremities(&pdu.room_id, leaves, state_lock)?; + services() + .rooms + .pdu_metadata + .mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; + services() + .rooms + .state + .set_forward_extremities(&pdu.room_id, leaves, state_lock)?; - let mutex_insert = - Arc::clone(services().globals.roomid_mutex_insert.write().await.entry(pdu.room_id.clone()).or_default()); + let mutex_insert = Arc::clone( + services() + .globals + .roomid_mutex_insert + .write() + .await + .entry(pdu.room_id.clone()) + .or_default(), + ); let insert_lock = mutex_insert.lock().await; let count1 = services().globals.next_count()?; // Mark as read first so the sending client doesn't get a notification even if // appending fails - services().rooms.edus.read_receipt.private_read_set(&pdu.room_id, &pdu.sender, count1)?; - services().rooms.user.reset_notification_counts(&pdu.sender, &pdu.room_id)?; + services() + .rooms + .edus + .read_receipt + .private_read_set(&pdu.room_id, &pdu.sender, count1)?; + services() + .rooms + .user + .reset_notification_counts(&pdu.sender, &pdu.room_id)?; let count2 = services().globals.next_count()?; let mut pdu_id = shortroomid.to_be_bytes().to_vec(); @@ -318,7 +350,10 @@ impl Service { let mut notifies = Vec::new(); let mut highlights = Vec::new(); - let mut push_target = services().rooms.state_cache.get_our_real_users(&pdu.room_id)?; + let mut push_target = services() + .rooms + .state_cache + .get_our_real_users(&pdu.room_id)?; if pdu.kind == TimelineEventType::RoomMember { if let Some(state_key) = &pdu.state_key { @@ -354,7 +389,9 @@ impl Service { let mut notify = false; for action in - services().pusher.get_actions(user, &rules_for_user, &power_levels, &sync_pdu, &pdu.room_id)? + services() + .pusher + .get_actions(user, &rules_for_user, &power_levels, &sync_pdu, &pdu.room_id)? { match action { Action::Notify => notify = true, @@ -378,7 +415,8 @@ impl Service { } } - self.db.increment_notification_counts(&pdu.room_id, notifies, highlights)?; + self.db + .increment_notification_counts(&pdu.room_id, notifies, highlights)?; match pdu.kind { TimelineEventType::RoomRedaction => { @@ -422,7 +460,13 @@ impl Service { }, TimelineEventType::SpaceChild => { if let Some(_state_key) = &pdu.state_key { - services().rooms.spaces.roomid_spacehierarchy_cache.lock().await.remove(&pdu.room_id); + services() + .rooms + .spaces + .roomid_spacehierarchy_cache + .lock() + .await + .remove(&pdu.room_id); } }, TimelineEventType::RoomMember => { @@ -463,7 +507,10 @@ impl Service { .map_err(|_| Error::bad_database("Invalid content in pdu."))?; if let Some(body) = content.body { - services().rooms.search.index_pdu(shortroomid, &pdu_id, &body)?; + services() + .rooms + .search + .index_pdu(shortroomid, &pdu_id, &body)?; let server_user = format!("@conduit:{}", services().globals.server_name()); @@ -488,8 +535,15 @@ impl Service { } if let Ok(content) = serde_json::from_str::(pdu.content.get()) { - if let Some(related_pducount) = services().rooms.timeline.get_pdu_count(&content.relates_to.event_id)? { - services().rooms.pdu_metadata.add_relation(PduCount::Normal(count2), related_pducount)?; + if let Some(related_pducount) = services() + .rooms + .timeline + .get_pdu_count(&content.relates_to.event_id)? + { + services() + .rooms + .pdu_metadata + .add_relation(PduCount::Normal(count2), related_pducount)?; } } @@ -500,32 +554,52 @@ impl Service { } => { // We need to do it again here, because replies don't have // event_id as a top level field - if let Some(related_pducount) = services().rooms.timeline.get_pdu_count(&in_reply_to.event_id)? { - services().rooms.pdu_metadata.add_relation(PduCount::Normal(count2), related_pducount)?; + if let Some(related_pducount) = services() + .rooms + .timeline + .get_pdu_count(&in_reply_to.event_id)? + { + services() + .rooms + .pdu_metadata + .add_relation(PduCount::Normal(count2), related_pducount)?; } }, Relation::Thread(thread) => { - services().rooms.threads.add_to_thread(&thread.event_id, pdu)?; + services() + .rooms + .threads + .add_to_thread(&thread.event_id, pdu)?; }, _ => {}, // TODO: Aggregate other types } } for appservice in services().appservice.read().await.values() { - if services().rooms.state_cache.appservice_in_room(&pdu.room_id, appservice)? { - services().sending.send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; + if services() + .rooms + .state_cache + .appservice_in_room(&pdu.room_id, appservice)? + { + services() + .sending + .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; continue; } // If the RoomMember event has a non-empty state_key, it is targeted at someone. // If it is our appservice user, we send this PDU to it. if pdu.kind == TimelineEventType::RoomMember { - if let Some(state_key_uid) = - &pdu.state_key.as_ref().and_then(|state_key| UserId::parse(state_key.as_str()).ok()) + if let Some(state_key_uid) = &pdu + .state_key + .as_ref() + .and_then(|state_key| UserId::parse(state_key.as_str()).ok()) { let appservice_uid = appservice.registration.sender_localpart.as_str(); if state_key_uid == appservice_uid { - services().sending.send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; + services() + .sending + .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; continue; } } @@ -534,7 +608,10 @@ impl Service { let matching_users = |users: &NamespaceRegex| { appservice.users.is_match(pdu.sender.as_str()) || pdu.kind == TimelineEventType::RoomMember - && pdu.state_key.as_ref().map_or(false, |state_key| users.is_match(state_key)) + && pdu + .state_key + .as_ref() + .map_or(false, |state_key| users.is_match(state_key)) }; let matching_aliases = |aliases: &NamespaceRegex| { services() @@ -549,7 +626,9 @@ impl Service { || appservice.rooms.is_match(pdu.room_id.as_str()) || matching_users(&appservice.users) { - services().sending.send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; + services() + .sending + .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; } } @@ -571,31 +650,43 @@ impl Service { redacts, } = pdu_builder; - let prev_events: Vec<_> = - services().rooms.state.get_forward_extremities(room_id)?.into_iter().take(20).collect(); + let prev_events: Vec<_> = services() + .rooms + .state + .get_forward_extremities(room_id)? + .into_iter() + .take(20) + .collect(); // If there was no create event yet, assume we are creating a room - let room_version_id = services().rooms.state.get_room_version(room_id).or_else(|_| { - if event_type == TimelineEventType::RoomCreate { - #[derive(Deserialize)] - struct RoomCreate { - room_version: RoomVersionId, + let room_version_id = services() + .rooms + .state + .get_room_version(room_id) + .or_else(|_| { + if event_type == TimelineEventType::RoomCreate { + #[derive(Deserialize)] + struct RoomCreate { + room_version: RoomVersionId, + } + let content = + serde_json::from_str::(content.get()).expect("Invalid content in RoomCreate pdu."); + Ok(content.room_version) + } else { + Err(Error::InconsistentRoomState( + "non-create event for room of unknown version", + room_id.to_owned(), + )) } - let content = - serde_json::from_str::(content.get()).expect("Invalid content in RoomCreate pdu."); - Ok(content.room_version) - } else { - Err(Error::InconsistentRoomState( - "non-create event for room of unknown version", - room_id.to_owned(), - )) - } - })?; + })?; let room_version = RoomVersion::new(&room_version_id).expect("room version is supported"); let auth_events = - services().rooms.state.get_auth_events(room_id, &event_type, sender, state_key.as_deref(), &content)?; + services() + .rooms + .state + .get_auth_events(room_id, &event_type, sender, state_key.as_deref(), &content)?; // Our depth is the maximum depth of prev_events + 1 let depth = prev_events @@ -609,7 +700,10 @@ impl Service { if let Some(state_key) = &state_key { if let Some(prev_pdu) = - services().rooms.state_accessor.room_state_get(room_id, &event_type.to_string().into(), state_key)? + services() + .rooms + .state_accessor + .room_state_get(room_id, &event_type.to_string().into(), state_key)? { unsigned.insert( "prev_content".to_owned(), @@ -626,13 +720,18 @@ impl Service { event_id: ruma::event_id!("$thiswillbefilledinlater").into(), room_id: room_id.to_owned(), sender: sender.to_owned(), - origin_server_ts: utils::millis_since_unix_epoch().try_into().expect("time is valid"), + origin_server_ts: utils::millis_since_unix_epoch() + .try_into() + .expect("time is valid"), kind: event_type, content, state_key, prev_events, depth, - auth_events: auth_events.values().map(|pdu| pdu.event_id.clone()).collect(), + auth_events: auth_events + .values() + .map(|pdu| pdu.event_id.clone()) + .collect(), redacts, unsigned: if unsigned.is_empty() { None @@ -710,7 +809,10 @@ impl Service { ); // Generate short event id - let _shorteventid = services().rooms.short.get_or_create_shorteventid(&pdu.event_id)?; + let _shorteventid = services() + .rooms + .short + .get_or_create_shorteventid(&pdu.event_id)?; Ok((pdu, pdu_json)) } @@ -744,7 +846,10 @@ impl Service { membership: MembershipState, } - let target = pdu.state_key().filter(|v| v.starts_with('@')).unwrap_or(sender.as_str()); + let target = pdu + .state_key() + .filter(|v| v.starts_with('@')) + .unwrap_or(sender.as_str()); let server_name = services().globals.server_name(); let server_user = format!("@conduit:{server_name}"); let content = serde_json::from_str::(pdu.content.get()) @@ -825,16 +930,25 @@ impl Service { // We set the room state after inserting the pdu, so that we never have a moment // in time where events in the current room state do not exist - services().rooms.state.set_room_state(room_id, statehashid, state_lock)?; + services() + .rooms + .state + .set_room_state(room_id, statehashid, state_lock)?; - let mut servers: HashSet = - services().rooms.state_cache.room_servers(room_id).filter_map(Result::ok).collect(); + let mut servers: HashSet = services() + .rooms + .state_cache + .room_servers(room_id) + .filter_map(Result::ok) + .collect(); // In case we are kicking or banning a user, we need to inform their server of // the change if pdu.kind == TimelineEventType::RoomMember { - if let Some(state_key_uid) = - &pdu.state_key.as_ref().and_then(|state_key| UserId::parse(state_key.as_str()).ok()) + if let Some(state_key_uid) = &pdu + .state_key + .as_ref() + .and_then(|state_key| UserId::parse(state_key.as_str()).ok()) { servers.insert(state_key_uid.server_name().to_owned()); } @@ -864,15 +978,28 @@ impl Service { // We append to state before appending the pdu, so we don't have a moment in // time with the pdu without it's state. This is okay because append_pdu can't // fail. - services().rooms.state.set_event_state(&pdu.event_id, &pdu.room_id, state_ids_compressed)?; + services() + .rooms + .state + .set_event_state(&pdu.event_id, &pdu.room_id, state_ids_compressed)?; if soft_fail { - services().rooms.pdu_metadata.mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; - services().rooms.state.set_forward_extremities(&pdu.room_id, new_room_leaves, state_lock)?; + services() + .rooms + .pdu_metadata + .mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; + services() + .rooms + .state + .set_forward_extremities(&pdu.room_id, new_room_leaves, state_lock)?; return Ok(None); } - let pdu_id = services().rooms.timeline.append_pdu(pdu, pdu_json, new_room_leaves, state_lock).await?; + let pdu_id = services() + .rooms + .timeline + .append_pdu(pdu, pdu_json, new_room_leaves, state_lock) + .await?; Ok(Some(pdu_id)) } @@ -908,8 +1035,9 @@ impl Service { pub fn redact_pdu(&self, event_id: &EventId, reason: &PduEvent) -> Result<()> { // TODO: Don't reserialize, keep original json if let Some(pdu_id) = self.get_pdu_id(event_id)? { - let mut pdu = - self.get_pdu_from_id(&pdu_id)?.ok_or_else(|| Error::bad_database("PDU ID points to invalid PDU."))?; + let mut pdu = self + .get_pdu_from_id(&pdu_id)? + .ok_or_else(|| Error::bad_database("PDU ID points to invalid PDU."))?; let room_version_id = services().rooms.state.get_room_version(&pdu.room_id)?; pdu.redact(room_version_id, reason)?; self.replace_pdu( @@ -927,8 +1055,10 @@ impl Service { #[tracing::instrument(skip(self, room_id))] pub async fn backfill_if_required(&self, room_id: &RoomId, from: PduCount) -> Result<()> { - let first_pdu = - self.all_pdus(user_id!("@doesntmatter:conduit.rs"), room_id)?.next().expect("Room is not empty")?; + let first_pdu = self + .all_pdus(user_id!("@doesntmatter:conduit.rs"), room_id)? + .next() + .expect("Room is not empty")?; if first_pdu.0 < from { // No backfill required, there are still events between them @@ -938,7 +1068,11 @@ impl Service { let mut servers: Vec<&ServerName> = vec![]; // add server names from room aliases on the room ID - let room_aliases = services().rooms.alias.local_aliases_for_room(room_id).collect::, _>>(); + let room_aliases = services() + .rooms + .alias + .local_aliases_for_room(room_id) + .collect::, _>>(); if let Ok(aliases) = &room_aliases { for alias in aliases { if alias.server_name() != services().globals.server_name() { @@ -979,8 +1113,10 @@ impl Service { } // don't backfill from ourselves (might be noop if we checked it above already) - if let Some(server_index) = - servers.clone().into_iter().position(|server| server == services().globals.server_name()) + if let Some(server_index) = servers + .clone() + .into_iter() + .position(|server| server == services().globals.server_name()) { servers.remove(server_index); } @@ -1030,8 +1166,15 @@ impl Service { let (event_id, value, room_id) = server_server::parse_incoming_pdu(&pdu)?; // Lock so we cannot backfill the same pdu twice at the same time - let mutex = - Arc::clone(services().globals.roomid_mutex_federation.write().await.entry(room_id.clone()).or_default()); + let mutex = Arc::clone( + services() + .globals + .roomid_mutex_federation + .write() + .await + .entry(room_id.clone()) + .or_default(), + ); let mutex_lock = mutex.lock().await; // Skip the PDU if we already have it as a timeline event @@ -1040,7 +1183,11 @@ impl Service { return Ok(()); } - services().rooms.event_handler.fetch_required_signing_keys([&value], pub_key_map).await?; + services() + .rooms + .event_handler + .fetch_required_signing_keys([&value], pub_key_map) + .await?; services() .rooms @@ -1051,10 +1198,21 @@ impl Service { let value = self.get_pdu_json(&event_id)?.expect("We just created it"); let pdu = self.get_pdu(&event_id)?.expect("We just created it"); - let shortroomid = services().rooms.short.get_shortroomid(&room_id)?.expect("room exists"); + let shortroomid = services() + .rooms + .short + .get_shortroomid(&room_id)? + .expect("room exists"); - let mutex_insert = - Arc::clone(services().globals.roomid_mutex_insert.write().await.entry(room_id.clone()).or_default()); + let mutex_insert = Arc::clone( + services() + .globals + .roomid_mutex_insert + .write() + .await + .entry(room_id.clone()) + .or_default(), + ); let insert_lock = mutex_insert.lock().await; let count = services().globals.next_count()?; @@ -1077,7 +1235,10 @@ impl Service { .map_err(|_| Error::bad_database("Invalid content in pdu."))?; if let Some(body) = content.body { - services().rooms.search.index_pdu(shortroomid, &pdu_id, &body)?; + services() + .rooms + .search + .index_pdu(shortroomid, &pdu_id, &body)?; } } drop(mutex_lock); diff --git a/src/service/rooms/user/mod.rs b/src/service/rooms/user/mod.rs index 7cf49d9e..e1741782 100644 --- a/src/service/rooms/user/mod.rs +++ b/src/service/rooms/user/mod.rs @@ -27,7 +27,8 @@ impl Service { } pub fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) -> Result<()> { - self.db.associate_token_shortstatehash(room_id, token, shortstatehash) + self.db + .associate_token_shortstatehash(room_id, token, shortstatehash) } pub fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result> { diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index 3326a214..254edfe9 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -127,7 +127,9 @@ impl Service { let mut initial_transactions = HashMap::>::new(); for (key, outgoing_kind, event) in self.db.active_requests().filter_map(Result::ok) { - let entry = initial_transactions.entry(outgoing_kind.clone()).or_default(); + let entry = initial_transactions + .entry(outgoing_kind.clone()) + .or_default(); if entry.len() > 30 { warn!("Dropping some current events: {:?} {:?} {:?}", key, outgoing_kind, event); @@ -236,7 +238,11 @@ impl Service { if retry { // We retry the previous transaction - for (_, e) in self.db.active_requests_for(outgoing_kind).filter_map(Result::ok) { + for (_, e) in self + .db + .active_requests_for(outgoing_kind) + .filter_map(Result::ok) + { events.push(e); } } else { @@ -282,7 +288,12 @@ impl Service { // Look for presence updates in this room let mut presence_updates = Vec::new(); - for (user_id, count, presence_event) in services().rooms.edus.presence.presence_since(&room_id, since) { + for (user_id, count, presence_event) in services() + .rooms + .edus + .presence + .presence_since(&room_id, since) + { if count > max_edu_count { max_edu_count = count; } @@ -295,7 +306,10 @@ impl Service { user_id, presence: presence_event.content.presence, currently_active: presence_event.content.currently_active.unwrap_or(false), - last_active_ago: presence_event.content.last_active_ago.unwrap_or_else(|| uint!(0)), + last_active_ago: presence_event + .content + .last_active_ago + .unwrap_or_else(|| uint!(0)), status_msg: presence_event.content.status_msg, }); } @@ -305,7 +319,12 @@ impl Service { } // Look for read receipts in this room - for r in services().rooms.edus.read_receipt.readreceipts_since(&room_id, since) { + for r in services() + .rooms + .edus + .read_receipt + .readreceipts_since(&room_id, since) + { let (user_id, count, read_receipt) = r?; if count > max_edu_count { @@ -321,8 +340,12 @@ impl Service { let federation_event = if let AnySyncEphemeralRoomEvent::Receipt(r) = event { let mut read = BTreeMap::new(); - let (event_id, mut receipt) = - r.content.0.into_iter().next().expect("we only use one event per read receipt"); + let (event_id, mut receipt) = r + .content + .0 + .into_iter() + .next() + .expect("we only use one event per read receipt"); let receipt = receipt .remove(&ReceiptType::Read) .expect("our read receipts always set this") @@ -385,7 +408,9 @@ impl Service { let event = SendingEventType::Pdu(pdu_id.to_owned()); let _cork = services().globals.db.cork()?; let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; - self.sender.send((outgoing_kind, event, keys.into_iter().next().unwrap())).unwrap(); + self.sender + .send((outgoing_kind, event, keys.into_iter().next().unwrap())) + .unwrap(); Ok(()) } @@ -397,9 +422,16 @@ impl Service { .map(|server| (OutgoingKind::Normal(server), SendingEventType::Pdu(pdu_id.to_owned()))) .collect::>(); let _cork = services().globals.db.cork()?; - let keys = self.db.queue_requests(&requests.iter().map(|(o, e)| (o, e.clone())).collect::>())?; + let keys = self.db.queue_requests( + &requests + .iter() + .map(|(o, e)| (o, e.clone())) + .collect::>(), + )?; for ((outgoing_kind, event), key) in requests.into_iter().zip(keys) { - self.sender.send((outgoing_kind.clone(), event, key)).unwrap(); + self.sender + .send((outgoing_kind.clone(), event, key)) + .unwrap(); } Ok(()) @@ -411,7 +443,9 @@ impl Service { let event = SendingEventType::Edu(serialized); let _cork = services().globals.db.cork()?; let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; - self.sender.send((outgoing_kind, event, keys.into_iter().next().unwrap())).unwrap(); + self.sender + .send((outgoing_kind, event, keys.into_iter().next().unwrap())) + .unwrap(); Ok(()) } @@ -422,15 +456,21 @@ impl Service { let event = SendingEventType::Pdu(pdu_id); let _cork = services().globals.db.cork()?; let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; - self.sender.send((outgoing_kind, event, keys.into_iter().next().unwrap())).unwrap(); + self.sender + .send((outgoing_kind, event, keys.into_iter().next().unwrap())) + .unwrap(); Ok(()) } #[tracing::instrument(skip(self, room_id))] pub fn flush_room(&self, room_id: &RoomId) -> Result<()> { - let servers: HashSet = - services().rooms.state_cache.room_servers(room_id).filter_map(Result::ok).collect(); + let servers: HashSet = services() + .rooms + .state_cache + .room_servers(room_id) + .filter_map(Result::ok) + .collect(); self.flush_servers(servers.into_iter()) } @@ -444,7 +484,9 @@ impl Service { .collect::>(); for outgoing_kind in requests { - self.sender.send((outgoing_kind, SendingEventType::Flush, Vec::::new())).unwrap(); + self.sender + .send((outgoing_kind, SendingEventType::Flush, Vec::::new())) + .unwrap(); } Ok(()) @@ -454,7 +496,8 @@ impl Service { /// Used for instance after we remove an appservice registration #[tracing::instrument(skip(self))] pub fn cleanup_events(&self, appservice_id: String) -> Result<()> { - self.db.delete_all_requests_for(&OutgoingKind::Appservice(appservice_id))?; + self.db + .delete_all_requests_for(&OutgoingKind::Appservice(appservice_id))?; Ok(()) } @@ -497,12 +540,16 @@ impl Service { let permit = services().sending.maximum_requests.acquire().await; let response = match appservice_server::send_request( - services().appservice.get_registration(id).await.ok_or_else(|| { - ( - kind.clone(), - Error::bad_database("[Appservice] Could not load registration from db."), - ) - })?, + services() + .appservice + .get_registration(id) + .await + .ok_or_else(|| { + ( + kind.clone(), + Error::bad_database("[Appservice] Could not load registration from db."), + ) + })?, appservice::event::push_events::v1::Request { events: pdu_jsons, txn_id: (&*general_purpose::URL_SAFE_NO_PAD.encode(calculate_hash( @@ -520,7 +567,9 @@ impl Service { .await { None => Ok(kind.clone()), - Some(op_resp) => op_resp.map(|_response| kind.clone()).map_err(|e| (kind.clone(), e)), + Some(op_resp) => op_resp + .map(|_response| kind.clone()) + .map_err(|e| (kind.clone(), e)), }; drop(permit); diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index f7c0ca74..8682b8db 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -44,7 +44,10 @@ impl Service { pub fn exists(&self, user_id: &UserId) -> Result { self.db.exists(user_id) } pub fn forget_sync_request_connection(&self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String) { - self.connections.lock().unwrap().remove(&(user_id, device_id, conn_id)); + self.connections + .lock() + .unwrap() + .remove(&(user_id, device_id, conn_id)); } pub fn update_sync_request_with_cache( @@ -55,14 +58,18 @@ impl Service { }; let mut cache = self.connections.lock().unwrap(); - let cached = Arc::clone(cache.entry((user_id, device_id, conn_id)).or_insert_with(|| { - Arc::new(Mutex::new(SlidingSyncCache { - lists: BTreeMap::new(), - subscriptions: BTreeMap::new(), - known_rooms: BTreeMap::new(), - extensions: ExtensionsConfig::default(), - })) - })); + let cached = Arc::clone( + cache + .entry((user_id, device_id, conn_id)) + .or_insert_with(|| { + Arc::new(Mutex::new(SlidingSyncCache { + lists: BTreeMap::new(), + subscriptions: BTreeMap::new(), + known_rooms: BTreeMap::new(), + extensions: ExtensionsConfig::default(), + })) + }), + ); let cached = &mut cached.lock().unwrap(); drop(cache); @@ -72,12 +79,18 @@ impl Service { list.sort.clone_from(&cached_list.sort); }; if list.room_details.required_state.is_empty() { - list.room_details.required_state.clone_from(&cached_list.room_details.required_state); + list.room_details + .required_state + .clone_from(&cached_list.room_details.required_state); }; - list.room_details.timeline_limit = - list.room_details.timeline_limit.or(cached_list.room_details.timeline_limit); - list.include_old_rooms = - list.include_old_rooms.clone().or_else(|| cached_list.include_old_rooms.clone()); + list.room_details.timeline_limit = list + .room_details + .timeline_limit + .or(cached_list.room_details.timeline_limit); + list.include_old_rooms = list + .include_old_rooms + .clone() + .or_else(|| cached_list.include_old_rooms.clone()); match (&mut list.filters, cached_list.filters.clone()) { (Some(list_filters), Some(cached_filters)) => { list_filters.is_dm = list_filters.is_dm.or(cached_filters.is_dm); @@ -92,8 +105,10 @@ impl Service { if list_filters.not_room_types.is_empty() { list_filters.not_room_types = cached_filters.not_room_types; } - list_filters.room_name_like = - list_filters.room_name_like.clone().or(cached_filters.room_name_like); + list_filters.room_name_like = list_filters + .room_name_like + .clone() + .or(cached_filters.room_name_like); if list_filters.tags.is_empty() { list_filters.tags = cached_filters.tags; } @@ -106,26 +121,49 @@ impl Service { (..) => {}, } if list.bump_event_types.is_empty() { - list.bump_event_types.clone_from(&cached_list.bump_event_types); + list.bump_event_types + .clone_from(&cached_list.bump_event_types); }; } cached.lists.insert(list_id.clone(), list.clone()); } - cached.subscriptions.extend(request.room_subscriptions.clone()); - request.room_subscriptions.extend(cached.subscriptions.clone()); + cached + .subscriptions + .extend(request.room_subscriptions.clone()); + request + .room_subscriptions + .extend(cached.subscriptions.clone()); - request.extensions.e2ee.enabled = request.extensions.e2ee.enabled.or(cached.extensions.e2ee.enabled); + request.extensions.e2ee.enabled = request + .extensions + .e2ee + .enabled + .or(cached.extensions.e2ee.enabled); - request.extensions.to_device.enabled = - request.extensions.to_device.enabled.or(cached.extensions.to_device.enabled); + request.extensions.to_device.enabled = request + .extensions + .to_device + .enabled + .or(cached.extensions.to_device.enabled); - request.extensions.account_data.enabled = - request.extensions.account_data.enabled.or(cached.extensions.account_data.enabled); - request.extensions.account_data.lists = - request.extensions.account_data.lists.clone().or_else(|| cached.extensions.account_data.lists.clone()); - request.extensions.account_data.rooms = - request.extensions.account_data.rooms.clone().or_else(|| cached.extensions.account_data.rooms.clone()); + request.extensions.account_data.enabled = request + .extensions + .account_data + .enabled + .or(cached.extensions.account_data.enabled); + request.extensions.account_data.lists = request + .extensions + .account_data + .lists + .clone() + .or_else(|| cached.extensions.account_data.lists.clone()); + request.extensions.account_data.rooms = request + .extensions + .account_data + .rooms + .clone() + .or_else(|| cached.extensions.account_data.rooms.clone()); cached.extensions = request.extensions.clone(); @@ -137,14 +175,18 @@ impl Service { subscriptions: BTreeMap, ) { let mut cache = self.connections.lock().unwrap(); - let cached = Arc::clone(cache.entry((user_id, device_id, conn_id)).or_insert_with(|| { - Arc::new(Mutex::new(SlidingSyncCache { - lists: BTreeMap::new(), - subscriptions: BTreeMap::new(), - known_rooms: BTreeMap::new(), - extensions: ExtensionsConfig::default(), - })) - })); + let cached = Arc::clone( + cache + .entry((user_id, device_id, conn_id)) + .or_insert_with(|| { + Arc::new(Mutex::new(SlidingSyncCache { + lists: BTreeMap::new(), + subscriptions: BTreeMap::new(), + known_rooms: BTreeMap::new(), + extensions: ExtensionsConfig::default(), + })) + }), + ); let cached = &mut cached.lock().unwrap(); drop(cache); @@ -156,18 +198,27 @@ impl Service { new_cached_rooms: BTreeSet, globalsince: u64, ) { let mut cache = self.connections.lock().unwrap(); - let cached = Arc::clone(cache.entry((user_id, device_id, conn_id)).or_insert_with(|| { - Arc::new(Mutex::new(SlidingSyncCache { - lists: BTreeMap::new(), - subscriptions: BTreeMap::new(), - known_rooms: BTreeMap::new(), - extensions: ExtensionsConfig::default(), - })) - })); + let cached = Arc::clone( + cache + .entry((user_id, device_id, conn_id)) + .or_insert_with(|| { + Arc::new(Mutex::new(SlidingSyncCache { + lists: BTreeMap::new(), + subscriptions: BTreeMap::new(), + known_rooms: BTreeMap::new(), + extensions: ExtensionsConfig::default(), + })) + }), + ); let cached = &mut cached.lock().unwrap(); drop(cache); - for (roomid, lastsince) in cached.known_rooms.entry(list_id.clone()).or_default().iter_mut() { + for (roomid, lastsince) in cached + .known_rooms + .entry(list_id.clone()) + .or_default() + .iter_mut() + { if !new_cached_rooms.contains(roomid) { *lastsince = 0; } @@ -185,9 +236,16 @@ impl Service { pub fn is_admin(&self, user_id: &UserId) -> Result { let admin_room_alias_id = RoomAliasId::parse(format!("#admins:{}", services().globals.server_name())) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid alias."))?; - let admin_room_id = services().rooms.alias.resolve_local_alias(&admin_room_alias_id)?.unwrap(); + let admin_room_id = services() + .rooms + .alias + .resolve_local_alias(&admin_room_alias_id)? + .unwrap(); - services().rooms.state_cache.is_joined(user_id, &admin_room_id) + services() + .rooms + .state_cache + .is_joined(user_id, &admin_room_id) } /// Create a new user account on this homeserver. @@ -250,7 +308,8 @@ impl Service { pub fn create_device( &self, user_id: &UserId, device_id: &DeviceId, token: &str, initial_device_display_name: Option, ) -> Result<()> { - self.db.create_device(user_id, device_id, token, initial_device_display_name) + self.db + .create_device(user_id, device_id, token, initial_device_display_name) } /// Removes a device from a user. @@ -272,7 +331,8 @@ impl Service { &self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &DeviceKeyId, one_time_key_value: &Raw, ) -> Result<()> { - self.db.add_one_time_key(user_id, device_id, one_time_key_key, one_time_key_value) + self.db + .add_one_time_key(user_id, device_id, one_time_key_key, one_time_key_value) } pub fn last_one_time_keys_update(&self, user_id: &UserId) -> Result { @@ -299,7 +359,8 @@ impl Service { &self, user_id: &UserId, master_key: &Raw, self_signing_key: &Option>, user_signing_key: &Option>, notify: bool, ) -> Result<()> { - self.db.add_cross_signing_keys(user_id, master_key, self_signing_key, user_signing_key, notify) + self.db + .add_cross_signing_keys(user_id, master_key, self_signing_key, user_signing_key, notify) } pub fn sign_key( @@ -329,19 +390,22 @@ impl Service { pub fn get_key( &self, key: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, ) -> Result>> { - self.db.get_key(key, sender_user, user_id, allowed_signatures) + self.db + .get_key(key, sender_user, user_id, allowed_signatures) } pub fn get_master_key( &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, ) -> Result>> { - self.db.get_master_key(sender_user, user_id, allowed_signatures) + self.db + .get_master_key(sender_user, user_id, allowed_signatures) } pub fn get_self_signing_key( &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, ) -> Result>> { - self.db.get_self_signing_key(sender_user, user_id, allowed_signatures) + self.db + .get_self_signing_key(sender_user, user_id, allowed_signatures) } pub fn get_user_signing_key(&self, user_id: &UserId) -> Result>> { @@ -352,7 +416,8 @@ impl Service { &self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str, content: serde_json::Value, ) -> Result<()> { - self.db.add_to_device_event(sender, target_user_id, target_device_id, event_type, content) + self.db + .add_to_device_event(sender, target_user_id, target_device_id, event_type, content) } pub fn get_to_device_events(&self, user_id: &UserId, device_id: &DeviceId) -> Result>> { @@ -411,7 +476,10 @@ impl Service { pub fn clean_signatures bool>( cross_signing_key: &mut serde_json::Value, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: F, ) -> Result<(), Error> { - if let Some(signatures) = cross_signing_key.get_mut("signatures").and_then(|v| v.as_object_mut()) { + if let Some(signatures) = cross_signing_key + .get_mut("signatures") + .and_then(|v| v.as_object_mut()) + { // Don't allocate for the full size of the current signatures, but require // at most one resize if nothing is dropped let new_capacity = signatures.len() / 2; diff --git a/src/utils/mod.rs b/src/utils/mod.rs index b19889ad..c4e1c7fc 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -15,7 +15,10 @@ use ruma::{canonical_json::try_from_json_map, CanonicalJsonError, CanonicalJsonO use crate::{services, Error, Result}; pub(crate) fn millis_since_unix_epoch() -> u64 { - SystemTime::now().duration_since(UNIX_EPOCH).expect("time is valid").as_millis() as u64 + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time is valid") + .as_millis() as u64 } pub(crate) fn increment(old: Option<&[u8]>) -> Vec { @@ -59,13 +62,21 @@ pub fn user_id_from_bytes(bytes: &[u8]) -> Result { } pub fn random_string(length: usize) -> String { - thread_rng().sample_iter(&rand::distributions::Alphanumeric).take(length).map(char::from).collect() + thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(length) + .map(char::from) + .collect() } /// Calculate a new hash for the given password pub fn calculate_password_hash(password: &str) -> Result { let salt = SaltString::generate(thread_rng()); - services().globals.argon.hash_password(password.as_bytes(), &salt).map(|it| it.to_string()) + services() + .globals + .argon + .hash_password(password.as_bytes(), &salt) + .map(|it| it.to_string()) } #[tracing::instrument(skip(keys))]