From 4ec2d3ecb5e1f008c52d139cc3340e2bd06ea00c Mon Sep 17 00:00:00 2001 From: Matthias Ahouansou Date: Tue, 5 Mar 2024 20:52:16 -0500 Subject: [PATCH] refactor: use async-aware RwLocks and Mutexes where possible squashed from https://gitlab.com/famedly/conduit/-/merge_requests/595 Signed-off-by: strawberry --- src/api/client_server/keys.rs | 16 ++-- src/api/client_server/media.rs | 3 +- src/api/client_server/membership.rs | 48 ++++++------ src/api/client_server/message.rs | 6 +- src/api/client_server/profile.rs | 4 +- src/api/client_server/redact.rs | 2 +- src/api/client_server/room.rs | 6 +- src/api/client_server/state.rs | 2 +- src/api/client_server/sync.rs | 36 ++++----- src/api/server_server.rs | 15 ++-- src/lib.rs | 4 + src/service/admin/mod.rs | 23 +++--- src/service/globals/mod.rs | 20 ++--- src/service/media/mod.rs | 9 +-- src/service/mod.rs | 27 +++---- src/service/rooms/event_handler/mod.rs | 103 +++++++++++-------------- src/service/rooms/lazy_loading/mod.rs | 19 ++--- src/service/rooms/spaces/mod.rs | 11 +-- src/service/rooms/state/mod.rs | 2 +- src/service/rooms/timeline/mod.rs | 12 +-- 20 files changed, 174 insertions(+), 194 deletions(-) diff --git a/src/api/client_server/keys.rs b/src/api/client_server/keys.rs index e32d7a97..1f44f4c0 100644 --- a/src/api/client_server/keys.rs +++ b/src/api/client_server/keys.rs @@ -281,17 +281,19 @@ pub(crate) async fn get_keys_helper bool>( let mut failures = BTreeMap::new(); - let back_off = |id| match services().globals.bad_query_ratelimiter.write().unwrap().entry(id) { - hash_map::Entry::Vacant(e) => { - e.insert((Instant::now(), 1)); - }, - hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), + let back_off = |id| async { + match services().globals.bad_query_ratelimiter.write().await.entry(id) { + hash_map::Entry::Vacant(e) => { + e.insert((Instant::now(), 1)); + }, + hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), + } }; let mut futures: FuturesUnordered<_> = get_over_federation .into_iter() .map(|(server, vec)| async move { - if let Some((time, tries)) = services().globals.bad_query_ratelimiter.read().unwrap().get(server) { + if let Some((time, tries)) = services().globals.bad_query_ratelimiter.read().await.get(server) { // Exponential backoff let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { @@ -354,7 +356,7 @@ pub(crate) async fn get_keys_helper bool>( device_keys.extend(response.device_keys); }, _ => { - back_off(server.to_owned()); + back_off(server.to_owned()).await; failures.insert(server.to_string(), json!({})); }, } diff --git a/src/api/client_server/media.rs b/src/api/client_server/media.rs index 7d0cf7ab..aa8176d5 100644 --- a/src/api/client_server/media.rs +++ b/src/api/client_server/media.rs @@ -440,8 +440,7 @@ async fn get_url_preview(url: &str) -> Result { } // ensure that only one request is made per URL - let mutex_request = - Arc::clone(services().media.url_preview_mutex.write().unwrap().entry(url.to_owned()).or_default()); + let mutex_request = Arc::clone(services().media.url_preview_mutex.write().await.entry(url.to_owned()).or_default()); let _request_lock = mutex_request.lock().await; match services().media.get_url_preview(url).await { diff --git a/src/api/client_server/membership.rs b/src/api/client_server/membership.rs index e046df96..798575a8 100644 --- a/src/api/client_server/membership.rs +++ b/src/api/client_server/membership.rs @@ -1,6 +1,6 @@ use std::{ collections::{hash_map::Entry, BTreeMap, HashMap, HashSet}, - sync::{Arc, RwLock}, + sync::Arc, time::{Duration, Instant}, }; @@ -29,6 +29,7 @@ use ruma::{ OwnedUserId, RoomId, RoomVersionId, UserId, }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; +use tokio::sync::RwLock; use tracing::{debug, error, info, warn}; use super::get_alias_helper; @@ -242,7 +243,7 @@ pub async fn kick_user_route(body: Ruma) -> Result) -> Result) -> Result t, Err(_) => continue, }; @@ -705,7 +706,7 @@ async fn join_room_by_id_helper( .iter() .map(|pdu| validate_and_add_event_id(pdu, &room_version_id, &pub_key_map)) { - let (event_id, value) = match result { + let (event_id, value) = match result.await { Ok(t) => t, Err(_) => continue, }; @@ -1048,7 +1049,7 @@ async fn make_join_request( make_join_response_and_server } -fn validate_and_add_event_id( +async fn validate_and_add_event_id( pdu: &RawJsonValue, room_version: &RoomVersionId, pub_key_map: &RwLock>>, ) -> Result<(OwnedEventId, CanonicalJsonObject)> { let mut value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { @@ -1061,14 +1062,16 @@ fn validate_and_add_event_id( )) .expect("ruma's reference hashes are valid event ids"); - let back_off = |id| match services().globals.bad_event_ratelimiter.write().unwrap().entry(id) { - Entry::Vacant(e) => { - e.insert((Instant::now(), 1)); - }, - Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), + let back_off = |id| async { + match services().globals.bad_event_ratelimiter.write().await.entry(id) { + Entry::Vacant(e) => { + e.insert((Instant::now(), 1)); + }, + Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), + } }; - if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().unwrap().get(&event_id) { + if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().await.get(&event_id) { // Exponential backoff let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { @@ -1081,13 +1084,9 @@ fn validate_and_add_event_id( } } - if let Err(e) = ruma::signatures::verify_event( - &*pub_key_map.read().map_err(|_| Error::bad_database("RwLock is poisoned."))?, - &value, - room_version, - ) { + if let Err(e) = ruma::signatures::verify_event(&*pub_key_map.read().await, &value, room_version) { warn!("Event {} failed verification {:?} {}", event_id, pdu, e); - back_off(event_id); + back_off(event_id).await; return Err(Error::BadServerResponse("Event failed verification.")); } @@ -1109,9 +1108,8 @@ pub(crate) async fn invite_helper( if user_id.server_name() != services().globals.server_name() { let (pdu, pdu_json, invite_room_state) = { - let mutex_state = Arc::clone( - services().globals.roomid_mutex_state.write().unwrap().entry(room_id.to_owned()).or_default(), - ); + let mutex_state = + Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.to_owned()).or_default()); let state_lock = mutex_state.lock().await; let content = to_raw_value(&RoomMemberEventContent { @@ -1229,7 +1227,7 @@ pub(crate) async fn invite_helper( } let mutex_state = - Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.to_owned()).or_default()); + Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.to_owned()).or_default()); let state_lock = mutex_state.lock().await; services() @@ -1314,7 +1312,7 @@ pub async fn leave_room(user_id: &UserId, room_id: &RoomId, reason: Option) -> Re for (pdu_builder, room_id) in all_joined_rooms { let mutex_state = - Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.clone()).or_default()); + Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.clone()).or_default()); let state_lock = mutex_state.lock().await; let _ = services().rooms.timeline.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock).await; diff --git a/src/api/client_server/redact.rs b/src/api/client_server/redact.rs index 674a67c6..44e42f08 100644 --- a/src/api/client_server/redact.rs +++ b/src/api/client_server/redact.rs @@ -18,7 +18,7 @@ pub async fn redact_event_route(body: Ruma) -> Result let body = body.body; let mutex_state = - Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(body.room_id.clone()).or_default()); + Arc::clone(services().globals.roomid_mutex_state.write().await.entry(body.room_id.clone()).or_default()); let state_lock = mutex_state.lock().await; let event_id = services() diff --git a/src/api/client_server/room.rs b/src/api/client_server/room.rs index c0247d27..a6e667d7 100644 --- a/src/api/client_server/room.rs +++ b/src/api/client_server/room.rs @@ -111,7 +111,7 @@ pub async fn create_room_route(body: Ruma) -> Result = body.room_alias_name.as_ref().map_or(Ok(None), |localpart| { @@ -610,7 +610,7 @@ pub async fn upgrade_room_route(body: Ruma) -> Result services().rooms.short.get_or_create_shortroomid(&replacement_room)?; let mutex_state = - Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(body.room_id.clone()).or_default()); + Arc::clone(services().globals.roomid_mutex_state.write().await.entry(body.room_id.clone()).or_default()); let state_lock = mutex_state.lock().await; // Send a m.room.tombstone event to the old room to indicate that it is not @@ -640,7 +640,7 @@ pub async fn upgrade_room_route(body: Ruma) -> Result // Change lock to replacement room drop(state_lock); let mutex_state = - Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(replacement_room.clone()).or_default()); + Arc::clone(services().globals.roomid_mutex_state.write().await.entry(replacement_room.clone()).or_default()); let state_lock = mutex_state.lock().await; // Get the old room creation event diff --git a/src/api/client_server/state.rs b/src/api/client_server/state.rs index 6dfb0fcc..d18608a2 100644 --- a/src/api/client_server/state.rs +++ b/src/api/client_server/state.rs @@ -230,7 +230,7 @@ async fn send_state_event_for_key_helper( } let mutex_state = - Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.to_owned()).or_default()); + Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.to_owned()).or_default()); let state_lock = mutex_state.lock().await; let event_id = services() diff --git a/src/api/client_server/sync.rs b/src/api/client_server/sync.rs index 28c2793d..38e1cd3d 100644 --- a/src/api/client_server/sync.rs +++ b/src/api/client_server/sync.rs @@ -82,7 +82,7 @@ pub async fn sync_events_route( let body = body.body; let mut rx = - match services().globals.sync_receivers.write().unwrap().entry((sender_user.clone(), sender_device.clone())) { + match services().globals.sync_receivers.write().await.entry((sender_user.clone(), sender_device.clone())) { Entry::Vacant(v) => { let (tx, rx) = tokio::sync::watch::channel(None); @@ -132,7 +132,7 @@ async fn sync_helper_wrapper( if let Ok((_, caching_allowed)) = r { if !caching_allowed { - match services().globals.sync_receivers.write().unwrap().entry((sender_user, sender_device)) { + match services().globals.sync_receivers.write().await.entry((sender_user, sender_device)) { Entry::Occupied(o) => { // Only remove if the device didn't start a different /sync already if o.get().0 == since { @@ -233,7 +233,7 @@ async fn sync_helper( { // Get and drop the lock to wait for remaining operations to finish let mutex_insert = - Arc::clone(services().globals.roomid_mutex_insert.write().unwrap().entry(room_id.clone()).or_default()); + Arc::clone(services().globals.roomid_mutex_insert.write().await.entry(room_id.clone()).or_default()); let insert_lock = mutex_insert.lock().await; drop(insert_lock); }; @@ -339,7 +339,7 @@ async fn sync_helper( { // Get and drop the lock to wait for remaining operations to finish let mutex_insert = - Arc::clone(services().globals.roomid_mutex_insert.write().unwrap().entry(room_id.clone()).or_default()); + Arc::clone(services().globals.roomid_mutex_insert.write().await.entry(room_id.clone()).or_default()); let insert_lock = mutex_insert.lock().await; drop(insert_lock); }; @@ -485,7 +485,7 @@ async fn load_joined_room( // Get and drop the lock to wait for remaining operations to finish // This will make sure the we have all events until next_batch let mutex_insert = - Arc::clone(services().globals.roomid_mutex_insert.write().unwrap().entry(room_id.to_owned()).or_default()); + Arc::clone(services().globals.roomid_mutex_insert.write().await.entry(room_id.to_owned()).or_default()); let insert_lock = mutex_insert.lock().await; drop(insert_lock); }; @@ -500,7 +500,7 @@ async fn load_joined_room( timeline_users.insert(event.sender.as_str().to_owned()); } - services().rooms.lazy_loading.lazy_load_confirm_delivery(sender_user, sender_device, room_id, sincecount)?; + services().rooms.lazy_loading.lazy_load_confirm_delivery(sender_user, sender_device, room_id, sincecount).await?; // Database queries: @@ -653,13 +653,11 @@ async fn load_joined_room( // The state_events above should contain all timeline_users, let's mark them as // lazy loaded. - services().rooms.lazy_loading.lazy_load_mark_sent( - sender_user, - sender_device, - room_id, - lazy_loaded, - next_batchcount, - ); + services() + .rooms + .lazy_loading + .lazy_load_mark_sent(sender_user, sender_device, room_id, lazy_loaded, next_batchcount) + .await; (heroes, joined_member_count, invited_member_count, true, state_events) } else { @@ -721,13 +719,11 @@ async fn load_joined_room( } } - services().rooms.lazy_loading.lazy_load_mark_sent( - sender_user, - sender_device, - room_id, - lazy_loaded, - next_batchcount, - ); + services() + .rooms + .lazy_loading + .lazy_load_mark_sent(sender_user, sender_device, room_id, lazy_loaded, next_batchcount) + .await; let encrypted_room = services() .rooms diff --git a/src/api/server_server.rs b/src/api/server_server.rs index d33d3f2b..1413f55b 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -6,7 +6,7 @@ use std::{ fmt::Debug, mem, net::{IpAddr, SocketAddr}, - sync::{Arc, RwLock}, + sync::Arc, time::{Duration, Instant, SystemTime}, }; @@ -50,6 +50,7 @@ use ruma::{ OwnedRoomId, OwnedServerName, OwnedServerSigningKeyId, OwnedUserId, RoomId, ServerName, }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; +use tokio::sync::RwLock; use tracing::{debug, error, info, warn}; use trust_dns_resolver::{error::ResolveError, lookup::SrvLookup}; @@ -157,7 +158,7 @@ where let mut write_destination_to_cache = false; - let cached_result = services().globals.actual_destination_cache.read().unwrap().get(destination).cloned(); + let cached_result = services().globals.actual_destination_cache.read().await.get(destination).cloned(); let (actual_destination, host) = if let Some(result) = cached_result { result @@ -276,7 +277,7 @@ where .globals .actual_destination_cache .write() - .unwrap() + .await .insert(OwnedServerName::from(destination), (actual_destination, host)); } @@ -291,7 +292,7 @@ where // well-knowns if !write_destination_to_cache { info!("Evicting {destination} from our true destination cache due to failed request."); - services().globals.actual_destination_cache.write().unwrap().remove(destination); + services().globals.actual_destination_cache.write().await.remove(destination); } Err(Error::FederationError( @@ -767,7 +768,7 @@ pub async fn send_transaction_message_route( for (event_id, value, room_id) in parsed_pdus { let mutex = - Arc::clone(services().globals.roomid_mutex_federation.write().unwrap().entry(room_id.clone()).or_default()); + Arc::clone(services().globals.roomid_mutex_federation.write().await.entry(room_id.clone()).or_default()); let mutex_lock = mutex.lock().await; let start_time = Instant::now(); resolved_map.insert( @@ -1264,7 +1265,7 @@ pub async fn create_join_event_template_route( services().rooms.event_handler.acl_check(sender_servername, &body.room_id)?; let mutex_state = - Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(body.room_id.clone()).or_default()); + Arc::clone(services().globals.roomid_mutex_state.write().await.entry(body.room_id.clone()).or_default()); let state_lock = mutex_state.lock().await; // TODO: Conduit does not implement restricted join rules yet, we always reject @@ -1413,7 +1414,7 @@ async fn create_join_event( services().rooms.event_handler.fetch_required_signing_keys([&value], &pub_key_map).await?; let mutex = - Arc::clone(services().globals.roomid_mutex_federation.write().unwrap().entry(room_id.to_owned()).or_default()); + Arc::clone(services().globals.roomid_mutex_federation.write().await.entry(room_id.to_owned()).or_default()); let mutex_lock = mutex.lock().await; let pdu_id: Vec = services() .rooms diff --git a/src/lib.rs b/src/lib.rs index 45b43e9d..39a09141 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,6 +4,10 @@ mod database; mod service; mod utils; +// Not async due to services() being used in many closures, and async closures +// are not stable as of writing This is the case for every other occurence of +// sync Mutex/RwLock, except for database related ones, where the current +// maintainer (Timo) has asked to not modify those use std::sync::RwLock; pub use api::ruma_wrapper::{Ruma, RumaResponse}; diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index 10641c34..fec27e1e 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -1,9 +1,4 @@ -use std::{ - collections::BTreeMap, - fmt::Write as _, - sync::{Arc, RwLock}, - time::Instant, -}; +use std::{collections::BTreeMap, fmt::Write as _, sync::Arc, time::Instant}; use clap::{Parser, Subcommand}; use regex::Regex; @@ -29,7 +24,7 @@ use ruma::{ ServerName, UserId, }; use serde_json::value::to_raw_value; -use tokio::sync::{mpsc, Mutex}; +use tokio::sync::{mpsc, Mutex, RwLock}; use tracing::{debug, error, info, warn}; use super::pdu::PduBuilder; @@ -473,7 +468,7 @@ impl Service { services().globals .roomid_mutex_state .write() - .unwrap() + .await .entry(conduit_room.clone()) .or_default(), ); @@ -1773,7 +1768,7 @@ impl Service { RoomMessageEventContent::text_plain("Room enabled.") }, FederationCommand::IncomingFederation => { - let map = services().globals.roomid_federationhandletime.read().unwrap(); + let map = services().globals.roomid_federationhandletime.read().await; let mut msg: String = format!("Handling {} incoming pdus:\n", map.len()); for (r, (e, i)) in map.iter() { @@ -1818,7 +1813,7 @@ impl Service { .fetch_required_signing_keys([&value], &pub_key_map) .await?; - let pub_key_map = pub_key_map.read().unwrap(); + let pub_key_map = pub_key_map.read().await; match ruma::signatures::verify_json(&pub_key_map, &value) { Ok(_) => RoomMessageEventContent::text_plain("Signature correct"), Err(e) => RoomMessageEventContent::text_plain(format!( @@ -1841,7 +1836,7 @@ impl Service { RoomMessageEventContent::text_plain(format!("{}", services().globals.config)) }, ServerCommand::MemoryUsage => { - let response1 = services().memory_usage(); + let response1 = services().memory_usage().await; let response2 = services().globals.db.memory_usage(); RoomMessageEventContent::text_plain(format!("Services:\n{response1}\n\nDatabase:\n{response2}")) @@ -1856,7 +1851,7 @@ impl Service { ServerCommand::ClearServiceCaches { amount, } => { - services().clear_caches(amount); + services().clear_caches(amount).await; RoomMessageEventContent::text_plain("Done.") }, @@ -2047,7 +2042,7 @@ impl Service { services().rooms.short.get_or_create_shortroomid(&room_id)?; let mutex_state = - Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.clone()).or_default()); + Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.clone()).or_default()); let state_lock = mutex_state.lock().await; // Create a user for the server @@ -2291,7 +2286,7 @@ impl Service { let room_id = services().rooms.alias.resolve_local_alias(&admin_room_alias)?.expect("Admin room must exist"); let mutex_state = - Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.clone()).or_default()); + Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.clone()).or_default()); let state_lock = mutex_state.lock().await; // Use the server user to grant the new admin's power level diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 1f88cbc3..c0b2db3c 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -8,7 +8,7 @@ use std::{ path::PathBuf, sync::{ atomic::{self, AtomicBool}, - Arc, Mutex, RwLock, + Arc, RwLock as StdRwLock, }, time::{Duration, Instant}, }; @@ -33,7 +33,7 @@ use ruma::{ RoomVersionId, ServerName, UserId, }; use sha2::Digest; -use tokio::sync::{broadcast, watch::Receiver, Mutex as TokioMutex, Semaphore}; +use tokio::sync::{broadcast, watch::Receiver, Mutex, RwLock, Semaphore}; use tracing::{error, info}; use trust_dns_resolver::TokioAsyncResolver; @@ -53,7 +53,7 @@ pub struct Service<'a> { pub db: &'static dyn Data, pub actual_destination_cache: Arc>, // actual_destination, host - pub tls_name_override: Arc>, + pub tls_name_override: Arc>, pub config: Config, keypair: Arc, dns_resolver: TokioAsyncResolver, @@ -68,9 +68,9 @@ pub struct Service<'a> { pub bad_query_ratelimiter: Arc>>, pub servername_ratelimiter: Arc>>>, pub sync_receivers: RwLock>, - pub roomid_mutex_insert: RwLock>>>, - pub roomid_mutex_state: RwLock>>>, - pub roomid_mutex_federation: RwLock>>>, // this lock will be held longer + pub roomid_mutex_insert: RwLock>>>, + pub roomid_mutex_state: RwLock>>>, + pub roomid_mutex_federation: RwLock>>>, // this lock will be held longer pub roomid_federationhandletime: RwLock>, pub stateres_mutex: Arc>, pub(crate) rotate: RotationHandler, @@ -109,11 +109,11 @@ impl Default for RotationHandler { struct Resolver { inner: GaiResolver, - overrides: Arc>, + overrides: Arc>, } impl Resolver { - fn new(overrides: Arc>) -> Self { + fn new(overrides: Arc>) -> Self { Resolver { inner: GaiResolver::new(), overrides, @@ -125,7 +125,7 @@ impl Resolve for Resolver { fn resolve(&self, name: Name) -> Resolving { self.overrides .read() - .expect("lock should not be poisoned") + .unwrap() .get(name.as_str()) .and_then(|(override_name, port)| { override_name.first().map(|first_name| { @@ -159,7 +159,7 @@ impl Service<'_> { }, }; - let tls_name_override = Arc::new(RwLock::new(TlsNameMap::new())); + let tls_name_override = Arc::new(StdRwLock::new(TlsNameMap::new())); let jwt_decoding_key = config.jwt_secret.as_ref().map(|secret| jsonwebtoken::DecodingKey::from_secret(secret.as_bytes())); diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs index 8b3b002a..97652a35 100644 --- a/src/service/media/mod.rs +++ b/src/service/media/mod.rs @@ -1,10 +1,5 @@ mod data; -use std::{ - collections::HashMap, - io::Cursor, - sync::{Arc, RwLock}, - time::SystemTime, -}; +use std::{collections::HashMap, io::Cursor, sync::Arc, time::SystemTime}; pub(crate) use data::Data; use image::imageops::FilterType; @@ -13,7 +8,7 @@ use serde::Serialize; use tokio::{ fs::{self, File}, io::{AsyncReadExt, AsyncWriteExt, BufReader}, - sync::Mutex, + sync::{Mutex, RwLock}, }; use tracing::{debug, error}; diff --git a/src/service/mod.rs b/src/service/mod.rs index 9bee19fd..117af89a 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -1,9 +1,10 @@ use std::{ collections::{BTreeMap, HashMap}, - sync::{Arc, Mutex, RwLock}, + sync::{Arc, Mutex as StdMutex}, }; use lru_cache::LruCache; +use tokio::sync::{Mutex, RwLock}; use crate::{Config, Result}; @@ -106,10 +107,10 @@ impl Services<'_> { }, state_accessor: rooms::state_accessor::Service { db, - server_visibility_cache: Mutex::new(LruCache::new( + server_visibility_cache: StdMutex::new(LruCache::new( (100.0 * config.conduit_cache_capacity_modifier) as usize, )), - user_visibility_cache: Mutex::new(LruCache::new( + user_visibility_cache: StdMutex::new(LruCache::new( (100.0 * config.conduit_cache_capacity_modifier) as usize, )), }, @@ -118,7 +119,7 @@ impl Services<'_> { }, state_compressor: rooms::state_compressor::Service { db, - stateinfo_cache: Mutex::new(LruCache::new( + stateinfo_cache: StdMutex::new(LruCache::new( (100.0 * config.conduit_cache_capacity_modifier) as usize, )), }, @@ -146,7 +147,7 @@ impl Services<'_> { }, users: users::Service { db, - connections: Mutex::new(BTreeMap::new()), + connections: StdMutex::new(BTreeMap::new()), }, account_data: account_data::Service { db, @@ -165,13 +166,13 @@ impl Services<'_> { }) } - fn memory_usage(&self) -> String { - let lazy_load_waiting = self.rooms.lazy_loading.lazy_load_waiting.lock().unwrap().len(); + async fn memory_usage(&self) -> String { + let lazy_load_waiting = self.rooms.lazy_loading.lazy_load_waiting.lock().await.len(); let server_visibility_cache = self.rooms.state_accessor.server_visibility_cache.lock().unwrap().len(); let user_visibility_cache = self.rooms.state_accessor.user_visibility_cache.lock().unwrap().len(); let stateinfo_cache = self.rooms.state_compressor.stateinfo_cache.lock().unwrap().len(); - let lasttimelinecount_cache = self.rooms.timeline.lasttimelinecount_cache.lock().unwrap().len(); - let roomid_spacechunk_cache = self.rooms.spaces.roomid_spacechunk_cache.lock().unwrap().len(); + let lasttimelinecount_cache = self.rooms.timeline.lasttimelinecount_cache.lock().await.len(); + let roomid_spacechunk_cache = self.rooms.spaces.roomid_spacechunk_cache.lock().await.len(); format!( "\ @@ -184,9 +185,9 @@ roomid_spacechunk_cache: {roomid_spacechunk_cache}" ) } - fn clear_caches(&self, amount: u32) { + async fn clear_caches(&self, amount: u32) { if amount > 0 { - self.rooms.lazy_loading.lazy_load_waiting.lock().unwrap().clear(); + self.rooms.lazy_loading.lazy_load_waiting.lock().await.clear(); } if amount > 1 { self.rooms.state_accessor.server_visibility_cache.lock().unwrap().clear(); @@ -198,10 +199,10 @@ roomid_spacechunk_cache: {roomid_spacechunk_cache}" self.rooms.state_compressor.stateinfo_cache.lock().unwrap().clear(); } if amount > 4 { - self.rooms.timeline.lasttimelinecount_cache.lock().unwrap().clear(); + self.rooms.timeline.lasttimelinecount_cache.lock().await.clear(); } if amount > 5 { - self.rooms.spaces.roomid_spacechunk_cache.lock().unwrap().clear(); + self.rooms.spaces.roomid_spacechunk_cache.lock().await.clear(); } } } diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index 21864e12..ae2fd74d 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -4,7 +4,6 @@ type AsyncRecursiveType<'a, T> = Pin + 'a + Send>>; use std::{ collections::{hash_map, HashSet}, pin::Pin, - sync::RwLockWriteGuard, time::{Duration, Instant, SystemTime}, }; @@ -33,12 +32,12 @@ use ruma::{ OwnedServerSigningKeyId, RoomId, RoomVersionId, ServerName, }; use serde_json::value::RawValue as RawJsonValue; -use tokio::sync::Semaphore; +use tokio::sync::{RwLock, RwLockWriteGuard, Semaphore}; use tracing::{debug, error, info, trace, warn}; use super::state_compressor::CompressedStateEvent; use crate::{ - service::{pdu, Arc, BTreeMap, HashMap, Result, RwLock}, + service::{pdu, Arc, BTreeMap, HashMap, Result}, services, Error, PduEvent, }; @@ -168,7 +167,7 @@ impl Service { )); } - if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().unwrap().get(&*prev_id) { + if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().await.get(&*prev_id) { // Exponential backoff let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { @@ -183,7 +182,7 @@ impl Service { if errors >= 5 { // Timeout other events - match services().globals.bad_event_ratelimiter.write().unwrap().entry((*prev_id).to_owned()) { + match services().globals.bad_event_ratelimiter.write().await.entry((*prev_id).to_owned()) { hash_map::Entry::Vacant(e) => { e.insert((Instant::now(), 1)); }, @@ -205,7 +204,7 @@ impl Service { .globals .roomid_federationhandletime .write() - .unwrap() + .await .insert(room_id.to_owned(), ((*prev_id).to_owned(), start_time)); if let Err(e) = @@ -213,7 +212,7 @@ impl Service { { errors += 1; warn!("Prev event {} failed: {}", prev_id, e); - match services().globals.bad_event_ratelimiter.write().unwrap().entry((*prev_id).to_owned()) { + match services().globals.bad_event_ratelimiter.write().await.entry((*prev_id).to_owned()) { hash_map::Entry::Vacant(e) => { e.insert((Instant::now(), 1)); }, @@ -223,7 +222,7 @@ impl Service { } } let elapsed = start_time.elapsed(); - services().globals.roomid_federationhandletime.write().unwrap().remove(&room_id.to_owned()); + services().globals.roomid_federationhandletime.write().await.remove(&room_id.to_owned()); debug!( "Handling prev event {} took {}m{}s", prev_id, @@ -240,14 +239,14 @@ impl Service { .globals .roomid_federationhandletime .write() - .unwrap() + .await .insert(room_id.to_owned(), (event_id.to_owned(), start_time)); let r = services() .rooms .event_handler .upgrade_outlier_to_timeline_pdu(incoming_pdu, val, &create_event, origin, room_id, pub_key_map) .await; - services().globals.roomid_federationhandletime.write().unwrap().remove(&room_id.to_owned()); + services().globals.roomid_federationhandletime.write().await.remove(&room_id.to_owned()); r } @@ -275,11 +274,8 @@ impl Service { let room_version_id = &create_event_content.room_version; let room_version = RoomVersion::new(room_version_id).expect("room version is supported"); - let mut val = match ruma::signatures::verify_event( - &pub_key_map.read().expect("RwLock is poisoned."), - &value, - room_version_id, - ) { + let guard = pub_key_map.read().await; + let mut val = match ruma::signatures::verify_event(&guard, &value, room_version_id) { Err(e) => { // Drop warn!("Dropping bad event {}: {}", event_id, e,); @@ -306,6 +302,8 @@ impl Service { Ok(ruma::signatures::Verified::All) => value, }; + drop(guard); + // Now that we have checked the signature and hashes we can add the eventID and // convert to our PduEvent type val.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned())); @@ -580,10 +578,13 @@ impl Service { { Ok(res) => { debug!("Fetching state events at event."); + + let collect = res.pdu_ids.iter().map(|x| Arc::from(&**x)).collect::>(); + let state_vec = self .fetch_and_handle_outliers( origin, - &res.pdu_ids.iter().map(|x| Arc::from(&**x)).collect::>(), + &collect, create_event, room_id, room_version_id, @@ -680,7 +681,7 @@ impl Service { // We start looking at current room state now, so lets lock the room let mutex_state = - Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.to_owned()).or_default()); + Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.to_owned()).or_default()); let state_lock = mutex_state.lock().await; // Now we calculate the set of extremities this room has after the incoming @@ -876,11 +877,13 @@ impl Service { room_version_id: &'a RoomVersionId, pub_key_map: &'a RwLock>>, ) -> AsyncRecursiveCanonicalJsonVec<'a> { Box::pin(async move { - let back_off = |id| match services().globals.bad_event_ratelimiter.write().unwrap().entry(id) { - hash_map::Entry::Vacant(e) => { - e.insert((Instant::now(), 1)); - }, - hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), + let back_off = |id| async { + match services().globals.bad_event_ratelimiter.write().await.entry(id) { + hash_map::Entry::Vacant(e) => { + e.insert((Instant::now(), 1)); + }, + hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), + } }; let mut events_with_auth_events = vec![]; @@ -902,8 +905,7 @@ impl Service { let mut events_all = HashSet::new(); let mut i = 0; while let Some(next_id) = todo_auth_events.pop() { - if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().unwrap().get(&*next_id) - { + if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().await.get(&*next_id) { // Exponential backoff let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { @@ -1009,9 +1011,7 @@ impl Service { pdus.push((local_pdu, None)); } for (next_id, value) in events_in_reverse_order.iter().rev() { - if let Some((time, tries)) = - services().globals.bad_event_ratelimiter.read().unwrap().get(&**next_id) - { + if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().await.get(&**next_id) { // Exponential backoff let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { @@ -1205,10 +1205,7 @@ impl Service { while let Some(fetch_res) = server_keys.next().await { match fetch_res { Ok((signature_server, keys)) => { - pub_key_map - .write() - .map_err(|_| Error::bad_database("RwLock is poisoned."))? - .insert(signature_server.clone(), keys); + pub_key_map.write().await.insert(signature_server.clone(), keys); }, Err((signature_server, e)) => { warn!("Failed to fetch keys for {}: {:?}", signature_server, e); @@ -1222,7 +1219,7 @@ impl Service { // Gets a list of servers for which we don't have the signing key yet. We go // over the PDUs and either cache the key or add it to the list that needs to be // retrieved. - fn get_server_keys_from_cache( + async fn get_server_keys_from_cache( &self, pdu: &RawJsonValue, servers: &mut BTreeMap>, room_version: &RoomVersionId, @@ -1239,7 +1236,7 @@ impl Service { ); let event_id = <&EventId>::try_from(event_id.as_str()).expect("ruma's reference hashes are valid event ids"); - if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().unwrap().get(event_id) { + if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().await.get(event_id) { // Exponential backoff let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { @@ -1310,7 +1307,7 @@ impl Service { .await { debug!("Got signing keys: {:?}", keys); - let mut pkm = pub_key_map.write().map_err(|_| Error::bad_database("RwLock is poisoned."))?; + let mut pkm = pub_key_map.write().await; for k in keys.server_keys { let k = match k.deserialize() { Ok(key) => key, @@ -1365,10 +1362,7 @@ impl Service { .into_iter() .map(|(k, v)| (k.to_string(), v.key)) .collect(); - pub_key_map - .write() - .map_err(|_| Error::bad_database("RwLock is poisoned."))? - .insert(origin.to_string(), result); + pub_key_map.write().await.insert(origin.to_string(), result); } } debug!("Done handling Future result"); @@ -1384,15 +1378,15 @@ impl Service { let mut servers: BTreeMap> = BTreeMap::new(); { - let mut pkm = pub_key_map.write().map_err(|_| Error::bad_database("RwLock is poisoned."))?; + let mut pkm = pub_key_map.write().await; // Try to fetch keys, failure is okay // Servers we couldn't find in the cache will be added to `servers` for pdu in &event.room_state.state { - let _ = self.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm); + let _ = self.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm).await; } for pdu in &event.room_state.auth_chain { - let _ = self.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm); + let _ = self.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm).await; } drop(pkm); @@ -1491,18 +1485,13 @@ impl Service { ) -> Result> { let contains_all_ids = |keys: &BTreeMap| signature_ids.iter().all(|id| keys.contains_key(id)); - let permit = services() - .globals - .servername_ratelimiter - .read() - .unwrap() - .get(origin) - .map(|s| Arc::clone(s).acquire_owned()); + let permit = + services().globals.servername_ratelimiter.read().await.get(origin).map(|s| Arc::clone(s).acquire_owned()); let permit = match permit { Some(p) => p, None => { - let mut write = services().globals.servername_ratelimiter.write().unwrap(); + let mut write = services().globals.servername_ratelimiter.write().await; let s = Arc::clone(write.entry(origin.to_owned()).or_insert_with(|| Arc::new(Semaphore::new(1)))); s.acquire_owned() @@ -1510,14 +1499,16 @@ impl Service { } .await; - let back_off = |id| match services().globals.bad_signature_ratelimiter.write().unwrap().entry(id) { - hash_map::Entry::Vacant(e) => { - e.insert((Instant::now(), 1)); - }, - hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), + let back_off = |id| async { + match services().globals.bad_signature_ratelimiter.write().await.entry(id) { + hash_map::Entry::Vacant(e) => { + e.insert((Instant::now(), 1)); + }, + hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), + } }; - if let Some((time, tries)) = services().globals.bad_signature_ratelimiter.read().unwrap().get(&signature_ids) { + if let Some((time, tries)) = services().globals.bad_signature_ratelimiter.read().await.get(&signature_ids) { // Exponential backoff let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { @@ -1591,7 +1582,7 @@ impl Service { drop(permit); - back_off(signature_ids); + back_off(signature_ids).await; warn!("Failed to find public key for server: {}", origin); Err(Error::BadServerResponse("Failed to find public key for server")) diff --git a/src/service/rooms/lazy_loading/mod.rs b/src/service/rooms/lazy_loading/mod.rs index b925fc0f..c94eb92b 100644 --- a/src/service/rooms/lazy_loading/mod.rs +++ b/src/service/rooms/lazy_loading/mod.rs @@ -1,21 +1,18 @@ mod data; -use std::{ - collections::{HashMap, HashSet}, - sync::Mutex, -}; +use std::collections::{HashMap, HashSet}; pub use data::Data; use ruma::{DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId}; +use tokio::sync::Mutex; use super::timeline::PduCount; use crate::Result; -type LazyLoadWaitingMutex = Mutex>>; - pub struct Service { pub db: &'static dyn Data, - pub lazy_load_waiting: LazyLoadWaitingMutex, + #[allow(clippy::type_complexity)] + pub lazy_load_waiting: Mutex>>, } impl Service { @@ -27,21 +24,21 @@ impl Service { } #[tracing::instrument(skip(self))] - pub fn lazy_load_mark_sent( + pub async fn lazy_load_mark_sent( &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, lazy_load: HashSet, count: PduCount, ) { self.lazy_load_waiting .lock() - .unwrap() + .await .insert((user_id.to_owned(), device_id.to_owned(), room_id.to_owned(), count), lazy_load); } #[tracing::instrument(skip(self))] - pub fn lazy_load_confirm_delivery( + pub async fn lazy_load_confirm_delivery( &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, since: PduCount, ) -> Result<()> { - if let Some(user_ids) = self.lazy_load_waiting.lock().unwrap().remove(&( + if let Some(user_ids) = self.lazy_load_waiting.lock().await.remove(&( user_id.to_owned(), device_id.to_owned(), room_id.to_owned(), diff --git a/src/service/rooms/spaces/mod.rs b/src/service/rooms/spaces/mod.rs index 2021df07..51187f59 100644 --- a/src/service/rooms/spaces/mod.rs +++ b/src/service/rooms/spaces/mod.rs @@ -1,4 +1,4 @@ -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use lru_cache::LruCache; use ruma::{ @@ -25,6 +25,7 @@ use ruma::{ space::SpaceRoomJoinRule, OwnedRoomId, RoomId, UserId, }; +use tokio::sync::Mutex; use tracing::{debug, error, warn}; use crate::{services, Error, PduEvent, Result}; @@ -70,7 +71,7 @@ impl Service { break; } - if let Some(cached) = self.roomid_spacechunk_cache.lock().unwrap().get_mut(¤t_room.clone()).as_ref() { + if let Some(cached) = self.roomid_spacechunk_cache.lock().await.get_mut(¤t_room.clone()).as_ref() { if let Some(cached) = cached { let allowed = match &cached.join_rule { //CachedJoinRule::Simplified(s) => { @@ -148,7 +149,7 @@ impl Service { .transpose()? .unwrap_or(JoinRule::Invite); - self.roomid_spacechunk_cache.lock().unwrap().insert( + self.roomid_spacechunk_cache.lock().await.insert( current_room.clone(), Some(CachedSpaceChunk { chunk, @@ -223,7 +224,7 @@ impl Service { } } - self.roomid_spacechunk_cache.lock().unwrap().insert( + self.roomid_spacechunk_cache.lock().await.insert( current_room.clone(), Some(CachedSpaceChunk { chunk, @@ -245,7 +246,7 @@ impl Service { } */ } else { - self.roomid_spacechunk_cache.lock().unwrap().insert(current_room.clone(), None); + self.roomid_spacechunk_cache.lock().await.insert(current_room.clone(), None); } } } diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index f7024fa9..3273d9c2 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -74,7 +74,7 @@ impl Service { .await?; }, TimelineEventType::SpaceChild => { - services().rooms.spaces.roomid_spacechunk_cache.lock().unwrap().remove(&pdu.room_id); + services().rooms.spaces.roomid_spacechunk_cache.lock().await.remove(&pdu.room_id); }, _ => continue, } diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 0a913e14..95063371 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -3,7 +3,7 @@ pub(crate) mod data; use std::{ cmp::Ordering, collections::{BTreeMap, HashMap, HashSet}, - sync::{Arc, Mutex, RwLock}, + sync::Arc, }; pub use data::Data; @@ -30,7 +30,7 @@ use ruma::{ }; use serde::Deserialize; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; -use tokio::sync::MutexGuard; +use tokio::sync::{Mutex, MutexGuard, RwLock}; use tracing::{error, info, warn}; use super::state_compressor::CompressedStateEvent; @@ -239,7 +239,7 @@ impl Service { services().rooms.state.set_forward_extremities(&pdu.room_id, leaves, state_lock)?; let mutex_insert = - Arc::clone(services().globals.roomid_mutex_insert.write().unwrap().entry(pdu.room_id.clone()).or_default()); + Arc::clone(services().globals.roomid_mutex_insert.write().await.entry(pdu.room_id.clone()).or_default()); let insert_lock = mutex_insert.lock().await; let count1 = services().globals.next_count()?; @@ -398,7 +398,7 @@ impl Service { }, TimelineEventType::SpaceChild => { if let Some(_state_key) = &pdu.state_key { - services().rooms.spaces.roomid_spacechunk_cache.lock().unwrap().remove(&pdu.room_id); + services().rooms.spaces.roomid_spacechunk_cache.lock().await.remove(&pdu.room_id); } }, TimelineEventType::RoomMember => { @@ -997,7 +997,7 @@ impl Service { // Lock so we cannot backfill the same pdu twice at the same time let mutex = - Arc::clone(services().globals.roomid_mutex_federation.write().unwrap().entry(room_id.clone()).or_default()); + Arc::clone(services().globals.roomid_mutex_federation.write().await.entry(room_id.clone()).or_default()); let mutex_lock = mutex.lock().await; // Skip the PDU if we already have it as a timeline event @@ -1020,7 +1020,7 @@ impl Service { let shortroomid = services().rooms.short.get_shortroomid(&room_id)?.expect("room exists"); let mutex_insert = - Arc::clone(services().globals.roomid_mutex_insert.write().unwrap().entry(room_id.clone()).or_default()); + Arc::clone(services().globals.roomid_mutex_insert.write().await.entry(room_id.clone()).or_default()); let insert_lock = mutex_insert.lock().await; let count = services().globals.next_count()?;