From 62f90ed04c51d6e577f4f97e671dae98333b2e0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Thu, 1 Jul 2021 19:55:26 +0200 Subject: [PATCH] improvement: efficient /sync, mutex for federation transactions --- src/appservice_server.rs | 5 +- src/client_server/membership.rs | 44 ++-- src/client_server/sync.rs | 382 +++++++++++++++++--------------- src/database/globals.rs | 11 +- src/database/rooms.rs | 10 +- src/server_server.rs | 198 +++++++++++------ 6 files changed, 374 insertions(+), 276 deletions(-) diff --git a/src/appservice_server.rs b/src/appservice_server.rs index 42918577..109671fd 100644 --- a/src/appservice_server.rs +++ b/src/appservice_server.rs @@ -46,7 +46,10 @@ where *reqwest_request.timeout_mut() = Some(Duration::from_secs(30)); let url = reqwest_request.url().clone(); - let mut response = globals.reqwest_client().execute(reqwest_request).await?; + let mut response = globals + .reqwest_client() + .execute(reqwest_request) + .await?; // reqwest::Response -> http::Response conversion let status = response.status(); diff --git a/src/client_server/membership.rs b/src/client_server/membership.rs index c42d779f..8b131b77 100644 --- a/src/client_server/membership.rs +++ b/src/client_server/membership.rs @@ -898,19 +898,37 @@ pub async fn invite_helper( ) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid."))?; - let pdu_id = - server_server::handle_incoming_pdu(&origin, &event_id, value, true, &db, &pub_key_map) - .await - .map_err(|_| { - Error::BadRequest( - ErrorKind::InvalidParam, - "Error while handling incoming PDU.", - ) - })? - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Could not accept incoming PDU as timeline event.", - ))?; + let mutex = Arc::clone( + db.globals + .roomid_mutex + .write() + .unwrap() + .entry(room_id.clone()) + .or_default(), + ); + let mutex_lock = mutex.lock().await; + + let pdu_id = server_server::handle_incoming_pdu( + &origin, + &event_id, + &room_id, + value, + true, + &db, + &pub_key_map, + ) + .await + .map_err(|_| { + Error::BadRequest( + ErrorKind::InvalidParam, + "Error while handling incoming PDU.", + ) + })? + .ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "Could not accept incoming PDU as timeline event.", + ))?; + drop(mutex_lock); for server in db .rooms diff --git a/src/client_server/sync.rs b/src/client_server/sync.rs index a7babd7e..5e7193b7 100644 --- a/src/client_server/sync.rs +++ b/src/client_server/sync.rs @@ -224,13 +224,16 @@ async fn sync_helper( // Database queries: - let current_shortstatehash = db.rooms.current_shortstatehash(&room_id)?; + let current_shortstatehash = db + .rooms + .current_shortstatehash(&room_id)? + .expect("All rooms have state"); - // These type is Option>. The outer Option is None when there is no event between - // since and the current room state, meaning there should be no updates. - // The inner Option is None when there is an event, but there is no state hash associated - // with it. This can happen for the RoomCreate event, so all updates should arrive. - let first_pdu_before_since = db.rooms.pdus_until(&sender_user, &room_id, since).next(); + let first_pdu_before_since = db + .rooms + .pdus_until(&sender_user, &room_id, since) + .next() + .transpose()?; let pdus_after_since = db .rooms @@ -238,11 +241,78 @@ async fn sync_helper( .next() .is_some(); - let since_shortstatehash = first_pdu_before_since.as_ref().map(|pdu| { - db.rooms - .pdu_shortstatehash(&pdu.as_ref().ok()?.1.event_id) - .ok()? - }); + let since_shortstatehash = first_pdu_before_since + .as_ref() + .map(|pdu| { + db.rooms + .pdu_shortstatehash(&pdu.1.event_id) + .transpose() + .expect("all pdus have state") + }) + .transpose()?; + + // Calculates joined_member_count, invited_member_count and heroes + let calculate_counts = || { + let joined_member_count = db.rooms.room_members(&room_id).count(); + let invited_member_count = db.rooms.room_members_invited(&room_id).count(); + + // Recalculate heroes (first 5 members) + let mut heroes = Vec::new(); + + if joined_member_count + invited_member_count <= 5 { + // Go through all PDUs and for each member event, check if the user is still joined or + // invited until we have 5 or we reach the end + + for hero in db + .rooms + .all_pdus(&sender_user, &room_id) + .filter_map(|pdu| pdu.ok()) // Ignore all broken pdus + .filter(|(_, pdu)| pdu.kind == EventType::RoomMember) + .map(|(_, pdu)| { + let content = serde_json::from_value::< + ruma::events::room::member::MemberEventContent, + >(pdu.content.clone()) + .map_err(|_| Error::bad_database("Invalid member event in database."))?; + + if let Some(state_key) = &pdu.state_key { + let user_id = UserId::try_from(state_key.clone()).map_err(|_| { + Error::bad_database("Invalid UserId in member PDU.") + })?; + + // The membership was and still is invite or join + if matches!( + content.membership, + MembershipState::Join | MembershipState::Invite + ) && (db.rooms.is_joined(&user_id, &room_id)? + || db.rooms.is_invited(&user_id, &room_id)?) + { + Ok::<_, Error>(Some(state_key.clone())) + } else { + Ok(None) + } + } else { + Ok(None) + } + }) + // Filter out buggy users + .filter_map(|u| u.ok()) + // Filter for possible heroes + .flatten() + { + if heroes.contains(&hero) || hero == sender_user.as_str() { + continue; + } + + heroes.push(hero); + } + } + + ( + Some(joined_member_count), + Some(invited_member_count), + heroes, + ) + }; let ( heroes, @@ -250,63 +320,107 @@ async fn sync_helper( invited_member_count, joined_since_last_sync, state_events, - ) = if pdus_after_since && Some(current_shortstatehash) != since_shortstatehash { - let current_state = db.rooms.room_state_full(&room_id)?; - let current_members = current_state - .iter() - .filter(|(key, _)| key.0 == EventType::RoomMember) - .map(|(key, value)| (&key.1, value)) // Only keep state key - .collect::>(); - let encrypted_room = current_state - .get(&(EventType::RoomEncryption, "".to_owned())) - .is_some(); - let since_state = since_shortstatehash - .as_ref() - .map(|since_shortstatehash| { - since_shortstatehash - .map(|since_shortstatehash| db.rooms.state_full(since_shortstatehash)) - .transpose() - }) - .transpose()?; + ) = if since_shortstatehash.is_none() { + // Probably since = 0, we will do an initial sync + let (joined_member_count, invited_member_count, heroes) = calculate_counts(); - let since_encryption = since_state.as_ref().map(|state| { - state - .as_ref() - .map(|state| state.get(&(EventType::RoomEncryption, "".to_owned()))) - }); + let current_state_ids = db.rooms.state_full_ids(current_shortstatehash)?; + let state_events = current_state_ids + .iter() + .map(|id| db.rooms.get_pdu(id)) + .filter_map(|r| r.ok().flatten()) + .collect::>(); + + ( + heroes, + joined_member_count, + invited_member_count, + true, + state_events + ) + } else if !pdus_after_since || since_shortstatehash == Some(current_shortstatehash) { + // No state changes + (Vec::new(), None, None, false, Vec::new()) + } else { + // Incremental /sync + let since_shortstatehash = since_shortstatehash.unwrap(); + + let since_sender_member = db + .rooms + .state_get( + since_shortstatehash, + &EventType::RoomMember, + sender_user.as_str(), + )? + .and_then(|pdu| { + serde_json::from_value::>( + pdu.content.clone(), + ) + .expect("Raw::from_value always works") + .deserialize() + .map_err(|_| Error::bad_database("Invalid PDU in database.")) + .ok() + }); + + let joined_since_last_sync = since_sender_member + .map_or(true, |member| member.membership != MembershipState::Join); + + let current_state_ids = db.rooms.state_full_ids(current_shortstatehash)?; + + let since_state_ids = db.rooms.state_full_ids(since_shortstatehash)?; + + let state_events = if joined_since_last_sync { + current_state_ids + .iter() + .map(|id| db.rooms.get_pdu(id)) + .filter_map(|r| r.ok().flatten()) + .collect::>() + } else { + current_state_ids + .difference(&since_state_ids) + .filter(|id| { + !timeline_pdus + .iter() + .any(|(_, timeline_pdu)| timeline_pdu.event_id == **id) + }) + .map(|id| db.rooms.get_pdu(id)) + .filter_map(|r| r.ok().flatten()) + .collect() + }; + + let encrypted_room = db + .rooms + .state_get(current_shortstatehash, &EventType::RoomEncryption, "")? + .is_some(); + + let since_encryption = + db.rooms + .state_get(since_shortstatehash, &EventType::RoomEncryption, "")?; // Calculations: - let new_encrypted_room = - encrypted_room && since_encryption.map_or(true, |encryption| encryption.is_none()); + let new_encrypted_room = encrypted_room && since_encryption.is_none(); - let send_member_count = since_state.as_ref().map_or(true, |since_state| { - since_state.as_ref().map_or(true, |since_state| { - current_members.len() - != since_state - .iter() - .filter(|(key, _)| key.0 == EventType::RoomMember) - .count() - }) - }); - - let since_sender_member = since_state.as_ref().map(|since_state| { - since_state.as_ref().and_then(|state| { - state - .get(&(EventType::RoomMember, sender_user.as_str().to_owned())) - .and_then(|pdu| { - serde_json::from_value::< - Raw, - >(pdu.content.clone()) - .expect("Raw::from_value always works") - .deserialize() - .map_err(|_| Error::bad_database("Invalid PDU in database.")) - .ok() - }) - }) - }); + let send_member_count = state_events + .iter() + .any(|event| event.kind == EventType::RoomMember); if encrypted_room { - for (user_id, current_member) in current_members { + for (user_id, current_member) in db + .rooms + .room_members(&room_id) + .filter_map(|r| r.ok()) + .filter_map(|user_id| { + db.rooms + .state_get( + current_shortstatehash, + &EventType::RoomMember, + user_id.as_str(), + ) + .ok() + .flatten() + .map(|current_member| (user_id, current_member)) + }) + { let current_membership = serde_json::from_value::< Raw, >(current_member.content.clone()) @@ -315,31 +429,23 @@ async fn sync_helper( .map_err(|_| Error::bad_database("Invalid PDU in database."))? .membership; - let since_membership = - since_state - .as_ref() - .map_or(MembershipState::Leave, |since_state| { - since_state - .as_ref() - .and_then(|since_state| { - since_state - .get(&(EventType::RoomMember, user_id.clone())) - .and_then(|since_member| { - serde_json::from_value::< - Raw, - >( - since_member.content.clone() - ) - .expect("Raw::from_value always works") - .deserialize() - .map_err(|_| { - Error::bad_database("Invalid PDU in database.") - }) - .ok() - }) - }) - .map_or(MembershipState::Leave, |member| member.membership) - }); + let since_membership = db + .rooms + .state_get( + since_shortstatehash, + &EventType::RoomMember, + user_id.as_str(), + )? + .and_then(|since_member| { + serde_json::from_value::< + Raw, + >(since_member.content.clone()) + .expect("Raw::from_value always works") + .deserialize() + .map_err(|_| Error::bad_database("Invalid PDU in database.")) + .ok() + }) + .map_or(MembershipState::Leave, |member| member.membership); let user_id = UserId::try_from(user_id.clone()) .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; @@ -361,10 +467,6 @@ async fn sync_helper( } } - let joined_since_last_sync = since_sender_member.map_or(true, |member| { - member.map_or(true, |member| member.membership != MembershipState::Join) - }); - if joined_since_last_sync && encrypted_room || new_encrypted_room { // If the user is in a new encrypted room, give them all joined users device_list_updates.extend( @@ -384,100 +486,11 @@ async fn sync_helper( } let (joined_member_count, invited_member_count, heroes) = if send_member_count { - let joined_member_count = db.rooms.room_members(&room_id).count(); - let invited_member_count = db.rooms.room_members_invited(&room_id).count(); - - // Recalculate heroes (first 5 members) - let mut heroes = Vec::new(); - - if joined_member_count + invited_member_count <= 5 { - // Go through all PDUs and for each member event, check if the user is still joined or - // invited until we have 5 or we reach the end - - for hero in db - .rooms - .all_pdus(&sender_user, &room_id) - .filter_map(|pdu| pdu.ok()) // Ignore all broken pdus - .filter(|(_, pdu)| pdu.kind == EventType::RoomMember) - .map(|(_, pdu)| { - let content = serde_json::from_value::< - ruma::events::room::member::MemberEventContent, - >(pdu.content.clone()) - .map_err(|_| { - Error::bad_database("Invalid member event in database.") - })?; - - if let Some(state_key) = &pdu.state_key { - let user_id = - UserId::try_from(state_key.clone()).map_err(|_| { - Error::bad_database("Invalid UserId in member PDU.") - })?; - - // The membership was and still is invite or join - if matches!( - content.membership, - MembershipState::Join | MembershipState::Invite - ) && (db.rooms.is_joined(&user_id, &room_id)? - || db.rooms.is_invited(&user_id, &room_id)?) - { - Ok::<_, Error>(Some(state_key.clone())) - } else { - Ok(None) - } - } else { - Ok(None) - } - }) - // Filter out buggy users - .filter_map(|u| u.ok()) - // Filter for possible heroes - .flatten() - { - if heroes.contains(&hero) || hero == sender_user.as_str() { - continue; - } - - heroes.push(hero); - } - } - - ( - Some(joined_member_count), - Some(invited_member_count), - heroes, - ) + calculate_counts() } else { (None, None, Vec::new()) }; - let state_events = if joined_since_last_sync { - current_state - .iter() - .map(|(_, pdu)| pdu.to_sync_state_event()) - .collect() - } else { - match since_state { - None => Vec::new(), - Some(Some(since_state)) => current_state - .iter() - .filter(|(key, value)| { - since_state.get(key).map(|e| &e.event_id) != Some(&value.event_id) - }) - .filter(|(_, value)| { - !timeline_pdus.iter().any(|(_, timeline_pdu)| { - timeline_pdu.kind == value.kind - && timeline_pdu.state_key == value.state_key - }) - }) - .map(|(_, pdu)| pdu.to_sync_state_event()) - .collect(), - Some(None) => current_state - .iter() - .map(|(_, pdu)| pdu.to_sync_state_event()) - .collect(), - } - }; - ( heroes, joined_member_count, @@ -485,8 +498,6 @@ async fn sync_helper( joined_since_last_sync, state_events, ) - } else { - (Vec::new(), None, None, false, Vec::new()) }; // Look for device list updates in this room @@ -577,7 +588,10 @@ async fn sync_helper( events: room_events, }, state: sync_events::State { - events: state_events, + events: state_events + .iter() + .map(|pdu| pdu.to_sync_state_event()) + .collect(), }, ephemeral: sync_events::Ephemeral { events: edus }, }; diff --git a/src/database/globals.rs b/src/database/globals.rs index 4859ef4e..82a2aa28 100644 --- a/src/database/globals.rs +++ b/src/database/globals.rs @@ -1,12 +1,9 @@ use crate::{database::Config, utils, ConduitResult, Error, Result}; use log::{error, info}; -use ruma::{ - api::{ +use ruma::{DeviceId, EventId, MilliSecondsSinceUnixEpoch, RoomId, ServerName, ServerSigningKeyId, UserId, api::{ client::r0::sync::sync_events, federation::discovery::{ServerSigningKeys, VerifyKey}, - }, - DeviceId, EventId, MilliSecondsSinceUnixEpoch, ServerName, ServerSigningKeyId, UserId, -}; + }}; use rustls::{ServerCertVerifier, WebPKIVerifier}; use std::{ collections::{BTreeMap, HashMap}, @@ -15,7 +12,7 @@ use std::{ sync::{Arc, RwLock}, time::{Duration, Instant}, }; -use tokio::sync::Semaphore; +use tokio::sync::{Mutex, Semaphore}; use trust_dns_resolver::TokioAsyncResolver; use super::abstraction::Tree; @@ -38,6 +35,7 @@ pub struct Globals { pub bad_event_ratelimiter: Arc>>, pub bad_signature_ratelimiter: Arc, RateLimitState>>>, pub servername_ratelimiter: Arc, Arc>>>, + pub roomid_mutex: RwLock>>>, pub sync_receivers: RwLock< BTreeMap< (UserId, Box), @@ -165,6 +163,7 @@ impl Globals { bad_event_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), bad_signature_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), servername_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), + roomid_mutex: RwLock::new(BTreeMap::new()), sync_receivers: RwLock::new(BTreeMap::new()), }; diff --git a/src/database/rooms.rs b/src/database/rooms.rs index 23b87fa0..550974df 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -20,12 +20,7 @@ use ruma::{ state_res::{self, Event, RoomVersion, StateMap}, uint, EventId, RoomAliasId, RoomId, RoomVersionId, ServerName, UserId, }; -use std::{ - collections::{BTreeMap, HashMap, HashSet}, - convert::{TryFrom, TryInto}, - mem, - sync::{Arc, RwLock}, -}; +use std::{collections::{BTreeMap, BTreeSet, HashMap, HashSet}, convert::{TryFrom, TryInto}, mem, sync::{Arc, RwLock}}; use super::{abstraction::Tree, admin::AdminCommand, pusher}; @@ -89,7 +84,7 @@ pub struct Rooms { impl Rooms { /// Builds a StateMap by iterating over all keys that start /// with state_hash, this gives the full state for the given state_hash. - pub fn state_full_ids(&self, shortstatehash: u64) -> Result> { + pub fn state_full_ids(&self, shortstatehash: u64) -> Result> { Ok(self .stateid_shorteventid .scan_prefix(shortstatehash.to_be_bytes().to_vec()) @@ -1217,6 +1212,7 @@ impl Rooms { state_key, redacts, } = pdu_builder; + // TODO: Make sure this isn't called twice in parallel let prev_events = self .get_pdu_leaves(&room_id)? diff --git a/src/server_server.rs b/src/server_server.rs index 80340521..969a2610 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -624,13 +624,44 @@ pub async fn send_transaction_message_route( } }; + // 0. Check the server is in the room + let room_id = match value + .get("room_id") + .and_then(|id| RoomId::try_from(id.as_str()?).ok()) + { + Some(id) => id, + None => { + // Event is invalid + resolved_map.insert(event_id, Err("Event needs a valid RoomId.".to_string())); + continue; + } + }; + + let mutex = Arc::clone( + db.globals + .roomid_mutex + .write() + .unwrap() + .entry(room_id.clone()) + .or_default(), + ); + let mutex_lock = mutex.lock().await; let start_time = Instant::now(); resolved_map.insert( event_id.clone(), - handle_incoming_pdu(&body.origin, &event_id, value, true, &db, &pub_key_map) - .await - .map(|_| ()), + handle_incoming_pdu( + &body.origin, + &event_id, + &room_id, + value, + true, + &db, + &pub_key_map, + ) + .await + .map(|_| ()), ); + drop(mutex_lock); let elapsed = start_time.elapsed(); if elapsed > Duration::from_secs(1) { @@ -782,8 +813,8 @@ pub async fn send_transaction_message_route( type AsyncRecursiveResult<'a, T, E> = Pin> + 'a + Send>>; /// When receiving an event one needs to: -/// 0. Skip the PDU if we already know about it -/// 1. Check the server is in the room +/// 0. Check the server is in the room +/// 1. Skip the PDU if we already know about it /// 2. Check signatures, otherwise drop /// 3. Check content hash, redact if doesn't match /// 4. Fetch any missing auth events doing all checks listed here starting at 1. These are not @@ -808,6 +839,7 @@ type AsyncRecursiveResult<'a, T, E> = Pin( origin: &'a ServerName, event_id: &'a EventId, + room_id: &'a RoomId, value: BTreeMap, is_timeline_event: bool, db: &'a Database, @@ -815,24 +847,6 @@ pub fn handle_incoming_pdu<'a>( ) -> AsyncRecursiveResult<'a, Option>, String> { Box::pin(async move { // TODO: For RoomVersion6 we must check that Raw<..> is canonical do we anywhere?: https://matrix.org/docs/spec/rooms/v6#canonical-json - - // 0. Skip the PDU if we already have it as a timeline event - if let Ok(Some(pdu_id)) = db.rooms.get_pdu_id(&event_id) { - return Ok(Some(pdu_id.to_vec())); - } - - // 1. Check the server is in the room - let room_id = match value - .get("room_id") - .and_then(|id| RoomId::try_from(id.as_str()?).ok()) - { - Some(id) => id, - None => { - // Event is invalid - return Err("Event needs a valid RoomId.".to_string()); - } - }; - match db.rooms.exists(&room_id) { Ok(true) => {} _ => { @@ -840,6 +854,11 @@ pub fn handle_incoming_pdu<'a>( } } + // 1. Skip the PDU if we already have it as a timeline event + if let Ok(Some(pdu_id)) = db.rooms.get_pdu_id(&event_id) { + return Ok(Some(pdu_id.to_vec())); + } + // We go through all the signatures we see on the value and fetch the corresponding signing // keys fetch_required_signing_keys(&value, &pub_key_map, db) @@ -899,7 +918,7 @@ pub fn handle_incoming_pdu<'a>( // 5. Reject "due to auth events" if can't get all the auth events or some of the auth events are also rejected "due to auth events" // EDIT: Step 5 is not applied anymore because it failed too often debug!("Fetching auth events for {}", incoming_pdu.event_id); - fetch_and_handle_events(db, origin, &incoming_pdu.auth_events, pub_key_map) + fetch_and_handle_events(db, origin, &incoming_pdu.auth_events, &room_id, pub_key_map) .await .map_err(|e| e.to_string())?; @@ -1000,13 +1019,13 @@ pub fn handle_incoming_pdu<'a>( if incoming_pdu.prev_events.len() == 1 { let prev_event = &incoming_pdu.prev_events[0]; - let state_vec = db + let state = db .rooms .pdu_shortstatehash(prev_event) .map_err(|_| "Failed talking to db".to_owned())? .map(|shortstatehash| db.rooms.state_full_ids(shortstatehash).ok()) .flatten(); - if let Some(mut state_vec) = state_vec { + if let Some(mut state) = state { if db .rooms .get_pdu(prev_event) @@ -1016,25 +1035,31 @@ pub fn handle_incoming_pdu<'a>( .state_key .is_some() { - state_vec.push(prev_event.clone()); + state.insert(prev_event.clone()); } state_at_incoming_event = Some( - fetch_and_handle_events(db, origin, &state_vec, pub_key_map) - .await - .map_err(|_| "Failed to fetch state events locally".to_owned())? - .into_iter() - .map(|pdu| { + fetch_and_handle_events( + db, + origin, + &state.into_iter().collect::>(), + &room_id, + pub_key_map, + ) + .await + .map_err(|_| "Failed to fetch state events locally".to_owned())? + .into_iter() + .map(|pdu| { + ( ( - ( - pdu.kind.clone(), - pdu.state_key - .clone() - .expect("events from state_full_ids are state events"), - ), - pdu, - ) - }) - .collect(), + pdu.kind.clone(), + pdu.state_key + .clone() + .expect("events from state_full_ids are state events"), + ), + pdu, + ) + }) + .collect(), ); } // TODO: set incoming_auth_events? @@ -1057,12 +1082,18 @@ pub fn handle_incoming_pdu<'a>( { Ok(res) => { debug!("Fetching state events at event."); - let state_vec = - match fetch_and_handle_events(&db, origin, &res.pdu_ids, pub_key_map).await - { - Ok(state) => state, - Err(_) => return Err("Failed to fetch state events.".to_owned()), - }; + let state_vec = match fetch_and_handle_events( + &db, + origin, + &res.pdu_ids, + &room_id, + pub_key_map, + ) + .await + { + Ok(state) => state, + Err(_) => return Err("Failed to fetch state events.".to_owned()), + }; let mut state = BTreeMap::new(); for pdu in state_vec { @@ -1088,8 +1119,14 @@ pub fn handle_incoming_pdu<'a>( } debug!("Fetching auth chain events at event."); - match fetch_and_handle_events(&db, origin, &res.auth_chain_ids, pub_key_map) - .await + match fetch_and_handle_events( + &db, + origin, + &res.auth_chain_ids, + &room_id, + pub_key_map, + ) + .await { Ok(state) => state, Err(_) => return Err("Failed to fetch auth chain.".to_owned()), @@ -1219,8 +1256,14 @@ pub fn handle_incoming_pdu<'a>( for map in &fork_states { let mut state_auth = vec![]; for auth_id in map.values().flat_map(|pdu| &pdu.auth_events) { - match fetch_and_handle_events(&db, origin, &[auth_id.clone()], pub_key_map) - .await + match fetch_and_handle_events( + &db, + origin, + &[auth_id.clone()], + &room_id, + pub_key_map, + ) + .await { // This should always contain exactly one element when Ok Ok(events) => state_auth.extend_from_slice(&events), @@ -1326,6 +1369,7 @@ pub(crate) fn fetch_and_handle_events<'a>( db: &'a Database, origin: &'a ServerName, events: &'a [EventId], + room_id: &'a RoomId, pub_key_map: &'a RwLock>>, ) -> AsyncRecursiveResult<'a, Vec>, Error> { Box::pin(async move { @@ -1379,6 +1423,7 @@ pub(crate) fn fetch_and_handle_events<'a>( match handle_incoming_pdu( origin, &event_id, + &room_id, value.clone(), false, db, @@ -1657,7 +1702,8 @@ pub(crate) fn append_incoming_pdu( .filter_map(|r| r.ok()) .any(|member| users.iter().any(|regex| regex.is_match(member.as_str()))) { - db.sending.send_pdu_appservice(&appservice.0, &pdu_id)?; + db.sending + .send_pdu_appservice(&appservice.0, &pdu_id)?; } } } @@ -1867,7 +1913,11 @@ pub fn get_room_state_ids_route( "Pdu state not found.", ))?; - let pdu_ids = db.rooms.state_full_ids(shortstatehash)?; + let pdu_ids = db + .rooms + .state_full_ids(shortstatehash)? + .into_iter() + .collect(); let mut auth_chain_ids = BTreeSet::::new(); let mut todo = BTreeSet::new(); @@ -2113,18 +2163,36 @@ pub async fn create_join_event_route( ) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid."))?; - let pdu_id = handle_incoming_pdu(&origin, &event_id, value, true, &db, &pub_key_map) - .await - .map_err(|_| { - Error::BadRequest( - ErrorKind::InvalidParam, - "Error while handling incoming PDU.", - ) - })? - .ok_or(Error::BadRequest( + let mutex = Arc::clone( + db.globals + .roomid_mutex + .write() + .unwrap() + .entry(body.room_id.clone()) + .or_default(), + ); + let mutex_lock = mutex.lock().await; + let pdu_id = handle_incoming_pdu( + &origin, + &event_id, + &body.room_id, + value, + true, + &db, + &pub_key_map, + ) + .await + .map_err(|_| { + Error::BadRequest( ErrorKind::InvalidParam, - "Could not accept incoming PDU as timeline event.", - ))?; + "Error while handling incoming PDU.", + ) + })? + .ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "Could not accept incoming PDU as timeline event.", + ))?; + drop(mutex_lock); let state_ids = db.rooms.state_full_ids(shortstatehash)?;