diff --git a/src/api/client_server/membership.rs b/src/api/client_server/membership.rs index 8ccaa89c..0aae9959 100644 --- a/src/api/client_server/membership.rs +++ b/src/api/client_server/membership.rs @@ -670,7 +670,7 @@ async fn join_room_by_id_helper( .add_pdu_outlier(&event_id, &value)?; } - let shortstatehash = services().rooms.state.set_event_state( + let statehash_before_join = services().rooms.state.set_event_state( event_id, room_id, state @@ -684,8 +684,15 @@ async fn join_room_by_id_helper( .collect::>()?, )?; + services() + .rooms + .state + .set_room_state(room_id, statehash_before_join, &state_lock)?; + // We append to state before appending the pdu, so we don't have a moment in time with the // pdu without it's state. This is okay because append_pdu can't fail. + let statehash_after_join = services().rooms.state.append_to_state(&parsed_pdu)?; + services().rooms.timeline.append_pdu( &parsed_pdu, join_event, @@ -698,9 +705,7 @@ async fn join_room_by_id_helper( services() .rooms .state - .set_room_state(room_id, shortstatehash, &state_lock)?; - - let statehashid = services().rooms.state.append_to_state(&parsed_pdu)?; + .set_room_state(room_id, statehash_after_join, &state_lock)?; } else { let event = RoomMemberEventContent { membership: MembershipState::Join, diff --git a/src/api/server_server.rs b/src/api/server_server.rs index c832b0d4..bcf893c6 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -1319,7 +1319,7 @@ pub async fn create_join_event_template_route( }) .expect("member event is valid value"); - let (pdu, pdu_json) = services().rooms.timeline.create_hash_and_sign_event( + let (_pdu, pdu_json) = services().rooms.timeline.create_hash_and_sign_event( PduBuilder { event_type: RoomEventType::RoomMember, content, diff --git a/src/database/key_value/media.rs b/src/database/key_value/media.rs index de96ace1..6abe5ba5 100644 --- a/src/database/key_value/media.rs +++ b/src/database/key_value/media.rs @@ -43,8 +43,8 @@ impl service::media::Data for KeyValueDatabase { ) -> Result<(Option, Option, Vec)> { let mut prefix = mxc.as_bytes().to_vec(); prefix.push(0xff); - prefix.extend_from_slice(&0_u32.to_be_bytes()); // Width = 0 if it's not a thumbnail - prefix.extend_from_slice(&0_u32.to_be_bytes()); // Height = 0 if it's not a thumbnail + prefix.extend_from_slice(&width.to_be_bytes()); + prefix.extend_from_slice(&height.to_be_bytes()); prefix.push(0xff); let (key, _) = self diff --git a/src/database/key_value/mod.rs b/src/database/key_value/mod.rs index efb85509..c4496af8 100644 --- a/src/database/key_value/mod.rs +++ b/src/database/key_value/mod.rs @@ -7,7 +7,7 @@ mod media; //mod pdu; mod pusher; mod rooms; -//mod sending; +mod sending; mod transaction_ids; mod uiaa; mod users; diff --git a/src/database/key_value/pusher.rs b/src/database/key_value/pusher.rs index 15f4e26e..1468a553 100644 --- a/src/database/key_value/pusher.rs +++ b/src/database/key_value/pusher.rs @@ -3,7 +3,7 @@ use ruma::{ UserId, }; -use crate::{database::KeyValueDatabase, service, Error, Result}; +use crate::{database::KeyValueDatabase, service, Error, Result, utils}; impl service::pusher::Data for KeyValueDatabase { fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()> { @@ -28,9 +28,13 @@ impl service::pusher::Data for KeyValueDatabase { Ok(()) } - fn get_pusher(&self, senderkey: &[u8]) -> Result> { + fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result> { + let mut senderkey = sender.as_bytes().to_vec(); + senderkey.push(0xff); + senderkey.extend_from_slice(pushkey.as_bytes()); + self.senderkey_pusher - .get(senderkey)? + .get(&senderkey)? .map(|push| { serde_json::from_slice(&*push) .map_err(|_| Error::bad_database("Invalid Pusher in db.")) @@ -51,10 +55,17 @@ impl service::pusher::Data for KeyValueDatabase { .collect() } - fn get_pusher_senderkeys<'a>(&'a self, sender: &UserId) -> Box>> { + fn get_pushkeys<'a>(&'a self, sender: &UserId) -> Box> + 'a> { let mut prefix = sender.as_bytes().to_vec(); prefix.push(0xff); - Box::new(self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| k)) + Box::new(self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| { + let mut parts = k.splitn(2, |&b| b == 0xff); + let _senderkey = parts.next(); + let push_key = parts.next().ok_or_else(|| Error::bad_database("Invalid senderkey_pusher in db"))?; + let push_key_string = utils::string_from_bytes(push_key).map_err(|_| Error::bad_database("Invalid pusher bytes in senderkey_pusher"))?; + + Ok(push_key_string) + })) } } diff --git a/src/database/key_value/rooms/alias.rs b/src/database/key_value/rooms/alias.rs index 112d6eb0..f3de89da 100644 --- a/src/database/key_value/rooms/alias.rs +++ b/src/database/key_value/rooms/alias.rs @@ -43,10 +43,10 @@ impl service::rooms::alias::Data for KeyValueDatabase { .transpose() } - fn local_aliases_for_room( - &self, + fn local_aliases_for_room<'a>( + &'a self, room_id: &RoomId, - ) -> Box>>> { + ) -> Box>> + 'a> { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); diff --git a/src/database/key_value/rooms/directory.rs b/src/database/key_value/rooms/directory.rs index 661c202d..212ced91 100644 --- a/src/database/key_value/rooms/directory.rs +++ b/src/database/key_value/rooms/directory.rs @@ -15,7 +15,7 @@ impl service::rooms::directory::Data for KeyValueDatabase { Ok(self.publicroomids.get(room_id.as_bytes())?.is_some()) } - fn public_rooms(&self) -> Box>>> { + fn public_rooms<'a>(&'a self) -> Box>> + 'a> { Box::new(self.publicroomids.iter().map(|(bytes, _)| { RoomId::parse( utils::string_from_bytes(&bytes).map_err(|_| { diff --git a/src/database/key_value/rooms/edus/read_receipt.rs b/src/database/key_value/rooms/edus/read_receipt.rs index c78f0f51..19c1ced7 100644 --- a/src/database/key_value/rooms/edus/read_receipt.rs +++ b/src/database/key_value/rooms/edus/read_receipt.rs @@ -59,7 +59,7 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase { u64, Raw, )>, - >, + > + 'a, > { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); diff --git a/src/database/key_value/rooms/metadata.rs b/src/database/key_value/rooms/metadata.rs index 63a6b1aa..2ec18bed 100644 --- a/src/database/key_value/rooms/metadata.rs +++ b/src/database/key_value/rooms/metadata.rs @@ -1,6 +1,6 @@ use ruma::RoomId; -use crate::{database::KeyValueDatabase, service, services, Result}; +use crate::{database::KeyValueDatabase, service, services, Result, utils, Error}; impl service::rooms::metadata::Data for KeyValueDatabase { fn exists(&self, room_id: &RoomId) -> Result { @@ -18,6 +18,18 @@ impl service::rooms::metadata::Data for KeyValueDatabase { .is_some()) } + fn iter_ids<'a>(&'a self) -> Box>> + 'a> { + Box::new(self.roomid_shortroomid.iter().map(|(bytes, _)| { + RoomId::parse( + utils::string_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Room ID in publicroomids is invalid unicode.") + })?, + ) + .map_err(|_| Error::bad_database("Room ID in roomid_shortroomid is invalid.")) + })) + + } + fn is_disabled(&self, room_id: &RoomId) -> Result { Ok(self.disabledroomids.get(room_id.as_bytes())?.is_some()) } diff --git a/src/database/key_value/rooms/search.rs b/src/database/key_value/rooms/search.rs index 79e6a326..8aa7a639 100644 --- a/src/database/key_value/rooms/search.rs +++ b/src/database/key_value/rooms/search.rs @@ -5,7 +5,7 @@ use ruma::RoomId; use crate::{database::KeyValueDatabase, service, services, utils, Result}; impl service::rooms::search::Data for KeyValueDatabase { - fn index_pdu<'a>(&self, shortroomid: u64, pdu_id: &[u8], message_body: String) -> Result<()> { + fn index_pdu<'a>(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { let mut batch = message_body .split_terminator(|c: char| !c.is_alphanumeric()) .filter(|s| !s.is_empty()) @@ -26,7 +26,7 @@ impl service::rooms::search::Data for KeyValueDatabase { &'a self, room_id: &RoomId, search_string: &str, - ) -> Result>>, Vec)>> { + ) -> Result>+ 'a>, Vec)>> { let prefix = services() .rooms .short diff --git a/src/database/key_value/rooms/timeline.rs b/src/database/key_value/rooms/timeline.rs index 5d684a1b..1660a9ec 100644 --- a/src/database/key_value/rooms/timeline.rs +++ b/src/database/key_value/rooms/timeline.rs @@ -235,7 +235,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase { user_id: &UserId, room_id: &RoomId, since: u64, - ) -> Result, PduEvent)>>>> { + ) -> Result, PduEvent)>> + 'a>> { let prefix = services() .rooms .short @@ -272,7 +272,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase { user_id: &UserId, room_id: &RoomId, until: u64, - ) -> Result, PduEvent)>>>> { + ) -> Result, PduEvent)>> + 'a>> { // Create the first part of the full pdu id let prefix = services() .rooms @@ -309,7 +309,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase { user_id: &UserId, room_id: &RoomId, from: u64, - ) -> Result, PduEvent)>>>> { + ) -> Result, PduEvent)>> + 'a>> { // Create the first part of the full pdu id let prefix = services() .rooms @@ -347,8 +347,8 @@ impl service::rooms::timeline::Data for KeyValueDatabase { notifies: Vec>, highlights: Vec>, ) -> Result<()> { - let notifies_batch = Vec::new(); - let highlights_batch = Vec::new(); + let mut notifies_batch = Vec::new(); + let mut highlights_batch = Vec::new(); for user in notifies { let mut userroom_id = user.as_bytes().to_vec(); userroom_id.push(0xff); diff --git a/src/database/key_value/rooms/user.rs b/src/database/key_value/rooms/user.rs index 78c78e19..9230e611 100644 --- a/src/database/key_value/rooms/user.rs +++ b/src/database/key_value/rooms/user.rs @@ -86,7 +86,7 @@ impl service::rooms::user::Data for KeyValueDatabase { fn get_shared_rooms<'a>( &'a self, users: Vec>, - ) -> Result>>>> { + ) -> Result>> + 'a>> { let iterators = users.into_iter().map(move |user_id| { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xff); diff --git a/src/database/key_value/sending.rs b/src/database/key_value/sending.rs new file mode 100644 index 00000000..d84bd494 --- /dev/null +++ b/src/database/key_value/sending.rs @@ -0,0 +1,203 @@ +use ruma::{ServerName, UserId}; + +use crate::{ + database::KeyValueDatabase, + service::{ + self, + sending::{OutgoingKind, SendingEventType}, + }, + utils, Error, Result, +}; + +impl service::sending::Data for KeyValueDatabase { + fn active_requests<'a>( + &'a self, + ) -> Box, OutgoingKind, SendingEventType)>> + 'a> { + Box::new( + self.servercurrentevent_data + .iter() + .map(|(key, v)| parse_servercurrentevent(&key, v).map(|(k, e)| (key, k, e))), + ) + } + + fn active_requests_for<'a>( + &'a self, + outgoing_kind: &OutgoingKind, + ) -> Box, SendingEventType)>> + 'a> { + let prefix = outgoing_kind.get_prefix(); + Box::new( + self.servercurrentevent_data + .scan_prefix(prefix) + .map(|(key, v)| parse_servercurrentevent(&key, v).map(|(_, e)| (key, e))), + ) + } + + fn delete_active_request(&self, key: Vec) -> Result<()> { + self.servercurrentevent_data.remove(&key) + } + + fn delete_all_active_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()> { + let prefix = outgoing_kind.get_prefix(); + for (key, _) in self.servercurrentevent_data.scan_prefix(prefix.clone()) { + self.servercurrentevent_data.remove(&key)?; + } + + Ok(()) + } + + fn delete_all_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()> { + let prefix = outgoing_kind.get_prefix(); + for (key, _) in self.servercurrentevent_data.scan_prefix(prefix.clone()) { + self.servercurrentevent_data.remove(&key).unwrap(); + } + + for (key, _) in self.servernameevent_data.scan_prefix(prefix.clone()) { + self.servernameevent_data.remove(&key).unwrap(); + } + + Ok(()) + } + + fn queue_requests( + &self, + requests: &[(&OutgoingKind, SendingEventType)], + ) -> Result>> { + let mut batch = Vec::new(); + let mut keys = Vec::new(); + for (outgoing_kind, event) in requests { + let mut key = outgoing_kind.get_prefix(); + key.push(0xff); + key.extend_from_slice(if let SendingEventType::Pdu(value) = &event { + &**value + } else { + &[] + }); + let value = if let SendingEventType::Edu(value) = &event { + &**value + } else { + &[] + }; + batch.push((key.clone(), value.to_owned())); + keys.push(key); + } + self.servernameevent_data + .insert_batch(&mut batch.into_iter())?; + Ok(keys) + } + + fn queued_requests<'a>( + &'a self, + outgoing_kind: &OutgoingKind, + ) -> Box)>> + 'a> { + let prefix = outgoing_kind.get_prefix(); + return Box::new( + self.servernameevent_data + .scan_prefix(prefix.clone()) + .map(|(k, v)| parse_servercurrentevent(&k, v).map(|(_, ev)| (ev, k))), + ); + } + + fn mark_as_active(&self, events: &[(SendingEventType, Vec)]) -> Result<()> { + for (e, key) in events { + let value = if let SendingEventType::Edu(value) = &e { + &**value + } else { + &[] + }; + self.servercurrentevent_data.insert(key, value)?; + self.servernameevent_data.remove(key)?; + } + + Ok(()) + } + + fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) -> Result<()> { + self.servername_educount + .insert(server_name.as_bytes(), &last_count.to_be_bytes()) + } + + fn get_latest_educount(&self, server_name: &ServerName) -> Result { + self.servername_educount + .get(server_name.as_bytes())? + .map_or(Ok(0), |bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Invalid u64 in servername_educount.")) + }) + } +} + +#[tracing::instrument(skip(key))] +fn parse_servercurrentevent( + key: &[u8], + value: Vec, +) -> Result<(OutgoingKind, SendingEventType)> { + // Appservices start with a plus + Ok::<_, Error>(if key.starts_with(b"+") { + let mut parts = key[1..].splitn(2, |&b| b == 0xff); + + let server = parts.next().expect("splitn always returns one element"); + let event = parts + .next() + .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; + let server = utils::string_from_bytes(server).map_err(|_| { + Error::bad_database("Invalid server bytes in server_currenttransaction") + })?; + + ( + OutgoingKind::Appservice(server), + if value.is_empty() { + SendingEventType::Pdu(event.to_vec()) + } else { + SendingEventType::Edu(value) + }, + ) + } else if key.starts_with(b"$") { + let mut parts = key[1..].splitn(3, |&b| b == 0xff); + + let user = parts.next().expect("splitn always returns one element"); + let user_string = utils::string_from_bytes(&user) + .map_err(|_| Error::bad_database("Invalid user string in servercurrentevent"))?; + let user_id = UserId::parse(user_string) + .map_err(|_| Error::bad_database("Invalid user id in servercurrentevent"))?; + + let pushkey = parts + .next() + .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; + let pushkey_string = utils::string_from_bytes(pushkey) + .map_err(|_| Error::bad_database("Invalid pushkey in servercurrentevent"))?; + + let event = parts + .next() + .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; + ( + OutgoingKind::Push(user_id, pushkey_string), + if value.is_empty() { + SendingEventType::Pdu(event.to_vec()) + } else { + // I'm pretty sure this should never be called + SendingEventType::Edu(value) + }, + ) + } else { + let mut parts = key.splitn(2, |&b| b == 0xff); + + let server = parts.next().expect("splitn always returns one element"); + let event = parts + .next() + .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; + let server = utils::string_from_bytes(server).map_err(|_| { + Error::bad_database("Invalid server bytes in server_currenttransaction") + })?; + + ( + OutgoingKind::Normal(ServerName::parse(server).map_err(|_| { + Error::bad_database("Invalid server string in server_currenttransaction") + })?), + if value.is_empty() { + SendingEventType::Pdu(event.to_vec()) + } else { + SendingEventType::Edu(value) + }, + ) + }) +} diff --git a/src/database/key_value/users.rs b/src/database/key_value/users.rs index 791e2498..86689f85 100644 --- a/src/database/key_value/users.rs +++ b/src/database/key_value/users.rs @@ -67,7 +67,7 @@ impl service::users::Data for KeyValueDatabase { } /// Returns an iterator over all users on this homeserver. - fn iter(&self) -> Box>>> { + fn iter<'a>(&'a self) -> Box>> + 'a> { Box::new(self.userid_password.iter().map(|(bytes, _)| { UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| { Error::bad_database("User ID in userid_password is invalid unicode.") @@ -83,33 +83,11 @@ impl service::users::Data for KeyValueDatabase { let users: Vec = self .userid_password .iter() - .filter_map(|(username, pw)| self.get_username_with_valid_password(&username, &pw)) + .filter_map(|(username, pw)| get_username_with_valid_password(&username, &pw)) .collect(); Ok(users) } - /// Will only return with Some(username) if the password was not empty and the - /// username could be successfully parsed. - /// If utils::string_from_bytes(...) returns an error that username will be skipped - /// and the error will be logged. - fn get_username_with_valid_password(&self, username: &[u8], password: &[u8]) -> Option { - // A valid password is not empty - if password.is_empty() { - None - } else { - match utils::string_from_bytes(username) { - Ok(u) => Some(u), - Err(e) => { - warn!( - "Failed to parse username while calling get_local_users(): {}", - e.to_string() - ); - None - } - } - } - } - /// Returns the password hash for the given user. fn password_hash(&self, user_id: &UserId) -> Result> { self.userid_password @@ -281,7 +259,7 @@ impl service::users::Data for KeyValueDatabase { fn all_device_ids<'a>( &'a self, user_id: &UserId, - ) -> Box>>> { + ) -> Box>> + 'a> { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xff); // All devices have metadata @@ -626,7 +604,7 @@ impl service::users::Data for KeyValueDatabase { user_or_room_id: &str, from: u64, to: Option, - ) -> Box>>> { + ) -> Box>> + 'a> { let mut prefix = user_or_room_id.as_bytes().to_vec(); prefix.push(0xff); @@ -906,7 +884,7 @@ impl service::users::Data for KeyValueDatabase { fn all_devices_metadata<'a>( &'a self, user_id: &UserId, - ) -> Box>> { + ) -> Box> + 'a> { let mut key = user_id.as_bytes().to_vec(); key.push(0xff); @@ -956,3 +934,26 @@ impl service::users::Data for KeyValueDatabase { } } } + +/// Will only return with Some(username) if the password was not empty and the +/// username could be successfully parsed. +/// If utils::string_from_bytes(...) returns an error that username will be skipped +/// and the error will be logged. +fn get_username_with_valid_password(username: &[u8], password: &[u8]) -> Option { + // A valid password is not empty + if password.is_empty() { + None + } else { + match utils::string_from_bytes(username) { + Ok(u) => Some(u), + Err(e) => { + warn!( + "Failed to parse username while calling get_local_users(): {}", + e.to_string() + ); + None + } + } + } +} + diff --git a/src/database/mod.rs b/src/database/mod.rs index c4e64af8..191cd62f 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -166,19 +166,6 @@ pub struct KeyValueDatabase { pub(super) shortstatekey_cache: Mutex>, pub(super) our_real_users_cache: RwLock, Arc>>>>, pub(super) appservice_in_room_cache: RwLock, HashMap>>, - pub(super) lazy_load_waiting: - Mutex, Box, Box, u64), HashSet>>>, - pub(super) stateinfo_cache: Mutex< - LruCache< - u64, - Vec<( - u64, // sstatehash - HashSet, // full state - HashSet, // added - HashSet, // removed - )>, - >, - >, pub(super) lasttimelinecount_cache: Mutex, u64>>, } @@ -279,10 +266,7 @@ impl KeyValueDatabase { eprintln!("ERROR: Max request size is less than 1KB. Please increase it."); } - let (admin_sender, admin_receiver) = mpsc::unbounded_channel(); - let (sending_sender, sending_receiver) = mpsc::unbounded_channel(); - - let db = Arc::new(Self { + let db_raw = Box::new(Self { _db: builder.clone(), userid_password: builder.open_tree("userid_password")?, userid_displayname: builder.open_tree("userid_displayname")?, @@ -399,14 +383,12 @@ impl KeyValueDatabase { )), our_real_users_cache: RwLock::new(HashMap::new()), appservice_in_room_cache: RwLock::new(HashMap::new()), - lazy_load_waiting: Mutex::new(HashMap::new()), - stateinfo_cache: Mutex::new(LruCache::new( - (100.0 * config.conduit_cache_capacity_modifier) as usize, - )), lasttimelinecount_cache: Mutex::new(HashMap::new()), }); - let services_raw = Box::new(Services::build(Arc::clone(&db), config)?); + let db = Box::leak(db_raw); + + let services_raw = Box::new(Services::build(db, config)?); // This is the first and only time we initialize the SERVICE static *SERVICES.write().unwrap() = Some(Box::leak(services_raw)); @@ -851,8 +833,6 @@ impl KeyValueDatabase { // This data is probably outdated db.presenceid_presence.clear()?; - services().admin.start_handler(admin_receiver); - // Set emergency access for the conduit user match set_emergency_access() { Ok(pwd_set) => { @@ -869,19 +849,11 @@ impl KeyValueDatabase { } }; - services().sending.start_handler(sending_receiver); - Self::start_cleanup_task().await; Ok(()) } - #[cfg(feature = "conduit_bin")] - pub async fn on_shutdown() { - info!(target: "shutdown-sync", "Received shutdown notification, notifying sync helpers..."); - services().globals.rotate.fire(); - } - #[tracing::instrument(skip(self))] pub fn flush(&self) -> Result<()> { let start = std::time::Instant::now(); diff --git a/src/lib.rs b/src/lib.rs index 0afc75f1..9c397c08 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,6 +21,7 @@ pub use api::ruma_wrapper::{Ruma, RumaResponse}; pub use config::Config; pub use service::{pdu::PduEvent, Services}; pub use utils::error::{Error, Result}; +pub use database::KeyValueDatabase; pub static SERVICES: RwLock> = RwLock::new(None); diff --git a/src/main.rs b/src/main.rs index d5b2731e..71eaa660 100644 --- a/src/main.rs +++ b/src/main.rs @@ -17,6 +17,7 @@ use axum::{ Router, }; use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle}; +use conduit::api::{client_server, server_server}; use figment::{ providers::{Env, Format, Toml}, Figment, @@ -34,7 +35,7 @@ use tower_http::{ trace::TraceLayer, ServiceBuilderExt as _, }; -use tracing::warn; +use tracing::{warn, info}; use tracing_subscriber::{prelude::*, EnvFilter}; pub use conduit::*; // Re-export everything from the library crate @@ -69,7 +70,7 @@ async fn main() { config.warn_deprecated(); - if let Err(e) = KeyValueDatabase::load_or_create(&config).await { + if let Err(e) = KeyValueDatabase::load_or_create(config).await { eprintln!( "The database couldn't be loaded or created. The following error occured: {}", e @@ -77,6 +78,8 @@ async fn main() { std::process::exit(1); }; + let config = &services().globals.config; + let start = async { run_server().await.unwrap(); }; @@ -119,7 +122,7 @@ async fn main() { } async fn run_server() -> io::Result<()> { - let config = DB.globals.config; + let config = &services().globals.config; let addr = SocketAddr::from((config.address, config.port)); let x_requested_with = HeaderName::from_static("x-requested-with"); @@ -156,8 +159,7 @@ async fn run_server() -> io::Result<()> { header::AUTHORIZATION, ]) .max_age(Duration::from_secs(86400)), - ) - .add_extension(db.clone()); + ); let app = routes().layer(middlewares).into_make_service(); let handle = ServerHandle::new(); @@ -174,8 +176,9 @@ async fn run_server() -> io::Result<()> { } } - // After serve exits and before exiting, shutdown the DB - Database::on_shutdown(db).await; + // On shutdown + info!(target: "shutdown-sync", "Received shutdown notification, notifying sync helpers..."); + services().globals.rotate.fire(); Ok(()) } diff --git a/src/service/account_data/mod.rs b/src/service/account_data/mod.rs index 60a53080..975c8203 100644 --- a/src/service/account_data/mod.rs +++ b/src/service/account_data/mod.rs @@ -13,7 +13,7 @@ use std::{collections::HashMap, sync::Arc}; use crate::{Result}; pub struct Service { - db: Arc, + pub db: &'static dyn Data, } impl Service { diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index 8725e674..2c776611 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -172,74 +172,82 @@ pub struct Service { } impl Service { - pub fn start_handler(&self, mut receiver: mpsc::UnboundedReceiver) { - tokio::spawn(async move { - // TODO: Use futures when we have long admin commands - //let mut futures = FuturesUnordered::new(); + pub fn build() -> Arc { + let (sender, receiver) = mpsc::unbounded_channel(); + let self1 = Arc::new(Self { sender }); + let self2 = Arc::clone(&self1); - let conduit_user = - UserId::parse(format!("@conduit:{}", services().globals.server_name())) - .expect("@conduit:server_name is valid"); + tokio::spawn(async move { self2.start_handler(receiver).await; }); - let conduit_room = services() + self1 + } + + async fn start_handler(&self, mut receiver: mpsc::UnboundedReceiver) { + // TODO: Use futures when we have long admin commands + //let mut futures = FuturesUnordered::new(); + + let conduit_user = + UserId::parse(format!("@conduit:{}", services().globals.server_name())) + .expect("@conduit:server_name is valid"); + + let conduit_room = services() + .rooms + .alias + .resolve_local_alias( + format!("#admins:{}", services().globals.server_name()) + .as_str() + .try_into() + .expect("#admins:server_name is a valid room alias"), + ) + .expect("Database data for admin room alias must be valid") + .expect("Admin room must exist"); + + let send_message = |message: RoomMessageEventContent, + mutex_lock: &MutexGuard<'_, ()>| { + services() .rooms - .alias - .resolve_local_alias( - format!("#admins:{}", services().globals.server_name()) - .as_str() - .try_into() - .expect("#admins:server_name is a valid room alias"), + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomMessage, + content: to_raw_value(&message) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: None, + redacts: None, + }, + &conduit_user, + &conduit_room, + mutex_lock, ) - .expect("Database data for admin room alias must be valid") - .expect("Admin room must exist"); + .unwrap(); + }; - let send_message = |message: RoomMessageEventContent, - mutex_lock: &MutexGuard<'_, ()>| { - services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomMessage, - content: to_raw_value(&message) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: None, - redacts: None, - }, - &conduit_user, - &conduit_room, - mutex_lock, - ) - .unwrap(); - }; + loop { + tokio::select! { + Some(event) = receiver.recv() => { + let message_content = match event { + AdminRoomEvent::SendMessage(content) => content, + AdminRoomEvent::ProcessMessage(room_message) => self.process_admin_message(room_message).await + }; - loop { - tokio::select! { - Some(event) = receiver.recv() => { - let message_content = match event { - AdminRoomEvent::SendMessage(content) => content, - AdminRoomEvent::ProcessMessage(room_message) => self.process_admin_message(room_message).await - }; + let mutex_state = Arc::clone( + services().globals + .roomid_mutex_state + .write() + .unwrap() + .entry(conduit_room.to_owned()) + .or_default(), + ); - let mutex_state = Arc::clone( - services().globals - .roomid_mutex_state - .write() - .unwrap() - .entry(conduit_room.to_owned()) - .or_default(), - ); + let state_lock = mutex_state.lock().await; - let state_lock = mutex_state.lock().await; + send_message(message_content, &state_lock); - send_message(message_content, &state_lock); - - drop(state_lock); - } + drop(state_lock); } } - }); + } } pub fn process_message(&self, room_message: String) { @@ -382,9 +390,7 @@ impl Service { } } AdminCommand::ListRooms => { - todo!(); - /* - let room_ids = services().rooms.iter_ids(); + let room_ids = services().rooms.metadata.iter_ids(); let output = format!( "Rooms:\n{}", room_ids @@ -393,6 +399,7 @@ impl Service { + "\tMembers: " + &services() .rooms + .state_cache .room_joined_count(&id) .ok() .flatten() @@ -402,7 +409,6 @@ impl Service { .join("\n") ); RoomMessageEventContent::text_plain(output) - */ } AdminCommand::ListLocalUsers => match services().users.list_local_users() { Ok(users) => { @@ -648,11 +654,11 @@ impl Service { )) } AdminCommand::DisableRoom { room_id } => { - services().rooms.metadata.disable_room(&room_id, true); + services().rooms.metadata.disable_room(&room_id, true)?; RoomMessageEventContent::text_plain("Room disabled.") } AdminCommand::EnableRoom { room_id } => { - services().rooms.metadata.disable_room(&room_id, false); + services().rooms.metadata.disable_room(&room_id, false)?; RoomMessageEventContent::text_plain("Room enabled.") } AdminCommand::DeactivateUser { diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs index ad5ab4aa..20ba08ad 100644 --- a/src/service/appservice/mod.rs +++ b/src/service/appservice/mod.rs @@ -6,7 +6,7 @@ pub use data::Data; use crate::Result; pub struct Service { - db: Arc, + pub db: &'static dyn Data, } impl Service { diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 6e03c156..477b269d 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -35,7 +35,7 @@ type SyncHandle = ( ); pub struct Service { - pub db: Arc, + pub db: &'static dyn Data, pub actual_destination_cache: Arc>, // actual_destination, host pub tls_name_override: Arc>, @@ -90,14 +90,14 @@ impl Default for RotationHandler { } impl Service { - pub fn load(db: Arc, config: Config) -> Result { + pub fn load(db: &'static dyn Data, config: Config) -> Result { let keypair = db.load_keypair(); let keypair = match keypair { Ok(k) => k, Err(e) => { error!("Keypair invalid. Deleting..."); - db.remove_keypair(); + db.remove_keypair()?; return Err(e); } }; diff --git a/src/service/key_backups/mod.rs b/src/service/key_backups/mod.rs index 31652d2f..5d0ad599 100644 --- a/src/service/key_backups/mod.rs +++ b/src/service/key_backups/mod.rs @@ -12,7 +12,7 @@ use ruma::{ use std::{collections::BTreeMap, sync::Arc}; pub struct Service { - db: Arc, + pub db: &'static dyn Data, } impl Service { diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs index 61a733a7..29648577 100644 --- a/src/service/media/mod.rs +++ b/src/service/media/mod.rs @@ -16,7 +16,7 @@ pub struct FileMeta { } pub struct Service { - db: Arc, + pub db: &'static dyn Data, } impl Service { diff --git a/src/service/mod.rs b/src/service/mod.rs index e1c6f7a4..e8696e79 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -29,11 +29,11 @@ pub struct Services { pub uiaa: uiaa::Service, pub users: users::Service, pub account_data: account_data::Service, - pub admin: admin::Service, + pub admin: Arc, pub globals: globals::Service, pub key_backups: key_backups::Service, pub media: media::Service, - pub sending: sending::Service, + pub sending: Arc, } impl Services { @@ -47,60 +47,60 @@ impl Services { + account_data::Data + globals::Data + key_backups::Data - + media::Data, + + media::Data + + sending::Data + + 'static >( - db: Arc, + db: &'static D, config: Config, ) -> Result { Ok(Self { - appservice: appservice::Service { db: db.clone() }, - pusher: pusher::Service { db: db.clone() }, + appservice: appservice::Service { db }, + pusher: pusher::Service { db }, rooms: rooms::Service { - alias: rooms::alias::Service { db: db.clone() }, - auth_chain: rooms::auth_chain::Service { db: db.clone() }, - directory: rooms::directory::Service { db: db.clone() }, + alias: rooms::alias::Service { db }, + auth_chain: rooms::auth_chain::Service { db }, + directory: rooms::directory::Service { db }, edus: rooms::edus::Service { - presence: rooms::edus::presence::Service { db: db.clone() }, - read_receipt: rooms::edus::read_receipt::Service { db: db.clone() }, - typing: rooms::edus::typing::Service { db: db.clone() }, + presence: rooms::edus::presence::Service { db }, + read_receipt: rooms::edus::read_receipt::Service { db }, + typing: rooms::edus::typing::Service { db }, }, event_handler: rooms::event_handler::Service, lazy_loading: rooms::lazy_loading::Service { - db: db.clone(), + db, lazy_load_waiting: Mutex::new(HashMap::new()), }, - metadata: rooms::metadata::Service { db: db.clone() }, - outlier: rooms::outlier::Service { db: db.clone() }, - pdu_metadata: rooms::pdu_metadata::Service { db: db.clone() }, - search: rooms::search::Service { db: db.clone() }, - short: rooms::short::Service { db: db.clone() }, - state: rooms::state::Service { db: db.clone() }, - state_accessor: rooms::state_accessor::Service { db: db.clone() }, - state_cache: rooms::state_cache::Service { db: db.clone() }, + metadata: rooms::metadata::Service { db }, + outlier: rooms::outlier::Service { db }, + pdu_metadata: rooms::pdu_metadata::Service { db }, + search: rooms::search::Service { db }, + short: rooms::short::Service { db }, + state: rooms::state::Service { db }, + state_accessor: rooms::state_accessor::Service { db }, + state_cache: rooms::state_cache::Service { db }, state_compressor: rooms::state_compressor::Service { - db: db.clone(), + db, stateinfo_cache: Mutex::new(LruCache::new( (100.0 * config.conduit_cache_capacity_modifier) as usize, )), }, timeline: rooms::timeline::Service { - db: db.clone(), + db, lasttimelinecount_cache: Mutex::new(HashMap::new()), }, - user: rooms::user::Service { db: db.clone() }, - }, - transaction_ids: transaction_ids::Service { db: db.clone() }, - uiaa: uiaa::Service { db: db.clone() }, - users: users::Service { db: db.clone() }, - account_data: account_data::Service { db: db.clone() }, - admin: admin::Service { sender: todo!() }, - globals: globals::Service::load(db.clone(), config)?, - key_backups: key_backups::Service { db: db.clone() }, - media: media::Service { db: db.clone() }, - sending: sending::Service { - maximum_requests: todo!(), - sender: todo!(), + user: rooms::user::Service { db }, }, + transaction_ids: transaction_ids::Service { db }, + uiaa: uiaa::Service { db }, + users: users::Service { db }, + account_data: account_data::Service { db }, + admin: admin::Service::build(), + key_backups: key_backups::Service { db }, + media: media::Service { db }, + sending: sending::Service::build(db, &config), + + globals: globals::Service::load(db, config)?, }) } } diff --git a/src/service/pusher/data.rs b/src/service/pusher/data.rs index 243b77f7..cb8768d8 100644 --- a/src/service/pusher/data.rs +++ b/src/service/pusher/data.rs @@ -7,9 +7,9 @@ use ruma::{ pub trait Data: Send + Sync { fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()>; - fn get_pusher(&self, senderkey: &[u8]) -> Result>; + fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result>; fn get_pushers(&self, sender: &UserId) -> Result>; - fn get_pusher_senderkeys<'a>(&'a self, sender: &UserId) -> Box>>; + fn get_pushkeys<'a>(&'a self, sender: &UserId) -> Box> + 'a>; } diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs index 78d5f26c..3b12f38b 100644 --- a/src/service/pusher/mod.rs +++ b/src/service/pusher/mod.rs @@ -26,7 +26,7 @@ use std::{fmt::Debug, mem}; use tracing::{error, info, warn}; pub struct Service { - db: Arc, + pub db: &'static dyn Data, } impl Service { @@ -34,19 +34,19 @@ impl Service { self.db.set_pusher(sender, pusher) } - pub fn get_pusher(&self, senderkey: &[u8]) -> Result> { - self.db.get_pusher(senderkey) + pub fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result> { + self.db.get_pusher(sender, pushkey) } pub fn get_pushers(&self, sender: &UserId) -> Result> { self.db.get_pushers(sender) } - pub fn get_pusher_senderkeys<'a>( + pub fn get_pushkeys<'a>( &'a self, sender: &UserId, - ) -> impl Iterator> + 'a { - self.db.get_pusher_senderkeys(sender) + ) -> Box>> { + self.db.get_pushkeys(sender) } #[tracing::instrument(skip(self, destination, request))] diff --git a/src/service/rooms/alias/data.rs b/src/service/rooms/alias/data.rs index 90205f93..6299add7 100644 --- a/src/service/rooms/alias/data.rs +++ b/src/service/rooms/alias/data.rs @@ -12,8 +12,8 @@ pub trait Data: Send + Sync { fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result>>; /// Returns all local aliases that point to the given room - fn local_aliases_for_room( - &self, + fn local_aliases_for_room<'a>( + &'a self, room_id: &RoomId, - ) -> Box>>>; + ) -> Box>> + 'a>; } diff --git a/src/service/rooms/alias/mod.rs b/src/service/rooms/alias/mod.rs index 6a3cf4e0..e76589ab 100644 --- a/src/service/rooms/alias/mod.rs +++ b/src/service/rooms/alias/mod.rs @@ -7,7 +7,7 @@ use crate::Result; use ruma::{RoomAliasId, RoomId}; pub struct Service { - db: Arc, + pub db: &'static dyn Data, } impl Service { @@ -30,7 +30,7 @@ impl Service { pub fn local_aliases_for_room<'a>( &'a self, room_id: &RoomId, - ) -> impl Iterator>> + 'a { + ) -> Box>> + 'a> { self.db.local_aliases_for_room(room_id) } } diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index ed06385d..d3b6e401 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -11,7 +11,7 @@ use tracing::log::warn; use crate::{services, Error, Result}; pub struct Service { - db: Arc, + pub db: &'static dyn Data, } impl Service { diff --git a/src/service/rooms/directory/data.rs b/src/service/rooms/directory/data.rs index fb523cf8..320c6db1 100644 --- a/src/service/rooms/directory/data.rs +++ b/src/service/rooms/directory/data.rs @@ -12,5 +12,5 @@ pub trait Data: Send + Sync { fn is_public_room(&self, room_id: &RoomId) -> Result; /// Returns the unsorted public room directory - fn public_rooms(&self) -> Box>>>; + fn public_rooms<'a>(&'a self) -> Box>> + 'a>; } diff --git a/src/service/rooms/directory/mod.rs b/src/service/rooms/directory/mod.rs index e85afef6..9e5e8156 100644 --- a/src/service/rooms/directory/mod.rs +++ b/src/service/rooms/directory/mod.rs @@ -7,7 +7,7 @@ use ruma::RoomId; use crate::Result; pub struct Service { - db: Arc, + pub db: &'static dyn Data, } impl Service { diff --git a/src/service/rooms/edus/mod.rs b/src/service/rooms/edus/mod.rs index 8552363e..cf7a3591 100644 --- a/src/service/rooms/edus/mod.rs +++ b/src/service/rooms/edus/mod.rs @@ -2,7 +2,7 @@ pub mod presence; pub mod read_receipt; pub mod typing; -pub trait Data: presence::Data + read_receipt::Data + typing::Data {} +pub trait Data: presence::Data + read_receipt::Data + typing::Data + 'static {} pub struct Service { pub presence: presence::Service, diff --git a/src/service/rooms/edus/presence/mod.rs b/src/service/rooms/edus/presence/mod.rs index 636bd910..9cce9d8c 100644 --- a/src/service/rooms/edus/presence/mod.rs +++ b/src/service/rooms/edus/presence/mod.rs @@ -7,7 +7,7 @@ use ruma::{events::presence::PresenceEvent, RoomId, UserId}; use crate::Result; pub struct Service { - db: Arc, + pub db: &'static dyn Data, } impl Service { diff --git a/src/service/rooms/edus/read_receipt/data.rs b/src/service/rooms/edus/read_receipt/data.rs index 734c68d5..9a02ee40 100644 --- a/src/service/rooms/edus/read_receipt/data.rs +++ b/src/service/rooms/edus/read_receipt/data.rs @@ -11,8 +11,8 @@ pub trait Data: Send + Sync { ) -> Result<()>; /// Returns an iterator over the most recent read_receipts in a room that happened after the event with id `since`. - fn readreceipts_since( - &self, + fn readreceipts_since<'a>( + &'a self, room_id: &RoomId, since: u64, ) -> Box< @@ -22,7 +22,7 @@ pub trait Data: Send + Sync { u64, Raw, )>, - >, + > + 'a, >; /// Sets a private read marker at `count`. diff --git a/src/service/rooms/edus/read_receipt/mod.rs b/src/service/rooms/edus/read_receipt/mod.rs index 35fee1a5..8d6eaafd 100644 --- a/src/service/rooms/edus/read_receipt/mod.rs +++ b/src/service/rooms/edus/read_receipt/mod.rs @@ -7,7 +7,7 @@ use crate::Result; use ruma::{events::receipt::ReceiptEvent, serde::Raw, RoomId, UserId}; pub struct Service { - db: Arc, + pub db: &'static dyn Data, } impl Service { diff --git a/src/service/rooms/edus/typing/mod.rs b/src/service/rooms/edus/typing/mod.rs index 91892df6..fc06fe4a 100644 --- a/src/service/rooms/edus/typing/mod.rs +++ b/src/service/rooms/edus/typing/mod.rs @@ -7,7 +7,7 @@ use ruma::{events::SyncEphemeralRoomEvent, RoomId, UserId}; use crate::Result; pub struct Service { - db: Arc, + pub db: &'static dyn Data, } impl Service { diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index 12320388..0c0bd2ce 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -256,7 +256,7 @@ impl Service { #[tracing::instrument(skip(self, create_event, value, pub_key_map))] fn handle_outlier_pdu<'a>( - &self, + &'a self, origin: &'a ServerName, create_event: &'a PduEvent, event_id: &'a EventId, @@ -1015,7 +1015,7 @@ impl Service { /// d. TODO: Ask other servers over federation? #[tracing::instrument(skip_all)] pub(crate) fn fetch_and_handle_outliers<'a>( - &self, + &'a self, origin: &'a ServerName, events: &'a [Arc], create_event: &'a PduEvent, diff --git a/src/service/rooms/lazy_loading/mod.rs b/src/service/rooms/lazy_loading/mod.rs index a01ce9ba..2ed0bed0 100644 --- a/src/service/rooms/lazy_loading/mod.rs +++ b/src/service/rooms/lazy_loading/mod.rs @@ -10,9 +10,9 @@ use ruma::{DeviceId, RoomId, UserId}; use crate::Result; pub struct Service { - db: Arc, + pub db: &'static dyn Data, - lazy_load_waiting: + pub lazy_load_waiting: Mutex, Box, Box, u64), HashSet>>>, } @@ -67,7 +67,7 @@ impl Service { user_id, device_id, room_id, - &mut user_ids.iter().map(|&u| &*u), + &mut user_ids.iter().map(|u| &**u), )?; } else { // Ignore diff --git a/src/service/rooms/metadata/data.rs b/src/service/rooms/metadata/data.rs index 27e7eb98..df416dac 100644 --- a/src/service/rooms/metadata/data.rs +++ b/src/service/rooms/metadata/data.rs @@ -3,6 +3,7 @@ use ruma::RoomId; pub trait Data: Send + Sync { fn exists(&self, room_id: &RoomId) -> Result; + fn iter_ids<'a>(&'a self) -> Box>> + 'a>; fn is_disabled(&self, room_id: &RoomId) -> Result; fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()>; } diff --git a/src/service/rooms/metadata/mod.rs b/src/service/rooms/metadata/mod.rs index b6cccd15..df9f40a8 100644 --- a/src/service/rooms/metadata/mod.rs +++ b/src/service/rooms/metadata/mod.rs @@ -7,7 +7,7 @@ use ruma::RoomId; use crate::Result; pub struct Service { - db: Arc, + pub db: &'static dyn Data, } impl Service { @@ -17,6 +17,10 @@ impl Service { self.db.exists(room_id) } + pub fn iter_ids<'a>(&'a self) -> Box>> + 'a> { + self.db.iter_ids() + } + pub fn is_disabled(&self, room_id: &RoomId) -> Result { self.db.is_disabled(room_id) } diff --git a/src/service/rooms/outlier/mod.rs b/src/service/rooms/outlier/mod.rs index 6404d8a1..443abd19 100644 --- a/src/service/rooms/outlier/mod.rs +++ b/src/service/rooms/outlier/mod.rs @@ -7,7 +7,7 @@ use ruma::{signatures::CanonicalJsonObject, EventId}; use crate::{PduEvent, Result}; pub struct Service { - db: Arc, + pub db: &'static dyn Data, } impl Service { diff --git a/src/service/rooms/pdu_metadata/mod.rs b/src/service/rooms/pdu_metadata/mod.rs index 70443389..b816678c 100644 --- a/src/service/rooms/pdu_metadata/mod.rs +++ b/src/service/rooms/pdu_metadata/mod.rs @@ -7,7 +7,7 @@ use ruma::{EventId, RoomId}; use crate::Result; pub struct Service { - db: Arc, + pub db: &'static dyn Data, } impl Service { diff --git a/src/service/rooms/search/data.rs b/src/service/rooms/search/data.rs index 59652e02..bd7d61bb 100644 --- a/src/service/rooms/search/data.rs +++ b/src/service/rooms/search/data.rs @@ -2,11 +2,11 @@ use crate::Result; use ruma::RoomId; pub trait Data: Send + Sync { - fn index_pdu<'a>(&self, shortroomid: u64, pdu_id: &[u8], message_body: String) -> Result<()>; + fn index_pdu<'a>(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()>; fn search_pdus<'a>( &'a self, room_id: &RoomId, search_string: &str, - ) -> Result>>, Vec)>>; + ) -> Result>+ 'a>, Vec)>>; } diff --git a/src/service/rooms/search/mod.rs b/src/service/rooms/search/mod.rs index 0ef96342..1d8d01e1 100644 --- a/src/service/rooms/search/mod.rs +++ b/src/service/rooms/search/mod.rs @@ -7,7 +7,7 @@ use crate::Result; use ruma::RoomId; pub struct Service { - db: Arc, + pub db: &'static dyn Data, } impl Service { @@ -16,7 +16,7 @@ impl Service { &self, shortroomid: u64, pdu_id: &[u8], - message_body: String, + message_body: &str, ) -> Result<()> { self.db.index_pdu(shortroomid, pdu_id, message_body) } diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index efa4362a..d847dea2 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -7,7 +7,7 @@ use ruma::{events::StateEventType, EventId, RoomId}; use crate::{Result}; pub struct Service { - db: Arc, + pub db: &'static dyn Data, } impl Service { diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index 2dff4b71..614236ca 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -23,7 +23,7 @@ use crate::{services, utils::calculate_hash, Error, PduEvent, Result}; use super::state_compressor::CompressedStateEvent; pub struct Service { - db: Arc, + pub db: &'static dyn Data, } impl Service { @@ -33,7 +33,7 @@ impl Service { room_id: &RoomId, shortstatehash: u64, statediffnew: HashSet, - statediffremoved: HashSet, + _statediffremoved: HashSet, ) -> Result<()> { let mutex_state = Arc::clone( services() @@ -102,7 +102,7 @@ impl Service { services().rooms.state_cache.update_joined_count(room_id)?; - self.db.set_room_state(room_id, shortstatehash, &state_lock); + self.db.set_room_state(room_id, shortstatehash, &state_lock)?; drop(state_lock); diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index e179d70f..1a9c4a9e 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -10,7 +10,7 @@ use ruma::{events::StateEventType, EventId, RoomId}; use crate::{PduEvent, Result}; pub struct Service { - db: Arc, + pub db: &'static dyn Data, } impl Service { diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index 608dbcaa..cf4c6655 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -17,7 +17,7 @@ use ruma::{ use crate::{services, Error, Result}; pub struct Service { - db: Arc, + pub db: &'static dyn Data, } impl Service { @@ -112,7 +112,7 @@ impl Service { }; // Copy direct chat flag - if let Some(mut direct_event) = services() + if let Some(direct_event) = services() .account_data .get( None, @@ -125,7 +125,7 @@ impl Service { }) }) { - let direct_event = direct_event?; + let mut direct_event = direct_event?; let mut room_ids_updated = false; for room_ids in direct_event.content.0.values_mut() { diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index f7c6dba0..b927cb72 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -14,7 +14,7 @@ use crate::{services, utils, Result}; use self::data::StateDiff; pub struct Service { - db: Arc, + pub db: &'static dyn Data, pub stateinfo_cache: Mutex< LruCache< diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index 4ae8ce96..095731ca 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -60,7 +60,7 @@ pub trait Data: Send + Sync { user_id: &UserId, room_id: &RoomId, since: u64, - ) -> Result, PduEvent)>>>>; + ) -> Result, PduEvent)>> + 'a>>; /// Returns an iterator over all events and their tokens in a room that happened before the /// event with id `until` in reverse-chronological order. @@ -69,14 +69,14 @@ pub trait Data: Send + Sync { user_id: &UserId, room_id: &RoomId, until: u64, - ) -> Result, PduEvent)>>>>; + ) -> Result, PduEvent)>> + 'a>>; fn pdus_after<'a>( &'a self, user_id: &UserId, room_id: &RoomId, from: u64, - ) -> Result, PduEvent)>>>>; + ) -> Result, PduEvent)>> + 'a>>; fn increment_notification_counts( &self, diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 73f1451b..01c54a3a 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -36,9 +36,9 @@ use crate::{ use super::state_compressor::CompressedStateEvent; pub struct Service { - db: Arc, + pub db: &'static dyn Data, - pub(super) lasttimelinecount_cache: Mutex, u64>>, + pub lasttimelinecount_cache: Mutex, u64>>, } impl Service { @@ -253,10 +253,10 @@ impl Service { .rooms .state_cache .get_our_real_users(&pdu.room_id)? - .into_iter() + .iter() { // Don't notify the user of their own events - if &user == &pdu.sender { + if user == &pdu.sender { continue; } @@ -297,20 +297,20 @@ impl Service { } if notify { - notifies.push(user); + notifies.push(user.clone()); } if highlight { - highlights.push(user); + highlights.push(user.clone()); } - for senderkey in services().pusher.get_pusher_senderkeys(&user) { - services().sending.send_push_pdu(&*pdu_id, senderkey)?; + for push_key in services().pusher.get_pushkeys(&user) { + services().sending.send_push_pdu(&*pdu_id, &user, push_key?)?; } } self.db - .increment_notification_counts(&pdu.room_id, notifies, highlights); + .increment_notification_counts(&pdu.room_id, notifies, highlights)?; match pdu.kind { RoomEventType::RoomRedaction => { @@ -365,7 +365,7 @@ impl Service { services() .rooms .search - .index_pdu(shortroomid, &pdu_id, body)?; + .index_pdu(shortroomid, &pdu_id, &body)?; let admin_room = services().rooms.alias.resolve_local_alias( <&RoomAliasId>::try_from( @@ -398,7 +398,7 @@ impl Service { { services() .sending - .send_pdu_appservice(&appservice.0, &pdu_id)?; + .send_pdu_appservice(appservice.0, pdu_id.clone())?; continue; } @@ -422,7 +422,7 @@ impl Service { if state_key_uid == &appservice_uid { services() .sending - .send_pdu_appservice(&appservice.0, &pdu_id)?; + .send_pdu_appservice(appservice.0, pdu_id.clone())?; continue; } } @@ -475,7 +475,7 @@ impl Service { { services() .sending - .send_pdu_appservice(&appservice.0, &pdu_id)?; + .send_pdu_appservice(appservice.0, pdu_id.clone())?; } } } @@ -565,7 +565,7 @@ impl Service { } } - let pdu = PduEvent { + let mut pdu = PduEvent { event_id: ruma::event_id!("$thiswillbefilledinlater").into(), room_id: room_id.to_owned(), sender: sender.to_owned(), diff --git a/src/service/rooms/user/data.rs b/src/service/rooms/user/data.rs index fcaff5ac..7b7841fb 100644 --- a/src/service/rooms/user/data.rs +++ b/src/service/rooms/user/data.rs @@ -20,5 +20,5 @@ pub trait Data: Send + Sync { fn get_shared_rooms<'a>( &'a self, users: Vec>, - ) -> Result>>>>; + ) -> Result>> + 'a>>; } diff --git a/src/service/rooms/user/mod.rs b/src/service/rooms/user/mod.rs index 1caa4b3f..0148399b 100644 --- a/src/service/rooms/user/mod.rs +++ b/src/service/rooms/user/mod.rs @@ -7,7 +7,7 @@ use ruma::{RoomId, UserId}; use crate::Result; pub struct Service { - db: Arc, + pub db: &'static dyn Data, } impl Service { diff --git a/src/service/sending/data.rs b/src/service/sending/data.rs new file mode 100644 index 00000000..2e574e23 --- /dev/null +++ b/src/service/sending/data.rs @@ -0,0 +1,29 @@ +use ruma::ServerName; + +use crate::Result; + +use super::{OutgoingKind, SendingEventType}; + +pub trait Data: Send + Sync { + fn active_requests<'a>( + &'a self, + ) -> Box, OutgoingKind, SendingEventType)>> + 'a>; + fn active_requests_for<'a>( + &'a self, + outgoing_kind: &OutgoingKind, + ) -> Box, SendingEventType)>> + 'a>; + fn delete_active_request(&self, key: Vec) -> Result<()>; + fn delete_all_active_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()>; + fn delete_all_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()>; + fn queue_requests( + &self, + requests: &[(&OutgoingKind, SendingEventType)], + ) -> Result>>; + fn queued_requests<'a>( + &'a self, + outgoing_kind: &OutgoingKind, + ) -> Box)>> + 'a>; + fn mark_as_active(&self, events: &[(SendingEventType, Vec)]) -> Result<()>; + fn set_latest_educount(&self, server_name: &ServerName, educount: u64) -> Result<()>; + fn get_latest_educount(&self, server_name: &ServerName) -> Result; +} diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index e5e8cffd..cb16e70d 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -1,15 +1,19 @@ +mod data; + +pub use data::Data; + use std::{ collections::{BTreeMap, HashMap, HashSet}, fmt::Debug, sync::Arc, - time::{Duration, Instant}, + time::{Duration, Instant}, iter, }; use crate::{ api::{appservice_server, server_server}, services, utils::{self, calculate_hash}, - Error, PduEvent, Result, + Error, PduEvent, Result, Config, }; use federation::transactions::send_transaction_message; use futures_util::{stream::FuturesUnordered, StreamExt}; @@ -40,7 +44,7 @@ use tracing::{error, warn}; #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub enum OutgoingKind { Appservice(String), - Push(Vec, Vec), // user and pushkey + Push(Box, String), // user and pushkey Normal(Box), } @@ -55,9 +59,9 @@ impl OutgoingKind { } OutgoingKind::Push(user, pushkey) => { let mut p = b"$".to_vec(); - p.extend_from_slice(user); + p.extend_from_slice(user.as_bytes()); p.push(0xff); - p.extend_from_slice(pushkey); + p.extend_from_slice(pushkey.as_bytes()); p } OutgoingKind::Normal(server) => { @@ -74,14 +78,16 @@ impl OutgoingKind { #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub enum SendingEventType { - Pdu(Vec), - Edu(Vec), + Pdu(Vec), // pduid + Edu(Vec), // pdu json } pub struct Service { + db: &'static dyn Data, + /// The state for a given state hash. pub(super) maximum_requests: Arc, - pub sender: mpsc::UnboundedSender<(Vec, Vec)>, + pub sender: mpsc::UnboundedSender<(OutgoingKind, SendingEventType, Vec)>, } enum TransactionStatus { @@ -91,131 +97,113 @@ enum TransactionStatus { } impl Service { - pub fn start_handler(&self, mut receiver: mpsc::UnboundedReceiver<(Vec, Vec)>) { + pub fn build(db: &'static dyn Data, config: &Config) -> Arc { + let (sender, receiver) = mpsc::unbounded_channel(); + + let self1 = Arc::new(Self { db, sender, maximum_requests: Arc::new(Semaphore::new(config.max_concurrent_requests as usize)) }); + let self2 = Arc::clone(&self1); + tokio::spawn(async move { - let mut futures = FuturesUnordered::new(); + self2.start_handler(receiver).await.unwrap(); + }); - let mut current_transaction_status = HashMap::, TransactionStatus>::new(); + self1 + } - // Retry requests we could not finish yet - let mut initial_transactions = HashMap::>::new(); + async fn start_handler(&self, mut receiver: mpsc::UnboundedReceiver<(OutgoingKind, SendingEventType, Vec)>) -> Result<()> { + let mut futures = FuturesUnordered::new(); - for (key, outgoing_kind, event) in services() - .sending - .servercurrentevent_data - .iter() - .filter_map(|(key, v)| { - Self::parse_servercurrentevent(&key, v) - .ok() - .map(|(k, e)| (key, k, e)) - }) - { - let entry = initial_transactions - .entry(outgoing_kind.clone()) - .or_insert_with(Vec::new); + let mut current_transaction_status = HashMap::::new(); - if entry.len() > 30 { - warn!( - "Dropping some current events: {:?} {:?} {:?}", - key, outgoing_kind, event - ); - services() - .sending - .servercurrentevent_data - .remove(&key) - .unwrap(); - continue; - } + // Retry requests we could not finish yet + let mut initial_transactions = HashMap::>::new(); - entry.push(event); + for (key, outgoing_kind, event) in self.db.active_requests().filter_map(|r| r.ok()) + { + let entry = initial_transactions + .entry(outgoing_kind.clone()) + .or_insert_with(Vec::new); + + if entry.len() > 30 { + warn!( + "Dropping some current events: {:?} {:?} {:?}", + key, outgoing_kind, event + ); + self.db.delete_active_request(key)?; + continue; } - for (outgoing_kind, events) in initial_transactions { - current_transaction_status - .insert(outgoing_kind.get_prefix(), TransactionStatus::Running); - futures.push(Self::handle_events(outgoing_kind.clone(), events)); - } + entry.push(event); + } - loop { - select! { - Some(response) = futures.next() => { - match response { - Ok(outgoing_kind) => { - let prefix = outgoing_kind.get_prefix(); - for (key, _) in services().sending.servercurrentevent_data - .scan_prefix(prefix.clone()) - { - services().sending.servercurrentevent_data.remove(&key).unwrap(); - } + for (outgoing_kind, events) in initial_transactions { + current_transaction_status + .insert(outgoing_kind.clone(), TransactionStatus::Running); + futures.push(Self::handle_events(outgoing_kind.clone(), events)); + } - // Find events that have been added since starting the last request - let new_events: Vec<_> = services().sending.servernameevent_data - .scan_prefix(prefix.clone()) - .filter_map(|(k, v)| { - Self::parse_servercurrentevent(&k, v).ok().map(|ev| (ev, k)) - }) - .take(30) - .collect(); + loop { + select! { + Some(response) = futures.next() => { + match response { + Ok(outgoing_kind) => { + self.db.delete_all_active_requests_for(&outgoing_kind)?; - // TODO: find edus + // Find events that have been added since starting the last request + let new_events = self.db.queued_requests(&outgoing_kind).filter_map(|r| r.ok()).take(30).collect::>(); - if !new_events.is_empty() { - // Insert pdus we found - for (e, key) in &new_events { - let value = if let SendingEventType::Edu(value) = &e.1 { &**value } else { &[] }; - services().sending.servercurrentevent_data.insert(key, value).unwrap(); - services().sending.servernameevent_data.remove(key).unwrap(); - } + // TODO: find edus - futures.push( - Self::handle_events( - outgoing_kind.clone(), - new_events.into_iter().map(|(event, _)| event.1).collect(), - ) - ); - } else { - current_transaction_status.remove(&prefix); - } - } - Err((outgoing_kind, _)) => { - current_transaction_status.entry(outgoing_kind.get_prefix()).and_modify(|e| *e = match e { - TransactionStatus::Running => TransactionStatus::Failed(1, Instant::now()), - TransactionStatus::Retrying(n) => TransactionStatus::Failed(*n+1, Instant::now()), - TransactionStatus::Failed(_, _) => { - error!("Request that was not even running failed?!"); - return - }, - }); - } - }; - }, - Some((key, value)) = receiver.recv() => { - if let Ok((outgoing_kind, event)) = Self::parse_servercurrentevent(&key, value) { - if let Ok(Some(events)) = Self::select_events( - &outgoing_kind, - vec![(event, key)], - &mut current_transaction_status, - ) { - futures.push(Self::handle_events(outgoing_kind, events)); + if !new_events.is_empty() { + // Insert pdus we found + self.db.mark_as_active(&new_events)?; + + futures.push( + Self::handle_events( + outgoing_kind.clone(), + new_events.into_iter().map(|(event, _)| event).collect(), + ) + ); + } else { + current_transaction_status.remove(&outgoing_kind); } } + Err((outgoing_kind, _)) => { + current_transaction_status.entry(outgoing_kind).and_modify(|e| *e = match e { + TransactionStatus::Running => TransactionStatus::Failed(1, Instant::now()), + TransactionStatus::Retrying(n) => TransactionStatus::Failed(*n+1, Instant::now()), + TransactionStatus::Failed(_, _) => { + error!("Request that was not even running failed?!"); + return + }, + }); + } + }; + }, + Some((outgoing_kind, event, key)) = receiver.recv() => { + if let Ok(Some(events)) = self.select_events( + &outgoing_kind, + vec![(event, key)], + &mut current_transaction_status, + ) { + futures.push(Self::handle_events(outgoing_kind, events)); } } } - }); + } } - #[tracing::instrument(skip(outgoing_kind, new_events, current_transaction_status))] + #[tracing::instrument(skip(self, outgoing_kind, new_events, current_transaction_status))] fn select_events( + &self, outgoing_kind: &OutgoingKind, new_events: Vec<(SendingEventType, Vec)>, // Events we want to send: event and full key - current_transaction_status: &mut HashMap, TransactionStatus>, + current_transaction_status: &mut HashMap, ) -> Result>> { let mut retry = false; let mut allow = true; - let prefix = outgoing_kind.get_prefix(); - let entry = current_transaction_status.entry(prefix.clone()); + let entry = current_transaction_status.entry(outgoing_kind.clone()); entry .and_modify(|e| match e { @@ -247,42 +235,20 @@ impl Service { if retry { // We retry the previous transaction - for (key, value) in services() - .sending - .servercurrentevent_data - .scan_prefix(prefix) - { - if let Ok((_, e)) = Self::parse_servercurrentevent(&key, value) { - events.push(e); - } + for (_, e) in self.db.active_requests_for(outgoing_kind).filter_map(|r| r.ok()) { + events.push(e); } } else { - for (e, full_key) in new_events { - let value = if let SendingEventType::Edu(value) = &e { - &**value - } else { - &[][..] - }; - services() - .sending - .servercurrentevent_data - .insert(&full_key, value)?; - - // If it was a PDU we have to unqueue it - // TODO: don't try to unqueue EDUs - services().sending.servernameevent_data.remove(&full_key)?; - + self.db.mark_as_active(&new_events)?; + for (e, _) in new_events { events.push(e); } if let OutgoingKind::Normal(server_name) = outgoing_kind { - if let Ok((select_edus, last_count)) = Self::select_edus(server_name) { + if let Ok((select_edus, last_count)) = self.select_edus(server_name) { events.extend(select_edus.into_iter().map(SendingEventType::Edu)); - services() - .sending - .servername_educount - .insert(server_name.as_bytes(), &last_count.to_be_bytes())?; + self.db.set_latest_educount(server_name, last_count)?; } } } @@ -290,22 +256,15 @@ impl Service { Ok(Some(events)) } - #[tracing::instrument(skip(server))] - pub fn select_edus(server: &ServerName) -> Result<(Vec>, u64)> { + #[tracing::instrument(skip(self, server_name))] + pub fn select_edus(&self, server_name: &ServerName) -> Result<(Vec>, u64)> { // u64: count of last edu - let since = services() - .sending - .servername_educount - .get(server.as_bytes())? - .map_or(Ok(0), |&bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid u64 in servername_educount.")) - })?; + let since = self.db.get_latest_educount(server_name)?; let mut events = Vec::new(); let mut max_edu_count = since; let mut device_list_changes = HashSet::new(); - 'outer: for room_id in services().rooms.server_rooms(server) { + 'outer: for room_id in services().rooms.state_cache.server_rooms(server_name) { let room_id = room_id?; // Look for device list updates in this room device_list_changes.extend( @@ -317,7 +276,7 @@ impl Service { ); // Look for read receipts in this room - for r in services().rooms.edus.readreceipts_since(&room_id, since) { + for r in services().rooms.edus.read_receipt.readreceipts_since(&room_id, since) { let (user_id, count, read_receipt) = r?; if count > max_edu_count { @@ -395,14 +354,12 @@ impl Service { Ok((events, max_edu_count)) } - #[tracing::instrument(skip(self, pdu_id, senderkey))] - pub fn send_push_pdu(&self, pdu_id: &[u8], senderkey: Vec) -> Result<()> { - let mut key = b"$".to_vec(); - key.extend_from_slice(&senderkey); - key.push(0xff); - key.extend_from_slice(pdu_id); - self.servernameevent_data.insert(&key, &[])?; - self.sender.send((key, vec![])).unwrap(); + #[tracing::instrument(skip(self, pdu_id, user, pushkey))] + pub fn send_push_pdu(&self, pdu_id: &[u8], user: &UserId, pushkey: String) -> Result<()> { + let outgoing_kind = OutgoingKind::Push(user.to_owned(), pushkey); + let event = SendingEventType::Pdu(pdu_id.to_owned()); + let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; + self.sender.send((outgoing_kind, event, keys.into_iter().next().unwrap())).unwrap(); Ok(()) } @@ -413,17 +370,11 @@ impl Service { servers: I, pdu_id: &[u8], ) -> Result<()> { - let mut batch = servers.map(|server| { - let mut key = server.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(pdu_id); - - self.sender.send((key.clone(), vec![])).unwrap(); - - (key, Vec::new()) - }); - - self.servernameevent_data.insert_batch(&mut batch)?; + let requests = servers.into_iter().map(|server| (OutgoingKind::Normal(server), SendingEventType::Pdu(pdu_id.to_owned()))).collect::>(); + let keys = self.db.queue_requests(&requests.iter().map(|(o, e)| (o, e.clone())).collect::>())?; + for ((outgoing_kind, event), key) in requests.into_iter().zip(keys) { + self.sender.send((outgoing_kind.to_owned(), event, key)).unwrap(); + } Ok(()) } @@ -435,23 +386,20 @@ impl Service { serialized: Vec, id: u64, ) -> Result<()> { - let mut key = server.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(&id.to_be_bytes()); - self.servernameevent_data.insert(&key, &serialized)?; - self.sender.send((key, serialized)).unwrap(); + let outgoing_kind = OutgoingKind::Normal(server.to_owned()); + let event = SendingEventType::Edu(serialized); + let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; + self.sender.send((outgoing_kind, event, keys.into_iter().next().unwrap())).unwrap(); Ok(()) } #[tracing::instrument(skip(self))] - pub fn send_pdu_appservice(&self, appservice_id: &str, pdu_id: &[u8]) -> Result<()> { - let mut key = b"+".to_vec(); - key.extend_from_slice(appservice_id.as_bytes()); - key.push(0xff); - key.extend_from_slice(pdu_id); - self.servernameevent_data.insert(&key, &[])?; - self.sender.send((key, vec![])).unwrap(); + pub fn send_pdu_appservice(&self, appservice_id: String, pdu_id: Vec) -> Result<()> { + let outgoing_kind = OutgoingKind::Appservice(appservice_id); + let event = SendingEventType::Pdu(pdu_id); + let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; + self.sender.send((outgoing_kind, event, keys.into_iter().next().unwrap())).unwrap(); Ok(()) } @@ -460,18 +408,8 @@ impl Service { /// Used for instance after we remove an appservice registration /// #[tracing::instrument(skip(self))] - pub fn cleanup_events(&self, key_id: &str) -> Result<()> { - let mut prefix = b"+".to_vec(); - prefix.extend_from_slice(key_id.as_bytes()); - prefix.push(0xff); - - for (key, _) in self.servercurrentevent_data.scan_prefix(prefix.clone()) { - self.servercurrentevent_data.remove(&key).unwrap(); - } - - for (key, _) in self.servernameevent_data.scan_prefix(prefix.clone()) { - self.servernameevent_data.remove(&key).unwrap(); - } + pub fn cleanup_events(&self, appservice_id: String) -> Result<()> { + self.db.delete_all_requests_for(&OutgoingKind::Appservice(appservice_id))?; Ok(()) } @@ -488,7 +426,7 @@ impl Service { for event in &events { match event { SendingEventType::Pdu(pdu_id) => { - pdu_jsons.push(services().rooms + pdu_jsons.push(services().rooms.timeline .get_pdu_from_id(pdu_id) .map_err(|e| (kind.clone(), e))? .ok_or_else(|| { @@ -525,7 +463,7 @@ impl Service { appservice::event::push_events::v1::Request { events: &pdu_jsons, txn_id: (&*base64::encode_config( - Self::calculate_hash( + calculate_hash( &events .iter() .map(|e| match e { @@ -546,7 +484,7 @@ impl Service { response } - OutgoingKind::Push(user, pushkey) => { + OutgoingKind::Push(userid, pushkey) => { let mut pdus = Vec::new(); for event in &events { @@ -554,6 +492,7 @@ impl Service { SendingEventType::Pdu(pdu_id) => { pdus.push( services().rooms + .timeline .get_pdu_from_id(pdu_id) .map_err(|e| (kind.clone(), e))? .ok_or_else(|| { @@ -584,27 +523,10 @@ impl Service { } } - let userid = UserId::parse(utils::string_from_bytes(user).map_err(|_| { - ( - kind.clone(), - Error::bad_database("Invalid push user string in db."), - ) - })?) - .map_err(|_| { - ( - kind.clone(), - Error::bad_database("Invalid push user id in db."), - ) - })?; - - let mut senderkey = user.clone(); - senderkey.push(0xff); - senderkey.extend_from_slice(pushkey); - let pusher = match services() .pusher - .get_pusher(&senderkey) - .map_err(|e| (OutgoingKind::Push(user.clone(), pushkey.clone()), e))? + .get_pusher(&userid, pushkey) + .map_err(|e| (OutgoingKind::Push(userid.clone(), pushkey.clone()), e))? { Some(pusher) => pusher, None => continue, @@ -618,11 +540,13 @@ impl Service { GlobalAccountDataEventType::PushRules.to_string().into(), ) .unwrap_or_default() + .and_then(|event| serde_json::from_str::(event.get()).ok()) .map(|ev: PushRulesEvent| ev.content.global) .unwrap_or_else(|| push::Ruleset::server_default(&userid)); let unread: UInt = services() .rooms + .user .notification_count(&userid, &pdu.room_id) .map_err(|e| (kind.clone(), e))? .try_into() @@ -639,7 +563,7 @@ impl Service { drop(permit); } - Ok(OutgoingKind::Push(user.clone(), pushkey.clone())) + Ok(OutgoingKind::Push(userid.clone(), pushkey.clone())) } OutgoingKind::Normal(server) => { let mut edu_jsons = Vec::new(); @@ -651,6 +575,7 @@ impl Service { // TODO: check room version and remove event_id if needed let raw = PduEvent::convert_to_outgoing_federation_event( services().rooms + .timeline .get_pdu_json_from_id(pdu_id) .map_err(|e| (OutgoingKind::Normal(server.clone()), e))? .ok_or_else(|| { @@ -713,72 +638,6 @@ impl Service { } } - #[tracing::instrument(skip(key))] - fn parse_servercurrentevent( - key: &[u8], - value: Vec, - ) -> Result<(OutgoingKind, SendingEventType)> { - // Appservices start with a plus - Ok::<_, Error>(if key.starts_with(b"+") { - let mut parts = key[1..].splitn(2, |&b| b == 0xff); - - let server = parts.next().expect("splitn always returns one element"); - let event = parts - .next() - .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; - let server = utils::string_from_bytes(server).map_err(|_| { - Error::bad_database("Invalid server bytes in server_currenttransaction") - })?; - - ( - OutgoingKind::Appservice(server), - if value.is_empty() { - SendingEventType::Pdu(event.to_vec()) - } else { - SendingEventType::Edu(value) - }, - ) - } else if key.starts_with(b"$") { - let mut parts = key[1..].splitn(3, |&b| b == 0xff); - - let user = parts.next().expect("splitn always returns one element"); - let pushkey = parts - .next() - .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; - let event = parts - .next() - .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; - ( - OutgoingKind::Push(user.to_vec(), pushkey.to_vec()), - if value.is_empty() { - SendingEventType::Pdu(event.to_vec()) - } else { - SendingEventType::Edu(value) - }, - ) - } else { - let mut parts = key.splitn(2, |&b| b == 0xff); - - let server = parts.next().expect("splitn always returns one element"); - let event = parts - .next() - .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; - let server = utils::string_from_bytes(server).map_err(|_| { - Error::bad_database("Invalid server bytes in server_currenttransaction") - })?; - - ( - OutgoingKind::Normal(ServerName::parse(server).map_err(|_| { - Error::bad_database("Invalid server string in server_currenttransaction") - })?), - if value.is_empty() { - SendingEventType::Pdu(event.to_vec()) - } else { - SendingEventType::Edu(value) - }, - ) - }) - } #[tracing::instrument(skip(self, destination, request))] pub async fn send_federation_request( diff --git a/src/service/transaction_ids/mod.rs b/src/service/transaction_ids/mod.rs index a473e2b1..509b65c0 100644 --- a/src/service/transaction_ids/mod.rs +++ b/src/service/transaction_ids/mod.rs @@ -7,7 +7,7 @@ use crate::Result; use ruma::{DeviceId, TransactionId, UserId}; pub struct Service { - db: Arc, + pub db: &'static dyn Data, } impl Service { diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs index 8f3b3b8b..f8addcc5 100644 --- a/src/service/uiaa/mod.rs +++ b/src/service/uiaa/mod.rs @@ -16,7 +16,7 @@ use tracing::error; use crate::{api::client_server::SESSION_ID_LENGTH, services, utils, Error, Result}; pub struct Service { - db: Arc, + pub db: &'static dyn Data, } impl Service { diff --git a/src/service/users/data.rs b/src/service/users/data.rs index 9f315d3b..9537ed2a 100644 --- a/src/service/users/data.rs +++ b/src/service/users/data.rs @@ -22,19 +22,13 @@ pub trait Data: Send + Sync { fn find_from_token(&self, token: &str) -> Result, String)>>; /// Returns an iterator over all users on this homeserver. - fn iter(&self) -> Box>>>; + fn iter<'a>(&'a self) -> Box>> + 'a>; /// Returns a list of local users as list of usernames. /// /// A user account is considered `local` if the length of it's password is greater then zero. fn list_local_users(&self) -> Result>; - /// Will only return with Some(username) if the password was not empty and the - /// username could be successfully parsed. - /// If utils::string_from_bytes(...) returns an error that username will be skipped - /// and the error will be logged. - fn get_username_with_valid_password(&self, username: &[u8], password: &[u8]) -> Option; - /// Returns the password hash for the given user. fn password_hash(&self, user_id: &UserId) -> Result>; @@ -75,7 +69,7 @@ pub trait Data: Send + Sync { fn all_device_ids<'a>( &'a self, user_id: &UserId, - ) -> Box>>>; + ) -> Box>> + 'a>; /// Replaces the access token of one device. fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()>; @@ -131,7 +125,7 @@ pub trait Data: Send + Sync { user_or_room_id: &str, from: u64, to: Option, - ) -> Box>>>; + ) -> Box>> + 'a>; fn mark_device_key_update(&self, user_id: &UserId) -> Result<()>; @@ -193,7 +187,7 @@ pub trait Data: Send + Sync { fn all_devices_metadata<'a>( &'a self, user_id: &UserId, - ) -> Box>>; + ) -> Box> + 'a>; /// Creates a new sync filter. Returns the filter id. fn create_filter(&self, user_id: &UserId, filter: &IncomingFilterDefinition) -> Result; diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index 0b83460c..e3419e75 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -13,7 +13,7 @@ use ruma::{ use crate::{services, Error, Result}; pub struct Service { - db: Arc, + pub db: &'static dyn Data, } impl Service { @@ -72,14 +72,6 @@ impl Service { self.db.list_local_users() } - /// Will only return with Some(username) if the password was not empty and the - /// username could be successfully parsed. - /// If utils::string_from_bytes(...) returns an error that username will be skipped - /// and the error will be logged. - fn get_username_with_valid_password(&self, username: &[u8], password: &[u8]) -> Option { - self.db.get_username_with_valid_password(username, password) - } - /// Returns the password hash for the given user. pub fn password_hash(&self, user_id: &UserId) -> Result> { self.db.password_hash(user_id) @@ -275,7 +267,7 @@ impl Service { user_id: &UserId, device_id: &DeviceId, ) -> Result>> { - self.get_to_device_events(user_id, device_id) + self.db.get_to_device_events(user_id, device_id) } pub fn remove_to_device_events( @@ -302,7 +294,7 @@ impl Service { user_id: &UserId, device_id: &DeviceId, ) -> Result> { - self.get_device_metadata(user_id, device_id) + self.db.get_device_metadata(user_id, device_id) } pub fn get_devicelist_version(&self, user_id: &UserId) -> Result> {