refactor: use async-aware RwLocks and Mutexes where possible

squashed from https://gitlab.com/famedly/conduit/-/merge_requests/595

Signed-off-by: strawberry <strawberry@puppygock.gay>
This commit is contained in:
Matthias Ahouansou 2024-03-05 20:52:16 -05:00 committed by June
parent 46b543eebe
commit 4ec2d3ecb5
20 changed files with 174 additions and 194 deletions

View file

@ -281,17 +281,19 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
let mut failures = BTreeMap::new();
let back_off = |id| match services().globals.bad_query_ratelimiter.write().unwrap().entry(id) {
let back_off = |id| async {
match services().globals.bad_query_ratelimiter.write().await.entry(id) {
hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1));
},
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
}
};
let mut futures: FuturesUnordered<_> = get_over_federation
.into_iter()
.map(|(server, vec)| async move {
if let Some((time, tries)) = services().globals.bad_query_ratelimiter.read().unwrap().get(server) {
if let Some((time, tries)) = services().globals.bad_query_ratelimiter.read().await.get(server) {
// Exponential backoff
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
@ -354,7 +356,7 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
device_keys.extend(response.device_keys);
},
_ => {
back_off(server.to_owned());
back_off(server.to_owned()).await;
failures.insert(server.to_string(), json!({}));
},
}

View file

@ -440,8 +440,7 @@ async fn get_url_preview(url: &str) -> Result<UrlPreviewData> {
}
// ensure that only one request is made per URL
let mutex_request =
Arc::clone(services().media.url_preview_mutex.write().unwrap().entry(url.to_owned()).or_default());
let mutex_request = Arc::clone(services().media.url_preview_mutex.write().await.entry(url.to_owned()).or_default());
let _request_lock = mutex_request.lock().await;
match services().media.get_url_preview(url).await {

View file

@ -1,6 +1,6 @@
use std::{
collections::{hash_map::Entry, BTreeMap, HashMap, HashSet},
sync::{Arc, RwLock},
sync::Arc,
time::{Duration, Instant},
};
@ -29,6 +29,7 @@ use ruma::{
OwnedUserId, RoomId, RoomVersionId, UserId,
};
use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
use tokio::sync::RwLock;
use tracing::{debug, error, info, warn};
use super::get_alias_helper;
@ -242,7 +243,7 @@ pub async fn kick_user_route(body: Ruma<kick_user::v3::Request>) -> Result<kick_
event.reason = body.reason.clone();
let mutex_state =
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(body.room_id.clone()).or_default());
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(body.room_id.clone()).or_default());
let state_lock = mutex_state.lock().await;
services()
@ -303,7 +304,7 @@ pub async fn ban_user_route(body: Ruma<ban_user::v3::Request>) -> Result<ban_use
)?;
let mutex_state =
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(body.room_id.clone()).or_default());
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(body.room_id.clone()).or_default());
let state_lock = mutex_state.lock().await;
services()
@ -349,7 +350,7 @@ pub async fn unban_user_route(body: Ruma<unban_user::v3::Request>) -> Result<unb
event.reason = body.reason.clone();
let mutex_state =
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(body.room_id.clone()).or_default());
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(body.room_id.clone()).or_default());
let state_lock = mutex_state.lock().await;
services()
@ -480,7 +481,7 @@ async fn join_room_by_id_helper(
let sender_user = sender_user.expect("user is authenticated");
let mutex_state =
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.to_owned()).or_default());
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.to_owned()).or_default());
let state_lock = mutex_state.lock().await;
// Ask a remote server if we are not participating in this room
@ -680,7 +681,7 @@ async fn join_room_by_id_helper(
.iter()
.map(|pdu| validate_and_add_event_id(pdu, &room_version_id, &pub_key_map))
{
let (event_id, value) = match result {
let (event_id, value) = match result.await {
Ok(t) => t,
Err(_) => continue,
};
@ -705,7 +706,7 @@ async fn join_room_by_id_helper(
.iter()
.map(|pdu| validate_and_add_event_id(pdu, &room_version_id, &pub_key_map))
{
let (event_id, value) = match result {
let (event_id, value) = match result.await {
Ok(t) => t,
Err(_) => continue,
};
@ -1048,7 +1049,7 @@ async fn make_join_request(
make_join_response_and_server
}
fn validate_and_add_event_id(
async fn validate_and_add_event_id(
pdu: &RawJsonValue, room_version: &RoomVersionId, pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
) -> Result<(OwnedEventId, CanonicalJsonObject)> {
let mut value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| {
@ -1061,14 +1062,16 @@ fn validate_and_add_event_id(
))
.expect("ruma's reference hashes are valid event ids");
let back_off = |id| match services().globals.bad_event_ratelimiter.write().unwrap().entry(id) {
let back_off = |id| async {
match services().globals.bad_event_ratelimiter.write().await.entry(id) {
Entry::Vacant(e) => {
e.insert((Instant::now(), 1));
},
Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
}
};
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().unwrap().get(&event_id) {
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().await.get(&event_id) {
// Exponential backoff
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
@ -1081,13 +1084,9 @@ fn validate_and_add_event_id(
}
}
if let Err(e) = ruma::signatures::verify_event(
&*pub_key_map.read().map_err(|_| Error::bad_database("RwLock is poisoned."))?,
&value,
room_version,
) {
if let Err(e) = ruma::signatures::verify_event(&*pub_key_map.read().await, &value, room_version) {
warn!("Event {} failed verification {:?} {}", event_id, pdu, e);
back_off(event_id);
back_off(event_id).await;
return Err(Error::BadServerResponse("Event failed verification."));
}
@ -1109,9 +1108,8 @@ pub(crate) async fn invite_helper(
if user_id.server_name() != services().globals.server_name() {
let (pdu, pdu_json, invite_room_state) = {
let mutex_state = Arc::clone(
services().globals.roomid_mutex_state.write().unwrap().entry(room_id.to_owned()).or_default(),
);
let mutex_state =
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.to_owned()).or_default());
let state_lock = mutex_state.lock().await;
let content = to_raw_value(&RoomMemberEventContent {
@ -1229,7 +1227,7 @@ pub(crate) async fn invite_helper(
}
let mutex_state =
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.to_owned()).or_default());
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.to_owned()).or_default());
let state_lock = mutex_state.lock().await;
services()
@ -1314,7 +1312,7 @@ pub async fn leave_room(user_id: &UserId, room_id: &RoomId, reason: Option<Strin
.await?;
} else {
let mutex_state =
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.to_owned()).or_default());
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.to_owned()).or_default());
let state_lock = mutex_state.lock().await;
let member_event =

View file

@ -33,7 +33,7 @@ pub async fn send_message_event_route(
let sender_device = body.sender_device.as_deref();
let mutex_state =
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(body.room_id.clone()).or_default());
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(body.room_id.clone()).or_default());
let state_lock = mutex_state.lock().await;
// Forbid m.room.encrypted if encryption is disabled
@ -160,7 +160,7 @@ pub async fn get_message_events_route(
let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok());
services().rooms.lazy_loading.lazy_load_confirm_delivery(sender_user, sender_device, &body.room_id, from)?;
services().rooms.lazy_loading.lazy_load_confirm_delivery(sender_user, sender_device, &body.room_id, from).await?;
let limit = u64::from(body.limit).min(100) as usize;
@ -276,7 +276,7 @@ pub async fn get_message_events_route(
&body.room_id,
lazy_loaded,
next_token,
);
).await;
}
*/

View file

@ -65,7 +65,7 @@ pub async fn set_displayname_route(
for (pdu_builder, room_id) in all_rooms_joined {
let mutex_state =
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.clone()).or_default());
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.clone()).or_default());
let state_lock = mutex_state.lock().await;
let _ = services().rooms.timeline.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock).await;
@ -176,7 +176,7 @@ pub async fn set_avatar_url_route(body: Ruma<set_avatar_url::v3::Request>) -> Re
for (pdu_builder, room_id) in all_joined_rooms {
let mutex_state =
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.clone()).or_default());
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.clone()).or_default());
let state_lock = mutex_state.lock().await;
let _ = services().rooms.timeline.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock).await;

View file

@ -18,7 +18,7 @@ pub async fn redact_event_route(body: Ruma<redact_event::v3::Request>) -> Result
let body = body.body;
let mutex_state =
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(body.room_id.clone()).or_default());
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(body.room_id.clone()).or_default());
let state_lock = mutex_state.lock().await;
let event_id = services()

View file

@ -111,7 +111,7 @@ pub async fn create_room_route(body: Ruma<create_room::v3::Request>) -> Result<c
services().rooms.short.get_or_create_shortroomid(&room_id)?;
let mutex_state =
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.clone()).or_default());
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.clone()).or_default());
let state_lock = mutex_state.lock().await;
let alias: Option<OwnedRoomAliasId> = body.room_alias_name.as_ref().map_or(Ok(None), |localpart| {
@ -610,7 +610,7 @@ pub async fn upgrade_room_route(body: Ruma<upgrade_room::v3::Request>) -> Result
services().rooms.short.get_or_create_shortroomid(&replacement_room)?;
let mutex_state =
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(body.room_id.clone()).or_default());
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(body.room_id.clone()).or_default());
let state_lock = mutex_state.lock().await;
// Send a m.room.tombstone event to the old room to indicate that it is not
@ -640,7 +640,7 @@ pub async fn upgrade_room_route(body: Ruma<upgrade_room::v3::Request>) -> Result
// Change lock to replacement room
drop(state_lock);
let mutex_state =
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(replacement_room.clone()).or_default());
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(replacement_room.clone()).or_default());
let state_lock = mutex_state.lock().await;
// Get the old room creation event

View file

@ -230,7 +230,7 @@ async fn send_state_event_for_key_helper(
}
let mutex_state =
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.to_owned()).or_default());
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.to_owned()).or_default());
let state_lock = mutex_state.lock().await;
let event_id = services()

View file

@ -82,7 +82,7 @@ pub async fn sync_events_route(
let body = body.body;
let mut rx =
match services().globals.sync_receivers.write().unwrap().entry((sender_user.clone(), sender_device.clone())) {
match services().globals.sync_receivers.write().await.entry((sender_user.clone(), sender_device.clone())) {
Entry::Vacant(v) => {
let (tx, rx) = tokio::sync::watch::channel(None);
@ -132,7 +132,7 @@ async fn sync_helper_wrapper(
if let Ok((_, caching_allowed)) = r {
if !caching_allowed {
match services().globals.sync_receivers.write().unwrap().entry((sender_user, sender_device)) {
match services().globals.sync_receivers.write().await.entry((sender_user, sender_device)) {
Entry::Occupied(o) => {
// Only remove if the device didn't start a different /sync already
if o.get().0 == since {
@ -233,7 +233,7 @@ async fn sync_helper(
{
// Get and drop the lock to wait for remaining operations to finish
let mutex_insert =
Arc::clone(services().globals.roomid_mutex_insert.write().unwrap().entry(room_id.clone()).or_default());
Arc::clone(services().globals.roomid_mutex_insert.write().await.entry(room_id.clone()).or_default());
let insert_lock = mutex_insert.lock().await;
drop(insert_lock);
};
@ -339,7 +339,7 @@ async fn sync_helper(
{
// Get and drop the lock to wait for remaining operations to finish
let mutex_insert =
Arc::clone(services().globals.roomid_mutex_insert.write().unwrap().entry(room_id.clone()).or_default());
Arc::clone(services().globals.roomid_mutex_insert.write().await.entry(room_id.clone()).or_default());
let insert_lock = mutex_insert.lock().await;
drop(insert_lock);
};
@ -485,7 +485,7 @@ async fn load_joined_room(
// Get and drop the lock to wait for remaining operations to finish
// This will make sure the we have all events until next_batch
let mutex_insert =
Arc::clone(services().globals.roomid_mutex_insert.write().unwrap().entry(room_id.to_owned()).or_default());
Arc::clone(services().globals.roomid_mutex_insert.write().await.entry(room_id.to_owned()).or_default());
let insert_lock = mutex_insert.lock().await;
drop(insert_lock);
};
@ -500,7 +500,7 @@ async fn load_joined_room(
timeline_users.insert(event.sender.as_str().to_owned());
}
services().rooms.lazy_loading.lazy_load_confirm_delivery(sender_user, sender_device, room_id, sincecount)?;
services().rooms.lazy_loading.lazy_load_confirm_delivery(sender_user, sender_device, room_id, sincecount).await?;
// Database queries:
@ -653,13 +653,11 @@ async fn load_joined_room(
// The state_events above should contain all timeline_users, let's mark them as
// lazy loaded.
services().rooms.lazy_loading.lazy_load_mark_sent(
sender_user,
sender_device,
room_id,
lazy_loaded,
next_batchcount,
);
services()
.rooms
.lazy_loading
.lazy_load_mark_sent(sender_user, sender_device, room_id, lazy_loaded, next_batchcount)
.await;
(heroes, joined_member_count, invited_member_count, true, state_events)
} else {
@ -721,13 +719,11 @@ async fn load_joined_room(
}
}
services().rooms.lazy_loading.lazy_load_mark_sent(
sender_user,
sender_device,
room_id,
lazy_loaded,
next_batchcount,
);
services()
.rooms
.lazy_loading
.lazy_load_mark_sent(sender_user, sender_device, room_id, lazy_loaded, next_batchcount)
.await;
let encrypted_room = services()
.rooms

View file

@ -6,7 +6,7 @@ use std::{
fmt::Debug,
mem,
net::{IpAddr, SocketAddr},
sync::{Arc, RwLock},
sync::Arc,
time::{Duration, Instant, SystemTime},
};
@ -50,6 +50,7 @@ use ruma::{
OwnedRoomId, OwnedServerName, OwnedServerSigningKeyId, OwnedUserId, RoomId, ServerName,
};
use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
use tokio::sync::RwLock;
use tracing::{debug, error, info, warn};
use trust_dns_resolver::{error::ResolveError, lookup::SrvLookup};
@ -157,7 +158,7 @@ where
let mut write_destination_to_cache = false;
let cached_result = services().globals.actual_destination_cache.read().unwrap().get(destination).cloned();
let cached_result = services().globals.actual_destination_cache.read().await.get(destination).cloned();
let (actual_destination, host) = if let Some(result) = cached_result {
result
@ -276,7 +277,7 @@ where
.globals
.actual_destination_cache
.write()
.unwrap()
.await
.insert(OwnedServerName::from(destination), (actual_destination, host));
}
@ -291,7 +292,7 @@ where
// well-knowns
if !write_destination_to_cache {
info!("Evicting {destination} from our true destination cache due to failed request.");
services().globals.actual_destination_cache.write().unwrap().remove(destination);
services().globals.actual_destination_cache.write().await.remove(destination);
}
Err(Error::FederationError(
@ -767,7 +768,7 @@ pub async fn send_transaction_message_route(
for (event_id, value, room_id) in parsed_pdus {
let mutex =
Arc::clone(services().globals.roomid_mutex_federation.write().unwrap().entry(room_id.clone()).or_default());
Arc::clone(services().globals.roomid_mutex_federation.write().await.entry(room_id.clone()).or_default());
let mutex_lock = mutex.lock().await;
let start_time = Instant::now();
resolved_map.insert(
@ -1264,7 +1265,7 @@ pub async fn create_join_event_template_route(
services().rooms.event_handler.acl_check(sender_servername, &body.room_id)?;
let mutex_state =
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(body.room_id.clone()).or_default());
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(body.room_id.clone()).or_default());
let state_lock = mutex_state.lock().await;
// TODO: Conduit does not implement restricted join rules yet, we always reject
@ -1413,7 +1414,7 @@ async fn create_join_event(
services().rooms.event_handler.fetch_required_signing_keys([&value], &pub_key_map).await?;
let mutex =
Arc::clone(services().globals.roomid_mutex_federation.write().unwrap().entry(room_id.to_owned()).or_default());
Arc::clone(services().globals.roomid_mutex_federation.write().await.entry(room_id.to_owned()).or_default());
let mutex_lock = mutex.lock().await;
let pdu_id: Vec<u8> = services()
.rooms

View file

@ -4,6 +4,10 @@ mod database;
mod service;
mod utils;
// Not async due to services() being used in many closures, and async closures
// are not stable as of writing This is the case for every other occurence of
// sync Mutex/RwLock, except for database related ones, where the current
// maintainer (Timo) has asked to not modify those
use std::sync::RwLock;
pub use api::ruma_wrapper::{Ruma, RumaResponse};

View file

@ -1,9 +1,4 @@
use std::{
collections::BTreeMap,
fmt::Write as _,
sync::{Arc, RwLock},
time::Instant,
};
use std::{collections::BTreeMap, fmt::Write as _, sync::Arc, time::Instant};
use clap::{Parser, Subcommand};
use regex::Regex;
@ -29,7 +24,7 @@ use ruma::{
ServerName, UserId,
};
use serde_json::value::to_raw_value;
use tokio::sync::{mpsc, Mutex};
use tokio::sync::{mpsc, Mutex, RwLock};
use tracing::{debug, error, info, warn};
use super::pdu::PduBuilder;
@ -473,7 +468,7 @@ impl Service {
services().globals
.roomid_mutex_state
.write()
.unwrap()
.await
.entry(conduit_room.clone())
.or_default(),
);
@ -1773,7 +1768,7 @@ impl Service {
RoomMessageEventContent::text_plain("Room enabled.")
},
FederationCommand::IncomingFederation => {
let map = services().globals.roomid_federationhandletime.read().unwrap();
let map = services().globals.roomid_federationhandletime.read().await;
let mut msg: String = format!("Handling {} incoming pdus:\n", map.len());
for (r, (e, i)) in map.iter() {
@ -1818,7 +1813,7 @@ impl Service {
.fetch_required_signing_keys([&value], &pub_key_map)
.await?;
let pub_key_map = pub_key_map.read().unwrap();
let pub_key_map = pub_key_map.read().await;
match ruma::signatures::verify_json(&pub_key_map, &value) {
Ok(_) => RoomMessageEventContent::text_plain("Signature correct"),
Err(e) => RoomMessageEventContent::text_plain(format!(
@ -1841,7 +1836,7 @@ impl Service {
RoomMessageEventContent::text_plain(format!("{}", services().globals.config))
},
ServerCommand::MemoryUsage => {
let response1 = services().memory_usage();
let response1 = services().memory_usage().await;
let response2 = services().globals.db.memory_usage();
RoomMessageEventContent::text_plain(format!("Services:\n{response1}\n\nDatabase:\n{response2}"))
@ -1856,7 +1851,7 @@ impl Service {
ServerCommand::ClearServiceCaches {
amount,
} => {
services().clear_caches(amount);
services().clear_caches(amount).await;
RoomMessageEventContent::text_plain("Done.")
},
@ -2047,7 +2042,7 @@ impl Service {
services().rooms.short.get_or_create_shortroomid(&room_id)?;
let mutex_state =
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.clone()).or_default());
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.clone()).or_default());
let state_lock = mutex_state.lock().await;
// Create a user for the server
@ -2291,7 +2286,7 @@ impl Service {
let room_id = services().rooms.alias.resolve_local_alias(&admin_room_alias)?.expect("Admin room must exist");
let mutex_state =
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.clone()).or_default());
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.clone()).or_default());
let state_lock = mutex_state.lock().await;
// Use the server user to grant the new admin's power level

View file

@ -8,7 +8,7 @@ use std::{
path::PathBuf,
sync::{
atomic::{self, AtomicBool},
Arc, Mutex, RwLock,
Arc, RwLock as StdRwLock,
},
time::{Duration, Instant},
};
@ -33,7 +33,7 @@ use ruma::{
RoomVersionId, ServerName, UserId,
};
use sha2::Digest;
use tokio::sync::{broadcast, watch::Receiver, Mutex as TokioMutex, Semaphore};
use tokio::sync::{broadcast, watch::Receiver, Mutex, RwLock, Semaphore};
use tracing::{error, info};
use trust_dns_resolver::TokioAsyncResolver;
@ -53,7 +53,7 @@ pub struct Service<'a> {
pub db: &'static dyn Data,
pub actual_destination_cache: Arc<RwLock<WellKnownMap>>, // actual_destination, host
pub tls_name_override: Arc<RwLock<TlsNameMap>>,
pub tls_name_override: Arc<StdRwLock<TlsNameMap>>,
pub config: Config,
keypair: Arc<ruma::signatures::Ed25519KeyPair>,
dns_resolver: TokioAsyncResolver,
@ -68,9 +68,9 @@ pub struct Service<'a> {
pub bad_query_ratelimiter: Arc<RwLock<HashMap<OwnedServerName, RateLimitState>>>,
pub servername_ratelimiter: Arc<RwLock<HashMap<OwnedServerName, Arc<Semaphore>>>>,
pub sync_receivers: RwLock<HashMap<(OwnedUserId, OwnedDeviceId), SyncHandle>>,
pub roomid_mutex_insert: RwLock<HashMap<OwnedRoomId, Arc<TokioMutex<()>>>>,
pub roomid_mutex_state: RwLock<HashMap<OwnedRoomId, Arc<TokioMutex<()>>>>,
pub roomid_mutex_federation: RwLock<HashMap<OwnedRoomId, Arc<TokioMutex<()>>>>, // this lock will be held longer
pub roomid_mutex_insert: RwLock<HashMap<OwnedRoomId, Arc<Mutex<()>>>>,
pub roomid_mutex_state: RwLock<HashMap<OwnedRoomId, Arc<Mutex<()>>>>,
pub roomid_mutex_federation: RwLock<HashMap<OwnedRoomId, Arc<Mutex<()>>>>, // this lock will be held longer
pub roomid_federationhandletime: RwLock<HashMap<OwnedRoomId, (OwnedEventId, Instant)>>,
pub stateres_mutex: Arc<Mutex<()>>,
pub(crate) rotate: RotationHandler,
@ -109,11 +109,11 @@ impl Default for RotationHandler {
struct Resolver {
inner: GaiResolver,
overrides: Arc<RwLock<TlsNameMap>>,
overrides: Arc<StdRwLock<TlsNameMap>>,
}
impl Resolver {
fn new(overrides: Arc<RwLock<TlsNameMap>>) -> Self {
fn new(overrides: Arc<StdRwLock<TlsNameMap>>) -> Self {
Resolver {
inner: GaiResolver::new(),
overrides,
@ -125,7 +125,7 @@ impl Resolve for Resolver {
fn resolve(&self, name: Name) -> Resolving {
self.overrides
.read()
.expect("lock should not be poisoned")
.unwrap()
.get(name.as_str())
.and_then(|(override_name, port)| {
override_name.first().map(|first_name| {
@ -159,7 +159,7 @@ impl Service<'_> {
},
};
let tls_name_override = Arc::new(RwLock::new(TlsNameMap::new()));
let tls_name_override = Arc::new(StdRwLock::new(TlsNameMap::new()));
let jwt_decoding_key =
config.jwt_secret.as_ref().map(|secret| jsonwebtoken::DecodingKey::from_secret(secret.as_bytes()));

View file

@ -1,10 +1,5 @@
mod data;
use std::{
collections::HashMap,
io::Cursor,
sync::{Arc, RwLock},
time::SystemTime,
};
use std::{collections::HashMap, io::Cursor, sync::Arc, time::SystemTime};
pub(crate) use data::Data;
use image::imageops::FilterType;
@ -13,7 +8,7 @@ use serde::Serialize;
use tokio::{
fs::{self, File},
io::{AsyncReadExt, AsyncWriteExt, BufReader},
sync::Mutex,
sync::{Mutex, RwLock},
};
use tracing::{debug, error};

View file

@ -1,9 +1,10 @@
use std::{
collections::{BTreeMap, HashMap},
sync::{Arc, Mutex, RwLock},
sync::{Arc, Mutex as StdMutex},
};
use lru_cache::LruCache;
use tokio::sync::{Mutex, RwLock};
use crate::{Config, Result};
@ -106,10 +107,10 @@ impl Services<'_> {
},
state_accessor: rooms::state_accessor::Service {
db,
server_visibility_cache: Mutex::new(LruCache::new(
server_visibility_cache: StdMutex::new(LruCache::new(
(100.0 * config.conduit_cache_capacity_modifier) as usize,
)),
user_visibility_cache: Mutex::new(LruCache::new(
user_visibility_cache: StdMutex::new(LruCache::new(
(100.0 * config.conduit_cache_capacity_modifier) as usize,
)),
},
@ -118,7 +119,7 @@ impl Services<'_> {
},
state_compressor: rooms::state_compressor::Service {
db,
stateinfo_cache: Mutex::new(LruCache::new(
stateinfo_cache: StdMutex::new(LruCache::new(
(100.0 * config.conduit_cache_capacity_modifier) as usize,
)),
},
@ -146,7 +147,7 @@ impl Services<'_> {
},
users: users::Service {
db,
connections: Mutex::new(BTreeMap::new()),
connections: StdMutex::new(BTreeMap::new()),
},
account_data: account_data::Service {
db,
@ -165,13 +166,13 @@ impl Services<'_> {
})
}
fn memory_usage(&self) -> String {
let lazy_load_waiting = self.rooms.lazy_loading.lazy_load_waiting.lock().unwrap().len();
async fn memory_usage(&self) -> String {
let lazy_load_waiting = self.rooms.lazy_loading.lazy_load_waiting.lock().await.len();
let server_visibility_cache = self.rooms.state_accessor.server_visibility_cache.lock().unwrap().len();
let user_visibility_cache = self.rooms.state_accessor.user_visibility_cache.lock().unwrap().len();
let stateinfo_cache = self.rooms.state_compressor.stateinfo_cache.lock().unwrap().len();
let lasttimelinecount_cache = self.rooms.timeline.lasttimelinecount_cache.lock().unwrap().len();
let roomid_spacechunk_cache = self.rooms.spaces.roomid_spacechunk_cache.lock().unwrap().len();
let lasttimelinecount_cache = self.rooms.timeline.lasttimelinecount_cache.lock().await.len();
let roomid_spacechunk_cache = self.rooms.spaces.roomid_spacechunk_cache.lock().await.len();
format!(
"\
@ -184,9 +185,9 @@ roomid_spacechunk_cache: {roomid_spacechunk_cache}"
)
}
fn clear_caches(&self, amount: u32) {
async fn clear_caches(&self, amount: u32) {
if amount > 0 {
self.rooms.lazy_loading.lazy_load_waiting.lock().unwrap().clear();
self.rooms.lazy_loading.lazy_load_waiting.lock().await.clear();
}
if amount > 1 {
self.rooms.state_accessor.server_visibility_cache.lock().unwrap().clear();
@ -198,10 +199,10 @@ roomid_spacechunk_cache: {roomid_spacechunk_cache}"
self.rooms.state_compressor.stateinfo_cache.lock().unwrap().clear();
}
if amount > 4 {
self.rooms.timeline.lasttimelinecount_cache.lock().unwrap().clear();
self.rooms.timeline.lasttimelinecount_cache.lock().await.clear();
}
if amount > 5 {
self.rooms.spaces.roomid_spacechunk_cache.lock().unwrap().clear();
self.rooms.spaces.roomid_spacechunk_cache.lock().await.clear();
}
}
}

View file

@ -4,7 +4,6 @@ type AsyncRecursiveType<'a, T> = Pin<Box<dyn Future<Output = T> + 'a + Send>>;
use std::{
collections::{hash_map, HashSet},
pin::Pin,
sync::RwLockWriteGuard,
time::{Duration, Instant, SystemTime},
};
@ -33,12 +32,12 @@ use ruma::{
OwnedServerSigningKeyId, RoomId, RoomVersionId, ServerName,
};
use serde_json::value::RawValue as RawJsonValue;
use tokio::sync::Semaphore;
use tokio::sync::{RwLock, RwLockWriteGuard, Semaphore};
use tracing::{debug, error, info, trace, warn};
use super::state_compressor::CompressedStateEvent;
use crate::{
service::{pdu, Arc, BTreeMap, HashMap, Result, RwLock},
service::{pdu, Arc, BTreeMap, HashMap, Result},
services, Error, PduEvent,
};
@ -168,7 +167,7 @@ impl Service {
));
}
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().unwrap().get(&*prev_id) {
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().await.get(&*prev_id) {
// Exponential backoff
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
@ -183,7 +182,7 @@ impl Service {
if errors >= 5 {
// Timeout other events
match services().globals.bad_event_ratelimiter.write().unwrap().entry((*prev_id).to_owned()) {
match services().globals.bad_event_ratelimiter.write().await.entry((*prev_id).to_owned()) {
hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1));
},
@ -205,7 +204,7 @@ impl Service {
.globals
.roomid_federationhandletime
.write()
.unwrap()
.await
.insert(room_id.to_owned(), ((*prev_id).to_owned(), start_time));
if let Err(e) =
@ -213,7 +212,7 @@ impl Service {
{
errors += 1;
warn!("Prev event {} failed: {}", prev_id, e);
match services().globals.bad_event_ratelimiter.write().unwrap().entry((*prev_id).to_owned()) {
match services().globals.bad_event_ratelimiter.write().await.entry((*prev_id).to_owned()) {
hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1));
},
@ -223,7 +222,7 @@ impl Service {
}
}
let elapsed = start_time.elapsed();
services().globals.roomid_federationhandletime.write().unwrap().remove(&room_id.to_owned());
services().globals.roomid_federationhandletime.write().await.remove(&room_id.to_owned());
debug!(
"Handling prev event {} took {}m{}s",
prev_id,
@ -240,14 +239,14 @@ impl Service {
.globals
.roomid_federationhandletime
.write()
.unwrap()
.await
.insert(room_id.to_owned(), (event_id.to_owned(), start_time));
let r = services()
.rooms
.event_handler
.upgrade_outlier_to_timeline_pdu(incoming_pdu, val, &create_event, origin, room_id, pub_key_map)
.await;
services().globals.roomid_federationhandletime.write().unwrap().remove(&room_id.to_owned());
services().globals.roomid_federationhandletime.write().await.remove(&room_id.to_owned());
r
}
@ -275,11 +274,8 @@ impl Service {
let room_version_id = &create_event_content.room_version;
let room_version = RoomVersion::new(room_version_id).expect("room version is supported");
let mut val = match ruma::signatures::verify_event(
&pub_key_map.read().expect("RwLock is poisoned."),
&value,
room_version_id,
) {
let guard = pub_key_map.read().await;
let mut val = match ruma::signatures::verify_event(&guard, &value, room_version_id) {
Err(e) => {
// Drop
warn!("Dropping bad event {}: {}", event_id, e,);
@ -306,6 +302,8 @@ impl Service {
Ok(ruma::signatures::Verified::All) => value,
};
drop(guard);
// Now that we have checked the signature and hashes we can add the eventID and
// convert to our PduEvent type
val.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned()));
@ -580,10 +578,13 @@ impl Service {
{
Ok(res) => {
debug!("Fetching state events at event.");
let collect = res.pdu_ids.iter().map(|x| Arc::from(&**x)).collect::<Vec<_>>();
let state_vec = self
.fetch_and_handle_outliers(
origin,
&res.pdu_ids.iter().map(|x| Arc::from(&**x)).collect::<Vec<_>>(),
&collect,
create_event,
room_id,
room_version_id,
@ -680,7 +681,7 @@ impl Service {
// We start looking at current room state now, so lets lock the room
let mutex_state =
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.to_owned()).or_default());
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.to_owned()).or_default());
let state_lock = mutex_state.lock().await;
// Now we calculate the set of extremities this room has after the incoming
@ -876,11 +877,13 @@ impl Service {
room_version_id: &'a RoomVersionId, pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
) -> AsyncRecursiveCanonicalJsonVec<'a> {
Box::pin(async move {
let back_off = |id| match services().globals.bad_event_ratelimiter.write().unwrap().entry(id) {
let back_off = |id| async {
match services().globals.bad_event_ratelimiter.write().await.entry(id) {
hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1));
},
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
}
};
let mut events_with_auth_events = vec![];
@ -902,8 +905,7 @@ impl Service {
let mut events_all = HashSet::new();
let mut i = 0;
while let Some(next_id) = todo_auth_events.pop() {
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().unwrap().get(&*next_id)
{
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().await.get(&*next_id) {
// Exponential backoff
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
@ -1009,9 +1011,7 @@ impl Service {
pdus.push((local_pdu, None));
}
for (next_id, value) in events_in_reverse_order.iter().rev() {
if let Some((time, tries)) =
services().globals.bad_event_ratelimiter.read().unwrap().get(&**next_id)
{
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().await.get(&**next_id) {
// Exponential backoff
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
@ -1205,10 +1205,7 @@ impl Service {
while let Some(fetch_res) = server_keys.next().await {
match fetch_res {
Ok((signature_server, keys)) => {
pub_key_map
.write()
.map_err(|_| Error::bad_database("RwLock is poisoned."))?
.insert(signature_server.clone(), keys);
pub_key_map.write().await.insert(signature_server.clone(), keys);
},
Err((signature_server, e)) => {
warn!("Failed to fetch keys for {}: {:?}", signature_server, e);
@ -1222,7 +1219,7 @@ impl Service {
// Gets a list of servers for which we don't have the signing key yet. We go
// over the PDUs and either cache the key or add it to the list that needs to be
// retrieved.
fn get_server_keys_from_cache(
async fn get_server_keys_from_cache(
&self, pdu: &RawJsonValue,
servers: &mut BTreeMap<OwnedServerName, BTreeMap<OwnedServerSigningKeyId, QueryCriteria>>,
room_version: &RoomVersionId,
@ -1239,7 +1236,7 @@ impl Service {
);
let event_id = <&EventId>::try_from(event_id.as_str()).expect("ruma's reference hashes are valid event ids");
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().unwrap().get(event_id) {
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().await.get(event_id) {
// Exponential backoff
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
@ -1310,7 +1307,7 @@ impl Service {
.await
{
debug!("Got signing keys: {:?}", keys);
let mut pkm = pub_key_map.write().map_err(|_| Error::bad_database("RwLock is poisoned."))?;
let mut pkm = pub_key_map.write().await;
for k in keys.server_keys {
let k = match k.deserialize() {
Ok(key) => key,
@ -1365,10 +1362,7 @@ impl Service {
.into_iter()
.map(|(k, v)| (k.to_string(), v.key))
.collect();
pub_key_map
.write()
.map_err(|_| Error::bad_database("RwLock is poisoned."))?
.insert(origin.to_string(), result);
pub_key_map.write().await.insert(origin.to_string(), result);
}
}
debug!("Done handling Future result");
@ -1384,15 +1378,15 @@ impl Service {
let mut servers: BTreeMap<OwnedServerName, BTreeMap<OwnedServerSigningKeyId, QueryCriteria>> = BTreeMap::new();
{
let mut pkm = pub_key_map.write().map_err(|_| Error::bad_database("RwLock is poisoned."))?;
let mut pkm = pub_key_map.write().await;
// Try to fetch keys, failure is okay
// Servers we couldn't find in the cache will be added to `servers`
for pdu in &event.room_state.state {
let _ = self.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm);
let _ = self.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm).await;
}
for pdu in &event.room_state.auth_chain {
let _ = self.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm);
let _ = self.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm).await;
}
drop(pkm);
@ -1491,18 +1485,13 @@ impl Service {
) -> Result<BTreeMap<String, Base64>> {
let contains_all_ids = |keys: &BTreeMap<String, Base64>| signature_ids.iter().all(|id| keys.contains_key(id));
let permit = services()
.globals
.servername_ratelimiter
.read()
.unwrap()
.get(origin)
.map(|s| Arc::clone(s).acquire_owned());
let permit =
services().globals.servername_ratelimiter.read().await.get(origin).map(|s| Arc::clone(s).acquire_owned());
let permit = match permit {
Some(p) => p,
None => {
let mut write = services().globals.servername_ratelimiter.write().unwrap();
let mut write = services().globals.servername_ratelimiter.write().await;
let s = Arc::clone(write.entry(origin.to_owned()).or_insert_with(|| Arc::new(Semaphore::new(1))));
s.acquire_owned()
@ -1510,14 +1499,16 @@ impl Service {
}
.await;
let back_off = |id| match services().globals.bad_signature_ratelimiter.write().unwrap().entry(id) {
let back_off = |id| async {
match services().globals.bad_signature_ratelimiter.write().await.entry(id) {
hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1));
},
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
}
};
if let Some((time, tries)) = services().globals.bad_signature_ratelimiter.read().unwrap().get(&signature_ids) {
if let Some((time, tries)) = services().globals.bad_signature_ratelimiter.read().await.get(&signature_ids) {
// Exponential backoff
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
@ -1591,7 +1582,7 @@ impl Service {
drop(permit);
back_off(signature_ids);
back_off(signature_ids).await;
warn!("Failed to find public key for server: {}", origin);
Err(Error::BadServerResponse("Failed to find public key for server"))

View file

@ -1,21 +1,18 @@
mod data;
use std::{
collections::{HashMap, HashSet},
sync::Mutex,
};
use std::collections::{HashMap, HashSet};
pub use data::Data;
use ruma::{DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId};
use tokio::sync::Mutex;
use super::timeline::PduCount;
use crate::Result;
type LazyLoadWaitingMutex = Mutex<HashMap<(OwnedUserId, OwnedDeviceId, OwnedRoomId, PduCount), HashSet<OwnedUserId>>>;
pub struct Service {
pub db: &'static dyn Data,
pub lazy_load_waiting: LazyLoadWaitingMutex,
#[allow(clippy::type_complexity)]
pub lazy_load_waiting: Mutex<HashMap<(OwnedUserId, OwnedDeviceId, OwnedRoomId, PduCount), HashSet<OwnedUserId>>>,
}
impl Service {
@ -27,21 +24,21 @@ impl Service {
}
#[tracing::instrument(skip(self))]
pub fn lazy_load_mark_sent(
pub async fn lazy_load_mark_sent(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, lazy_load: HashSet<OwnedUserId>,
count: PduCount,
) {
self.lazy_load_waiting
.lock()
.unwrap()
.await
.insert((user_id.to_owned(), device_id.to_owned(), room_id.to_owned(), count), lazy_load);
}
#[tracing::instrument(skip(self))]
pub fn lazy_load_confirm_delivery(
pub async fn lazy_load_confirm_delivery(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, since: PduCount,
) -> Result<()> {
if let Some(user_ids) = self.lazy_load_waiting.lock().unwrap().remove(&(
if let Some(user_ids) = self.lazy_load_waiting.lock().await.remove(&(
user_id.to_owned(),
device_id.to_owned(),
room_id.to_owned(),

View file

@ -1,4 +1,4 @@
use std::sync::{Arc, Mutex};
use std::sync::Arc;
use lru_cache::LruCache;
use ruma::{
@ -25,6 +25,7 @@ use ruma::{
space::SpaceRoomJoinRule,
OwnedRoomId, RoomId, UserId,
};
use tokio::sync::Mutex;
use tracing::{debug, error, warn};
use crate::{services, Error, PduEvent, Result};
@ -70,7 +71,7 @@ impl Service {
break;
}
if let Some(cached) = self.roomid_spacechunk_cache.lock().unwrap().get_mut(&current_room.clone()).as_ref() {
if let Some(cached) = self.roomid_spacechunk_cache.lock().await.get_mut(&current_room.clone()).as_ref() {
if let Some(cached) = cached {
let allowed = match &cached.join_rule {
//CachedJoinRule::Simplified(s) => {
@ -148,7 +149,7 @@ impl Service {
.transpose()?
.unwrap_or(JoinRule::Invite);
self.roomid_spacechunk_cache.lock().unwrap().insert(
self.roomid_spacechunk_cache.lock().await.insert(
current_room.clone(),
Some(CachedSpaceChunk {
chunk,
@ -223,7 +224,7 @@ impl Service {
}
}
self.roomid_spacechunk_cache.lock().unwrap().insert(
self.roomid_spacechunk_cache.lock().await.insert(
current_room.clone(),
Some(CachedSpaceChunk {
chunk,
@ -245,7 +246,7 @@ impl Service {
}
*/
} else {
self.roomid_spacechunk_cache.lock().unwrap().insert(current_room.clone(), None);
self.roomid_spacechunk_cache.lock().await.insert(current_room.clone(), None);
}
}
}

View file

@ -74,7 +74,7 @@ impl Service {
.await?;
},
TimelineEventType::SpaceChild => {
services().rooms.spaces.roomid_spacechunk_cache.lock().unwrap().remove(&pdu.room_id);
services().rooms.spaces.roomid_spacechunk_cache.lock().await.remove(&pdu.room_id);
},
_ => continue,
}

View file

@ -3,7 +3,7 @@ pub(crate) mod data;
use std::{
cmp::Ordering,
collections::{BTreeMap, HashMap, HashSet},
sync::{Arc, Mutex, RwLock},
sync::Arc,
};
pub use data::Data;
@ -30,7 +30,7 @@ use ruma::{
};
use serde::Deserialize;
use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
use tokio::sync::MutexGuard;
use tokio::sync::{Mutex, MutexGuard, RwLock};
use tracing::{error, info, warn};
use super::state_compressor::CompressedStateEvent;
@ -239,7 +239,7 @@ impl Service {
services().rooms.state.set_forward_extremities(&pdu.room_id, leaves, state_lock)?;
let mutex_insert =
Arc::clone(services().globals.roomid_mutex_insert.write().unwrap().entry(pdu.room_id.clone()).or_default());
Arc::clone(services().globals.roomid_mutex_insert.write().await.entry(pdu.room_id.clone()).or_default());
let insert_lock = mutex_insert.lock().await;
let count1 = services().globals.next_count()?;
@ -398,7 +398,7 @@ impl Service {
},
TimelineEventType::SpaceChild => {
if let Some(_state_key) = &pdu.state_key {
services().rooms.spaces.roomid_spacechunk_cache.lock().unwrap().remove(&pdu.room_id);
services().rooms.spaces.roomid_spacechunk_cache.lock().await.remove(&pdu.room_id);
}
},
TimelineEventType::RoomMember => {
@ -997,7 +997,7 @@ impl Service {
// Lock so we cannot backfill the same pdu twice at the same time
let mutex =
Arc::clone(services().globals.roomid_mutex_federation.write().unwrap().entry(room_id.clone()).or_default());
Arc::clone(services().globals.roomid_mutex_federation.write().await.entry(room_id.clone()).or_default());
let mutex_lock = mutex.lock().await;
// Skip the PDU if we already have it as a timeline event
@ -1020,7 +1020,7 @@ impl Service {
let shortroomid = services().rooms.short.get_shortroomid(&room_id)?.expect("room exists");
let mutex_insert =
Arc::clone(services().globals.roomid_mutex_insert.write().unwrap().entry(room_id.clone()).or_default());
Arc::clone(services().globals.roomid_mutex_insert.write().await.entry(room_id.clone()).or_default());
let insert_lock = mutex_insert.lock().await;
let count = services().globals.next_count()?;