From face766e0f32481fd97a435f1ed8579d8cfc634c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Wed, 5 Oct 2022 12:45:54 +0200 Subject: [PATCH] messing with trait objects --- src/api/client_server/membership.rs | 6 +- src/api/client_server/room.rs | 4 +- src/api/client_server/sync.rs | 2 +- src/api/ruma_wrapper/axum.rs | 2 +- src/api/server_server.rs | 175 +---------- src/database/key_value/globals.rs | 8 +- src/database/key_value/media.rs | 4 +- src/database/key_value/pusher.rs | 2 +- src/database/key_value/rooms/alias.rs | 6 +- src/database/key_value/rooms/auth_chain.rs | 6 +- src/database/key_value/rooms/edus/presence.rs | 2 + .../key_value/rooms/edus/read_receipt.rs | 4 +- src/database/key_value/rooms/lazy_load.rs | 29 +- src/database/key_value/rooms/metadata.rs | 4 +- src/database/key_value/rooms/search.rs | 10 +- .../key_value/rooms/state_accessor.rs | 26 +- .../key_value/rooms/state_compressor.rs | 4 +- src/database/key_value/rooms/timeline.rs | 20 +- src/database/key_value/rooms/user.rs | 6 +- src/database/key_value/users.rs | 19 +- src/database/mod.rs | 296 ++++++++++-------- src/lib.rs | 14 +- src/main.rs | 17 +- src/service/account_data/mod.rs | 6 +- src/service/appservice/mod.rs | 6 +- src/service/globals/mod.rs | 8 +- src/service/key_backups/mod.rs | 6 +- src/service/media/mod.rs | 6 +- src/service/mod.rs | 44 ++- src/service/pusher/mod.rs | 6 +- src/service/rooms/alias/data.rs | 2 +- src/service/rooms/alias/mod.rs | 6 +- src/service/rooms/auth_chain/data.rs | 2 +- src/service/rooms/auth_chain/mod.rs | 6 +- src/service/rooms/directory/mod.rs | 6 +- src/service/rooms/edus/mod.rs | 8 +- src/service/rooms/edus/presence/mod.rs | 6 +- src/service/rooms/edus/read_receipt/mod.rs | 6 +- src/service/rooms/edus/typing/mod.rs | 6 +- src/service/rooms/event_handler/mod.rs | 187 ++++++++++- src/service/rooms/lazy_loading/data.rs | 2 +- src/service/rooms/lazy_loading/mod.rs | 23 +- src/service/rooms/metadata/mod.rs | 6 +- src/service/rooms/mod.rs | 34 +- src/service/rooms/outlier/mod.rs | 6 +- src/service/rooms/pdu_metadata/mod.rs | 6 +- src/service/rooms/search/data.rs | 2 +- src/service/rooms/search/mod.rs | 11 +- src/service/rooms/short/mod.rs | 6 +- src/service/rooms/state/data.rs | 3 +- src/service/rooms/state/mod.rs | 19 +- src/service/rooms/state_accessor/mod.rs | 6 +- src/service/rooms/state_cache/mod.rs | 6 +- src/service/rooms/state_compressor/data.rs | 8 +- src/service/rooms/state_compressor/mod.rs | 6 +- src/service/rooms/timeline/mod.rs | 6 +- src/service/rooms/user/mod.rs | 6 +- src/service/transaction_ids/mod.rs | 6 +- src/service/uiaa/mod.rs | 6 +- src/service/users/data.rs | 10 +- src/service/users/mod.rs | 6 +- 61 files changed, 623 insertions(+), 544 deletions(-) diff --git a/src/api/client_server/membership.rs b/src/api/client_server/membership.rs index 98931f25..720c1e64 100644 --- a/src/api/client_server/membership.rs +++ b/src/api/client_server/membership.rs @@ -481,7 +481,7 @@ async fn join_room_by_id_helper( let (make_join_response, remote_server) = make_join_response_and_server?; let room_version = match make_join_response.room_version { - Some(room_version) if services().rooms.metadata.is_supported_version(&room_version) => room_version, + Some(room_version) if services().globals.supported_room_versions().contains(&room_version) => room_version, _ => return Err(Error::BadServerResponse("Room version is not supported")), }; @@ -568,7 +568,7 @@ async fn join_room_by_id_helper( let mut state = HashMap::new(); let pub_key_map = RwLock::new(BTreeMap::new()); - server_server::fetch_join_signing_keys( + services().rooms.event_handler.fetch_join_signing_keys( &send_join_response, &room_version, &pub_key_map, @@ -1048,7 +1048,7 @@ async fn remote_leave_room( let (make_leave_response, remote_server) = make_leave_response_and_server?; let room_version_id = match make_leave_response.room_version { - Some(version) if services().rooms.is_supported_version(&version) => version, + Some(version) if services().globals.supported_room_versions().contains(&version) => version, _ => return Err(Error::BadServerResponse("Room version is not supported")), }; diff --git a/src/api/client_server/room.rs b/src/api/client_server/room.rs index a7fa9520..939fbaa2 100644 --- a/src/api/client_server/room.rs +++ b/src/api/client_server/room.rs @@ -99,7 +99,7 @@ pub async fn create_room_route( let room_version = match body.room_version.clone() { Some(room_version) => { - if services().rooms.is_supported_version(&services(), &room_version) { + if services().globals.supported_room_versions().contains(&room_version) { room_version } else { return Err(Error::BadRequest( @@ -470,7 +470,7 @@ pub async fn upgrade_room_route( ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services().rooms.is_supported_version(&body.new_version) { + if !services().globals.supported_room_versions().contains(&body.new_version) { return Err(Error::BadRequest( ErrorKind::UnsupportedRoomVersion, "This server does not support that room version.", diff --git a/src/api/client_server/sync.rs b/src/api/client_server/sync.rs index e38ea600..3489a9a9 100644 --- a/src/api/client_server/sync.rs +++ b/src/api/client_server/sync.rs @@ -175,7 +175,7 @@ async fn sync_helper( services().rooms.edus.presence.ping_presence(&sender_user)?; // Setup watchers, so if there's no response, we can wait for them - let watcher = services().watch(&sender_user, &sender_device); + let watcher = services().globals.db.watch(&sender_user, &sender_device); let next_batch = services().globals.current_count()?; let next_batch_string = next_batch.to_string(); diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs index babf2a74..d926b89b 100644 --- a/src/api/ruma_wrapper/axum.rs +++ b/src/api/ruma_wrapper/axum.rs @@ -197,7 +197,7 @@ where request_map.insert("content".to_owned(), json_body.clone()); }; - let keys_result = server_server::fetch_signing_keys( + let keys_result = services().rooms.event_handler.fetch_signing_keys( &x_matrix.origin, vec![x_matrix.key.to_owned()], ) diff --git a/src/api/server_server.rs b/src/api/server_server.rs index 9aa2beb9..45d749d0 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -664,7 +664,7 @@ pub async fn send_transaction_message_route( Some(id) => id, None => { // Event is invalid - resolved_map.insert(event_id, Err("Event needs a valid RoomId.".to_owned())); + resolved_map.insert(event_id, Err(Error::bad_database("Event needs a valid RoomId."))); continue; } }; @@ -707,7 +707,7 @@ pub async fn send_transaction_message_route( for pdu in &resolved_map { if let Err(e) = pdu.1 { - if e != "Room is unknown to this server." { + if matches!(e, Error::BadRequest(ErrorKind::NotFound, _)) { warn!("Incoming PDU failed {:?}", pdu); } } @@ -854,170 +854,7 @@ pub async fn send_transaction_message_route( } } - Ok(send_transaction_message::v1::Response { pdus: resolved_map }) -} - -/// Search the DB for the signing keys of the given server, if we don't have them -/// fetch them from the server and save to our DB. -#[tracing::instrument(skip_all)] -pub(crate) async fn fetch_signing_keys( - origin: &ServerName, - signature_ids: Vec, -) -> Result> { - let contains_all_ids = - |keys: &BTreeMap| signature_ids.iter().all(|id| keys.contains_key(id)); - - let permit = services() - .globals - .servername_ratelimiter - .read() - .unwrap() - .get(origin) - .map(|s| Arc::clone(s).acquire_owned()); - - let permit = match permit { - Some(p) => p, - None => { - let mut write = services().globals.servername_ratelimiter.write().unwrap(); - let s = Arc::clone( - write - .entry(origin.to_owned()) - .or_insert_with(|| Arc::new(Semaphore::new(1))), - ); - - s.acquire_owned() - } - } - .await; - - let back_off = |id| match services() - .globals - .bad_signature_ratelimiter - .write() - .unwrap() - .entry(id) - { - hash_map::Entry::Vacant(e) => { - e.insert((Instant::now(), 1)); - } - hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), - }; - - if let Some((time, tries)) = services() - .globals - .bad_signature_ratelimiter - .read() - .unwrap() - .get(&signature_ids) - { - // Exponential backoff - let mut min_elapsed_duration = Duration::from_secs(30) * (*tries) * (*tries); - if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { - min_elapsed_duration = Duration::from_secs(60 * 60 * 24); - } - - if time.elapsed() < min_elapsed_duration { - debug!("Backing off from {:?}", signature_ids); - return Err(Error::BadServerResponse("bad signature, still backing off")); - } - } - - trace!("Loading signing keys for {}", origin); - - let mut result: BTreeMap<_, _> = services() - .globals - .signing_keys_for(origin)? - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)) - .collect(); - - if contains_all_ids(&result) { - return Ok(result); - } - - debug!("Fetching signing keys for {} over federation", origin); - - if let Some(server_key) = services() - .sending - .send_federation_request(origin, get_server_keys::v2::Request::new()) - .await - .ok() - .and_then(|resp| resp.server_key.deserialize().ok()) - { - services().globals.add_signing_key(origin, server_key.clone())?; - - result.extend( - server_key - .verify_keys - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)), - ); - result.extend( - server_key - .old_verify_keys - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)), - ); - - if contains_all_ids(&result) { - return Ok(result); - } - } - - for server in services().globals.trusted_servers() { - debug!("Asking {} for {}'s signing key", server, origin); - if let Some(server_keys) = services() - .sending - .send_federation_request( - server, - get_remote_server_keys::v2::Request::new( - origin, - MilliSecondsSinceUnixEpoch::from_system_time( - SystemTime::now() - .checked_add(Duration::from_secs(3600)) - .expect("SystemTime to large"), - ) - .expect("time is valid"), - ), - ) - .await - .ok() - .map(|resp| { - resp.server_keys - .into_iter() - .filter_map(|e| e.deserialize().ok()) - .collect::>() - }) - { - trace!("Got signing keys: {:?}", server_keys); - for k in server_keys { - services().globals.add_signing_key(origin, k.clone())?; - result.extend( - k.verify_keys - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)), - ); - result.extend( - k.old_verify_keys - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)), - ); - } - - if contains_all_ids(&result) { - return Ok(result); - } - } - } - - drop(permit); - - back_off(signature_ids); - - warn!("Failed to find public key for server: {}", origin); - Err(Error::BadServerResponse( - "Failed to find public key for server", - )) + Ok(send_transaction_message::v1::Response { pdus: resolved_map.into_iter().map(|(e, r)| (e, r.map_err(|e| e.to_string()))).collect() }) } #[tracing::instrument(skip(starting_events))] @@ -1050,7 +887,7 @@ pub(crate) async fn get_auth_chain<'a>( } let chunk_key: Vec = chunk.iter().map(|(short, _)| short).copied().collect(); - if let Some(cached) = services().rooms.auth_chain.get_auth_chain_from_cache(&chunk_key)? { + if let Some(cached) = services().rooms.auth_chain.get_cached_eventid_authchain(&chunk_key)? { hits += 1; full_auth_chain.extend(cached.iter().copied()); continue; @@ -1062,7 +899,7 @@ pub(crate) async fn get_auth_chain<'a>( let mut misses2 = 0; let mut i = 0; for (sevent_id, event_id) in chunk { - if let Some(cached) = services().rooms.auth_chain.get_auth_chain_from_cache(&[sevent_id])? { + if let Some(cached) = services().rooms.auth_chain.get_cached_eventid_authchain(&[sevent_id])? { hits2 += 1; chunk_cache.extend(cached.iter().copied()); } else { @@ -1689,7 +1526,7 @@ pub async fn create_invite_route( services().rooms.event_handler.acl_check(&sender_servername, &body.room_id)?; - if !services().rooms.is_supported_version(&body.room_version) { + if !services().globals.supported_room_versions().contains(&body.room_version) { return Err(Error::BadRequest( ErrorKind::IncompatibleRoomVersion { room_version: body.room_version.clone(), diff --git a/src/database/key_value/globals.rs b/src/database/key_value/globals.rs index 81e6ee1f..e6652290 100644 --- a/src/database/key_value/globals.rs +++ b/src/database/key_value/globals.rs @@ -4,10 +4,10 @@ use crate::{Result, service, database::KeyValueDatabase, Error, utils}; impl service::globals::Data for KeyValueDatabase { fn load_keypair(&self) -> Result { - let keypair_bytes = self.globals.get(b"keypair")?.map_or_else( + let keypair_bytes = self.global.get(b"keypair")?.map_or_else( || { let keypair = utils::generate_keypair(); - self.globals.insert(b"keypair", &keypair)?; + self.global.insert(b"keypair", &keypair)?; Ok::<_, Error>(keypair) }, |s| Ok(s.to_vec()), @@ -33,8 +33,10 @@ impl service::globals::Data for KeyValueDatabase { Ed25519KeyPair::from_der(key, version) .map_err(|_| Error::bad_database("Private or public keys are invalid.")) }); + + keypair } fn remove_keypair(&self) -> Result<()> { - self.globals.remove(b"keypair")? + self.global.remove(b"keypair") } } diff --git a/src/database/key_value/media.rs b/src/database/key_value/media.rs index 90a5c590..a84cbd53 100644 --- a/src/database/key_value/media.rs +++ b/src/database/key_value/media.rs @@ -1,3 +1,5 @@ +use ruma::api::client::error::ErrorKind; + use crate::{database::KeyValueDatabase, service, Error, utils, Result}; impl service::media::Data for KeyValueDatabase { @@ -33,7 +35,7 @@ impl service::media::Data for KeyValueDatabase { prefix.extend_from_slice(&0_u32.to_be_bytes()); // Height = 0 if it's not a thumbnail prefix.push(0xff); - let (key, _) = self.mediaid_file.scan_prefix(prefix).next().ok_or(Error::NotFound)?; + let (key, _) = self.mediaid_file.scan_prefix(prefix).next().ok_or(Error::BadRequest(ErrorKind::NotFound, "Media not found"))?; let mut parts = key.rsplit(|&b| b == 0xff); diff --git a/src/database/key_value/pusher.rs b/src/database/key_value/pusher.rs index 35c84638..b05e47be 100644 --- a/src/database/key_value/pusher.rs +++ b/src/database/key_value/pusher.rs @@ -55,6 +55,6 @@ impl service::pusher::Data for KeyValueDatabase { let mut prefix = sender.as_bytes().to_vec(); prefix.push(0xff); - self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| k) + Box::new(self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| k)) } } diff --git a/src/database/key_value/rooms/alias.rs b/src/database/key_value/rooms/alias.rs index c762defa..0aa8dd48 100644 --- a/src/database/key_value/rooms/alias.rs +++ b/src/database/key_value/rooms/alias.rs @@ -56,15 +56,15 @@ impl service::rooms::alias::Data for KeyValueDatabase { fn local_aliases_for_room( &self, room_id: &RoomId, - ) -> Result>> { + ) -> Box>>> { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); - self.aliasid_alias.scan_prefix(prefix).map(|(_, bytes)| { + Box::new(self.aliasid_alias.scan_prefix(prefix).map(|(_, bytes)| { utils::string_from_bytes(&bytes) .map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))? .try_into() .map_err(|_| Error::bad_database("Invalid alias in aliasid_alias.")) - }) + })) } } diff --git a/src/database/key_value/rooms/auth_chain.rs b/src/database/key_value/rooms/auth_chain.rs index 585d5626..888d472d 100644 --- a/src/database/key_value/rooms/auth_chain.rs +++ b/src/database/key_value/rooms/auth_chain.rs @@ -3,8 +3,8 @@ use std::{collections::HashSet, mem::size_of}; use crate::{service, database::KeyValueDatabase, Result, utils}; impl service::rooms::auth_chain::Data for KeyValueDatabase { - fn get_cached_eventid_authchain(&self, shorteventid: u64) -> Result> { - self.shorteventid_authchain + fn get_cached_eventid_authchain(&self, shorteventid: u64) -> Result>> { + Ok(self.shorteventid_authchain .get(&shorteventid.to_be_bytes())? .map(|chain| { chain @@ -13,7 +13,7 @@ impl service::rooms::auth_chain::Data for KeyValueDatabase { utils::u64_from_bytes(chunk).expect("byte length is correct") }) .collect() - }) + })) } fn cache_eventid_authchain(&self, shorteventid: u64, auth_chain: &HashSet) -> Result<()> { diff --git a/src/database/key_value/rooms/edus/presence.rs b/src/database/key_value/rooms/edus/presence.rs index fbbbff55..1477c28b 100644 --- a/src/database/key_value/rooms/edus/presence.rs +++ b/src/database/key_value/rooms/edus/presence.rs @@ -145,4 +145,6 @@ fn parse_presence_event(bytes: &[u8]) -> Result { .last_active_ago .map(|timestamp| current_timestamp - timestamp); } + + Ok(presence) } diff --git a/src/database/key_value/rooms/edus/read_receipt.rs b/src/database/key_value/rooms/edus/read_receipt.rs index 42d250f7..a12e2653 100644 --- a/src/database/key_value/rooms/edus/read_receipt.rs +++ b/src/database/key_value/rooms/edus/read_receipt.rs @@ -64,7 +64,7 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase { let mut first_possible_edu = prefix.clone(); first_possible_edu.extend_from_slice(&(since + 1).to_be_bytes()); // +1 so we don't send the event at since - self.readreceiptid_readreceipt + Box::new(self.readreceiptid_readreceipt .iter_from(&first_possible_edu, false) .take_while(move |(k, _)| k.starts_with(&prefix2)) .map(move |(k, v)| { @@ -91,7 +91,7 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase { serde_json::value::to_raw_value(&json).expect("json is valid raw value"), ), )) - }) + })) } fn private_read_set( diff --git a/src/database/key_value/rooms/lazy_load.rs b/src/database/key_value/rooms/lazy_load.rs index aaf14dd3..133e1d04 100644 --- a/src/database/key_value/rooms/lazy_load.rs +++ b/src/database/key_value/rooms/lazy_load.rs @@ -25,26 +25,19 @@ impl service::rooms::lazy_loading::Data for KeyValueDatabase { user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, - since: u64, + confirmed_user_ids: &mut dyn Iterator, ) -> Result<()> { - if let Some(user_ids) = self.lazy_load_waiting.lock().unwrap().remove(&( - user_id.to_owned(), - device_id.to_owned(), - room_id.to_owned(), - since, - )) { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xff); - prefix.extend_from_slice(room_id.as_bytes()); - prefix.push(0xff); + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xff); + prefix.extend_from_slice(device_id.as_bytes()); + prefix.push(0xff); + prefix.extend_from_slice(room_id.as_bytes()); + prefix.push(0xff); - for ll_id in user_ids { - let mut key = prefix.clone(); - key.extend_from_slice(ll_id.as_bytes()); - self.lazyloadedids.insert(&key, &[])?; - } + for ll_id in confirmed_user_ids { + let mut key = prefix.clone(); + key.extend_from_slice(ll_id.as_bytes()); + self.lazyloadedids.insert(&key, &[])?; } Ok(()) diff --git a/src/database/key_value/rooms/metadata.rs b/src/database/key_value/rooms/metadata.rs index 0509cbb8..db2bc69b 100644 --- a/src/database/key_value/rooms/metadata.rs +++ b/src/database/key_value/rooms/metadata.rs @@ -1,10 +1,10 @@ use ruma::RoomId; -use crate::{service, database::KeyValueDatabase, Result}; +use crate::{service, database::KeyValueDatabase, Result, services}; impl service::rooms::metadata::Data for KeyValueDatabase { fn exists(&self, room_id: &RoomId) -> Result { - let prefix = match self.get_shortroomid(room_id)? { + let prefix = match services().rooms.short.get_shortroomid(room_id)? { Some(b) => b.to_be_bytes().to_vec(), None => return Ok(false), }; diff --git a/src/database/key_value/rooms/search.rs b/src/database/key_value/rooms/search.rs index 15937f6d..dfbdbc64 100644 --- a/src/database/key_value/rooms/search.rs +++ b/src/database/key_value/rooms/search.rs @@ -2,10 +2,10 @@ use std::mem::size_of; use ruma::RoomId; -use crate::{service, database::KeyValueDatabase, utils, Result}; +use crate::{service, database::KeyValueDatabase, utils, Result, services}; impl service::rooms::search::Data for KeyValueDatabase { - fn index_pdu<'a>(&self, shortroomid: u64, pdu_id: u64, message_body: String) -> Result<()> { + fn index_pdu<'a>(&self, shortroomid: u64, pdu_id: &[u8], message_body: String) -> Result<()> { let mut batch = message_body .split_terminator(|c: char| !c.is_alphanumeric()) .filter(|s| !s.is_empty()) @@ -27,7 +27,7 @@ impl service::rooms::search::Data for KeyValueDatabase { room_id: &RoomId, search_string: &str, ) -> Result>>, Vec)>> { - let prefix = self + let prefix = services().rooms.short .get_shortroomid(room_id)? .expect("room exists") .to_be_bytes() @@ -60,11 +60,11 @@ impl service::rooms::search::Data for KeyValueDatabase { }) .map(|iter| { ( - iter.map(move |id| { + Box::new(iter.map(move |id| { let mut pduid = prefix_clone.clone(); pduid.extend_from_slice(&id); pduid - }), + })), words, ) })) diff --git a/src/database/key_value/rooms/state_accessor.rs b/src/database/key_value/rooms/state_accessor.rs index 037b98fc..4d5bd4a1 100644 --- a/src/database/key_value/rooms/state_accessor.rs +++ b/src/database/key_value/rooms/state_accessor.rs @@ -1,13 +1,13 @@ use std::{collections::{BTreeMap, HashMap}, sync::Arc}; -use crate::{database::KeyValueDatabase, service, PduEvent, Error, utils, Result}; +use crate::{database::KeyValueDatabase, service, PduEvent, Error, utils, Result, services}; use async_trait::async_trait; use ruma::{EventId, events::StateEventType, RoomId}; #[async_trait] impl service::rooms::state_accessor::Data for KeyValueDatabase { async fn state_full_ids(&self, shortstatehash: u64) -> Result>> { - let full_state = self + let full_state = services().rooms.state_compressor .load_shortstatehash_info(shortstatehash)? .pop() .expect("there is always one layer") @@ -15,7 +15,7 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase { let mut result = BTreeMap::new(); let mut i = 0; for compressed in full_state.into_iter() { - let parsed = self.parse_compressed_state_event(compressed)?; + let parsed = services().rooms.state_compressor.parse_compressed_state_event(compressed)?; result.insert(parsed.0, parsed.1); i += 1; @@ -30,7 +30,7 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase { &self, shortstatehash: u64, ) -> Result>> { - let full_state = self + let full_state = services().rooms.state_compressor .load_shortstatehash_info(shortstatehash)? .pop() .expect("there is always one layer") @@ -39,8 +39,8 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase { let mut result = HashMap::new(); let mut i = 0; for compressed in full_state { - let (_, eventid) = self.parse_compressed_state_event(compressed)?; - if let Some(pdu) = self.get_pdu(&eventid)? { + let (_, eventid) = services().rooms.state_compressor.parse_compressed_state_event(compressed)?; + if let Some(pdu) = services().rooms.timeline.get_pdu(&eventid)? { result.insert( ( pdu.kind.to_string().into(), @@ -69,11 +69,11 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase { event_type: &StateEventType, state_key: &str, ) -> Result>> { - let shortstatekey = match self.get_shortstatekey(event_type, state_key)? { + let shortstatekey = match services().rooms.short.get_shortstatekey(event_type, state_key)? { Some(s) => s, None => return Ok(None), }; - let full_state = self + let full_state = services().rooms.state_compressor .load_shortstatehash_info(shortstatehash)? .pop() .expect("there is always one layer") @@ -82,7 +82,7 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase { .into_iter() .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) .and_then(|compressed| { - self.parse_compressed_state_event(compressed) + services().rooms.state_compressor.parse_compressed_state_event(compressed) .ok() .map(|(_, id)| id) })) @@ -96,7 +96,7 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase { state_key: &str, ) -> Result>> { self.state_get_id(shortstatehash, event_type, state_key)? - .map_or(Ok(None), |event_id| self.get_pdu(&event_id)) + .map_or(Ok(None), |event_id| services().rooms.timeline.get_pdu(&event_id)) } /// Returns the state hash for this pdu. @@ -122,7 +122,7 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase { &self, room_id: &RoomId, ) -> Result>> { - if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { + if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? { self.state_full(current_shortstatehash).await } else { Ok(HashMap::new()) @@ -136,7 +136,7 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase { event_type: &StateEventType, state_key: &str, ) -> Result>> { - if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { + if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? { self.state_get_id(current_shortstatehash, event_type, state_key) } else { Ok(None) @@ -150,7 +150,7 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase { event_type: &StateEventType, state_key: &str, ) -> Result>> { - if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { + if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? { self.state_get(current_shortstatehash, event_type, state_key) } else { Ok(None) diff --git a/src/database/key_value/rooms/state_compressor.rs b/src/database/key_value/rooms/state_compressor.rs index 23a7122b..aee1890c 100644 --- a/src/database/key_value/rooms/state_compressor.rs +++ b/src/database/key_value/rooms/state_compressor.rs @@ -39,8 +39,8 @@ impl service::rooms::state_compressor::Data for KeyValueDatabase { } fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()> { - let mut value = diff.parent.to_be_bytes().to_vec(); - for new in &diff.new { + let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec(); + for new in &diff.added { value.extend_from_slice(&new[..]); } diff --git a/src/database/key_value/rooms/timeline.rs b/src/database/key_value/rooms/timeline.rs index c42509e0..a3b6c17d 100644 --- a/src/database/key_value/rooms/timeline.rs +++ b/src/database/key_value/rooms/timeline.rs @@ -3,7 +3,7 @@ use std::{collections::hash_map, mem::size_of, sync::Arc}; use ruma::{UserId, RoomId, api::client::error::ErrorKind, EventId, signatures::CanonicalJsonObject}; use tracing::error; -use crate::{service, database::KeyValueDatabase, utils, Error, PduEvent, Result}; +use crate::{service, database::KeyValueDatabase, utils, Error, PduEvent, Result, services}; impl service::rooms::timeline::Data for KeyValueDatabase { fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { @@ -191,7 +191,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase { room_id: &RoomId, since: u64, ) -> Result, PduEvent)>>>> { - let prefix = self + let prefix = services().rooms.short .get_shortroomid(room_id)? .expect("room exists") .to_be_bytes() @@ -203,7 +203,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase { let user_id = user_id.to_owned(); - Ok(self + Ok(Box::new(self .pduid_pdu .iter_from(&first_pdu_id, false) .take_while(move |(k, _)| k.starts_with(&prefix)) @@ -214,7 +214,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase { pdu.remove_transaction_id()?; } Ok((pdu_id, pdu)) - })) + }))) } /// Returns an iterator over all events and their tokens in a room that happened before the @@ -226,7 +226,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase { until: u64, ) -> Result, PduEvent)>>>> { // Create the first part of the full pdu id - let prefix = self + let prefix = services().rooms.short .get_shortroomid(room_id)? .expect("room exists") .to_be_bytes() @@ -239,7 +239,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase { let user_id = user_id.to_owned(); - Ok(self + Ok(Box::new(self .pduid_pdu .iter_from(current, true) .take_while(move |(k, _)| k.starts_with(&prefix)) @@ -250,7 +250,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase { pdu.remove_transaction_id()?; } Ok((pdu_id, pdu)) - })) + }))) } fn pdus_after<'a>( @@ -260,7 +260,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase { from: u64, ) -> Result, PduEvent)>>>> { // Create the first part of the full pdu id - let prefix = self + let prefix = services().rooms.short .get_shortroomid(room_id)? .expect("room exists") .to_be_bytes() @@ -273,7 +273,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase { let user_id = user_id.to_owned(); - Ok(self + Ok(Box::new(self .pduid_pdu .iter_from(current, false) .take_while(move |(k, _)| k.starts_with(&prefix)) @@ -284,6 +284,6 @@ impl service::rooms::timeline::Data for KeyValueDatabase { pdu.remove_transaction_id()?; } Ok((pdu_id, pdu)) - })) + }))) } } diff --git a/src/database/key_value/rooms/user.rs b/src/database/key_value/rooms/user.rs index d49bc1d7..66681e3c 100644 --- a/src/database/key_value/rooms/user.rs +++ b/src/database/key_value/rooms/user.rs @@ -1,6 +1,6 @@ use ruma::{UserId, RoomId}; -use crate::{service, database::KeyValueDatabase, utils, Error, Result}; +use crate::{service, database::KeyValueDatabase, utils, Error, Result, services}; impl service::rooms::user::Data for KeyValueDatabase { fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { @@ -50,7 +50,7 @@ impl service::rooms::user::Data for KeyValueDatabase { token: u64, shortstatehash: u64, ) -> Result<()> { - let shortroomid = self.get_shortroomid(room_id)?.expect("room exists"); + let shortroomid = services().rooms.short.get_shortroomid(room_id)?.expect("room exists"); let mut key = shortroomid.to_be_bytes().to_vec(); key.extend_from_slice(&token.to_be_bytes()); @@ -60,7 +60,7 @@ impl service::rooms::user::Data for KeyValueDatabase { } fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result> { - let shortroomid = self.get_shortroomid(room_id)?.expect("room exists"); + let shortroomid = services().rooms.short.get_shortroomid(room_id)?.expect("room exists"); let mut key = shortroomid.to_be_bytes().to_vec(); key.extend_from_slice(&token.to_be_bytes()); diff --git a/src/database/key_value/users.rs b/src/database/key_value/users.rs index 82e3bac6..338d8800 100644 --- a/src/database/key_value/users.rs +++ b/src/database/key_value/users.rs @@ -57,12 +57,12 @@ impl service::users::Data for KeyValueDatabase { /// Returns an iterator over all users on this homeserver. fn iter(&self) -> Box>>> { - self.userid_password.iter().map(|(bytes, _)| { + Box::new(self.userid_password.iter().map(|(bytes, _)| { UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| { Error::bad_database("User ID in userid_password is invalid unicode.") })?) .map_err(|_| Error::bad_database("User ID in userid_password is invalid.")) - }) + })) } /// Returns a list of local users as list of usernames. @@ -274,7 +274,7 @@ impl service::users::Data for KeyValueDatabase { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xff); // All devices have metadata - self.userdeviceid_metadata + Box::new(self.userdeviceid_metadata .scan_prefix(prefix) .map(|(bytes, _)| { Ok(utils::string_from_bytes( @@ -285,7 +285,7 @@ impl service::users::Data for KeyValueDatabase { ) .map_err(|_| Error::bad_database("Device ID in userdeviceid_metadata is invalid."))? .into()) - }) + })) } /// Replaces the access token of one device. @@ -617,7 +617,7 @@ impl service::users::Data for KeyValueDatabase { let to = to.unwrap_or(u64::MAX); - self.keychangeid_userid + Box::new(self.keychangeid_userid .iter_from(&start, false) .take_while(move |(k, _)| { k.starts_with(&prefix) @@ -638,7 +638,7 @@ impl service::users::Data for KeyValueDatabase { Error::bad_database("User ID in devicekeychangeid_userid is invalid unicode.") })?) .map_err(|_| Error::bad_database("User ID in devicekeychangeid_userid is invalid.")) - }) + })) } fn mark_device_key_update( @@ -646,9 +646,10 @@ impl service::users::Data for KeyValueDatabase { user_id: &UserId, ) -> Result<()> { let count = services().globals.next_count()?.to_be_bytes(); - for room_id in services().rooms.rooms_joined(user_id).filter_map(|r| r.ok()) { + for room_id in services().rooms.state_cache.rooms_joined(user_id).filter_map(|r| r.ok()) { // Don't send key updates to unencrypted rooms if services().rooms + .state_accessor .room_state_get(&room_id, &StateEventType::RoomEncryption, "")? .is_none() { @@ -882,12 +883,12 @@ impl service::users::Data for KeyValueDatabase { let mut key = user_id.as_bytes().to_vec(); key.push(0xff); - self.userdeviceid_metadata + Box::new(self.userdeviceid_metadata .scan_prefix(key) .map(|(_, bytes)| { serde_json::from_slice::(&bytes) .map_err(|_| Error::bad_database("Device in userdeviceid_metadata is invalid.")) - }) + })) } /// Creates a new sync filter. Returns the filter id. diff --git a/src/database/mod.rs b/src/database/mod.rs index 22bfef06..aa5c5839 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1,7 +1,7 @@ pub mod abstraction; pub mod key_value; -use crate::{utils, Config, Error, Result, service::{users, globals, uiaa, rooms, account_data, media, key_backups, transaction_ids, sending, appservice, pusher}}; +use crate::{utils, Config, Error, Result, service::{users, globals, uiaa, rooms::{self, state_compressor::CompressedStateEvent}, account_data, media, key_backups, transaction_ids, sending, appservice, pusher}, services, PduEvent, Services, SERVICES}; use abstraction::KeyValueDatabaseEngine; use directories::ProjectDirs; use futures_util::{stream::FuturesUnordered, StreamExt}; @@ -9,7 +9,7 @@ use lru_cache::LruCache; use ruma::{ events::{ push_rules::PushRulesEventContent, room::message::RoomMessageEventContent, - GlobalAccountDataEvent, GlobalAccountDataEventType, + GlobalAccountDataEvent, GlobalAccountDataEventType, StateEventType, }, push::Ruleset, DeviceId, EventId, RoomId, UserId, signatures::CanonicalJsonValue, @@ -151,6 +151,30 @@ pub struct KeyValueDatabase { //pub pusher: pusher::PushData, pub(super) senderkey_pusher: Arc, + + pub(super) cached_registrations: Arc>>, + pub(super) pdu_cache: Mutex, Arc>>, + pub(super) shorteventid_cache: Mutex>>, + pub(super) auth_chain_cache: Mutex, Arc>>>, + pub(super) eventidshort_cache: Mutex, u64>>, + pub(super) statekeyshort_cache: Mutex>, + pub(super) shortstatekey_cache: Mutex>, + pub(super) our_real_users_cache: RwLock, Arc>>>>, + pub(super) appservice_in_room_cache: RwLock, HashMap>>, + pub(super) lazy_load_waiting: + Mutex, Box, Box, u64), HashSet>>>, + pub(super) stateinfo_cache: Mutex< + LruCache< + u64, + Vec<( + u64, // sstatehash + HashSet, // full state + HashSet, // added + HashSet, // removed + )>, + >, + >, + pub(super) lasttimelinecount_cache: Mutex, u64>>, } impl KeyValueDatabase { @@ -214,7 +238,7 @@ impl KeyValueDatabase { } /// Load an existing database or create a new one. - pub async fn load_or_create(config: &Config) -> Result>> { + pub async fn load_or_create(config: &Config) -> Result<()> { Self::check_db_setup(config)?; if !Path::new(&config.database_path).exists() { @@ -253,7 +277,7 @@ impl KeyValueDatabase { let (admin_sender, admin_receiver) = mpsc::unbounded_channel(); let (sending_sender, sending_receiver) = mpsc::unbounded_channel(); - let db = Self { + let db = Arc::new(Self { _db: builder.clone(), userid_password: builder.open_tree("userid_password")?, userid_displayname: builder.open_tree("userid_displayname")?, @@ -345,18 +369,53 @@ impl KeyValueDatabase { senderkey_pusher: builder.open_tree("senderkey_pusher")?, global: builder.open_tree("global")?, server_signingkeys: builder.open_tree("server_signingkeys")?, - }; - // TODO: do this after constructing the db + cached_registrations: Arc::new(RwLock::new(HashMap::new())), + pdu_cache: Mutex::new(LruCache::new( + config + .pdu_cache_capacity + .try_into() + .expect("pdu cache capacity fits into usize"), + )), + auth_chain_cache: Mutex::new(LruCache::new( + (100_000.0 * config.conduit_cache_capacity_modifier) as usize, + )), + shorteventid_cache: Mutex::new(LruCache::new( + (100_000.0 * config.conduit_cache_capacity_modifier) as usize, + )), + eventidshort_cache: Mutex::new(LruCache::new( + (100_000.0 * config.conduit_cache_capacity_modifier) as usize, + )), + shortstatekey_cache: Mutex::new(LruCache::new( + (100_000.0 * config.conduit_cache_capacity_modifier) as usize, + )), + statekeyshort_cache: Mutex::new(LruCache::new( + (100_000.0 * config.conduit_cache_capacity_modifier) as usize, + )), + our_real_users_cache: RwLock::new(HashMap::new()), + appservice_in_room_cache: RwLock::new(HashMap::new()), + lazy_load_waiting: Mutex::new(HashMap::new()), + stateinfo_cache: Mutex::new(LruCache::new( + (100.0 * config.conduit_cache_capacity_modifier) as usize, + )), + lasttimelinecount_cache: Mutex::new(HashMap::new()), + + }); + + let services_raw = Services::build(Arc::clone(&db)); + + // This is the first and only time we initialize the SERVICE static + *SERVICES.write().unwrap() = Some(services_raw); + // Matrix resource ownership is based on the server name; changing it // requires recreating the database from scratch. - if guard.users.count()? > 0 { + if services().users.count()? > 0 { let conduit_user = - UserId::parse_with_server_name("conduit", guard.globals.server_name()) + UserId::parse_with_server_name("conduit", services().globals.server_name()) .expect("@conduit:server_name is valid"); - if !guard.users.exists(&conduit_user)? { + if !services().users.exists(&conduit_user)? { error!( "The {} server user does not exist, and the database is not new.", conduit_user @@ -370,11 +429,10 @@ impl KeyValueDatabase { // If the database has any data, perform data migrations before starting let latest_database_version = 11; - if guard.users.count()? > 0 { - let db = &*guard; + if services().users.count()? > 0 { // MIGRATIONS - if db.globals.database_version()? < 1 { - for (roomserverid, _) in db.rooms.roomserverids.iter() { + if services().globals.database_version()? < 1 { + for (roomserverid, _) in db.roomserverids.iter() { let mut parts = roomserverid.split(|&b| b == 0xff); let room_id = parts.next().expect("split always returns one element"); let servername = match parts.next() { @@ -388,17 +446,17 @@ impl KeyValueDatabase { serverroomid.push(0xff); serverroomid.extend_from_slice(room_id); - db.rooms.serverroomids.insert(&serverroomid, &[])?; + db.serverroomids.insert(&serverroomid, &[])?; } - db.globals.bump_database_version(1)?; + services().globals.bump_database_version(1)?; warn!("Migration: 0 -> 1 finished"); } - if db.globals.database_version()? < 2 { + if services().globals.database_version()? < 2 { // We accidentally inserted hashed versions of "" into the db instead of just "" - for (userid, password) in db.users.userid_password.iter() { + for (userid, password) in db.userid_password.iter() { let password = utils::string_from_bytes(&password); let empty_hashed_password = password.map_or(false, |password| { @@ -406,59 +464,59 @@ impl KeyValueDatabase { }); if empty_hashed_password { - db.users.userid_password.insert(&userid, b"")?; + db.userid_password.insert(&userid, b"")?; } } - db.globals.bump_database_version(2)?; + services().globals.bump_database_version(2)?; warn!("Migration: 1 -> 2 finished"); } - if db.globals.database_version()? < 3 { + if services().globals.database_version()? < 3 { // Move media to filesystem - for (key, content) in db.media.mediaid_file.iter() { + for (key, content) in db.mediaid_file.iter() { if content.is_empty() { continue; } - let path = db.globals.get_media_file(&key); + let path = services().globals.get_media_file(&key); let mut file = fs::File::create(path)?; file.write_all(&content)?; - db.media.mediaid_file.insert(&key, &[])?; + db.mediaid_file.insert(&key, &[])?; } - db.globals.bump_database_version(3)?; + services().globals.bump_database_version(3)?; warn!("Migration: 2 -> 3 finished"); } - if db.globals.database_version()? < 4 { - // Add federated users to db as deactivated - for our_user in db.users.iter() { + if services().globals.database_version()? < 4 { + // Add federated users to services() as deactivated + for our_user in services().users.iter() { let our_user = our_user?; - if db.users.is_deactivated(&our_user)? { + if services().users.is_deactivated(&our_user)? { continue; } - for room in db.rooms.rooms_joined(&our_user) { - for user in db.rooms.room_members(&room?) { + for room in services().rooms.state_cache.rooms_joined(&our_user) { + for user in services().rooms.state_cache.room_members(&room?) { let user = user?; - if user.server_name() != db.globals.server_name() { + if user.server_name() != services().globals.server_name() { println!("Migration: Creating user {}", user); - db.users.create(&user, None)?; + services().users.create(&user, None)?; } } } } - db.globals.bump_database_version(4)?; + services().globals.bump_database_version(4)?; warn!("Migration: 3 -> 4 finished"); } - if db.globals.database_version()? < 5 { + if services().globals.database_version()? < 5 { // Upgrade user data store - for (roomuserdataid, _) in db.account_data.roomuserdataid_accountdata.iter() { + for (roomuserdataid, _) in db.roomuserdataid_accountdata.iter() { let mut parts = roomuserdataid.split(|&b| b == 0xff); let room_id = parts.next().unwrap(); let user_id = parts.next().unwrap(); @@ -470,30 +528,29 @@ impl KeyValueDatabase { key.push(0xff); key.extend_from_slice(event_type); - db.account_data - .roomusertype_roomuserdataid + db.roomusertype_roomuserdataid .insert(&key, &roomuserdataid)?; } - db.globals.bump_database_version(5)?; + services().globals.bump_database_version(5)?; warn!("Migration: 4 -> 5 finished"); } - if db.globals.database_version()? < 6 { + if services().globals.database_version()? < 6 { // Set room member count - for (roomid, _) in db.rooms.roomid_shortstatehash.iter() { + for (roomid, _) in db.roomid_shortstatehash.iter() { let string = utils::string_from_bytes(&roomid).unwrap(); let room_id = <&RoomId>::try_from(string.as_str()).unwrap(); - db.rooms.update_joined_count(room_id, &db)?; + services().rooms.state_cache.update_joined_count(room_id)?; } - db.globals.bump_database_version(6)?; + services().globals.bump_database_version(6)?; warn!("Migration: 5 -> 6 finished"); } - if db.globals.database_version()? < 7 { + if services().globals.database_version()? < 7 { // Upgrade state store let mut last_roomstates: HashMap, u64> = HashMap::new(); let mut current_sstatehash: Option = None; @@ -513,7 +570,7 @@ impl KeyValueDatabase { let states_parents = last_roomsstatehash.map_or_else( || Ok(Vec::new()), |&last_roomsstatehash| { - db.rooms.state_accessor.load_shortstatehash_info(dbg!(last_roomsstatehash)) + services().rooms.state_compressor.load_shortstatehash_info(dbg!(last_roomsstatehash)) }, )?; @@ -535,7 +592,7 @@ impl KeyValueDatabase { (current_state, HashSet::new()) }; - db.rooms.save_state_from_diff( + services().rooms.state_compressor.save_state_from_diff( dbg!(current_sstatehash), statediffnew, statediffremoved, @@ -544,7 +601,7 @@ impl KeyValueDatabase { )?; /* - let mut tmp = db.rooms.load_shortstatehash_info(¤t_sstatehash, &db)?; + let mut tmp = services().rooms.load_shortstatehash_info(¤t_sstatehash)?; let state = tmp.pop().unwrap(); println!( "{}\t{}{:?}: {:?} + {:?} - {:?}", @@ -587,14 +644,13 @@ impl KeyValueDatabase { current_sstatehash = Some(sstatehash); let event_id = db - .rooms .shorteventid_eventid .get(&seventid) .unwrap() .unwrap(); let string = utils::string_from_bytes(&event_id).unwrap(); let event_id = <&EventId>::try_from(string.as_str()).unwrap(); - let pdu = db.rooms.get_pdu(event_id).unwrap().unwrap(); + let pdu = services().rooms.timeline.get_pdu(event_id).unwrap().unwrap(); if Some(&pdu.room_id) != current_room.as_ref() { current_room = Some(pdu.room_id.clone()); @@ -615,20 +671,20 @@ impl KeyValueDatabase { )?; } - db.globals.bump_database_version(7)?; + services().globals.bump_database_version(7)?; warn!("Migration: 6 -> 7 finished"); } - if db.globals.database_version()? < 8 { + if services().globals.database_version()? < 8 { // Generate short room ids for all rooms - for (room_id, _) in db.rooms.roomid_shortstatehash.iter() { - let shortroomid = db.globals.next_count()?.to_be_bytes(); - db.rooms.roomid_shortroomid.insert(&room_id, &shortroomid)?; + for (room_id, _) in db.roomid_shortstatehash.iter() { + let shortroomid = services().globals.next_count()?.to_be_bytes(); + db.roomid_shortroomid.insert(&room_id, &shortroomid)?; info!("Migration: 8"); } // Update pduids db layout - let mut batch = db.rooms.pduid_pdu.iter().filter_map(|(key, v)| { + let mut batch = db.pduid_pdu.iter().filter_map(|(key, v)| { if !key.starts_with(b"!") { return None; } @@ -637,7 +693,6 @@ impl KeyValueDatabase { let count = parts.next().unwrap(); let short_room_id = db - .rooms .roomid_shortroomid .get(room_id) .unwrap() @@ -649,9 +704,9 @@ impl KeyValueDatabase { Some((new_key, v)) }); - db.rooms.pduid_pdu.insert_batch(&mut batch)?; + db.pduid_pdu.insert_batch(&mut batch)?; - let mut batch2 = db.rooms.eventid_pduid.iter().filter_map(|(k, value)| { + let mut batch2 = db.eventid_pduid.iter().filter_map(|(k, value)| { if !value.starts_with(b"!") { return None; } @@ -660,7 +715,6 @@ impl KeyValueDatabase { let count = parts.next().unwrap(); let short_room_id = db - .rooms .roomid_shortroomid .get(room_id) .unwrap() @@ -672,17 +726,16 @@ impl KeyValueDatabase { Some((k, new_value)) }); - db.rooms.eventid_pduid.insert_batch(&mut batch2)?; + db.eventid_pduid.insert_batch(&mut batch2)?; - db.globals.bump_database_version(8)?; + services().globals.bump_database_version(8)?; warn!("Migration: 7 -> 8 finished"); } - if db.globals.database_version()? < 9 { + if services().globals.database_version()? < 9 { // Update tokenids db layout let mut iter = db - .rooms .tokenids .iter() .filter_map(|(key, _)| { @@ -696,7 +749,6 @@ impl KeyValueDatabase { let pdu_id_count = parts.next().unwrap(); let short_room_id = db - .rooms .roomid_shortroomid .get(room_id) .unwrap() @@ -712,8 +764,7 @@ impl KeyValueDatabase { .peekable(); while iter.peek().is_some() { - db.rooms - .tokenids + db.tokenids .insert_batch(&mut iter.by_ref().take(1000))?; println!("smaller batch done"); } @@ -721,7 +772,6 @@ impl KeyValueDatabase { info!("Deleting starts"); let batch2: Vec<_> = db - .rooms .tokenids .iter() .filter_map(|(key, _)| { @@ -736,38 +786,37 @@ impl KeyValueDatabase { for key in batch2 { println!("del"); - db.rooms.tokenids.remove(&key)?; + db.tokenids.remove(&key)?; } - db.globals.bump_database_version(9)?; + services().globals.bump_database_version(9)?; warn!("Migration: 8 -> 9 finished"); } - if db.globals.database_version()? < 10 { + if services().globals.database_version()? < 10 { // Add other direction for shortstatekeys - for (statekey, shortstatekey) in db.rooms.statekey_shortstatekey.iter() { - db.rooms - .shortstatekey_statekey + for (statekey, shortstatekey) in db.statekey_shortstatekey.iter() { + db.shortstatekey_statekey .insert(&shortstatekey, &statekey)?; } // Force E2EE device list updates so we can send them over federation - for user_id in db.users.iter().filter_map(|r| r.ok()) { - db.users - .mark_device_key_update(&user_id, &db.rooms, &db.globals)?; + for user_id in services().users.iter().filter_map(|r| r.ok()) { + services().users + .mark_device_key_update(&user_id)?; } - db.globals.bump_database_version(10)?; + services().globals.bump_database_version(10)?; warn!("Migration: 9 -> 10 finished"); } - if db.globals.database_version()? < 11 { + if services().globals.database_version()? < 11 { db._db .open_tree("userdevicesessionid_uiaarequest")? .clear()?; - db.globals.bump_database_version(11)?; + services().globals.bump_database_version(11)?; warn!("Migration: 10 -> 11 finished"); } @@ -779,12 +828,12 @@ impl KeyValueDatabase { config.database_backend, latest_database_version ); } else { - guard + services() .globals .bump_database_version(latest_database_version)?; // Create the admin room and server user on first run - create_admin_room().await?; + services().admin.create_admin_room().await?; warn!( "Created new {} database with version {}", @@ -793,16 +842,16 @@ impl KeyValueDatabase { } // This data is probably outdated - guard.rooms.edus.presenceid_presence.clear()?; + db.presenceid_presence.clear()?; - guard.admin.start_handler(Arc::clone(&db), admin_receiver); + services().admin.start_handler(admin_receiver); // Set emergency access for the conduit user - match set_emergency_access(&guard) { + match set_emergency_access() { Ok(pwd_set) => { if pwd_set { warn!("The Conduit account emergency password is set! Please unset it as soon as you finish admin account recovery!"); - guard.admin.send_message(RoomMessageEventContent::text_plain("The Conduit account emergency password is set! Please unset it as soon as you finish admin account recovery!")); + services().admin.send_message(RoomMessageEventContent::text_plain("The Conduit account emergency password is set! Please unset it as soon as you finish admin account recovery!")); } } Err(e) => { @@ -813,21 +862,19 @@ impl KeyValueDatabase { } }; - guard + services() .sending - .start_handler(Arc::clone(&db), sending_receiver); + .start_handler(sending_receiver); - drop(guard); + Self::start_cleanup_task(config).await; - Self::start_cleanup_task(Arc::clone(&db), config).await; - - Ok(db) + Ok(()) } #[cfg(feature = "conduit_bin")] - pub async fn on_shutdown(db: Arc>) { + pub async fn on_shutdown() { info!(target: "shutdown-sync", "Received shutdown notification, notifying sync helpers..."); - db.read().await.globals.rotate.fire(); + services().globals.rotate.fire(); } pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) { @@ -844,33 +891,30 @@ impl KeyValueDatabase { // Return when *any* user changed his key // TODO: only send for user they share a room with futures.push( - self.users - .todeviceid_events + self.todeviceid_events .watch_prefix(&userdeviceid_prefix), ); - futures.push(self.rooms.userroomid_joined.watch_prefix(&userid_prefix)); + futures.push(self.userroomid_joined.watch_prefix(&userid_prefix)); futures.push( - self.rooms - .userroomid_invitestate + self.userroomid_invitestate .watch_prefix(&userid_prefix), ); - futures.push(self.rooms.userroomid_leftstate.watch_prefix(&userid_prefix)); + futures.push(self.userroomid_leftstate.watch_prefix(&userid_prefix)); futures.push( - self.rooms - .userroomid_notificationcount + self.userroomid_notificationcount .watch_prefix(&userid_prefix), ); futures.push( - self.rooms - .userroomid_highlightcount + self.userroomid_highlightcount .watch_prefix(&userid_prefix), ); // Events for rooms we are in - for room_id in self.rooms.rooms_joined(user_id).filter_map(|r| r.ok()) { - let short_roomid = self + for room_id in services().rooms.state_cache.rooms_joined(user_id).filter_map(|r| r.ok()) { + let short_roomid = services() .rooms + .short .get_shortroomid(&room_id) .ok() .flatten() @@ -883,33 +927,28 @@ impl KeyValueDatabase { roomid_prefix.push(0xff); // PDUs - futures.push(self.rooms.pduid_pdu.watch_prefix(&short_roomid)); + futures.push(self.pduid_pdu.watch_prefix(&short_roomid)); // EDUs futures.push( - self.rooms - .edus - .roomid_lasttypingupdate + self.roomid_lasttypingupdate .watch_prefix(&roomid_bytes), ); futures.push( - self.rooms - .edus - .readreceiptid_readreceipt + self.readreceiptid_readreceipt .watch_prefix(&roomid_prefix), ); // Key changes - futures.push(self.users.keychangeid_userid.watch_prefix(&roomid_prefix)); + futures.push(self.keychangeid_userid.watch_prefix(&roomid_prefix)); // Room account data let mut roomuser_prefix = roomid_prefix.clone(); roomuser_prefix.extend_from_slice(&userid_prefix); futures.push( - self.account_data - .roomusertype_roomuserdataid + self.roomusertype_roomuserdataid .watch_prefix(&roomuser_prefix), ); } @@ -918,22 +957,20 @@ impl KeyValueDatabase { globaluserdata_prefix.extend_from_slice(&userid_prefix); futures.push( - self.account_data - .roomusertype_roomuserdataid + self.roomusertype_roomuserdataid .watch_prefix(&globaluserdata_prefix), ); // More key changes (used when user is not joined to any rooms) - futures.push(self.users.keychangeid_userid.watch_prefix(&userid_prefix)); + futures.push(self.keychangeid_userid.watch_prefix(&userid_prefix)); // One time keys futures.push( - self.users - .userid_lastonetimekeyupdate + self.userid_lastonetimekeyupdate .watch_prefix(&userid_bytes), ); - futures.push(Box::pin(self.globals.rotate.watch())); + futures.push(Box::pin(services().globals.rotate.watch())); // Wait until one of them finds something futures.next().await; @@ -950,8 +987,8 @@ impl KeyValueDatabase { res } - #[tracing::instrument(skip(db, config))] - pub async fn start_cleanup_task(db: Arc>, config: &Config) { + #[tracing::instrument(skip(config))] + pub async fn start_cleanup_task(config: &Config) { use tokio::time::interval; #[cfg(unix)] @@ -984,7 +1021,7 @@ impl KeyValueDatabase { } let start = Instant::now(); - if let Err(e) = db.read().await._db.cleanup() { + if let Err(e) = services().globals.db._db.cleanup() { error!("cleanup: Errored: {}", e); } else { info!("cleanup: Finished in {:?}", start.elapsed()); @@ -995,26 +1032,25 @@ impl KeyValueDatabase { } /// Sets the emergency password and push rules for the @conduit account in case emergency password is set -fn set_emergency_access(db: &KeyValueDatabase) -> Result { - let conduit_user = UserId::parse_with_server_name("conduit", db.globals.server_name()) +fn set_emergency_access() -> Result { + let conduit_user = UserId::parse_with_server_name("conduit", services().globals.server_name()) .expect("@conduit:server_name is a valid UserId"); - db.users - .set_password(&conduit_user, db.globals.emergency_password().as_deref())?; + services().users + .set_password(&conduit_user, services().globals.emergency_password().as_deref())?; - let (ruleset, res) = match db.globals.emergency_password() { + let (ruleset, res) = match services().globals.emergency_password() { Some(_) => (Ruleset::server_default(&conduit_user), Ok(true)), None => (Ruleset::new(), Ok(false)), }; - db.account_data.update( + services().account_data.update( None, &conduit_user, GlobalAccountDataEventType::PushRules.to_string().into(), &GlobalAccountDataEvent { content: PushRulesEventContent { global: ruleset }, }, - &db.globals, )?; res diff --git a/src/lib.rs b/src/lib.rs index 72399003..75cf6c7e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,22 +13,16 @@ mod service; pub mod api; mod utils; -use std::{cell::Cell, sync::RwLock}; +use std::{cell::Cell, sync::{RwLock, Arc}}; pub use config::Config; pub use utils::error::{Error, Result}; pub use service::{Services, pdu::PduEvent}; pub use api::ruma_wrapper::{Ruma, RumaResponse}; -use crate::database::KeyValueDatabase; +pub static SERVICES: RwLock>> = RwLock::new(None); -pub static SERVICES: RwLock> = RwLock::new(None); - -enum ServicesEnum { - Rocksdb(Services) -} - -pub fn services<'a>() -> &'a Services { - &SERVICES.read().unwrap() +pub fn services<'a>() -> Arc { + Arc::clone(&SERVICES.read().unwrap()) } diff --git a/src/main.rs b/src/main.rs index 543b953e..d5b2731e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -69,19 +69,14 @@ async fn main() { config.warn_deprecated(); - let db = match KeyValueDatabase::load_or_create(&config).await { - Ok(db) => db, - Err(e) => { - eprintln!( - "The database couldn't be loaded or created. The following error occured: {}", - e - ); - std::process::exit(1); - } + if let Err(e) = KeyValueDatabase::load_or_create(&config).await { + eprintln!( + "The database couldn't be loaded or created. The following error occured: {}", + e + ); + std::process::exit(1); }; - SERVICES.set(db).expect("this is the first and only time we initialize the SERVICE static"); - let start = async { run_server().await.unwrap(); }; diff --git a/src/service/account_data/mod.rs b/src/service/account_data/mod.rs index c56c69d2..35ca1495 100644 --- a/src/service/account_data/mod.rs +++ b/src/service/account_data/mod.rs @@ -17,11 +17,11 @@ use tracing::error; use crate::{service::*, services, utils, Error, Result}; -pub struct Service { - db: D, +pub struct Service { + db: Box, } -impl Service { +impl Service { /// Places one event in the account data of the user and removes the previous entry. #[tracing::instrument(skip(self, room_id, user_id, event_type, data))] pub fn update( diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs index 63fa3afe..1a5ce50c 100644 --- a/src/service/appservice/mod.rs +++ b/src/service/appservice/mod.rs @@ -3,11 +3,11 @@ pub use data::Data; use crate::Result; -pub struct Service { - db: D, +pub struct Service { + db: Box, } -impl Service { +impl Service { /// Registers an appservice and returns the ID to the caller pub fn register_appservice(&self, yaml: serde_yaml::Value) -> Result { self.db.register_appservice(yaml) diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 6cfeab81..48d7b064 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -36,8 +36,8 @@ type SyncHandle = ( Receiver>>, // rx ); -pub struct Service { - pub db: D, +pub struct Service { + pub db: Box, pub actual_destination_cache: Arc>, // actual_destination, host pub tls_name_override: Arc>, @@ -92,9 +92,9 @@ impl Default for RotationHandler { } -impl Service { +impl Service { pub fn load( - db: D, + db: Box, config: Config, ) -> Result { let keypair = db.load_keypair(); diff --git a/src/service/key_backups/mod.rs b/src/service/key_backups/mod.rs index ce867fb5..4bd9efd3 100644 --- a/src/service/key_backups/mod.rs +++ b/src/service/key_backups/mod.rs @@ -12,11 +12,11 @@ use ruma::{ }; use std::{collections::BTreeMap, sync::Arc}; -pub struct Service { - db: D, +pub struct Service { + db: Box, } -impl Service { +impl Service { pub fn create_backup( &self, user_id: &UserId, diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs index 5037809c..d61292bb 100644 --- a/src/service/media/mod.rs +++ b/src/service/media/mod.rs @@ -15,11 +15,11 @@ pub struct FileMeta { pub file: Vec, } -pub struct Service { - db: D, +pub struct Service { + db: Box, } -impl Service { +impl Service { /// Uploads a file. pub async fn create( &self, diff --git a/src/service/mod.rs b/src/service/mod.rs index 4364c72e..47d4651d 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + pub mod account_data; pub mod admin; pub mod appservice; @@ -12,18 +14,36 @@ pub mod transaction_ids; pub mod uiaa; pub mod users; -pub struct Services -{ - pub appservice: appservice::Service, - pub pusher: pusher::Service, - pub rooms: rooms::Service, - pub transaction_ids: transaction_ids::Service, - pub uiaa: uiaa::Service, - pub users: users::Service, - pub account_data: account_data::Service, +pub struct Services { + pub appservice: appservice::Service, + pub pusher: pusher::Service, + pub rooms: rooms::Service, + pub transaction_ids: transaction_ids::Service, + pub uiaa: uiaa::Service, + pub users: users::Service, + pub account_data: account_data::Service, pub admin: admin::Service, - pub globals: globals::Service, - pub key_backups: key_backups::Service, - pub media: media::Service, + pub globals: globals::Service, + pub key_backups: key_backups::Service, + pub media: media::Service, pub sending: sending::Service, } + +impl Services { + pub fn build(db: Arc) { + Self { + appservice: appservice::Service { db: Arc::clone(&db) }, + pusher: appservice::Service { db: Arc::clone(&db) }, + rooms: appservice::Service { db: Arc::clone(&db) }, + transaction_ids: appservice::Service { db: Arc::clone(&db) }, + uiaa: appservice::Service { db: Arc::clone(&db) }, + users: appservice::Service { db: Arc::clone(&db) }, + account_data: appservice::Service { db: Arc::clone(&db) }, + admin: appservice::Service { db: Arc::clone(&db) }, + globals: appservice::Service { db: Arc::clone(&db) }, + key_backups: appservice::Service { db: Arc::clone(&db) }, + media: appservice::Service { db: Arc::clone(&db) }, + sending: appservice::Service { db: Arc::clone(&db) }, + } + } +} diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs index 64c7f1fa..af30ca47 100644 --- a/src/service/pusher/mod.rs +++ b/src/service/pusher/mod.rs @@ -23,11 +23,11 @@ use ruma::{ use std::{fmt::Debug, mem}; use tracing::{error, info, warn}; -pub struct Service { - db: D, +pub struct Service { + db: Box, } -impl Service { +impl Service { pub fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()> { self.db.set_pusher(sender, pusher) } diff --git a/src/service/rooms/alias/data.rs b/src/service/rooms/alias/data.rs index c5d45e36..81022096 100644 --- a/src/service/rooms/alias/data.rs +++ b/src/service/rooms/alias/data.rs @@ -25,5 +25,5 @@ pub trait Data { fn local_aliases_for_room( &self, room_id: &RoomId, - ) -> Result>>; + ) -> Box>>>; } diff --git a/src/service/rooms/alias/mod.rs b/src/service/rooms/alias/mod.rs index abe299d4..ef5888fc 100644 --- a/src/service/rooms/alias/mod.rs +++ b/src/service/rooms/alias/mod.rs @@ -4,11 +4,11 @@ pub use data::Data; use ruma::{RoomAliasId, RoomId}; use crate::Result; -pub struct Service { - db: D, +pub struct Service { + db: Box, } -impl Service { +impl Service { #[tracing::instrument(skip(self))] pub fn set_alias( &self, diff --git a/src/service/rooms/auth_chain/data.rs b/src/service/rooms/auth_chain/data.rs index 5177d6d6..e4e8550b 100644 --- a/src/service/rooms/auth_chain/data.rs +++ b/src/service/rooms/auth_chain/data.rs @@ -2,6 +2,6 @@ use std::collections::HashSet; use crate::Result; pub trait Data { - fn get_cached_eventid_authchain(&self, shorteventid: u64) -> Result>; + fn get_cached_eventid_authchain(&self, shorteventid: u64) -> Result>>; fn cache_eventid_authchain(&self, shorteventid: u64, auth_chain: &HashSet) -> Result<()>; } diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index 9ea4763e..26a3f3f0 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -5,11 +5,11 @@ pub use data::Data; use crate::Result; -pub struct Service { - db: D, +pub struct Service { + db: Box, } -impl Service { +impl Service { #[tracing::instrument(skip(self))] pub fn get_cached_eventid_authchain<'a>( &'a self, diff --git a/src/service/rooms/directory/mod.rs b/src/service/rooms/directory/mod.rs index 68535057..fb289941 100644 --- a/src/service/rooms/directory/mod.rs +++ b/src/service/rooms/directory/mod.rs @@ -4,11 +4,11 @@ use ruma::RoomId; use crate::Result; -pub struct Service { - db: D, +pub struct Service { + db: Box, } -impl Service { +impl Service { #[tracing::instrument(skip(self))] pub fn set_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_public(room_id) diff --git a/src/service/rooms/edus/mod.rs b/src/service/rooms/edus/mod.rs index dbe1b6e8..8552363e 100644 --- a/src/service/rooms/edus/mod.rs +++ b/src/service/rooms/edus/mod.rs @@ -4,8 +4,8 @@ pub mod typing; pub trait Data: presence::Data + read_receipt::Data + typing::Data {} -pub struct Service { - pub presence: presence::Service, - pub read_receipt: read_receipt::Service, - pub typing: typing::Service, +pub struct Service { + pub presence: presence::Service, + pub read_receipt: read_receipt::Service, + pub typing: typing::Service, } diff --git a/src/service/rooms/edus/presence/mod.rs b/src/service/rooms/edus/presence/mod.rs index 646cf549..73b7b5a5 100644 --- a/src/service/rooms/edus/presence/mod.rs +++ b/src/service/rooms/edus/presence/mod.rs @@ -6,11 +6,11 @@ use ruma::{RoomId, UserId, events::presence::PresenceEvent}; use crate::Result; -pub struct Service { - db: D, +pub struct Service { + db: Box, } -impl Service { +impl Service { /// Adds a presence event which will be saved until a new event replaces it. /// /// Note: This method takes a RoomId because presence updates are always bound to rooms to diff --git a/src/service/rooms/edus/read_receipt/mod.rs b/src/service/rooms/edus/read_receipt/mod.rs index 3f0b1476..2a4c0b7f 100644 --- a/src/service/rooms/edus/read_receipt/mod.rs +++ b/src/service/rooms/edus/read_receipt/mod.rs @@ -4,11 +4,11 @@ pub use data::Data; use ruma::{RoomId, UserId, events::receipt::ReceiptEvent, serde::Raw}; use crate::Result; -pub struct Service { - db: D, +pub struct Service { + db: Box, } -impl Service { +impl Service { /// Replaces the previous read receipt. pub fn readreceipt_update( &self, diff --git a/src/service/rooms/edus/typing/mod.rs b/src/service/rooms/edus/typing/mod.rs index 00cfdecb..16a135f8 100644 --- a/src/service/rooms/edus/typing/mod.rs +++ b/src/service/rooms/edus/typing/mod.rs @@ -4,11 +4,11 @@ use ruma::{UserId, RoomId, events::SyncEphemeralRoomEvent}; use crate::Result; -pub struct Service { - db: D, +pub struct Service { + db: Box, } -impl Service { +impl Service { /// Sets a user as typing until the timeout timestamp is reached or roomtyping_remove is /// called. pub fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> { diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index 8a8725b8..e2291126 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -1,14 +1,16 @@ /// An async function that can recursively call itself. type AsyncRecursiveType<'a, T> = Pin + 'a + Send>>; +use ruma::{RoomVersionId, signatures::CanonicalJsonObject, api::federation::discovery::{get_server_keys, get_remote_server_keys}}; +use tokio::sync::Semaphore; use std::{ collections::{btree_map, hash_map, BTreeMap, HashMap, HashSet}, pin::Pin, - sync::{Arc, RwLock}, - time::{Duration, Instant}, + sync::{Arc, RwLock, RwLockWriteGuard}, + time::{Duration, Instant, SystemTime}, }; -use futures_util::{Future, stream::FuturesUnordered}; +use futures_util::{Future, stream::FuturesUnordered, StreamExt}; use ruma::{ api::{ client::error::ErrorKind, @@ -22,7 +24,7 @@ use ruma::{ uint, EventId, MilliSecondsSinceUnixEpoch, RoomId, ServerName, ServerSigningKeyId, }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; -use tracing::{error, info, trace, warn}; +use tracing::{error, info, trace, warn, debug}; use crate::{service::*, services, Result, Error, PduEvent}; @@ -53,7 +55,7 @@ impl Service { /// it /// 14. Use state resolution to find new room state // We use some AsyncRecursiveType hacks here so we can call this async funtion recursively - #[tracing::instrument(skip(value, is_timeline_event, pub_key_map))] + #[tracing::instrument(skip(self, value, is_timeline_event, pub_key_map))] pub(crate) async fn handle_incoming_pdu<'a>( &self, origin: &'a ServerName, @@ -64,10 +66,11 @@ impl Service { pub_key_map: &'a RwLock>>, ) -> Result>> { if !services().rooms.metadata.exists(room_id)? { - return Error::BadRequest( + return Err(Error::BadRequest( ErrorKind::NotFound, "Room is unknown to this server", - )}; + )); + } services() .rooms @@ -732,7 +735,7 @@ impl Service { &incoming_pdu.sender, incoming_pdu.state_key.as_deref(), &incoming_pdu.content, - )? + )?; let soft_fail = !state_res::event_auth::auth_check( &room_version, @@ -821,7 +824,7 @@ impl Service { let shortstatekey = services() .rooms .short - .get_or_create_shortstatekey(&incoming_pdu.kind.to_string().into(), state_key)? + .get_or_create_shortstatekey(&incoming_pdu.kind.to_string().into(), state_key)?; state_after.insert(shortstatekey, Arc::from(&*incoming_pdu.event_id)); } @@ -1236,7 +1239,7 @@ impl Service { let signature_ids = signature_object.keys().cloned().collect::>(); - let fetch_res = fetch_signing_keys( + let fetch_res = self.fetch_signing_keys( signature_server.as_str().try_into().map_err(|_| { Error::BadServerResponse("Invalid servername in signatures of server response pdu.") })?, @@ -1481,4 +1484,168 @@ impl Service { )) } } + + /// Search the DB for the signing keys of the given server, if we don't have them + /// fetch them from the server and save to our DB. + #[tracing::instrument(skip_all)] + pub async fn fetch_signing_keys( + &self, + origin: &ServerName, + signature_ids: Vec, + ) -> Result> { + let contains_all_ids = + |keys: &BTreeMap| signature_ids.iter().all(|id| keys.contains_key(id)); + + let permit = services() + .globals + .servername_ratelimiter + .read() + .unwrap() + .get(origin) + .map(|s| Arc::clone(s).acquire_owned()); + + let permit = match permit { + Some(p) => p, + None => { + let mut write = services().globals.servername_ratelimiter.write().unwrap(); + let s = Arc::clone( + write + .entry(origin.to_owned()) + .or_insert_with(|| Arc::new(Semaphore::new(1))), + ); + + s.acquire_owned() + } + } + .await; + + let back_off = |id| match services() + .globals + .bad_signature_ratelimiter + .write() + .unwrap() + .entry(id) + { + hash_map::Entry::Vacant(e) => { + e.insert((Instant::now(), 1)); + } + hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), + }; + + if let Some((time, tries)) = services() + .globals + .bad_signature_ratelimiter + .read() + .unwrap() + .get(&signature_ids) + { + // Exponential backoff + let mut min_elapsed_duration = Duration::from_secs(30) * (*tries) * (*tries); + if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { + min_elapsed_duration = Duration::from_secs(60 * 60 * 24); + } + + if time.elapsed() < min_elapsed_duration { + debug!("Backing off from {:?}", signature_ids); + return Err(Error::BadServerResponse("bad signature, still backing off")); + } + } + + trace!("Loading signing keys for {}", origin); + + let mut result: BTreeMap<_, _> = services() + .globals + .signing_keys_for(origin)? + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)) + .collect(); + + if contains_all_ids(&result) { + return Ok(result); + } + + debug!("Fetching signing keys for {} over federation", origin); + + if let Some(server_key) = services() + .sending + .send_federation_request(origin, get_server_keys::v2::Request::new()) + .await + .ok() + .and_then(|resp| resp.server_key.deserialize().ok()) + { + services().globals.add_signing_key(origin, server_key.clone())?; + + result.extend( + server_key + .verify_keys + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)), + ); + result.extend( + server_key + .old_verify_keys + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)), + ); + + if contains_all_ids(&result) { + return Ok(result); + } + } + + for server in services().globals.trusted_servers() { + debug!("Asking {} for {}'s signing key", server, origin); + if let Some(server_keys) = services() + .sending + .send_federation_request( + server, + get_remote_server_keys::v2::Request::new( + origin, + MilliSecondsSinceUnixEpoch::from_system_time( + SystemTime::now() + .checked_add(Duration::from_secs(3600)) + .expect("SystemTime to large"), + ) + .expect("time is valid"), + ), + ) + .await + .ok() + .map(|resp| { + resp.server_keys + .into_iter() + .filter_map(|e| e.deserialize().ok()) + .collect::>() + }) + { + trace!("Got signing keys: {:?}", server_keys); + for k in server_keys { + services().globals.add_signing_key(origin, k.clone())?; + result.extend( + k.verify_keys + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)), + ); + result.extend( + k.old_verify_keys + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)), + ); + } + + if contains_all_ids(&result) { + return Ok(result); + } + } + } + + drop(permit); + + back_off(signature_ids); + + warn!("Failed to find public key for server: {}", origin); + Err(Error::BadServerResponse( + "Failed to find public key for server", + )) + } } diff --git a/src/service/rooms/lazy_loading/data.rs b/src/service/rooms/lazy_loading/data.rs index 5fefd3f8..f1019c13 100644 --- a/src/service/rooms/lazy_loading/data.rs +++ b/src/service/rooms/lazy_loading/data.rs @@ -15,7 +15,7 @@ pub trait Data { user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, - since: u64, + confirmed_user_ids: &mut dyn Iterator, ) -> Result<()>; fn lazy_load_reset( diff --git a/src/service/rooms/lazy_loading/mod.rs b/src/service/rooms/lazy_loading/mod.rs index 283d45af..90dad21c 100644 --- a/src/service/rooms/lazy_loading/mod.rs +++ b/src/service/rooms/lazy_loading/mod.rs @@ -1,16 +1,18 @@ mod data; -use std::collections::HashSet; +use std::{collections::{HashSet, HashMap}, sync::Mutex}; pub use data::Data; use ruma::{DeviceId, UserId, RoomId}; use crate::Result; -pub struct Service { - db: D, +pub struct Service { + db: Box, + + lazy_load_waiting: Mutex, Box, Box, u64), HashSet>>>, } -impl Service { +impl Service { #[tracing::instrument(skip(self))] pub fn lazy_load_was_sent_before( &self, @@ -50,7 +52,18 @@ impl Service { room_id: &RoomId, since: u64, ) -> Result<()> { - self.db.lazy_load_confirm_delivery(user_id, device_id, room_id, since) + if let Some(user_ids) = self.lazy_load_waiting.lock().unwrap().remove(&( + user_id.to_owned(), + device_id.to_owned(), + room_id.to_owned(), + since, + )) { + self.db.lazy_load_confirm_delivery(user_id, device_id, room_id, &mut user_ids.iter().map(|&u| &*u))?; + } else { + // Ignore + } + + Ok(()) } #[tracing::instrument(skip(self))] diff --git a/src/service/rooms/metadata/mod.rs b/src/service/rooms/metadata/mod.rs index 1bdb78d6..3c21dd19 100644 --- a/src/service/rooms/metadata/mod.rs +++ b/src/service/rooms/metadata/mod.rs @@ -4,11 +4,11 @@ use ruma::RoomId; use crate::Result; -pub struct Service { - db: D, +pub struct Service { + db: Box, } -impl Service { +impl Service { /// Checks if a room exists. #[tracing::instrument(skip(self))] pub fn exists(&self, room_id: &RoomId) -> Result { diff --git a/src/service/rooms/mod.rs b/src/service/rooms/mod.rs index 4da42236..f1b0badf 100644 --- a/src/service/rooms/mod.rs +++ b/src/service/rooms/mod.rs @@ -18,22 +18,22 @@ pub mod user; pub trait Data: alias::Data + auth_chain::Data + directory::Data + edus::Data + lazy_loading::Data + metadata::Data + outlier::Data + pdu_metadata::Data + search::Data + short::Data + state::Data + state_accessor::Data + state_cache::Data + state_compressor::Data + timeline::Data + user::Data {} -pub struct Service { - pub alias: alias::Service, - pub auth_chain: auth_chain::Service, - pub directory: directory::Service, - pub edus: edus::Service, +pub struct Service { + pub alias: alias::Service, + pub auth_chain: auth_chain::Service, + pub directory: directory::Service, + pub edus: edus::Service, pub event_handler: event_handler::Service, - pub lazy_loading: lazy_loading::Service, - pub metadata: metadata::Service, - pub outlier: outlier::Service, - pub pdu_metadata: pdu_metadata::Service, - pub search: search::Service, - pub short: short::Service, - pub state: state::Service, - pub state_accessor: state_accessor::Service, - pub state_cache: state_cache::Service, - pub state_compressor: state_compressor::Service, - pub timeline: timeline::Service, - pub user: user::Service, + pub lazy_loading: lazy_loading::Service, + pub metadata: metadata::Service, + pub outlier: outlier::Service, + pub pdu_metadata: pdu_metadata::Service, + pub search: search::Service, + pub short: short::Service, + pub state: state::Service, + pub state_accessor: state_accessor::Service, + pub state_cache: state_cache::Service, + pub state_compressor: state_compressor::Service, + pub timeline: timeline::Service, + pub user: user::Service, } diff --git a/src/service/rooms/outlier/mod.rs b/src/service/rooms/outlier/mod.rs index a495db8f..5493ce48 100644 --- a/src/service/rooms/outlier/mod.rs +++ b/src/service/rooms/outlier/mod.rs @@ -4,11 +4,11 @@ use ruma::{EventId, signatures::CanonicalJsonObject}; use crate::{Result, PduEvent}; -pub struct Service { - db: D, +pub struct Service { + db: Box, } -impl Service { +impl Service { /// Returns the pdu from the outlier tree. pub fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result> { self.db.get_outlier_pdu_json(event_id) diff --git a/src/service/rooms/pdu_metadata/mod.rs b/src/service/rooms/pdu_metadata/mod.rs index c57c1a28..a81d05c1 100644 --- a/src/service/rooms/pdu_metadata/mod.rs +++ b/src/service/rooms/pdu_metadata/mod.rs @@ -6,11 +6,11 @@ use ruma::{RoomId, EventId}; use crate::Result; -pub struct Service { - db: D, +pub struct Service { + db: Box, } -impl Service { +impl Service { #[tracing::instrument(skip(self, room_id, event_ids))] pub fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) -> Result<()> { self.db.mark_as_referenced(room_id, event_ids) diff --git a/src/service/rooms/search/data.rs b/src/service/rooms/search/data.rs index c0fd2a37..b62904c1 100644 --- a/src/service/rooms/search/data.rs +++ b/src/service/rooms/search/data.rs @@ -2,7 +2,7 @@ use ruma::RoomId; use crate::Result; pub trait Data { - fn index_pdu<'a>(&self, shortroomid: u64, pdu_id: u64, message_body: String) -> Result<()>; + fn index_pdu<'a>(&self, shortroomid: u64, pdu_id: &[u8], message_body: String) -> Result<()>; fn search_pdus<'a>( &'a self, diff --git a/src/service/rooms/search/mod.rs b/src/service/rooms/search/mod.rs index b7023f32..dc571910 100644 --- a/src/service/rooms/search/mod.rs +++ b/src/service/rooms/search/mod.rs @@ -4,11 +4,16 @@ pub use data::Data; use crate::Result; use ruma::RoomId; -pub struct Service { - db: D, +pub struct Service { + db: Box, } -impl Service { +impl Service { + #[tracing::instrument(skip(self))] + pub fn index_pdu<'a>(&self, shortroomid: u64, pdu_id: &[u8], message_body: String) -> Result<()> { + self.db.index_pdu(shortroomid, pdu_id, message_body) + } + #[tracing::instrument(skip(self))] pub fn search_pdus<'a>( &'a self, diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index 1eb891e6..a024dc67 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -6,11 +6,11 @@ use ruma::{EventId, events::StateEventType, RoomId}; use crate::{Result, Error, utils, services}; -pub struct Service { - db: D, +pub struct Service { + db: Box, } -impl Service { +impl Service { pub fn get_or_create_shorteventid( &self, event_id: &EventId, diff --git a/src/service/rooms/state/data.rs b/src/service/rooms/state/data.rs index fd0de282..7008d86f 100644 --- a/src/service/rooms/state/data.rs +++ b/src/service/rooms/state/data.rs @@ -1,6 +1,5 @@ use std::sync::Arc; use std::{sync::MutexGuard, collections::HashSet}; -use std::fmt::Debug; use crate::Result; use ruma::{EventId, RoomId}; @@ -22,7 +21,7 @@ pub trait Data { /// Replace the forward extremities of the room. fn set_forward_extremities<'a>(&self, room_id: &RoomId, - event_ids: impl IntoIterator + Debug, + event_ids: &dyn Iterator, _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex ) -> Result<()>; } diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index a26ed46b..979060d9 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -10,11 +10,11 @@ use crate::{Result, services, PduEvent, Error, utils::calculate_hash}; use super::state_compressor::CompressedStateEvent; -pub struct Service { - db: D, +pub struct Service { + db: Box, } -impl Service { +impl Service { /// Set the room to the given statehash and update caches. pub fn force_state( &self, @@ -23,6 +23,15 @@ impl Service { statediffnew: HashSet, statediffremoved: HashSet, ) -> Result<()> { + let mutex_state = Arc::clone( + services().globals + .roomid_mutex_state + .write() + .unwrap() + .entry(body.room_id.to_owned()) + .or_default(), + ); + let state_lock = mutex_state.lock().await; for event_id in statediffnew.into_iter().filter_map(|new| { services().rooms.state_compressor.parse_compressed_state_event(new) @@ -70,7 +79,9 @@ impl Service { services().room.state_cache.update_joined_count(room_id)?; - self.db.set_room_state(room_id, shortstatehash); + self.db.set_room_state(room_id, shortstatehash, &state_lock); + + drop(state_lock); Ok(()) } diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index 5d6886d9..1911e52f 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -6,11 +6,11 @@ use ruma::{events::StateEventType, RoomId, EventId}; use crate::{Result, PduEvent}; -pub struct Service { - db: D, +pub struct Service { + db: Box, } -impl Service { +impl Service { /// Builds a StateMap by iterating over all keys that start /// with state_hash, this gives the full state for the given state_hash. #[tracing::instrument(skip(self))] diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index c3b4eb91..18d1123e 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -7,11 +7,11 @@ use ruma::{RoomId, UserId, events::{room::{member::MembershipState, create::Room use crate::{Result, services, utils, Error}; -pub struct Service { - db: D, +pub struct Service { + db: Box, } -impl Service { +impl Service { /// Update current membership data. #[tracing::instrument(skip(self, last_state))] pub fn update_membership( diff --git a/src/service/rooms/state_compressor/data.rs b/src/service/rooms/state_compressor/data.rs index 17689364..cd872422 100644 --- a/src/service/rooms/state_compressor/data.rs +++ b/src/service/rooms/state_compressor/data.rs @@ -1,10 +1,12 @@ +use std::collections::HashSet; + use super::CompressedStateEvent; use crate::Result; pub struct StateDiff { - parent: Option, - added: Vec, - removed: Vec, + pub parent: Option, + pub added: HashSet, + pub removed: HashSet, } pub trait Data { diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index 619e4cf5..ab9f4275 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -8,13 +8,13 @@ use crate::{Result, utils, services}; use self::data::StateDiff; -pub struct Service { - db: D, +pub struct Service { + db: Box, } pub type CompressedStateEvent = [u8; 2 * size_of::()]; -impl Service { +impl Service { /// Returns a stack with info on shortstatehash, full state, added diff and removed diff for the selected shortstatehash and each parent layer. #[tracing::instrument(skip(self))] pub fn load_shortstatehash_info( diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 7669b0b3..e8f42053 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -20,11 +20,11 @@ use crate::{services, Result, service::pdu::{PduBuilder, EventHash}, Error, PduE use super::state_compressor::CompressedStateEvent; -pub struct Service { - db: D, +pub struct Service { + db: Box, } -impl Service { +impl Service { /* /// Checks if a room exists. #[tracing::instrument(skip(self))] diff --git a/src/service/rooms/user/mod.rs b/src/service/rooms/user/mod.rs index 729887c3..7c7dfae6 100644 --- a/src/service/rooms/user/mod.rs +++ b/src/service/rooms/user/mod.rs @@ -4,11 +4,11 @@ use ruma::{RoomId, UserId}; use crate::Result; -pub struct Service { - db: D, +pub struct Service { + db: Box, } -impl Service { +impl Service { pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { self.db.reset_notification_counts(user_id, room_id) } diff --git a/src/service/transaction_ids/mod.rs b/src/service/transaction_ids/mod.rs index ea923722..a9c516cf 100644 --- a/src/service/transaction_ids/mod.rs +++ b/src/service/transaction_ids/mod.rs @@ -4,11 +4,11 @@ pub use data::Data; use ruma::{UserId, DeviceId, TransactionId}; use crate::Result; -pub struct Service { - db: D, +pub struct Service { + db: Box, } -impl Service { +impl Service { pub fn add_txnid( &self, user_id: &UserId, diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs index ffdbf356..01c0d2f6 100644 --- a/src/service/uiaa/mod.rs +++ b/src/service/uiaa/mod.rs @@ -6,11 +6,11 @@ use tracing::error; use crate::{Result, utils, Error, services, api::client_server::SESSION_ID_LENGTH}; -pub struct Service { - db: D, +pub struct Service { + db: Box, } -impl Service { +impl Service { /// Creates a new Uiaa session. Make sure the session token is unique. pub fn create( &self, diff --git a/src/service/users/data.rs b/src/service/users/data.rs index 3f87589c..7eb0cebd 100644 --- a/src/service/users/data.rs +++ b/src/service/users/data.rs @@ -2,7 +2,7 @@ use std::collections::BTreeMap; use crate::Result; use ruma::{UserId, DeviceId, DeviceKeyAlgorithm, DeviceKeyId, serde::Raw, encryption::{OneTimeKey, DeviceKeys, CrossSigningKey}, UInt, events::AnyToDeviceEvent, api::client::{device::Device, filter::IncomingFilterDefinition}, MxcUri}; -pub trait Data { +pub trait Data: Send + Sync { /// Check if a user has an account on this homeserver. fn exists(&self, user_id: &UserId) -> Result; @@ -138,16 +138,16 @@ pub trait Data { device_id: &DeviceId, ) -> Result>>; - fn get_master_key bool>( + fn get_master_key( &self, user_id: &UserId, - allowed_signatures: F, + allowed_signatures: &dyn Fn(&UserId) -> bool, ) -> Result>>; - fn get_self_signing_key bool>( + fn get_self_signing_key( &self, user_id: &UserId, - allowed_signatures: F, + allowed_signatures: &dyn Fn(&UserId) -> bool, ) -> Result>>; fn get_user_signing_key(&self, user_id: &UserId) -> Result>>; diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index dfe6c7fb..8adc9366 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -6,11 +6,11 @@ use ruma::{UserId, MxcUri, DeviceId, DeviceKeyId, serde::Raw, encryption::{OneTi use crate::{Result, Error, services}; -pub struct Service { - db: D, +pub struct Service { + db: Box, } -impl Service { +impl Service { /// Check if a user has an account on this homeserver. pub fn exists(&self, user_id: &UserId) -> Result { self.db.exists(user_id)