diff --git a/src/admin/query/account_data.rs b/src/admin/query/account_data.rs index 9f5b1482..ce9caedb 100644 --- a/src/admin/query/account_data.rs +++ b/src/admin/query/account_data.rs @@ -14,7 +14,6 @@ pub(super) async fn account_data(subcommand: AccountData) -> Result Result, user_id: &UserId, event_type: RoomAccountDataEventType, - data: &serde_json::Value, - ) -> Result<()>; - - /// Searches the account data for a specific kind. - fn get( - &self, room_id: Option<&RoomId>, user_id: &UserId, kind: RoomAccountDataEventType, - ) -> Result>>; - - /// Returns all changes to the account data that happened after `since`. - fn changes_since( - &self, room_id: Option<&RoomId>, user_id: &UserId, since: u64, - ) -> Result>>; +pub(super) struct Data { + roomuserdataid_accountdata: Arc, + roomusertype_roomuserdataid: Arc, } -impl Data for KeyValueDatabase { +impl Data { + pub(super) fn new(db: &Arc) -> Self { + Self { + roomuserdataid_accountdata: db.roomuserdataid_accountdata.clone(), + roomusertype_roomuserdataid: db.roomusertype_roomuserdataid.clone(), + } + } + /// Places one event in the account data of the user and removes the /// previous entry. - #[tracing::instrument(skip(self, room_id, user_id, event_type, data))] - fn update( - &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, + pub(super) fn update( + &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: &RoomAccountDataEventType, data: &serde_json::Value, ) -> Result<()> { let mut prefix = room_id @@ -80,9 +72,8 @@ impl Data for KeyValueDatabase { } /// Searches the account data for a specific kind. - #[tracing::instrument(skip(self, room_id, user_id, kind))] - fn get( - &self, room_id: Option<&RoomId>, user_id: &UserId, kind: RoomAccountDataEventType, + pub(super) fn get( + &self, room_id: Option<&RoomId>, user_id: &UserId, kind: &RoomAccountDataEventType, ) -> Result>> { let mut key = room_id .map(ToString::to_string) @@ -107,8 +98,7 @@ impl Data for KeyValueDatabase { } /// Returns all changes to the account data that happened after `since`. - #[tracing::instrument(skip(self, room_id, user_id, since))] - fn changes_since( + pub(super) fn changes_since( &self, room_id: Option<&RoomId>, user_id: &UserId, since: u64, ) -> Result>> { let mut userdata = HashMap::new(); diff --git a/src/service/account_data/mod.rs b/src/service/account_data/mod.rs index 7e17e145..fc330ead 100644 --- a/src/service/account_data/mod.rs +++ b/src/service/account_data/mod.rs @@ -2,40 +2,46 @@ mod data; use std::{collections::HashMap, sync::Arc}; -pub use data::Data; +use conduit::{Result, Server}; +use data::Data; +use database::KeyValueDatabase; use ruma::{ events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, serde::Raw, RoomId, UserId, }; -use crate::Result; - pub struct Service { - pub db: Arc, + db: Data, } impl Service { + pub fn build(_server: &Arc, db: &Arc) -> Result { + Ok(Self { + db: Data::new(db), + }) + } + /// Places one event in the account data of the user and removes the /// previous entry. - #[tracing::instrument(skip(self, room_id, user_id, event_type, data))] + #[allow(clippy::needless_pass_by_value)] pub fn update( &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, data: &serde_json::Value, ) -> Result<()> { - self.db.update(room_id, user_id, event_type, data) + self.db.update(room_id, user_id, &event_type, data) } /// Searches the account data for a specific kind. - #[tracing::instrument(skip(self, room_id, user_id, event_type))] + #[allow(clippy::needless_pass_by_value)] pub fn get( &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, ) -> Result>> { - self.db.get(room_id, user_id, event_type) + self.db.get(room_id, user_id, &event_type) } /// Returns all changes to the account data that happened after `since`. - #[tracing::instrument(skip(self, room_id, user_id, since))] + #[tracing::instrument(skip_all, name = "since")] pub fn changes_since( &self, room_id: Option<&RoomId>, user_id: &UserId, since: u64, ) -> Result>> { diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index 730163ac..5a1c161c 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -1,3 +1,6 @@ +use conduit::Server; +use database::KeyValueDatabase; + pub mod console; mod create; mod grant; @@ -44,17 +47,16 @@ pub struct Command { } impl Service { - #[must_use] - pub fn build() -> Arc { + pub fn build(_server: &Arc, _db: &Arc) -> Result> { let (sender, receiver) = loole::bounded(COMMAND_QUEUE_LIMIT); - Arc::new(Self { + Ok(Arc::new(Self { sender, receiver: Mutex::new(receiver), handler_join: Mutex::new(None), handle: Mutex::new(None), #[cfg(feature = "console")] console: console::Console::new(), - }) + })) } pub async fn start_handler(self: &Arc) { diff --git a/src/service/appservice/data.rs b/src/service/appservice/data.rs index e81cb7ac..59fabe96 100644 --- a/src/service/appservice/data.rs +++ b/src/service/appservice/data.rs @@ -1,28 +1,23 @@ +use std::sync::Arc; + +use database::KvTree; use ruma::api::appservice::Registration; use crate::{utils, Error, KeyValueDatabase, Result}; -pub trait Data: Send + Sync { - /// Registers an appservice and returns the ID to the caller - fn register_appservice(&self, yaml: Registration) -> Result; - - /// Remove an appservice registration - /// - /// # Arguments - /// - /// * `service_name` - the name you send to register the service previously - fn unregister_appservice(&self, service_name: &str) -> Result<()>; - - fn get_registration(&self, id: &str) -> Result>; - - fn iter_ids<'a>(&'a self) -> Result> + 'a>>; - - fn all(&self) -> Result>; +pub struct Data { + id_appserviceregistrations: Arc, } -impl Data for KeyValueDatabase { +impl Data { + pub(super) fn new(db: &Arc) -> Self { + Self { + id_appserviceregistrations: db.id_appserviceregistrations.clone(), + } + } + /// Registers an appservice and returns the ID to the caller - fn register_appservice(&self, yaml: Registration) -> Result { + pub(super) fn register_appservice(&self, yaml: Registration) -> Result { let id = yaml.id.as_str(); self.id_appserviceregistrations .insert(id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes())?; @@ -35,13 +30,13 @@ impl Data for KeyValueDatabase { /// # Arguments /// /// * `service_name` - the name you send to register the service previously - fn unregister_appservice(&self, service_name: &str) -> Result<()> { + pub(super) fn unregister_appservice(&self, service_name: &str) -> Result<()> { self.id_appserviceregistrations .remove(service_name.as_bytes())?; Ok(()) } - fn get_registration(&self, id: &str) -> Result> { + pub fn get_registration(&self, id: &str) -> Result> { self.id_appserviceregistrations .get(id.as_bytes())? .map(|bytes| { @@ -51,14 +46,14 @@ impl Data for KeyValueDatabase { .transpose() } - fn iter_ids<'a>(&'a self) -> Result> + 'a>> { + pub(super) fn iter_ids<'a>(&'a self) -> Result> + 'a>> { Ok(Box::new(self.id_appserviceregistrations.iter().map(|(id, _)| { utils::string_from_bytes(&id) .map_err(|_| Error::bad_database("Invalid id bytes in id_appserviceregistrations.")) }))) } - fn all(&self) -> Result> { + pub fn all(&self) -> Result> { self.iter_ids()? .filter_map(Result::ok) .map(move |id| { diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs index 94133396..83023151 100644 --- a/src/service/appservice/mod.rs +++ b/src/service/appservice/mod.rs @@ -113,13 +113,17 @@ impl TryFrom for RegistrationInfo { } pub struct Service { - pub db: Arc, + pub db: Data, registration_info: RwLock>, } +use conduit::Server; +use database::KeyValueDatabase; + impl Service { - pub fn build(db: Arc) -> Result { + pub fn build(_server: &Arc, db: &Arc) -> Result { let mut registration_info = BTreeMap::new(); + let db = Data::new(db); // Inserting registrations into cache for appservice in db.all()? { registration_info.insert( diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index 9851bc0a..1645c8d6 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -1,6 +1,9 @@ -use std::collections::{BTreeMap, HashMap}; +use std::{ + collections::{BTreeMap, HashMap}, + sync::Arc, +}; -use async_trait::async_trait; +use database::{Cork, KeyValueDatabase, KvTree}; use futures_util::{stream::FuturesUnordered, StreamExt}; use lru_cache::LruCache; use ruma::{ @@ -10,56 +13,60 @@ use ruma::{ }; use tracing::trace; -use crate::{database::Cork, services, utils, Error, KeyValueDatabase, Result}; +use crate::{services, utils, Error, Result}; const COUNTER: &[u8] = b"c"; const LAST_CHECK_FOR_UPDATES_COUNT: &[u8] = b"u"; -#[async_trait] -pub trait Data: Send + Sync { - fn next_count(&self) -> Result; - fn current_count(&self) -> Result; - fn last_check_for_updates_id(&self) -> Result; - fn update_check_for_updates_id(&self, id: u64) -> Result<()>; - #[allow(unused_qualifications)] // async traits - async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()>; - fn cleanup(&self) -> Result<()>; - - fn cork(&self) -> Cork; - fn cork_and_flush(&self) -> Cork; - - fn memory_usage(&self) -> String; - fn clear_caches(&self, amount: u32); - fn load_keypair(&self) -> Result; - fn remove_keypair(&self) -> Result<()>; - fn add_signing_key( - &self, origin: &ServerName, new_keys: ServerSigningKeys, - ) -> Result>; - - /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found - /// for the server. - fn signing_keys_for(&self, origin: &ServerName) -> Result>; - fn database_version(&self) -> Result; - fn bump_database_version(&self, new_version: u64) -> Result<()>; - fn backup(&self) -> Result<(), Box> { unimplemented!() } - fn backup_list(&self) -> Result { Ok(String::new()) } - fn file_list(&self) -> Result { Ok(String::new()) } +pub struct Data { + global: Arc, + todeviceid_events: Arc, + userroomid_joined: Arc, + userroomid_invitestate: Arc, + userroomid_leftstate: Arc, + userroomid_notificationcount: Arc, + userroomid_highlightcount: Arc, + pduid_pdu: Arc, + keychangeid_userid: Arc, + roomusertype_roomuserdataid: Arc, + server_signingkeys: Arc, + readreceiptid_readreceipt: Arc, + userid_lastonetimekeyupdate: Arc, + pub(super) db: Arc, } -#[async_trait] -impl Data for KeyValueDatabase { - fn next_count(&self) -> Result { +impl Data { + pub(super) fn new(db: &Arc) -> Self { + Self { + global: db.global.clone(), + todeviceid_events: db.todeviceid_events.clone(), + userroomid_joined: db.userroomid_joined.clone(), + userroomid_invitestate: db.userroomid_invitestate.clone(), + userroomid_leftstate: db.userroomid_leftstate.clone(), + userroomid_notificationcount: db.userroomid_notificationcount.clone(), + userroomid_highlightcount: db.userroomid_highlightcount.clone(), + pduid_pdu: db.pduid_pdu.clone(), + keychangeid_userid: db.keychangeid_userid.clone(), + roomusertype_roomuserdataid: db.roomusertype_roomuserdataid.clone(), + server_signingkeys: db.server_signingkeys.clone(), + readreceiptid_readreceipt: db.readreceiptid_readreceipt.clone(), + userid_lastonetimekeyupdate: db.userid_lastonetimekeyupdate.clone(), + db: db.clone(), + } + } + + pub fn next_count(&self) -> Result { utils::u64_from_bytes(&self.global.increment(COUNTER)?) .map_err(|_| Error::bad_database("Count has invalid bytes.")) } - fn current_count(&self) -> Result { + pub fn current_count(&self) -> Result { self.global.get(COUNTER)?.map_or(Ok(0_u64), |bytes| { utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Count has invalid bytes.")) }) } - fn last_check_for_updates_id(&self) -> Result { + pub fn last_check_for_updates_id(&self) -> Result { self.global .get(LAST_CHECK_FOR_UPDATES_COUNT)? .map_or(Ok(0_u64), |bytes| { @@ -68,16 +75,15 @@ impl Data for KeyValueDatabase { }) } - fn update_check_for_updates_id(&self, id: u64) -> Result<()> { + pub fn update_check_for_updates_id(&self, id: u64) -> Result<()> { self.global .insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?; Ok(()) } - #[allow(unused_qualifications)] // async traits #[tracing::instrument(skip(self))] - async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { + pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { let userid_bytes = user_id.as_bytes().to_vec(); let mut userid_prefix = userid_bytes.clone(); userid_prefix.push(0xFF); @@ -177,20 +183,20 @@ impl Data for KeyValueDatabase { Ok(()) } - fn cleanup(&self) -> Result<()> { self.db.cleanup() } + pub fn cleanup(&self) -> Result<()> { self.db.db.cleanup() } - fn cork(&self) -> Cork { Cork::new(&self.db, false, false) } + pub fn cork(&self) -> Cork { Cork::new(&self.db.db, false, false) } - fn cork_and_flush(&self) -> Cork { Cork::new(&self.db, true, false) } + pub fn cork_and_flush(&self) -> Cork { Cork::new(&self.db.db, true, false) } - fn memory_usage(&self) -> String { - let auth_chain_cache = self.auth_chain_cache.lock().unwrap().len(); - let appservice_in_room_cache = self.appservice_in_room_cache.read().unwrap().len(); - let lasttimelinecount_cache = self.lasttimelinecount_cache.lock().unwrap().len(); + pub fn memory_usage(&self) -> String { + let auth_chain_cache = self.db.auth_chain_cache.lock().unwrap().len(); + let appservice_in_room_cache = self.db.appservice_in_room_cache.read().unwrap().len(); + let lasttimelinecount_cache = self.db.lasttimelinecount_cache.lock().unwrap().len(); - let max_auth_chain_cache = self.auth_chain_cache.lock().unwrap().capacity(); - let max_appservice_in_room_cache = self.appservice_in_room_cache.read().unwrap().capacity(); - let max_lasttimelinecount_cache = self.lasttimelinecount_cache.lock().unwrap().capacity(); + let max_auth_chain_cache = self.db.auth_chain_cache.lock().unwrap().capacity(); + let max_appservice_in_room_cache = self.db.appservice_in_room_cache.read().unwrap().capacity(); + let max_lasttimelinecount_cache = self.db.lasttimelinecount_cache.lock().unwrap().capacity(); format!( "\ @@ -198,26 +204,26 @@ auth_chain_cache: {auth_chain_cache} / {max_auth_chain_cache} appservice_in_room_cache: {appservice_in_room_cache} / {max_appservice_in_room_cache} lasttimelinecount_cache: {lasttimelinecount_cache} / {max_lasttimelinecount_cache}\n\n {}", - self.db.memory_usage().unwrap_or_default() + self.db.db.memory_usage().unwrap_or_default() ) } - fn clear_caches(&self, amount: u32) { + pub fn clear_caches(&self, amount: u32) { if amount > 1 { - let c = &mut *self.auth_chain_cache.lock().unwrap(); + let c = &mut *self.db.auth_chain_cache.lock().unwrap(); *c = LruCache::new(c.capacity()); } if amount > 2 { - let c = &mut *self.appservice_in_room_cache.write().unwrap(); + let c = &mut *self.db.appservice_in_room_cache.write().unwrap(); *c = HashMap::new(); } if amount > 3 { - let c = &mut *self.lasttimelinecount_cache.lock().unwrap(); + let c = &mut *self.db.lasttimelinecount_cache.lock().unwrap(); *c = HashMap::new(); } } - fn load_keypair(&self) -> Result { + pub fn load_keypair(&self) -> Result { let keypair_bytes = self.global.get(b"keypair")?.map_or_else( || { let keypair = utils::generate_keypair(); @@ -249,9 +255,9 @@ lasttimelinecount_cache: {lasttimelinecount_cache} / {max_lasttimelinecount_cach }) } - fn remove_keypair(&self) -> Result<()> { self.global.remove(b"keypair") } + pub fn remove_keypair(&self) -> Result<()> { self.global.remove(b"keypair") } - fn add_signing_key( + pub fn add_signing_key( &self, origin: &ServerName, new_keys: ServerSigningKeys, ) -> Result> { // Not atomic, but this is not critical @@ -290,8 +296,9 @@ lasttimelinecount_cache: {lasttimelinecount_cache} / {max_lasttimelinecount_cach /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found /// for the server. - fn signing_keys_for(&self, origin: &ServerName) -> Result> { + pub fn signing_keys_for(&self, origin: &ServerName) -> Result> { let signingkeys = self + .db .server_signingkeys .get(origin.as_bytes())? .and_then(|bytes| serde_json::from_slice(&bytes).ok()) @@ -308,20 +315,20 @@ lasttimelinecount_cache: {lasttimelinecount_cache} / {max_lasttimelinecount_cach Ok(signingkeys) } - fn database_version(&self) -> Result { + pub fn database_version(&self) -> Result { self.global.get(b"version")?.map_or(Ok(0), |version| { utils::u64_from_bytes(&version).map_err(|_| Error::bad_database("Database version id is invalid.")) }) } - fn bump_database_version(&self, new_version: u64) -> Result<()> { + pub fn bump_database_version(&self, new_version: u64) -> Result<()> { self.global.insert(b"version", &new_version.to_be_bytes())?; Ok(()) } - fn backup(&self) -> Result<(), Box> { self.db.backup() } + pub fn backup(&self) -> Result<(), Box> { self.db.db.backup() } - fn backup_list(&self) -> Result { self.db.backup_list() } + pub fn backup_list(&self) -> Result { self.db.db.backup_list() } - fn file_list(&self) -> Result { self.db.file_list() } + pub fn file_list(&self) -> Result { self.db.db.file_list() } } diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 2e106030..9a6aef18 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -1,3 +1,5 @@ +use conduit::Server; + mod client; mod data; pub(super) mod emerg_access; @@ -33,12 +35,12 @@ use tracing::{error, trace}; use url::Url; use utils::MutexMap; -use crate::{services, Config, Result}; +use crate::{services, Config, KeyValueDatabase, Result}; type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries pub struct Service { - pub db: Arc, + pub db: Data, pub config: Config, pub cidr_range_denylist: Vec, @@ -62,7 +64,9 @@ pub struct Service { } impl Service { - pub fn load(db: Arc, config: &Config) -> Result { + pub fn build(server: &Arc, db: &Arc) -> Result { + let config = &server.config; + let db = Data::new(db); let keypair = db.load_keypair(); let keypair = match keypair { diff --git a/src/service/key_backups/data.rs b/src/service/key_backups/data.rs index 491fd0e4..78042ccf 100644 --- a/src/service/key_backups/data.rs +++ b/src/service/key_backups/data.rs @@ -1,4 +1,4 @@ -use std::collections::BTreeMap; +use std::{collections::BTreeMap, sync::Arc}; use ruma::{ api::client::{ @@ -9,48 +9,24 @@ use ruma::{ OwnedRoomId, RoomId, UserId, }; -use crate::{services, utils, Error, KeyValueDatabase, Result}; +use crate::{services, utils, Error, KeyValueDatabase, KvTree, Result}; -pub(crate) trait Data: Send + Sync { - fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw) -> Result; - - fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()>; - - fn update_backup(&self, user_id: &UserId, version: &str, backup_metadata: &Raw) -> Result; - - fn get_latest_backup_version(&self, user_id: &UserId) -> Result>; - - fn get_latest_backup(&self, user_id: &UserId) -> Result)>>; - - fn get_backup(&self, user_id: &UserId, version: &str) -> Result>>; - - fn add_key( - &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw, - ) -> Result<()>; - - fn count_keys(&self, user_id: &UserId, version: &str) -> Result; - - fn get_etag(&self, user_id: &UserId, version: &str) -> Result; - - fn get_all(&self, user_id: &UserId, version: &str) -> Result>; - - fn get_room( - &self, user_id: &UserId, version: &str, room_id: &RoomId, - ) -> Result>>; - - fn get_session( - &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, - ) -> Result>>; - - fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()>; - - fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()>; - - fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()>; +pub(crate) struct Data { + backupid_algorithm: Arc, + backupid_etag: Arc, + backupkeyid_backup: Arc, } -impl Data for KeyValueDatabase { - fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw) -> Result { +impl Data { + pub(super) fn new(db: &Arc) -> Self { + Self { + backupid_algorithm: db.backupid_algorithm.clone(), + backupid_etag: db.backupid_etag.clone(), + backupkeyid_backup: db.backupkeyid_backup.clone(), + } + } + + pub(super) fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw) -> Result { let version = services().globals.next_count()?.to_string(); let mut key = user_id.as_bytes().to_vec(); @@ -66,7 +42,7 @@ impl Data for KeyValueDatabase { Ok(version) } - fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> { + pub(super) fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> { let mut key = user_id.as_bytes().to_vec(); key.push(0xFF); key.extend_from_slice(version.as_bytes()); @@ -83,7 +59,9 @@ impl Data for KeyValueDatabase { Ok(()) } - fn update_backup(&self, user_id: &UserId, version: &str, backup_metadata: &Raw) -> Result { + pub(super) fn update_backup( + &self, user_id: &UserId, version: &str, backup_metadata: &Raw, + ) -> Result { let mut key = user_id.as_bytes().to_vec(); key.push(0xFF); key.extend_from_slice(version.as_bytes()); @@ -99,7 +77,7 @@ impl Data for KeyValueDatabase { Ok(version.to_owned()) } - fn get_latest_backup_version(&self, user_id: &UserId) -> Result> { + pub(super) fn get_latest_backup_version(&self, user_id: &UserId) -> Result> { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xFF); let mut last_possible_key = prefix.clone(); @@ -120,7 +98,7 @@ impl Data for KeyValueDatabase { .transpose() } - fn get_latest_backup(&self, user_id: &UserId) -> Result)>> { + pub(super) fn get_latest_backup(&self, user_id: &UserId) -> Result)>> { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xFF); let mut last_possible_key = prefix.clone(); @@ -147,7 +125,7 @@ impl Data for KeyValueDatabase { .transpose() } - fn get_backup(&self, user_id: &UserId, version: &str) -> Result>> { + pub(super) fn get_backup(&self, user_id: &UserId, version: &str) -> Result>> { let mut key = user_id.as_bytes().to_vec(); key.push(0xFF); key.extend_from_slice(version.as_bytes()); @@ -160,7 +138,7 @@ impl Data for KeyValueDatabase { }) } - fn add_key( + pub(super) fn add_key( &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw, ) -> Result<()> { let mut key = user_id.as_bytes().to_vec(); @@ -185,7 +163,7 @@ impl Data for KeyValueDatabase { Ok(()) } - fn count_keys(&self, user_id: &UserId, version: &str) -> Result { + pub(super) fn count_keys(&self, user_id: &UserId, version: &str) -> Result { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xFF); prefix.extend_from_slice(version.as_bytes()); @@ -193,7 +171,7 @@ impl Data for KeyValueDatabase { Ok(self.backupkeyid_backup.scan_prefix(prefix).count()) } - fn get_etag(&self, user_id: &UserId, version: &str) -> Result { + pub(super) fn get_etag(&self, user_id: &UserId, version: &str) -> Result { let mut key = user_id.as_bytes().to_vec(); key.push(0xFF); key.extend_from_slice(version.as_bytes()); @@ -208,7 +186,7 @@ impl Data for KeyValueDatabase { .to_string()) } - fn get_all(&self, user_id: &UserId, version: &str) -> Result> { + pub(super) fn get_all(&self, user_id: &UserId, version: &str) -> Result> { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xFF); prefix.extend_from_slice(version.as_bytes()); @@ -257,7 +235,7 @@ impl Data for KeyValueDatabase { Ok(rooms) } - fn get_room( + pub(super) fn get_room( &self, user_id: &UserId, version: &str, room_id: &RoomId, ) -> Result>> { let mut prefix = user_id.as_bytes().to_vec(); @@ -289,7 +267,7 @@ impl Data for KeyValueDatabase { .collect()) } - fn get_session( + pub(super) fn get_session( &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, ) -> Result>> { let mut key = user_id.as_bytes().to_vec(); @@ -309,7 +287,7 @@ impl Data for KeyValueDatabase { .transpose() } - fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> { + pub(super) fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> { let mut key = user_id.as_bytes().to_vec(); key.push(0xFF); key.extend_from_slice(version.as_bytes()); @@ -322,7 +300,7 @@ impl Data for KeyValueDatabase { Ok(()) } - fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> { + pub(super) fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> { let mut key = user_id.as_bytes().to_vec(); key.push(0xFF); key.extend_from_slice(version.as_bytes()); @@ -337,7 +315,9 @@ impl Data for KeyValueDatabase { Ok(()) } - fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()> { + pub(super) fn delete_room_key( + &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, + ) -> Result<()> { let mut key = user_id.as_bytes().to_vec(); key.push(0xFF); key.extend_from_slice(version.as_bytes()); diff --git a/src/service/key_backups/mod.rs b/src/service/key_backups/mod.rs index 9b88e293..19b5d53d 100644 --- a/src/service/key_backups/mod.rs +++ b/src/service/key_backups/mod.rs @@ -1,20 +1,28 @@ +use conduit::Server; + mod data; use std::{collections::BTreeMap, sync::Arc}; +use conduit::Result; use data::Data; +use database::KeyValueDatabase; use ruma::{ api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, serde::Raw, OwnedRoomId, RoomId, UserId, }; -use crate::Result; - pub struct Service { - pub(super) db: Arc, + pub(super) db: Data, } impl Service { + pub fn build(_server: &Arc, db: &Arc) -> Result { + Ok(Self { + db: Data::new(db), + }) + } + pub fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw) -> Result { self.db.create_backup(user_id, backup_metadata) } diff --git a/src/service/media/data.rs b/src/service/media/data.rs index f2a83a53..7b8d0b5a 100644 --- a/src/service/media/data.rs +++ b/src/service/media/data.rs @@ -1,37 +1,28 @@ +use std::sync::Arc; + use conduit::debug_info; +use database::{KeyValueDatabase, KvTree}; use ruma::api::client::error::ErrorKind; use tracing::debug; -use crate::{media::UrlPreviewData, utils::string_from_bytes, Error, KeyValueDatabase, Result}; +use crate::{media::UrlPreviewData, utils::string_from_bytes, Error, Result}; -pub(crate) trait Data: Send + Sync { - fn create_file_metadata( - &self, sender_user: Option<&str>, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, - content_type: Option<&str>, - ) -> Result>; - - fn delete_file_mxc(&self, mxc: String) -> Result<()>; - - /// Returns content_disposition, content_type and the metadata key. - fn search_file_metadata( - &self, mxc: String, width: u32, height: u32, - ) -> Result<(Option, Option, Vec)>; - - fn search_mxc_metadata_prefix(&self, mxc: String) -> Result>>; - - fn get_all_media_keys(&self) -> Vec>; - - // TODO: use this - #[allow(dead_code)] - fn remove_url_preview(&self, url: &str) -> Result<()>; - - fn set_url_preview(&self, url: &str, data: &UrlPreviewData, timestamp: std::time::Duration) -> Result<()>; - - fn get_url_preview(&self, url: &str) -> Option; +pub struct Data { + mediaid_file: Arc, + mediaid_user: Arc, + url_previews: Arc, } -impl Data for KeyValueDatabase { - fn create_file_metadata( +impl Data { + pub(super) fn new(db: &Arc) -> Self { + Self { + mediaid_file: db.mediaid_file.clone(), + mediaid_user: db.mediaid_user.clone(), + url_previews: db.url_previews.clone(), + } + } + + pub(super) fn create_file_metadata( &self, sender_user: Option<&str>, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, content_type: Option<&str>, ) -> Result> { @@ -65,7 +56,7 @@ impl Data for KeyValueDatabase { Ok(key) } - fn delete_file_mxc(&self, mxc: String) -> Result<()> { + pub(super) fn delete_file_mxc(&self, mxc: String) -> Result<()> { debug!("MXC URI: {:?}", mxc); let mut prefix = mxc.as_bytes().to_vec(); @@ -91,7 +82,7 @@ impl Data for KeyValueDatabase { } /// Searches for all files with the given MXC - fn search_mxc_metadata_prefix(&self, mxc: String) -> Result>> { + pub(super) fn search_mxc_metadata_prefix(&self, mxc: String) -> Result>> { debug!("MXC URI: {:?}", mxc); let mut prefix = mxc.as_bytes().to_vec(); @@ -114,7 +105,7 @@ impl Data for KeyValueDatabase { Ok(keys) } - fn search_file_metadata( + pub(super) fn search_file_metadata( &self, mxc: String, width: u32, height: u32, ) -> Result<(Option, Option, Vec)> { let mut prefix = mxc.as_bytes().to_vec(); @@ -156,11 +147,13 @@ impl Data for KeyValueDatabase { /// Gets all the media keys in our database (this includes all the metadata /// associated with it such as width, height, content-type, etc) - fn get_all_media_keys(&self) -> Vec> { self.mediaid_file.iter().map(|(key, _)| key).collect() } + pub(crate) fn get_all_media_keys(&self) -> Vec> { self.mediaid_file.iter().map(|(key, _)| key).collect() } - fn remove_url_preview(&self, url: &str) -> Result<()> { self.url_previews.remove(url.as_bytes()) } + pub(super) fn remove_url_preview(&self, url: &str) -> Result<()> { self.url_previews.remove(url.as_bytes()) } - fn set_url_preview(&self, url: &str, data: &UrlPreviewData, timestamp: std::time::Duration) -> Result<()> { + pub(super) fn set_url_preview( + &self, url: &str, data: &UrlPreviewData, timestamp: std::time::Duration, + ) -> Result<()> { let mut value = Vec::::new(); value.extend_from_slice(×tamp.as_secs().to_be_bytes()); value.push(0xFF); @@ -194,7 +187,7 @@ impl Data for KeyValueDatabase { self.url_previews.insert(url.as_bytes(), &value) } - fn get_url_preview(&self, url: &str) -> Option { + pub(super) fn get_url_preview(&self, url: &str) -> Option { let values = self.url_previews.get(url.as_bytes()).ok()??; let mut values = values.split(|&b| b == 0xFF); diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs index 17a1a6dc..a6e2a2a9 100644 --- a/src/service/media/mod.rs +++ b/src/service/media/mod.rs @@ -44,17 +44,17 @@ pub struct UrlPreviewData { pub struct Service { server: Arc, - pub(super) db: Arc, + pub db: Data, pub url_preview_mutex: RwLock>>>, } impl Service { - pub fn build(server: &Arc, db: &Arc) -> Self { - Self { + pub fn build(server: &Arc, db: &Arc) -> Result { + Ok(Self { server: server.clone(), - db: db.clone(), + db: Data::new(db), url_preview_mutex: RwLock::new(HashMap::new()), - } + }) } /// Uploads a file. diff --git a/src/service/mod.rs b/src/service/mod.rs index 2178316a..d1db2c25 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -20,7 +20,7 @@ extern crate conduit_database as database; use std::sync::{Arc, RwLock}; pub(crate) use conduit::{config, debug_error, debug_info, debug_warn, utils, Config, Error, PduCount, Result, Server}; -pub(crate) use database::KeyValueDatabase; +pub(crate) use database::{KeyValueDatabase, KvTree}; pub use globals::{server_is_ours, user_is_local}; pub use pdu::PduEvent; pub use services::Services; diff --git a/src/service/presence/data.rs b/src/service/presence/data.rs index bf730398..ade891a9 100644 --- a/src/service/presence/data.rs +++ b/src/service/presence/data.rs @@ -1,4 +1,7 @@ +use std::sync::Arc; + use conduit::debug_warn; +use database::KvTree; use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, UInt, UserId}; use crate::{ @@ -8,26 +11,20 @@ use crate::{ Error, KeyValueDatabase, Result, }; -pub trait Data: Send + Sync { - /// Returns the latest presence event for the given user. - fn get_presence(&self, user_id: &UserId) -> Result>; - - /// Adds a presence event which will be saved until a new event replaces it. - fn set_presence( - &self, user_id: &UserId, presence_state: &PresenceState, currently_active: Option, - last_active_ago: Option, status_msg: Option, - ) -> Result<()>; - - /// Removes the presence record for the given user from the database. - fn remove_presence(&self, user_id: &UserId) -> Result<()>; - - /// Returns the most recent presence updates that happened after the event - /// with id `since`. - fn presence_since<'a>(&'a self, since: u64) -> Box)> + 'a>; +pub struct Data { + presenceid_presence: Arc, + userid_presenceid: Arc, } -impl Data for KeyValueDatabase { - fn get_presence(&self, user_id: &UserId) -> Result> { +impl Data { + pub(super) fn new(db: &Arc) -> Self { + Self { + presenceid_presence: db.presenceid_presence.clone(), + userid_presenceid: db.userid_presenceid.clone(), + } + } + + pub fn get_presence(&self, user_id: &UserId) -> Result> { if let Some(count_bytes) = self.userid_presenceid.get(user_id.as_bytes())? { let count = utils::u64_from_bytes(&count_bytes) .map_err(|_e| Error::bad_database("No 'count' bytes in presence key"))?; @@ -44,7 +41,7 @@ impl Data for KeyValueDatabase { } } - fn set_presence( + pub(super) fn set_presence( &self, user_id: &UserId, presence_state: &PresenceState, currently_active: Option, last_active_ago: Option, status_msg: Option, ) -> Result<()> { @@ -105,7 +102,7 @@ impl Data for KeyValueDatabase { Ok(()) } - fn remove_presence(&self, user_id: &UserId) -> Result<()> { + pub(super) fn remove_presence(&self, user_id: &UserId) -> Result<()> { if let Some(count_bytes) = self.userid_presenceid.get(user_id.as_bytes())? { let count = utils::u64_from_bytes(&count_bytes) .map_err(|_e| Error::bad_database("No 'count' bytes in presence key"))?; @@ -117,7 +114,7 @@ impl Data for KeyValueDatabase { Ok(()) } - fn presence_since<'a>(&'a self, since: u64) -> Box)> + 'a> { + pub fn presence_since<'a>(&'a self, since: u64) -> Box)> + 'a> { Box::new( self.presenceid_presence .iter() diff --git a/src/service/presence/mod.rs b/src/service/presence/mod.rs index de5030f6..046614c9 100644 --- a/src/service/presence/mod.rs +++ b/src/service/presence/mod.rs @@ -1,3 +1,5 @@ +use conduit::Server; + mod data; use std::{sync::Arc, time::Duration}; @@ -14,9 +16,10 @@ use tokio::{sync::Mutex, task::JoinHandle, time::sleep}; use tracing::{debug, error}; use crate::{ + database::KeyValueDatabase, services, user_is_local, utils::{self}, - Config, Error, Result, + Error, Result, }; /// Represents data required to be kept in order to implement the presence @@ -77,7 +80,7 @@ impl Presence { } pub struct Service { - pub db: Arc, + pub db: Data, pub timer_sender: loole::Sender<(OwnedUserId, Duration)>, timer_receiver: Mutex>, handler_join: Mutex>>, @@ -85,15 +88,16 @@ pub struct Service { } impl Service { - pub fn build(db: Arc, config: &Config) -> Arc { + pub fn build(server: &Arc, db: &Arc) -> Result> { + let config = &server.config; let (timer_sender, timer_receiver) = loole::unbounded(); - Arc::new(Self { - db, + Ok(Arc::new(Self { + db: Data::new(db), timer_sender, timer_receiver: Mutex::new(timer_receiver), handler_join: Mutex::new(None), timeout_remote_users: config.presence_timeout_remote_users, - }) + })) } pub async fn start_handler(self: &Arc) { diff --git a/src/service/pusher/data.rs b/src/service/pusher/data.rs index d8c75f4d..0892073f 100644 --- a/src/service/pusher/data.rs +++ b/src/service/pusher/data.rs @@ -1,22 +1,25 @@ +use std::sync::Arc; + +use database::{KeyValueDatabase, KvTree}; use ruma::{ api::client::push::{set_pusher, Pusher}, UserId, }; -use crate::{utils, Error, KeyValueDatabase, Result}; +use crate::{utils, Error, Result}; -pub(crate) trait Data: Send + Sync { - fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()>; - - fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result>; - - fn get_pushers(&self, sender: &UserId) -> Result>; - - fn get_pushkeys<'a>(&'a self, sender: &UserId) -> Box> + 'a>; +pub struct Data { + senderkey_pusher: Arc, } -impl Data for KeyValueDatabase { - fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> { +impl Data { + pub(super) fn new(db: &Arc) -> Self { + Self { + senderkey_pusher: db.senderkey_pusher.clone(), + } + } + + pub(super) fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> { match &pusher { set_pusher::v3::PusherAction::Post(data) => { let mut key = sender.as_bytes().to_vec(); @@ -35,7 +38,7 @@ impl Data for KeyValueDatabase { } } - fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result> { + pub(super) fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result> { let mut senderkey = sender.as_bytes().to_vec(); senderkey.push(0xFF); senderkey.extend_from_slice(pushkey.as_bytes()); @@ -46,7 +49,7 @@ impl Data for KeyValueDatabase { .transpose() } - fn get_pushers(&self, sender: &UserId) -> Result> { + pub(super) fn get_pushers(&self, sender: &UserId) -> Result> { let mut prefix = sender.as_bytes().to_vec(); prefix.push(0xFF); @@ -56,7 +59,7 @@ impl Data for KeyValueDatabase { .collect() } - fn get_pushkeys<'a>(&'a self, sender: &UserId) -> Box> + 'a> { + pub(super) fn get_pushkeys<'a>(&'a self, sender: &UserId) -> Box> + 'a> { let mut prefix = sender.as_bytes().to_vec(); prefix.push(0xFF); diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs index 5736ef1d..ff167918 100644 --- a/src/service/pusher/mod.rs +++ b/src/service/pusher/mod.rs @@ -1,3 +1,6 @@ +use conduit::Server; +use database::KeyValueDatabase; + mod data; use std::{fmt::Debug, mem, sync::Arc}; @@ -25,10 +28,16 @@ use tracing::{info, trace, warn}; use crate::{debug_info, services, Error, PduEvent, Result}; pub struct Service { - pub(super) db: Arc, + pub db: Data, } impl Service { + pub fn build(_server: &Arc, db: &Arc) -> Result { + Ok(Self { + db: Data::new(db), + }) + } + pub fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> { self.db.set_pusher(sender, pusher) } diff --git a/src/service/rooms/alias/data.rs b/src/service/rooms/alias/data.rs index e0f4335e..bc0c9cc6 100644 --- a/src/service/rooms/alias/data.rs +++ b/src/service/rooms/alias/data.rs @@ -1,31 +1,26 @@ +use std::sync::Arc; + +use database::KvTree; use ruma::{api::client::error::ErrorKind, OwnedRoomAliasId, OwnedRoomId, OwnedUserId, RoomAliasId, RoomId, UserId}; use crate::{services, utils, Error, KeyValueDatabase, Result}; -pub trait Data: Send + Sync { - /// Creates or updates the alias to the given room id. - fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()>; - - /// Forgets about an alias. Returns an error if the alias did not exist. - fn remove_alias(&self, alias: &RoomAliasId) -> Result<()>; - - /// Looks up the roomid for the given alias. - fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result>; - - /// Finds the user who assigned the given alias to a room - fn who_created_alias(&self, alias: &RoomAliasId) -> Result>; - - /// Returns all local aliases that point to the given room - fn local_aliases_for_room<'a>( - &'a self, room_id: &RoomId, - ) -> Box> + 'a>; - - /// Returns all local aliases on the server - fn all_local_aliases<'a>(&'a self) -> Box> + 'a>; +pub struct Data { + alias_userid: Arc, + alias_roomid: Arc, + aliasid_alias: Arc, } -impl Data for KeyValueDatabase { - fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()> { +impl Data { + pub(super) fn new(db: &Arc) -> Self { + Self { + alias_userid: db.alias_userid.clone(), + alias_roomid: db.alias_roomid.clone(), + aliasid_alias: db.aliasid_alias.clone(), + } + } + + pub(super) fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()> { // Comes first as we don't want a stuck alias self.alias_userid .insert(alias.alias().as_bytes(), user_id.as_bytes())?; @@ -41,7 +36,7 @@ impl Data for KeyValueDatabase { Ok(()) } - fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> { + pub(super) fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> { if let Some(room_id) = self.alias_roomid.get(alias.alias().as_bytes())? { let mut prefix = room_id; prefix.push(0xFF); @@ -60,7 +55,7 @@ impl Data for KeyValueDatabase { Ok(()) } - fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result> { + pub fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result> { self.alias_roomid .get(alias.alias().as_bytes())? .map(|bytes| { @@ -73,7 +68,7 @@ impl Data for KeyValueDatabase { .transpose() } - fn who_created_alias(&self, alias: &RoomAliasId) -> Result> { + pub(super) fn who_created_alias(&self, alias: &RoomAliasId) -> Result> { self.alias_userid .get(alias.alias().as_bytes())? .map(|bytes| { @@ -86,7 +81,7 @@ impl Data for KeyValueDatabase { .transpose() } - fn local_aliases_for_room<'a>( + pub fn local_aliases_for_room<'a>( &'a self, room_id: &RoomId, ) -> Box> + 'a> { let mut prefix = room_id.as_bytes().to_vec(); @@ -100,7 +95,7 @@ impl Data for KeyValueDatabase { })) } - fn all_local_aliases<'a>(&'a self) -> Box> + 'a> { + pub fn all_local_aliases<'a>(&'a self) -> Box> + 'a> { Box::new( self.alias_roomid .iter() diff --git a/src/service/rooms/alias/mod.rs b/src/service/rooms/alias/mod.rs index 80f125dc..5c4f5cdb 100644 --- a/src/service/rooms/alias/mod.rs +++ b/src/service/rooms/alias/mod.rs @@ -1,3 +1,6 @@ +use conduit::Server; +use database::KeyValueDatabase; + mod data; use std::sync::Arc; @@ -15,10 +18,16 @@ use ruma::{ use crate::{services, Error, Result}; pub struct Service { - pub db: Arc, + pub db: Data, } impl Service { + pub fn build(_server: &Arc, db: &Arc) -> Result { + Ok(Self { + db: Data::new(db), + }) + } + #[tracing::instrument(skip(self))] pub fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()> { if alias == services().globals.admin_alias && user_id != services().globals.server_user { diff --git a/src/service/rooms/auth_chain/data.rs b/src/service/rooms/auth_chain/data.rs index c3f046fc..bb6e6516 100644 --- a/src/service/rooms/auth_chain/data.rs +++ b/src/service/rooms/auth_chain/data.rs @@ -1,16 +1,25 @@ use std::{mem::size_of, sync::Arc}; +use database::KvTree; + use crate::{utils, KeyValueDatabase, Result}; -pub trait Data: Send + Sync { - fn get_cached_eventid_authchain(&self, shorteventid: &[u64]) -> Result>>; - fn cache_auth_chain(&self, shorteventid: Vec, auth_chain: Arc<[u64]>) -> Result<()>; +pub(super) struct Data { + shorteventid_authchain: Arc, + db: Arc, } -impl Data for KeyValueDatabase { - fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result>> { +impl Data { + pub(super) fn new(db: &Arc) -> Self { + Self { + shorteventid_authchain: db.shorteventid_authchain.clone(), + db: db.clone(), + } + } + + pub(super) fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result>> { // Check RAM cache - if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) { + if let Some(result) = self.db.auth_chain_cache.lock().unwrap().get_mut(key) { return Ok(Some(Arc::clone(result))); } @@ -29,7 +38,8 @@ impl Data for KeyValueDatabase { if let Some(chain) = chain { // Cache in RAM - self.auth_chain_cache + self.db + .auth_chain_cache .lock() .unwrap() .insert(vec![key[0]], Arc::clone(&chain)); @@ -41,7 +51,7 @@ impl Data for KeyValueDatabase { Ok(None) } - fn cache_auth_chain(&self, key: Vec, auth_chain: Arc<[u64]>) -> Result<()> { + pub(super) fn cache_auth_chain(&self, key: Vec, auth_chain: Arc<[u64]>) -> Result<()> { // Only persist single events in db if key.len() == 1 { self.shorteventid_authchain.insert( @@ -54,7 +64,8 @@ impl Data for KeyValueDatabase { } // Cache in RAM - self.auth_chain_cache + self.db + .auth_chain_cache .lock() .unwrap() .insert(key, auth_chain); diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index b93ea8c5..6cadee0f 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -1,3 +1,6 @@ +use conduit::Server; +use database::KeyValueDatabase; + mod data; use std::{ collections::{BTreeSet, HashSet}, @@ -11,10 +14,16 @@ use tracing::{debug, error, trace, warn}; use crate::{services, Error, Result}; pub struct Service { - pub db: Arc, + db: Data, } impl Service { + pub fn build(_server: &Arc, db: &Arc) -> Result { + Ok(Self { + db: Data::new(db), + }) + } + pub async fn event_ids_iter<'a>( &self, room_id: &RoomId, starting_events_: Vec>, ) -> Result> + 'a> { diff --git a/src/service/rooms/directory/data.rs b/src/service/rooms/directory/data.rs index 67afbb6c..f4c453eb 100644 --- a/src/service/rooms/directory/data.rs +++ b/src/service/rooms/directory/data.rs @@ -1,31 +1,34 @@ +use std::sync::Arc; + +use database::KvTree; use ruma::{OwnedRoomId, RoomId}; use crate::{utils, Error, KeyValueDatabase, Result}; -pub trait Data: Send + Sync { - /// Adds the room to the public room directory - fn set_public(&self, room_id: &RoomId) -> Result<()>; - - /// Removes the room from the public room directory. - fn set_not_public(&self, room_id: &RoomId) -> Result<()>; - - /// Returns true if the room is in the public room directory. - fn is_public_room(&self, room_id: &RoomId) -> Result; - - /// Returns the unsorted public room directory - fn public_rooms<'a>(&'a self) -> Box> + 'a>; +pub(super) struct Data { + publicroomids: Arc, } -impl Data for KeyValueDatabase { - fn set_public(&self, room_id: &RoomId) -> Result<()> { self.publicroomids.insert(room_id.as_bytes(), &[]) } +impl Data { + pub(super) fn new(db: &Arc) -> Self { + Self { + publicroomids: db.publicroomids.clone(), + } + } - fn set_not_public(&self, room_id: &RoomId) -> Result<()> { self.publicroomids.remove(room_id.as_bytes()) } + pub(super) fn set_public(&self, room_id: &RoomId) -> Result<()> { + self.publicroomids.insert(room_id.as_bytes(), &[]) + } - fn is_public_room(&self, room_id: &RoomId) -> Result { + pub(super) fn set_not_public(&self, room_id: &RoomId) -> Result<()> { + self.publicroomids.remove(room_id.as_bytes()) + } + + pub(super) fn is_public_room(&self, room_id: &RoomId) -> Result { Ok(self.publicroomids.get(room_id.as_bytes())?.is_some()) } - fn public_rooms<'a>(&'a self) -> Box> + 'a> { + pub(super) fn public_rooms<'a>(&'a self) -> Box> + 'a> { Box::new(self.publicroomids.iter().map(|(bytes, _)| { RoomId::parse( utils::string_from_bytes(&bytes) diff --git a/src/service/rooms/directory/mod.rs b/src/service/rooms/directory/mod.rs index 85909c74..a0c4caf7 100644 --- a/src/service/rooms/directory/mod.rs +++ b/src/service/rooms/directory/mod.rs @@ -1,3 +1,6 @@ +use conduit::Server; +use database::KeyValueDatabase; + mod data; use std::sync::Arc; @@ -8,10 +11,16 @@ use ruma::{OwnedRoomId, RoomId}; use crate::Result; pub struct Service { - pub db: Arc, + db: Data, } impl Service { + pub fn build(_server: &Arc, db: &Arc) -> Result { + Ok(Self { + db: Data::new(db), + }) + } + #[tracing::instrument(skip(self))] pub fn set_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_public(room_id) } diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index 813b8d4a..c7e819ac 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -1,6 +1,8 @@ +use conduit::Server; +use database::KeyValueDatabase; + mod parse_incoming_pdu; mod signing_keys; -pub struct Service; use std::{ cmp, @@ -32,6 +34,8 @@ use tracing::{debug, error, info, trace, warn}; use super::state_compressor::CompressedStateEvent; use crate::{debug_error, debug_info, pdu, services, Error, PduEvent, Result}; +pub struct Service; + // We use some AsyncRecursiveType hacks here so we can call async funtion // recursively. type AsyncRecursiveType<'a, T> = Pin + 'a + Send>>; @@ -41,6 +45,8 @@ type AsyncRecursiveCanonicalJsonResult<'a> = AsyncRecursiveType<'a, Result<(Arc, BTreeMap)>>; impl Service { + pub fn build(_server: &Arc, _db: &Arc) -> Result { Ok(Self {}) } + /// When receiving an event one needs to: /// 0. Check the server is in the room /// 1. Skip the PDU if we already know about it diff --git a/src/service/rooms/lazy_loading/data.rs b/src/service/rooms/lazy_loading/data.rs index 7c237901..63823c4a 100644 --- a/src/service/rooms/lazy_loading/data.rs +++ b/src/service/rooms/lazy_loading/data.rs @@ -1,22 +1,22 @@ +use std::sync::Arc; + +use database::KvTree; use ruma::{DeviceId, RoomId, UserId}; use crate::{KeyValueDatabase, Result}; -pub trait Data: Send + Sync { - fn lazy_load_was_sent_before( - &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId, - ) -> Result; - - fn lazy_load_confirm_delivery( - &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, - confirmed_user_ids: &mut dyn Iterator, - ) -> Result<()>; - - fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()>; +pub struct Data { + lazyloadedids: Arc, } -impl Data for KeyValueDatabase { - fn lazy_load_was_sent_before( +impl Data { + pub(super) fn new(db: &Arc) -> Self { + Self { + lazyloadedids: db.lazyloadedids.clone(), + } + } + + pub(super) fn lazy_load_was_sent_before( &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId, ) -> Result { let mut key = user_id.as_bytes().to_vec(); @@ -29,7 +29,7 @@ impl Data for KeyValueDatabase { Ok(self.lazyloadedids.get(&key)?.is_some()) } - fn lazy_load_confirm_delivery( + pub(super) fn lazy_load_confirm_delivery( &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, confirmed_user_ids: &mut dyn Iterator, ) -> Result<()> { @@ -49,7 +49,7 @@ impl Data for KeyValueDatabase { Ok(()) } - fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()> { + pub(super) fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()> { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xFF); prefix.extend_from_slice(device_id.as_bytes()); diff --git a/src/service/rooms/lazy_loading/mod.rs b/src/service/rooms/lazy_loading/mod.rs index 283a03c1..dba78500 100644 --- a/src/service/rooms/lazy_loading/mod.rs +++ b/src/service/rooms/lazy_loading/mod.rs @@ -1,3 +1,6 @@ +use conduit::Server; +use database::KeyValueDatabase; + mod data; use std::{ collections::{HashMap, HashSet}, @@ -11,13 +14,20 @@ use tokio::sync::Mutex; use crate::{PduCount, Result}; pub struct Service { - pub db: Arc, + pub db: Data, #[allow(clippy::type_complexity)] pub lazy_load_waiting: Mutex>>, } impl Service { + pub fn build(_server: &Arc, db: &Arc) -> Result { + Ok(Self { + db: Data::new(db), + lazy_load_waiting: Mutex::new(HashMap::new()), + }) + } + #[tracing::instrument(skip(self))] pub fn lazy_load_was_sent_before( &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId, diff --git a/src/service/rooms/metadata/data.rs b/src/service/rooms/metadata/data.rs index a6bb701e..b5f2fcc3 100644 --- a/src/service/rooms/metadata/data.rs +++ b/src/service/rooms/metadata/data.rs @@ -1,20 +1,29 @@ +use std::sync::Arc; + +use database::KvTree; use ruma::{OwnedRoomId, RoomId}; use tracing::error; use crate::{services, utils, Error, KeyValueDatabase, Result}; -pub trait Data: Send + Sync { - fn exists(&self, room_id: &RoomId) -> Result; - fn iter_ids<'a>(&'a self) -> Box> + 'a>; - fn is_disabled(&self, room_id: &RoomId) -> Result; - fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()>; - fn is_banned(&self, room_id: &RoomId) -> Result; - fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()>; - fn list_banned_rooms<'a>(&'a self) -> Box> + 'a>; +pub struct Data { + disabledroomids: Arc, + bannedroomids: Arc, + roomid_shortroomid: Arc, + pduid_pdu: Arc, } -impl Data for KeyValueDatabase { - fn exists(&self, room_id: &RoomId) -> Result { +impl Data { + pub(super) fn new(db: &Arc) -> Self { + Self { + disabledroomids: db.disabledroomids.clone(), + bannedroomids: db.bannedroomids.clone(), + roomid_shortroomid: db.roomid_shortroomid.clone(), + pduid_pdu: db.pduid_pdu.clone(), + } + } + + pub(super) fn exists(&self, room_id: &RoomId) -> Result { let prefix = match services().rooms.short.get_shortroomid(room_id)? { Some(b) => b.to_be_bytes().to_vec(), None => return Ok(false), @@ -29,7 +38,7 @@ impl Data for KeyValueDatabase { .is_some()) } - fn iter_ids<'a>(&'a self) -> Box> + 'a> { + pub(super) fn iter_ids<'a>(&'a self) -> Box> + 'a> { Box::new(self.roomid_shortroomid.iter().map(|(bytes, _)| { RoomId::parse( utils::string_from_bytes(&bytes) @@ -39,11 +48,11 @@ impl Data for KeyValueDatabase { })) } - fn is_disabled(&self, room_id: &RoomId) -> Result { + pub(super) fn is_disabled(&self, room_id: &RoomId) -> Result { Ok(self.disabledroomids.get(room_id.as_bytes())?.is_some()) } - fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> { + pub(super) fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> { if disabled { self.disabledroomids.insert(room_id.as_bytes(), &[])?; } else { @@ -53,9 +62,11 @@ impl Data for KeyValueDatabase { Ok(()) } - fn is_banned(&self, room_id: &RoomId) -> Result { Ok(self.bannedroomids.get(room_id.as_bytes())?.is_some()) } + pub(super) fn is_banned(&self, room_id: &RoomId) -> Result { + Ok(self.bannedroomids.get(room_id.as_bytes())?.is_some()) + } - fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> { + pub(super) fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> { if banned { self.bannedroomids.insert(room_id.as_bytes(), &[])?; } else { @@ -65,7 +76,7 @@ impl Data for KeyValueDatabase { Ok(()) } - fn list_banned_rooms<'a>(&'a self) -> Box> + 'a> { + pub(super) fn list_banned_rooms<'a>(&'a self) -> Box> + 'a> { Box::new(self.bannedroomids.iter().map( |(room_id_bytes, _ /* non-banned rooms should not be in this table */)| { let room_id = utils::string_from_bytes(&room_id_bytes) diff --git a/src/service/rooms/metadata/mod.rs b/src/service/rooms/metadata/mod.rs index 2a6f7724..21f946b6 100644 --- a/src/service/rooms/metadata/mod.rs +++ b/src/service/rooms/metadata/mod.rs @@ -1,3 +1,6 @@ +use conduit::Server; +use database::KeyValueDatabase; + mod data; use std::sync::Arc; @@ -8,10 +11,16 @@ use ruma::{OwnedRoomId, RoomId}; use crate::Result; pub struct Service { - pub db: Arc, + pub db: Data, } impl Service { + pub fn build(_server: &Arc, db: &Arc) -> Result { + Ok(Self { + db: Data::new(db), + }) + } + /// Checks if a room exists. #[tracing::instrument(skip(self))] pub fn exists(&self, room_id: &RoomId) -> Result { self.db.exists(room_id) } diff --git a/src/service/rooms/outlier/data.rs b/src/service/rooms/outlier/data.rs index a278161c..f89e9fb4 100644 --- a/src/service/rooms/outlier/data.rs +++ b/src/service/rooms/outlier/data.rs @@ -1,15 +1,22 @@ +use std::sync::Arc; + +use database::KvTree; use ruma::{CanonicalJsonObject, EventId}; use crate::{Error, KeyValueDatabase, PduEvent, Result}; -pub trait Data: Send + Sync { - fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result>; - fn get_outlier_pdu(&self, event_id: &EventId) -> Result>; - fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()>; +pub struct Data { + eventid_outlierpdu: Arc, } -impl Data for KeyValueDatabase { - fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result> { +impl Data { + pub(super) fn new(db: &Arc) -> Self { + Self { + eventid_outlierpdu: db.eventid_outlierpdu.clone(), + } + } + + pub(super) fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result> { self.eventid_outlierpdu .get(event_id.as_bytes())? .map_or(Ok(None), |pdu| { @@ -17,7 +24,7 @@ impl Data for KeyValueDatabase { }) } - fn get_outlier_pdu(&self, event_id: &EventId) -> Result> { + pub(super) fn get_outlier_pdu(&self, event_id: &EventId) -> Result> { self.eventid_outlierpdu .get(event_id.as_bytes())? .map_or(Ok(None), |pdu| { @@ -25,7 +32,7 @@ impl Data for KeyValueDatabase { }) } - fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> { + pub(super) fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> { self.eventid_outlierpdu.insert( event_id.as_bytes(), &serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"), diff --git a/src/service/rooms/outlier/mod.rs b/src/service/rooms/outlier/mod.rs index 3e8b4ed9..ee1f9466 100644 --- a/src/service/rooms/outlier/mod.rs +++ b/src/service/rooms/outlier/mod.rs @@ -1,3 +1,6 @@ +use conduit::Server; +use database::KeyValueDatabase; + mod data; use std::sync::Arc; @@ -8,10 +11,16 @@ use ruma::{CanonicalJsonObject, EventId}; use crate::{PduEvent, Result}; pub struct Service { - pub db: Arc, + pub db: Data, } impl Service { + pub fn build(_server: &Arc, db: &Arc) -> Result { + Ok(Self { + db: Data::new(db), + }) + } + /// Returns the pdu from the outlier tree. pub fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result> { self.db.get_outlier_pdu_json(event_id) diff --git a/src/service/rooms/pdu_metadata/data.rs b/src/service/rooms/pdu_metadata/data.rs index e0309172..e652c9fe 100644 --- a/src/service/rooms/pdu_metadata/data.rs +++ b/src/service/rooms/pdu_metadata/data.rs @@ -1,30 +1,33 @@ use std::{mem::size_of, sync::Arc}; +use database::KvTree; use ruma::{EventId, RoomId, UserId}; use crate::{services, utils, Error, KeyValueDatabase, PduCount, PduEvent, Result}; -pub trait Data: Send + Sync { - fn add_relation(&self, from: u64, to: u64) -> Result<()>; - #[allow(clippy::type_complexity)] - fn relations_until<'a>( - &'a self, user_id: &'a UserId, room_id: u64, target: u64, until: PduCount, - ) -> Result> + 'a>>; - fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) -> Result<()>; - fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result; - fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()>; - fn is_event_soft_failed(&self, event_id: &EventId) -> Result; +pub(super) struct Data { + tofrom_relation: Arc, + referencedevents: Arc, + softfailedeventids: Arc, } -impl Data for KeyValueDatabase { - fn add_relation(&self, from: u64, to: u64) -> Result<()> { +impl Data { + pub(super) fn new(db: &Arc) -> Self { + Self { + tofrom_relation: db.tofrom_relation.clone(), + referencedevents: db.referencedevents.clone(), + softfailedeventids: db.softfailedeventids.clone(), + } + } + + pub(super) fn add_relation(&self, from: u64, to: u64) -> Result<()> { let mut key = to.to_be_bytes().to_vec(); key.extend_from_slice(&from.to_be_bytes()); self.tofrom_relation.insert(&key, &[])?; Ok(()) } - fn relations_until<'a>( + pub(super) fn relations_until<'a>( &'a self, user_id: &'a UserId, shortroomid: u64, target: u64, until: PduCount, ) -> Result> + 'a>> { let prefix = target.to_be_bytes().to_vec(); @@ -63,7 +66,7 @@ impl Data for KeyValueDatabase { )) } - fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) -> Result<()> { + pub(super) fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) -> Result<()> { for prev in event_ids { let mut key = room_id.as_bytes().to_vec(); key.extend_from_slice(prev.as_bytes()); @@ -73,17 +76,17 @@ impl Data for KeyValueDatabase { Ok(()) } - fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result { + pub(super) fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result { let mut key = room_id.as_bytes().to_vec(); key.extend_from_slice(event_id.as_bytes()); Ok(self.referencedevents.get(&key)?.is_some()) } - fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> { + pub(super) fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> { self.softfailedeventids.insert(event_id.as_bytes(), &[]) } - fn is_event_soft_failed(&self, event_id: &EventId) -> Result { + pub(super) fn is_event_soft_failed(&self, event_id: &EventId) -> Result { self.softfailedeventids .get(event_id.as_bytes()) .map(|o| o.is_some()) diff --git a/src/service/rooms/pdu_metadata/mod.rs b/src/service/rooms/pdu_metadata/mod.rs index c4f45848..488cf03e 100644 --- a/src/service/rooms/pdu_metadata/mod.rs +++ b/src/service/rooms/pdu_metadata/mod.rs @@ -1,3 +1,6 @@ +use conduit::Server; +use database::KeyValueDatabase; + mod data; use std::sync::Arc; @@ -13,7 +16,7 @@ use serde::Deserialize; use crate::{services, PduCount, PduEvent, Result}; pub struct Service { - pub db: Arc, + db: Data, } #[derive(Clone, Debug, Deserialize)] @@ -27,6 +30,12 @@ struct ExtractRelatesToEventId { } impl Service { + pub fn build(_server: &Arc, db: &Arc) -> Result { + Ok(Self { + db: Data::new(db), + }) + } + #[tracing::instrument(skip(self, from, to))] pub fn add_relation(&self, from: PduCount, to: PduCount) -> Result<()> { match (from, to) { diff --git a/src/service/rooms/read_receipt/data.rs b/src/service/rooms/read_receipt/data.rs index 702efd98..cf9d0ee3 100644 --- a/src/service/rooms/read_receipt/data.rs +++ b/src/service/rooms/read_receipt/data.rs @@ -1,5 +1,6 @@ -use std::mem::size_of; +use std::{mem::size_of, sync::Arc}; +use database::KvTree; use ruma::{ events::{receipt::ReceiptEvent, AnySyncEphemeralRoomEvent}, serde::Raw, @@ -11,32 +12,22 @@ use crate::{services, utils, Error, KeyValueDatabase, Result}; type AnySyncEphemeralRoomEventIter<'a> = Box)>> + 'a>; -pub trait Data: Send + Sync { - /// Replaces the previous read receipt. - fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: ReceiptEvent) -> Result<()>; - - /// Returns an iterator over the most recent read_receipts in a room that - /// happened after the event with id `since`. - fn readreceipts_since(&self, room_id: &RoomId, since: u64) -> AnySyncEphemeralRoomEventIter<'_>; - - /// Sets a private read marker at `count`. - fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()>; - - /// Returns the private read marker. - /// - /// TODO: use this? - #[allow(dead_code)] - fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result>; - - /// Returns the count of the last typing update in this room. - /// - /// TODO: use this? - #[allow(dead_code)] - fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result; +pub(super) struct Data { + roomuserid_privateread: Arc, + roomuserid_lastprivatereadupdate: Arc, + readreceiptid_readreceipt: Arc, } -impl Data for KeyValueDatabase { - fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: ReceiptEvent) -> Result<()> { +impl Data { + pub(super) fn new(db: &Arc) -> Self { + Self { + roomuserid_privateread: db.roomuserid_privateread.clone(), + roomuserid_lastprivatereadupdate: db.roomuserid_lastprivatereadupdate.clone(), + readreceiptid_readreceipt: db.readreceiptid_readreceipt.clone(), + } + } + + pub(super) fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: ReceiptEvent) -> Result<()> { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xFF); @@ -71,9 +62,7 @@ impl Data for KeyValueDatabase { Ok(()) } - fn readreceipts_since<'a>( - &'a self, room_id: &RoomId, since: u64, - ) -> Box)>> + 'a> { + pub(super) fn readreceipts_since<'a>(&'a self, room_id: &RoomId, since: u64) -> AnySyncEphemeralRoomEventIter<'a> { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xFF); let prefix2 = prefix.clone(); @@ -107,7 +96,7 @@ impl Data for KeyValueDatabase { ) } - fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { + pub(super) fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { let mut key = room_id.as_bytes().to_vec(); key.push(0xFF); key.extend_from_slice(user_id.as_bytes()); @@ -119,7 +108,7 @@ impl Data for KeyValueDatabase { .insert(&key, &services().globals.next_count()?.to_be_bytes()) } - fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result> { + pub(super) fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result> { let mut key = room_id.as_bytes().to_vec(); key.push(0xFF); key.extend_from_slice(user_id.as_bytes()); @@ -133,7 +122,7 @@ impl Data for KeyValueDatabase { }) } - fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result { + pub(super) fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result { let mut key = room_id.as_bytes().to_vec(); key.push(0xFF); key.extend_from_slice(user_id.as_bytes()); diff --git a/src/service/rooms/read_receipt/mod.rs b/src/service/rooms/read_receipt/mod.rs index 9afc1fd2..56401fe4 100644 --- a/src/service/rooms/read_receipt/mod.rs +++ b/src/service/rooms/read_receipt/mod.rs @@ -1,3 +1,6 @@ +use conduit::Server; +use database::KeyValueDatabase; + mod data; use std::sync::Arc; @@ -8,10 +11,16 @@ use ruma::{events::receipt::ReceiptEvent, serde::Raw, OwnedUserId, RoomId, UserI use crate::{services, Result}; pub struct Service { - pub db: Arc, + db: Data, } impl Service { + pub fn build(_server: &Arc, db: &Arc) -> Result { + Ok(Self { + db: Data::new(db), + }) + } + /// Replaces the previous read receipt. pub fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: ReceiptEvent) -> Result<()> { self.db.readreceipt_update(user_id, room_id, event)?; diff --git a/src/service/rooms/search/data.rs b/src/service/rooms/search/data.rs index 7fa49989..ff42cc9f 100644 --- a/src/service/rooms/search/data.rs +++ b/src/service/rooms/search/data.rs @@ -1,30 +1,24 @@ +use std::sync::Arc; + +use database::KvTree; use ruma::RoomId; use crate::{services, utils, KeyValueDatabase, Result}; type SearchPdusResult<'a> = Result> + 'a>, Vec)>>; -pub trait Data: Send + Sync { - fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()>; - - fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()>; - - fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a>; +pub struct Data { + tokenids: Arc, } -/// Splits a string into tokens used as keys in the search inverted index -/// -/// This may be used to tokenize both message bodies (for indexing) or search -/// queries (for querying). -fn tokenize(body: &str) -> impl Iterator + '_ { - body.split_terminator(|c: char| !c.is_alphanumeric()) - .filter(|s| !s.is_empty()) - .filter(|word| word.len() <= 50) - .map(str::to_lowercase) -} +impl Data { + pub(super) fn new(db: &Arc) -> Self { + Self { + tokenids: db.tokenids.clone(), + } + } -impl Data for KeyValueDatabase { - fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { + pub(super) fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { let mut batch = tokenize(message_body).map(|word| { let mut key = shortroomid.to_be_bytes().to_vec(); key.extend_from_slice(word.as_bytes()); @@ -36,7 +30,7 @@ impl Data for KeyValueDatabase { self.tokenids.insert_batch(&mut batch) } - fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { + pub(super) fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { let batch = tokenize(message_body).map(|word| { let mut key = shortroomid.to_be_bytes().to_vec(); key.extend_from_slice(word.as_bytes()); @@ -52,7 +46,7 @@ impl Data for KeyValueDatabase { Ok(()) } - fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> { + pub(super) fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> { let prefix = services() .rooms .short @@ -88,3 +82,14 @@ impl Data for KeyValueDatabase { Ok(Some((Box::new(common_elements), words))) } } + +/// Splits a string into tokens used as keys in the search inverted index +/// +/// This may be used to tokenize both message bodies (for indexing) or search +/// queries (for querying). +fn tokenize(body: &str) -> impl Iterator + '_ { + body.split_terminator(|c: char| !c.is_alphanumeric()) + .filter(|s| !s.is_empty()) + .filter(|word| word.len() <= 50) + .map(str::to_lowercase) +} diff --git a/src/service/rooms/search/mod.rs b/src/service/rooms/search/mod.rs index 9e7f14c7..9762c8fc 100644 --- a/src/service/rooms/search/mod.rs +++ b/src/service/rooms/search/mod.rs @@ -1,3 +1,6 @@ +use conduit::Server; +use database::KeyValueDatabase; + mod data; use std::sync::Arc; @@ -8,10 +11,16 @@ use ruma::RoomId; use crate::Result; pub struct Service { - pub db: Arc, + pub db: Data, } impl Service { + pub fn build(_server: &Arc, db: &Arc) -> Result { + Ok(Self { + db: Data::new(db), + }) + } + #[tracing::instrument(skip(self))] pub fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { self.db.index_pdu(shortroomid, pdu_id, message_body) diff --git a/src/service/rooms/short/data.rs b/src/service/rooms/short/data.rs index 11ac39b5..13088758 100644 --- a/src/service/rooms/short/data.rs +++ b/src/service/rooms/short/data.rs @@ -1,33 +1,33 @@ use std::sync::Arc; +use conduit::warn; +use database::{KeyValueDatabase, KvTree}; use ruma::{events::StateEventType, EventId, RoomId}; -use tracing::warn; -use crate::{services, utils, Error, KeyValueDatabase, Result}; +use crate::{services, utils, Error, Result}; -pub trait Data: Send + Sync { - fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result; - - fn multi_get_or_create_shorteventid(&self, event_id: &[&EventId]) -> Result>; - - fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result>; - - fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result; - - fn get_eventid_from_short(&self, shorteventid: u64) -> Result>; - - fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)>; - - /// Returns (shortstatehash, already_existed) - fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)>; - - fn get_shortroomid(&self, room_id: &RoomId) -> Result>; - - fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result; +pub(super) struct Data { + eventid_shorteventid: Arc, + shorteventid_eventid: Arc, + statekey_shortstatekey: Arc, + shortstatekey_statekey: Arc, + roomid_shortroomid: Arc, + statehash_shortstatehash: Arc, } -impl Data for KeyValueDatabase { - fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result { +impl Data { + pub(super) fn new(db: &Arc) -> Self { + Self { + eventid_shorteventid: db.eventid_shorteventid.clone(), + shorteventid_eventid: db.shorteventid_eventid.clone(), + statekey_shortstatekey: db.statekey_shortstatekey.clone(), + shortstatekey_statekey: db.shortstatekey_statekey.clone(), + roomid_shortroomid: db.roomid_shortroomid.clone(), + statehash_shortstatehash: db.statehash_shortstatehash.clone(), + } + } + + pub(super) fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result { let short = if let Some(shorteventid) = self.eventid_shorteventid.get(event_id.as_bytes())? { utils::u64_from_bytes(&shorteventid).map_err(|_| Error::bad_database("Invalid shorteventid in db."))? } else { @@ -42,7 +42,7 @@ impl Data for KeyValueDatabase { Ok(short) } - fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Result> { + pub(super) fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Result> { let mut ret: Vec = Vec::with_capacity(event_ids.len()); let keys = event_ids .iter() @@ -74,7 +74,7 @@ impl Data for KeyValueDatabase { Ok(ret) } - fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result> { + pub(super) fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result> { let mut statekey_vec = event_type.to_string().as_bytes().to_vec(); statekey_vec.push(0xFF); statekey_vec.extend_from_slice(state_key.as_bytes()); @@ -90,7 +90,7 @@ impl Data for KeyValueDatabase { Ok(short) } - fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result { + pub(super) fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result { let mut statekey_vec = event_type.to_string().as_bytes().to_vec(); statekey_vec.push(0xFF); statekey_vec.extend_from_slice(state_key.as_bytes()); @@ -109,7 +109,7 @@ impl Data for KeyValueDatabase { Ok(short) } - fn get_eventid_from_short(&self, shorteventid: u64) -> Result> { + pub(super) fn get_eventid_from_short(&self, shorteventid: u64) -> Result> { let bytes = self .shorteventid_eventid .get(&shorteventid.to_be_bytes())? @@ -124,7 +124,7 @@ impl Data for KeyValueDatabase { Ok(event_id) } - fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { + pub(super) fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { let bytes = self .shortstatekey_statekey .get(&shortstatekey.to_be_bytes())? @@ -150,7 +150,7 @@ impl Data for KeyValueDatabase { } /// Returns (shortstatehash, already_existed) - fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> { + pub(super) fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> { Ok(if let Some(shortstatehash) = self.statehash_shortstatehash.get(state_hash)? { ( utils::u64_from_bytes(&shortstatehash) @@ -165,14 +165,14 @@ impl Data for KeyValueDatabase { }) } - fn get_shortroomid(&self, room_id: &RoomId) -> Result> { + pub(super) fn get_shortroomid(&self, room_id: &RoomId) -> Result> { self.roomid_shortroomid .get(room_id.as_bytes())? .map(|bytes| utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid shortroomid in db."))) .transpose() } - fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result { + pub(super) fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result { Ok(if let Some(short) = self.roomid_shortroomid.get(room_id.as_bytes())? { utils::u64_from_bytes(&short).map_err(|_| Error::bad_database("Invalid shortroomid in db."))? } else { diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index 2e994c3c..0cd676a9 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -1,3 +1,6 @@ +use conduit::Server; +use database::KeyValueDatabase; + mod data; use std::sync::Arc; @@ -7,10 +10,16 @@ use ruma::{events::StateEventType, EventId, RoomId}; use crate::Result; pub struct Service { - pub db: Arc, + db: Data, } impl Service { + pub fn build(_server: &Arc, db: &Arc) -> Result { + Ok(Self { + db: Data::new(db), + }) + } + pub fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result { self.db.get_or_create_shorteventid(event_id) } diff --git a/src/service/rooms/spaces/mod.rs b/src/service/rooms/spaces/mod.rs index ffe1935d..e11bade0 100644 --- a/src/service/rooms/spaces/mod.rs +++ b/src/service/rooms/spaces/mod.rs @@ -1,8 +1,11 @@ use std::{ fmt::{Display, Formatter}, str::FromStr, + sync::Arc, }; +use conduit::Server; +use database::KeyValueDatabase; use lru_cache::LruCache; use ruma::{ api::{ @@ -330,6 +333,16 @@ impl From for SpaceHierarchyRoomsChunk { } impl Service { + pub fn build(server: &Arc, _db: &Arc) -> Result { + let config = &server.config; + Ok(Self { + roomid_spacehierarchy_cache: Mutex::new(LruCache::new( + (f64::from(config.roomid_spacehierarchy_cache_capacity) * config.conduit_cache_capacity_modifier) + as usize, + )), + }) + } + ///Gets the response for the space hierarchy over federation request /// ///Panics if the room does not exist, so a check if the room exists should diff --git a/src/service/rooms/state/data.rs b/src/service/rooms/state/data.rs index f8c7f6a6..3a0f93b5 100644 --- a/src/service/rooms/state/data.rs +++ b/src/service/rooms/state/data.rs @@ -1,39 +1,27 @@ use std::{collections::HashSet, sync::Arc}; use conduit::utils::mutex_map; +use database::KvTree; use ruma::{EventId, OwnedEventId, RoomId}; use crate::{utils, Error, KeyValueDatabase, Result}; -pub trait Data: Send + Sync { - /// Returns the last state hash key added to the db for the given room. - fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result>; - - /// Set the state hash to a new version, but does not update `state_cache`. - fn set_room_state( - &self, - room_id: &RoomId, - new_shortstatehash: u64, - _mutex_lock: &mutex_map::Guard<()>, // Take mutex guard to make sure users get the room state mutex - ) -> Result<()>; - - /// Associates a state with an event. - fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()>; - - /// Returns all events we would send as the `prev_events` of the next event. - fn get_forward_extremities(&self, room_id: &RoomId) -> Result>>; - - /// Replace the forward extremities of the room. - fn set_forward_extremities( - &self, - room_id: &RoomId, - event_ids: Vec, - _mutex_lock: &mutex_map::Guard<()>, // Take mutex guard to make sure users get the room state mutex - ) -> Result<()>; +pub struct Data { + shorteventid_shortstatehash: Arc, + roomid_pduleaves: Arc, + roomid_shortstatehash: Arc, } -impl Data for KeyValueDatabase { - fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result> { +impl Data { + pub(super) fn new(db: &Arc) -> Self { + Self { + shorteventid_shortstatehash: db.shorteventid_shortstatehash.clone(), + roomid_pduleaves: db.roomid_pduleaves.clone(), + roomid_shortstatehash: db.roomid_shortstatehash.clone(), + } + } + + pub(super) fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result> { self.roomid_shortstatehash .get(room_id.as_bytes())? .map_or(Ok(None), |bytes| { @@ -43,7 +31,7 @@ impl Data for KeyValueDatabase { }) } - fn set_room_state( + pub(super) fn set_room_state( &self, room_id: &RoomId, new_shortstatehash: u64, @@ -54,13 +42,13 @@ impl Data for KeyValueDatabase { Ok(()) } - fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()> { + pub(super) fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()> { self.shorteventid_shortstatehash .insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?; Ok(()) } - fn get_forward_extremities(&self, room_id: &RoomId) -> Result>> { + pub(super) fn get_forward_extremities(&self, room_id: &RoomId) -> Result>> { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xFF); @@ -76,7 +64,7 @@ impl Data for KeyValueDatabase { .collect() } - fn set_forward_extremities( + pub(super) fn set_forward_extremities( &self, room_id: &RoomId, event_ids: Vec, diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index 199ee87f..e0d75c80 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -1,3 +1,6 @@ +use conduit::Server; +use database::KeyValueDatabase; + mod data; use std::{ collections::{HashMap, HashSet}, @@ -22,10 +25,16 @@ use super::state_compressor::CompressedStateEvent; use crate::{services, utils::calculate_hash, Error, PduEvent, Result}; pub struct Service { - pub db: Arc, + pub db: Data, } impl Service { + pub fn build(_server: &Arc, db: &Arc) -> Result { + Ok(Self { + db: Data::new(db), + }) + } + /// Set the room to the given statehash and update caches. pub async fn force_state( &self, diff --git a/src/service/rooms/state_accessor/data.rs b/src/service/rooms/state_accessor/data.rs index a7b39e8c..4b6dd131 100644 --- a/src/service/rooms/state_accessor/data.rs +++ b/src/service/rooms/state_accessor/data.rs @@ -1,56 +1,25 @@ use std::{collections::HashMap, sync::Arc}; -use async_trait::async_trait; +use database::KvTree; use ruma::{events::StateEventType, EventId, RoomId}; use crate::{services, utils, Error, KeyValueDatabase, PduEvent, Result}; -#[async_trait] -pub trait Data: Send + Sync { - /// Builds a StateMap by iterating over all keys that start - /// with state_hash, this gives the full state for the given state_hash. - #[allow(unused_qualifications)] // async traits - async fn state_full_ids(&self, shortstatehash: u64) -> Result>>; - - #[allow(unused_qualifications)] // async traits - async fn state_full(&self, shortstatehash: u64) -> Result>>; - - /// Returns a single PDU from `room_id` with key (`event_type`, - /// `state_key`). - fn state_get_id( - &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, - ) -> Result>>; - - /// Returns a single PDU from `room_id` with key (`event_type`, - /// `state_key`). - fn state_get( - &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, - ) -> Result>>; - - /// Returns the state hash for this pdu. - fn pdu_shortstatehash(&self, event_id: &EventId) -> Result>; - - /// Returns the full room state. - #[allow(unused_qualifications)] // async traits - async fn room_state_full(&self, room_id: &RoomId) -> Result>>; - - /// Returns a single PDU from `room_id` with key (`event_type`, - /// `state_key`). - fn room_state_get_id( - &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, - ) -> Result>>; - - /// Returns a single PDU from `room_id` with key (`event_type`, - /// `state_key`). - fn room_state_get( - &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, - ) -> Result>>; +pub struct Data { + eventid_shorteventid: Arc, + shorteventid_shortstatehash: Arc, } -#[async_trait] -impl Data for KeyValueDatabase { +impl Data { + pub(super) fn new(db: &Arc) -> Self { + Self { + eventid_shorteventid: db.eventid_shorteventid.clone(), + shorteventid_shortstatehash: db.shorteventid_shortstatehash.clone(), + } + } + #[allow(unused_qualifications)] // async traits - async fn state_full_ids(&self, shortstatehash: u64) -> Result>> { + pub(super) async fn state_full_ids(&self, shortstatehash: u64) -> Result>> { let full_state = services() .rooms .state_compressor @@ -76,7 +45,9 @@ impl Data for KeyValueDatabase { } #[allow(unused_qualifications)] // async traits - async fn state_full(&self, shortstatehash: u64) -> Result>> { + pub(super) async fn state_full( + &self, shortstatehash: u64, + ) -> Result>> { let full_state = services() .rooms .state_compressor @@ -116,7 +87,8 @@ impl Data for KeyValueDatabase { /// Returns a single PDU from `room_id` with key (`event_type`, /// `state_key`). - fn state_get_id( + #[allow(clippy::unused_self)] + pub(super) fn state_get_id( &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, ) -> Result>> { let Some(shortstatekey) = services() @@ -148,7 +120,7 @@ impl Data for KeyValueDatabase { /// Returns a single PDU from `room_id` with key (`event_type`, /// `state_key`). - fn state_get( + pub(super) fn state_get( &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, ) -> Result>> { self.state_get_id(shortstatehash, event_type, state_key)? @@ -156,7 +128,7 @@ impl Data for KeyValueDatabase { } /// Returns the state hash for this pdu. - fn pdu_shortstatehash(&self, event_id: &EventId) -> Result> { + pub(super) fn pdu_shortstatehash(&self, event_id: &EventId) -> Result> { self.eventid_shorteventid .get(event_id.as_bytes())? .map_or(Ok(None), |shorteventid| { @@ -173,7 +145,9 @@ impl Data for KeyValueDatabase { /// Returns the full room state. #[allow(unused_qualifications)] // async traits - async fn room_state_full(&self, room_id: &RoomId) -> Result>> { + pub(super) async fn room_state_full( + &self, room_id: &RoomId, + ) -> Result>> { if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? { self.state_full(current_shortstatehash).await } else { @@ -183,7 +157,7 @@ impl Data for KeyValueDatabase { /// Returns a single PDU from `room_id` with key (`event_type`, /// `state_key`). - fn room_state_get_id( + pub(super) fn room_state_get_id( &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, ) -> Result>> { if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? { @@ -195,7 +169,7 @@ impl Data for KeyValueDatabase { /// Returns a single PDU from `room_id` with key (`event_type`, /// `state_key`). - fn room_state_get( + pub(super) fn room_state_get( &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, ) -> Result>> { if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? { diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index 06293ea3..e1466d75 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -1,3 +1,8 @@ +use std::sync::Mutex as StdMutex; + +use conduit::Server; +use database::KeyValueDatabase; + mod data; use std::{ collections::HashMap, @@ -29,12 +34,25 @@ use tracing::{error, warn}; use crate::{service::pdu::PduBuilder, services, Error, PduEvent, Result}; pub struct Service { - pub db: Arc, + pub db: Data, pub server_visibility_cache: Mutex>, pub user_visibility_cache: Mutex>, } impl Service { + pub fn build(server: &Arc, db: &Arc) -> Result { + let config = &server.config; + Ok(Self { + db: Data::new(db), + server_visibility_cache: StdMutex::new(LruCache::new( + (f64::from(config.server_visibility_cache_capacity) * config.conduit_cache_capacity_modifier) as usize, + )), + user_visibility_cache: StdMutex::new(LruCache::new( + (f64::from(config.user_visibility_cache_capacity) * config.conduit_cache_capacity_modifier) as usize, + )), + }) + } + /// Builds a StateMap by iterating over all keys that start /// with state_hash, this gives the full state for the given state_hash. #[tracing::instrument(skip(self))] diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs index d380c1ab..03c04abb 100644 --- a/src/service/rooms/state_cache/data.rs +++ b/src/service/rooms/state_cache/data.rs @@ -17,101 +17,53 @@ use crate::{ type StrippedStateEventIter<'a> = Box>)>> + 'a>; type AnySyncStateEventIter<'a> = Box>)>> + 'a>; -pub trait Data: Send + Sync { - fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; - fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; - fn mark_as_invited( - &self, user_id: &UserId, room_id: &RoomId, last_state: Option>>, - invite_via: Option>, - ) -> Result<()>; - fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; +use std::sync::Arc; - fn update_joined_count(&self, room_id: &RoomId) -> Result<()>; +use database::KvTree; - fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result; - - /// Makes a user forget a room. - fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()>; - - /// Returns an iterator of all servers participating in this room. - fn room_servers<'a>(&'a self, room_id: &RoomId) -> Box> + 'a>; - - fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result; - - /// Returns an iterator of all rooms a server participates in (as far as we - /// know). - fn server_rooms<'a>(&'a self, server: &ServerName) -> Box> + 'a>; - - /// Returns an iterator of all joined members of a room. - fn room_members<'a>(&'a self, room_id: &RoomId) -> Box> + 'a>; - - /// Returns an iterator of all our local users - /// in the room, even if they're deactivated/guests - fn local_users_in_room<'a>(&'a self, room_id: &RoomId) -> Box + 'a>; - - /// Returns an iterator of all our local joined users in a room who are - /// active (not deactivated, not guest) - fn active_local_users_in_room<'a>(&'a self, room_id: &RoomId) -> Box + 'a>; - - fn room_joined_count(&self, room_id: &RoomId) -> Result>; - - fn room_invited_count(&self, room_id: &RoomId) -> Result>; - - /// Returns an iterator over all User IDs who ever joined a room. - /// - /// TODO: use this? - #[allow(dead_code)] - fn room_useroncejoined<'a>(&'a self, room_id: &RoomId) -> Box> + 'a>; - - /// Returns an iterator over all invited members of a room. - fn room_members_invited<'a>(&'a self, room_id: &RoomId) -> Box> + 'a>; - - fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result>; - - fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result>; - - /// Returns an iterator over all rooms this user joined. - fn rooms_joined(&self, user_id: &UserId) -> Box> + '_>; - - /// Returns an iterator over all rooms a user was invited to. - fn rooms_invited<'a>(&'a self, user_id: &UserId) -> StrippedStateEventIter<'a>; - - fn invite_state(&self, user_id: &UserId, room_id: &RoomId) -> Result>>>; - - fn left_state(&self, user_id: &UserId, room_id: &RoomId) -> Result>>>; - - /// Returns an iterator over all rooms a user left. - fn rooms_left<'a>(&'a self, user_id: &UserId) -> AnySyncStateEventIter<'a>; - - fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result; - - fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result; - - fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result; - - fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result; - - /// Gets the servers to either accept or decline invites via for a given - /// room. - fn servers_invite_via<'a>(&'a self, room_id: &RoomId) -> Box> + 'a>; - - /// Add the given servers the list to accept or decline invites via for a - /// given room. - /// - /// TODO: use this? - #[allow(dead_code)] - fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) -> Result<()>; +pub struct Data { + userroomid_joined: Arc, + roomuserid_joined: Arc, + userroomid_invitestate: Arc, + roomuserid_invitecount: Arc, + userroomid_leftstate: Arc, + roomuserid_leftcount: Arc, + roomid_inviteviaservers: Arc, + roomuseroncejoinedids: Arc, + roomid_joinedcount: Arc, + roomid_invitedcount: Arc, + roomserverids: Arc, + serverroomids: Arc, + db: Arc, } -impl Data for KeyValueDatabase { - fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { +impl Data { + pub(super) fn new(db: &Arc) -> Self { + Self { + userroomid_joined: db.userroomid_joined.clone(), + roomuserid_joined: db.roomuserid_joined.clone(), + userroomid_invitestate: db.userroomid_invitestate.clone(), + roomuserid_invitecount: db.roomuserid_invitecount.clone(), + userroomid_leftstate: db.userroomid_leftstate.clone(), + roomuserid_leftcount: db.roomuserid_leftcount.clone(), + roomid_inviteviaservers: db.roomid_inviteviaservers.clone(), + roomuseroncejoinedids: db.roomuseroncejoinedids.clone(), + roomid_joinedcount: db.roomid_joinedcount.clone(), + roomid_invitedcount: db.roomid_invitedcount.clone(), + roomserverids: db.roomserverids.clone(), + serverroomids: db.serverroomids.clone(), + db: db.clone(), + } + } + + pub(super) fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); self.roomuseroncejoinedids.insert(&userroom_id, &[]) } - fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + pub(super) fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { let roomid = room_id.as_bytes().to_vec(); let mut roomuser_id = roomid.clone(); @@ -134,7 +86,7 @@ impl Data for KeyValueDatabase { Ok(()) } - fn mark_as_invited( + pub(super) fn mark_as_invited( &self, user_id: &UserId, room_id: &RoomId, last_state: Option>>, invite_via: Option>, ) -> Result<()> { @@ -179,7 +131,7 @@ impl Data for KeyValueDatabase { Ok(()) } - fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + pub(super) fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { let roomid = room_id.as_bytes().to_vec(); let mut roomuser_id = roomid.clone(); @@ -206,7 +158,7 @@ impl Data for KeyValueDatabase { Ok(()) } - fn update_joined_count(&self, room_id: &RoomId) -> Result<()> { + pub(super) fn update_joined_count(&self, room_id: &RoomId) -> Result<()> { let mut joinedcount = 0_u64; let mut invitedcount = 0_u64; let mut joined_servers = HashSet::new(); @@ -256,7 +208,8 @@ impl Data for KeyValueDatabase { self.serverroomids.insert(&serverroom_id, &[])?; } - self.appservice_in_room_cache + self.db + .appservice_in_room_cache .write() .unwrap() .remove(room_id); @@ -265,8 +218,9 @@ impl Data for KeyValueDatabase { } #[tracing::instrument(skip(self, room_id, appservice))] - fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result { + pub(super) fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result { let maybe = self + .db .appservice_in_room_cache .read() .unwrap() @@ -288,7 +242,8 @@ impl Data for KeyValueDatabase { .room_members(room_id) .any(|userid| userid.map_or(false, |userid| appservice.users.is_match(userid.as_str()))); - self.appservice_in_room_cache + self.db + .appservice_in_room_cache .write() .unwrap() .entry(room_id.to_owned()) @@ -301,7 +256,7 @@ impl Data for KeyValueDatabase { /// Makes a user forget a room. #[tracing::instrument(skip(self))] - fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { + pub(super) fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); @@ -318,7 +273,9 @@ impl Data for KeyValueDatabase { /// Returns an iterator of all servers participating in this room. #[tracing::instrument(skip(self))] - fn room_servers<'a>(&'a self, room_id: &RoomId) -> Box> + 'a> { + pub(super) fn room_servers<'a>( + &'a self, room_id: &RoomId, + ) -> Box> + 'a> { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xFF); @@ -336,7 +293,7 @@ impl Data for KeyValueDatabase { } #[tracing::instrument(skip(self))] - fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result { + pub(super) fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result { let mut key = server.as_bytes().to_vec(); key.push(0xFF); key.extend_from_slice(room_id.as_bytes()); @@ -347,7 +304,9 @@ impl Data for KeyValueDatabase { /// Returns an iterator of all rooms a server participates in (as far as we /// know). #[tracing::instrument(skip(self))] - fn server_rooms<'a>(&'a self, server: &ServerName) -> Box> + 'a> { + pub(super) fn server_rooms<'a>( + &'a self, server: &ServerName, + ) -> Box> + 'a> { let mut prefix = server.as_bytes().to_vec(); prefix.push(0xFF); @@ -366,7 +325,7 @@ impl Data for KeyValueDatabase { /// Returns an iterator of all joined members of a room. #[tracing::instrument(skip(self))] - fn room_members<'a>(&'a self, room_id: &RoomId) -> Box> + 'a> { + pub(super) fn room_members<'a>(&'a self, room_id: &RoomId) -> Box> + 'a> { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xFF); @@ -385,7 +344,7 @@ impl Data for KeyValueDatabase { /// Returns an iterator of all our local users in the room, even if they're /// deactivated/guests - fn local_users_in_room<'a>(&'a self, room_id: &RoomId) -> Box + 'a> { + pub(super) fn local_users_in_room<'a>(&'a self, room_id: &RoomId) -> Box + 'a> { Box::new( self.room_members(room_id) .filter_map(Result::ok) @@ -396,7 +355,9 @@ impl Data for KeyValueDatabase { /// Returns an iterator of all our local joined users in a room who are /// active (not deactivated, not guest) #[tracing::instrument(skip(self))] - fn active_local_users_in_room<'a>(&'a self, room_id: &RoomId) -> Box + 'a> { + pub(super) fn active_local_users_in_room<'a>( + &'a self, room_id: &RoomId, + ) -> Box + 'a> { Box::new( self.local_users_in_room(room_id) .filter(|user| !services().users.is_deactivated(user).unwrap_or(true)), @@ -405,7 +366,7 @@ impl Data for KeyValueDatabase { /// Returns the number of users which are currently in a room #[tracing::instrument(skip(self))] - fn room_joined_count(&self, room_id: &RoomId) -> Result> { + pub(super) fn room_joined_count(&self, room_id: &RoomId) -> Result> { self.roomid_joinedcount .get(room_id.as_bytes())? .map(|b| utils::u64_from_bytes(&b).map_err(|_| Error::bad_database("Invalid joinedcount in db."))) @@ -414,7 +375,7 @@ impl Data for KeyValueDatabase { /// Returns the number of users which are currently invited to a room #[tracing::instrument(skip(self))] - fn room_invited_count(&self, room_id: &RoomId) -> Result> { + pub(super) fn room_invited_count(&self, room_id: &RoomId) -> Result> { self.roomid_invitedcount .get(room_id.as_bytes())? .map(|b| utils::u64_from_bytes(&b).map_err(|_| Error::bad_database("Invalid joinedcount in db."))) @@ -423,7 +384,9 @@ impl Data for KeyValueDatabase { /// Returns an iterator over all User IDs who ever joined a room. #[tracing::instrument(skip(self))] - fn room_useroncejoined<'a>(&'a self, room_id: &RoomId) -> Box> + 'a> { + pub(super) fn room_useroncejoined<'a>( + &'a self, room_id: &RoomId, + ) -> Box> + 'a> { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xFF); @@ -446,7 +409,9 @@ impl Data for KeyValueDatabase { /// Returns an iterator over all invited members of a room. #[tracing::instrument(skip(self))] - fn room_members_invited<'a>(&'a self, room_id: &RoomId) -> Box> + 'a> { + pub(super) fn room_members_invited<'a>( + &'a self, room_id: &RoomId, + ) -> Box> + 'a> { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xFF); @@ -468,7 +433,7 @@ impl Data for KeyValueDatabase { } #[tracing::instrument(skip(self))] - fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { + pub(super) fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { let mut key = room_id.as_bytes().to_vec(); key.push(0xFF); key.extend_from_slice(user_id.as_bytes()); @@ -483,7 +448,7 @@ impl Data for KeyValueDatabase { } #[tracing::instrument(skip(self))] - fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { + pub(super) fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { let mut key = room_id.as_bytes().to_vec(); key.push(0xFF); key.extend_from_slice(user_id.as_bytes()); @@ -496,7 +461,7 @@ impl Data for KeyValueDatabase { /// Returns an iterator over all rooms this user joined. #[tracing::instrument(skip(self))] - fn rooms_joined(&self, user_id: &UserId) -> Box> + '_> { + pub(super) fn rooms_joined(&self, user_id: &UserId) -> Box> + '_> { Box::new( self.userroomid_joined .scan_prefix(user_id.as_bytes().to_vec()) @@ -516,7 +481,7 @@ impl Data for KeyValueDatabase { /// Returns an iterator over all rooms a user was invited to. #[tracing::instrument(skip(self))] - fn rooms_invited<'a>(&'a self, user_id: &UserId) -> StrippedStateEventIter<'a> { + pub(super) fn rooms_invited<'a>(&'a self, user_id: &UserId) -> StrippedStateEventIter<'a> { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xFF); @@ -543,7 +508,9 @@ impl Data for KeyValueDatabase { } #[tracing::instrument(skip(self))] - fn invite_state(&self, user_id: &UserId, room_id: &RoomId) -> Result>>> { + pub(super) fn invite_state( + &self, user_id: &UserId, room_id: &RoomId, + ) -> Result>>> { let mut key = user_id.as_bytes().to_vec(); key.push(0xFF); key.extend_from_slice(room_id.as_bytes()); @@ -560,7 +527,9 @@ impl Data for KeyValueDatabase { } #[tracing::instrument(skip(self))] - fn left_state(&self, user_id: &UserId, room_id: &RoomId) -> Result>>> { + pub(super) fn left_state( + &self, user_id: &UserId, room_id: &RoomId, + ) -> Result>>> { let mut key = user_id.as_bytes().to_vec(); key.push(0xFF); key.extend_from_slice(room_id.as_bytes()); @@ -578,7 +547,7 @@ impl Data for KeyValueDatabase { /// Returns an iterator over all rooms a user left. #[tracing::instrument(skip(self))] - fn rooms_left<'a>(&'a self, user_id: &UserId) -> AnySyncStateEventIter<'a> { + pub(super) fn rooms_left<'a>(&'a self, user_id: &UserId) -> AnySyncStateEventIter<'a> { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xFF); @@ -605,7 +574,7 @@ impl Data for KeyValueDatabase { } #[tracing::instrument(skip(self))] - fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { + pub(super) fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); @@ -614,7 +583,7 @@ impl Data for KeyValueDatabase { } #[tracing::instrument(skip(self))] - fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { + pub(super) fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); @@ -623,7 +592,7 @@ impl Data for KeyValueDatabase { } #[tracing::instrument(skip(self))] - fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result { + pub(super) fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); @@ -632,7 +601,7 @@ impl Data for KeyValueDatabase { } #[tracing::instrument(skip(self))] - fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result { + pub(super) fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); @@ -641,7 +610,9 @@ impl Data for KeyValueDatabase { } #[tracing::instrument(skip(self))] - fn servers_invite_via<'a>(&'a self, room_id: &RoomId) -> Box> + 'a> { + pub(super) fn servers_invite_via<'a>( + &'a self, room_id: &RoomId, + ) -> Box> + 'a> { let key = room_id.as_bytes().to_vec(); Box::new( @@ -665,7 +636,7 @@ impl Data for KeyValueDatabase { } #[tracing::instrument(skip(self))] - fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) -> Result<()> { + pub(super) fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) -> Result<()> { let mut prev_servers = self .servers_invite_via(room_id) .filter_map(Result::ok) diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index c9d1bbd7..c4eb0dff 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -1,6 +1,8 @@ use std::sync::Arc; +use conduit::Server; use data::Data; +use database::KeyValueDatabase; use itertools::Itertools; use ruma::{ events::{ @@ -24,10 +26,16 @@ use crate::{service::appservice::RegistrationInfo, services, user_is_local, Erro mod data; pub struct Service { - pub db: Arc, + pub db: Data, } impl Service { + pub fn build(_server: &Arc, db: &Arc) -> Result { + Ok(Self { + db: Data::new(db), + }) + } + /// Update current membership data. #[tracing::instrument(skip(self, last_state))] #[allow(clippy::too_many_arguments)] diff --git a/src/service/rooms/state_compressor/data.rs b/src/service/rooms/state_compressor/data.rs index 4612ebc6..2bf93f38 100644 --- a/src/service/rooms/state_compressor/data.rs +++ b/src/service/rooms/state_compressor/data.rs @@ -1,21 +1,28 @@ use std::{collections::HashSet, mem::size_of, sync::Arc}; +use database::KvTree; + use super::CompressedStateEvent; use crate::{utils, Error, KeyValueDatabase, Result}; -pub struct StateDiff { - pub parent: Option, - pub added: Arc>, - pub removed: Arc>, +pub(super) struct StateDiff { + pub(super) parent: Option, + pub(super) added: Arc>, + pub(super) removed: Arc>, } -pub trait Data: Send + Sync { - fn get_statediff(&self, shortstatehash: u64) -> Result; - fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()>; +pub struct Data { + shortstatehash_statediff: Arc, } -impl Data for KeyValueDatabase { - fn get_statediff(&self, shortstatehash: u64) -> Result { +impl Data { + pub(super) fn new(db: &Arc) -> Self { + Self { + shortstatehash_statediff: db.shortstatehash_statediff.clone(), + } + } + + pub(super) fn get_statediff(&self, shortstatehash: u64) -> Result { let value = self .shortstatehash_statediff .get(&shortstatehash.to_be_bytes())? @@ -53,7 +60,7 @@ impl Data for KeyValueDatabase { }) } - fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()> { + pub(super) fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()> { let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec(); for new in diff.added.iter() { value.extend_from_slice(&new[..]); diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index 3f025b47..8dde2890 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -1,3 +1,8 @@ +use std::sync::Mutex as StdMutex; + +use conduit::Server; +use database::KeyValueDatabase; + mod data; use std::{ collections::HashSet, @@ -41,16 +46,25 @@ type ParentStatesVec = Vec<( )>; type HashSetCompressStateEvent = Result<(u64, Arc>, Arc>)>; +pub type CompressedStateEvent = [u8; 2 * size_of::()]; pub struct Service { - pub db: Arc, + pub db: Data, pub stateinfo_cache: StateInfoLruCache, } -pub type CompressedStateEvent = [u8; 2 * size_of::()]; - impl Service { + pub fn build(server: &Arc, db: &Arc) -> Result { + let config = &server.config; + Ok(Self { + db: Data::new(db), + stateinfo_cache: StdMutex::new(LruCache::new( + (f64::from(config.stateinfo_cache_capacity) * config.conduit_cache_capacity_modifier) as usize, + )), + }) + } + /// Returns a stack with info on shortstatehash, full state, added diff and /// removed diff for the selected shortstatehash and each parent layer. #[tracing::instrument(skip(self))] diff --git a/src/service/rooms/threads/data.rs b/src/service/rooms/threads/data.rs index 8ed5c770..a4044294 100644 --- a/src/service/rooms/threads/data.rs +++ b/src/service/rooms/threads/data.rs @@ -1,22 +1,24 @@ -use std::mem::size_of; +use std::{mem::size_of, sync::Arc}; +use database::KvTree; use ruma::{api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId, UserId}; use crate::{services, utils, Error, KeyValueDatabase, PduEvent, Result}; type PduEventIterResult<'a> = Result> + 'a>>; -pub trait Data: Send + Sync { - fn threads_until<'a>( - &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, include: &'a IncludeThreads, - ) -> PduEventIterResult<'a>; - - fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()>; - fn get_participants(&self, root_id: &[u8]) -> Result>>; +pub struct Data { + threadid_userids: Arc, } -impl Data for KeyValueDatabase { - fn threads_until<'a>( +impl Data { + pub(super) fn new(db: &Arc) -> Self { + Self { + threadid_userids: db.threadid_userids.clone(), + } + } + + pub(super) fn threads_until<'a>( &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, _include: &'a IncludeThreads, ) -> PduEventIterResult<'a> { let prefix = services() @@ -50,7 +52,7 @@ impl Data for KeyValueDatabase { )) } - fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()> { + pub(super) fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()> { let users = participants .iter() .map(|user| user.as_bytes()) @@ -62,7 +64,7 @@ impl Data for KeyValueDatabase { Ok(()) } - fn get_participants(&self, root_id: &[u8]) -> Result>> { + pub(super) fn get_participants(&self, root_id: &[u8]) -> Result>> { if let Some(users) = self.threadid_userids.get(root_id)? { Ok(Some( users diff --git a/src/service/rooms/threads/mod.rs b/src/service/rooms/threads/mod.rs index 6c48e842..606cc001 100644 --- a/src/service/rooms/threads/mod.rs +++ b/src/service/rooms/threads/mod.rs @@ -1,3 +1,6 @@ +use conduit::Server; +use database::KeyValueDatabase; + mod data; use std::{collections::BTreeMap, sync::Arc}; @@ -13,10 +16,16 @@ use serde_json::json; use crate::{services, Error, PduEvent, Result}; pub struct Service { - pub db: Arc, + pub db: Data, } impl Service { + pub fn build(_server: &Arc, db: &Arc) -> Result { + Ok(Self { + db: Data::new(db), + }) + } + pub fn threads_until<'a>( &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, include: &'a IncludeThreads, ) -> Result> + 'a> { diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index 9fb1eea4..8264b6f3 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -1,76 +1,36 @@ use std::{collections::hash_map, mem::size_of, sync::Arc}; +use database::KvTree; use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId}; use tracing::error; use super::PduCount; use crate::{services, utils, Error, KeyValueDatabase, PduEvent, Result}; -pub trait Data: Send + Sync { - fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result; - - /// Returns the `count` of this pdu's id. - fn get_pdu_count(&self, event_id: &EventId) -> Result>; - - /// Returns the json of a pdu. - fn get_pdu_json(&self, event_id: &EventId) -> Result>; - - /// Returns the json of a pdu. - fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result>; - - /// Returns the pdu's id. - fn get_pdu_id(&self, event_id: &EventId) -> Result>>; - - /// Returns the pdu. - /// - /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. - fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result>; - - /// Returns the pdu. - /// - /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. - fn get_pdu(&self, event_id: &EventId) -> Result>>; - - /// Returns the pdu. - /// - /// This does __NOT__ check the outliers `Tree`. - fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result>; - - /// Returns the pdu as a `BTreeMap`. - fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result>; - - /// Adds a new pdu to the timeline - fn append_pdu(&self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64) -> Result<()>; - - // Adds a new pdu to the backfilled timeline - fn prepend_backfill_pdu(&self, pdu_id: &[u8], event_id: &EventId, json: &CanonicalJsonObject) -> Result<()>; - - /// Removes a pdu and creates a new one with the same id. - fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, pdu: &PduEvent) -> Result<()>; - - /// Returns an iterator over all events and their tokens in a room that - /// happened before the event with id `until` in reverse-chronological - /// order. - #[allow(clippy::type_complexity)] - fn pdus_until<'a>( - &'a self, user_id: &UserId, room_id: &RoomId, until: PduCount, - ) -> Result> + 'a>>; - - /// Returns an iterator over all events in a room that happened after the - /// event with id `from` in chronological order. - #[allow(clippy::type_complexity)] - fn pdus_after<'a>( - &'a self, user_id: &UserId, room_id: &RoomId, from: PduCount, - ) -> Result> + 'a>>; - - fn increment_notification_counts( - &self, room_id: &RoomId, notifies: Vec, highlights: Vec, - ) -> Result<()>; +pub struct Data { + eventid_pduid: Arc, + pduid_pdu: Arc, + eventid_outlierpdu: Arc, + userroomid_notificationcount: Arc, + userroomid_highlightcount: Arc, + db: Arc, } -impl Data for KeyValueDatabase { - fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { +impl Data { + pub(super) fn new(db: &Arc) -> Self { + Self { + eventid_pduid: db.eventid_pduid.clone(), + pduid_pdu: db.pduid_pdu.clone(), + eventid_outlierpdu: db.eventid_outlierpdu.clone(), + userroomid_notificationcount: db.userroomid_notificationcount.clone(), + userroomid_highlightcount: db.userroomid_highlightcount.clone(), + db: db.clone(), + } + } + + pub(super) fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { match self + .db .lasttimelinecount_cache .lock() .unwrap() @@ -96,7 +56,7 @@ impl Data for KeyValueDatabase { } /// Returns the `count` of this pdu's id. - fn get_pdu_count(&self, event_id: &EventId) -> Result> { + pub(super) fn get_pdu_count(&self, event_id: &EventId) -> Result> { self.eventid_pduid .get(event_id.as_bytes())? .map(|pdu_id| pdu_count(&pdu_id)) @@ -104,7 +64,7 @@ impl Data for KeyValueDatabase { } /// Returns the json of a pdu. - fn get_pdu_json(&self, event_id: &EventId) -> Result> { + pub(super) fn get_pdu_json(&self, event_id: &EventId) -> Result> { self.get_non_outlier_pdu_json(event_id)?.map_or_else( || { self.eventid_outlierpdu @@ -117,7 +77,7 @@ impl Data for KeyValueDatabase { } /// Returns the json of a pdu. - fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result> { + pub(super) fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result> { self.eventid_pduid .get(event_id.as_bytes())? .map(|pduid| { @@ -131,10 +91,12 @@ impl Data for KeyValueDatabase { } /// Returns the pdu's id. - fn get_pdu_id(&self, event_id: &EventId) -> Result>> { self.eventid_pduid.get(event_id.as_bytes()) } + pub(super) fn get_pdu_id(&self, event_id: &EventId) -> Result>> { + self.eventid_pduid.get(event_id.as_bytes()) + } /// Returns the pdu. - fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result> { + pub(super) fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result> { self.eventid_pduid .get(event_id.as_bytes())? .map(|pduid| { @@ -150,7 +112,7 @@ impl Data for KeyValueDatabase { /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. - fn get_pdu(&self, event_id: &EventId) -> Result>> { + pub(super) fn get_pdu(&self, event_id: &EventId) -> Result>> { if let Some(pdu) = self .get_non_outlier_pdu(event_id)? .map_or_else( @@ -173,7 +135,7 @@ impl Data for KeyValueDatabase { /// Returns the pdu. /// /// This does __NOT__ check the outliers `Tree`. - fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result> { + pub(super) fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result> { self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { Ok(Some( serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?, @@ -182,7 +144,7 @@ impl Data for KeyValueDatabase { } /// Returns the pdu as a `BTreeMap`. - fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result> { + pub(super) fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result> { self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { Ok(Some( serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?, @@ -190,13 +152,16 @@ impl Data for KeyValueDatabase { }) } - fn append_pdu(&self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64) -> Result<()> { + pub(super) fn append_pdu( + &self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64, + ) -> Result<()> { self.pduid_pdu.insert( pdu_id, &serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"), )?; - self.lasttimelinecount_cache + self.db + .lasttimelinecount_cache .lock() .unwrap() .insert(pdu.room_id.clone(), PduCount::Normal(count)); @@ -207,7 +172,9 @@ impl Data for KeyValueDatabase { Ok(()) } - fn prepend_backfill_pdu(&self, pdu_id: &[u8], event_id: &EventId, json: &CanonicalJsonObject) -> Result<()> { + pub(super) fn prepend_backfill_pdu( + &self, pdu_id: &[u8], event_id: &EventId, json: &CanonicalJsonObject, + ) -> Result<()> { self.pduid_pdu.insert( pdu_id, &serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"), @@ -220,7 +187,7 @@ impl Data for KeyValueDatabase { } /// Removes a pdu and creates a new one with the same id. - fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, _pdu: &PduEvent) -> Result<()> { + pub(super) fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, _pdu: &PduEvent) -> Result<()> { if self.pduid_pdu.get(pdu_id)?.is_some() { self.pduid_pdu.insert( pdu_id, @@ -236,7 +203,7 @@ impl Data for KeyValueDatabase { /// Returns an iterator over all events and their tokens in a room that /// happened before the event with id `until` in reverse-chronological /// order. - fn pdus_until<'a>( + pub(super) fn pdus_until<'a>( &'a self, user_id: &UserId, room_id: &RoomId, until: PduCount, ) -> Result> + 'a>> { let (prefix, current) = count_to_id(room_id, until, 1, true)?; @@ -260,7 +227,7 @@ impl Data for KeyValueDatabase { )) } - fn pdus_after<'a>( + pub(super) fn pdus_after<'a>( &'a self, user_id: &UserId, room_id: &RoomId, from: PduCount, ) -> Result> + 'a>> { let (prefix, current) = count_to_id(room_id, from, 1, false)?; @@ -284,7 +251,7 @@ impl Data for KeyValueDatabase { )) } - fn increment_notification_counts( + pub(super) fn increment_notification_counts( &self, room_id: &RoomId, notifies: Vec, highlights: Vec, ) -> Result<()> { let mut notifies_batch = Vec::new(); @@ -311,7 +278,7 @@ impl Data for KeyValueDatabase { } /// Returns the `count` of this pdu's id. -fn pdu_count(pdu_id: &[u8]) -> Result { +pub(super) fn pdu_count(pdu_id: &[u8]) -> Result { let last_u64 = utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::()..]) .map_err(|_| Error::bad_database("PDU has invalid count bytes."))?; let second_last_u64 = @@ -324,7 +291,9 @@ fn pdu_count(pdu_id: &[u8]) -> Result { } } -fn count_to_id(room_id: &RoomId, count: PduCount, offset: u64, subtract: bool) -> Result<(Vec, Vec)> { +pub(super) fn count_to_id( + room_id: &RoomId, count: PduCount, offset: u64, subtract: bool, +) -> Result<(Vec, Vec)> { let prefix = services() .rooms .short diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 7712ab90..994b5cd4 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -1,3 +1,6 @@ +use conduit::Server; +use database::KeyValueDatabase; + mod data; use std::{ @@ -74,12 +77,19 @@ struct ExtractBody { } pub struct Service { - pub db: Arc, + pub db: Data, pub lasttimelinecount_cache: Mutex>, } impl Service { + pub fn build(_server: &Arc, db: &Arc) -> Result { + Ok(Self { + db: Data::new(db), + lasttimelinecount_cache: Mutex::new(HashMap::new()), + }) + } + #[tracing::instrument(skip(self))] pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result>> { self.all_pdus(user_id!("@doesntmatter:conduit.rs"), room_id)? diff --git a/src/service/rooms/typing/mod.rs b/src/service/rooms/typing/mod.rs index 0b89dfdd..614d5af2 100644 --- a/src/service/rooms/typing/mod.rs +++ b/src/service/rooms/typing/mod.rs @@ -1,5 +1,7 @@ -use std::collections::BTreeMap; +use std::{collections::BTreeMap, sync::Arc}; +use conduit::Server; +use database::KeyValueDatabase; use ruma::{ api::federation::transactions::edu::{Edu, TypingContent}, events::SyncEphemeralRoomEvent, @@ -23,6 +25,14 @@ pub struct Service { } impl Service { + pub fn build(_server: &Arc, _db: &Arc) -> Result { + Ok(Self { + typing: RwLock::new(BTreeMap::new()), + last_typing_update: RwLock::new(BTreeMap::new()), + typing_update_sender: broadcast::channel(100).0, + }) + } + /// Sets a user as typing until the timeout timestamp is reached or /// roomtyping_remove is called. pub async fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> { diff --git a/src/service/rooms/user/data.rs b/src/service/rooms/user/data.rs index 3e8587da..e9339917 100644 --- a/src/service/rooms/user/data.rs +++ b/src/service/rooms/user/data.rs @@ -1,28 +1,30 @@ +use std::sync::Arc; + +use database::KvTree; use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; use crate::{services, utils, Error, KeyValueDatabase, Result}; -pub trait Data: Send + Sync { - fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; - - fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result; - - fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result; - - // Returns the count at which the last reset_notification_counts was called - fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result; - - fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) -> Result<()>; - - fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result>; - - fn get_shared_rooms<'a>( - &'a self, users: Vec, - ) -> Result> + 'a>>; +pub struct Data { + userroomid_notificationcount: Arc, + userroomid_highlightcount: Arc, + roomuserid_lastnotificationread: Arc, + roomsynctoken_shortstatehash: Arc, + userroomid_joined: Arc, } -impl Data for KeyValueDatabase { - fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { +impl Data { + pub(super) fn new(db: &Arc) -> Self { + Self { + userroomid_notificationcount: db.userroomid_notificationcount.clone(), + userroomid_highlightcount: db.userroomid_highlightcount.clone(), + roomuserid_lastnotificationread: db.roomuserid_lastnotificationread.clone(), + roomsynctoken_shortstatehash: db.roomsynctoken_shortstatehash.clone(), + userroomid_joined: db.userroomid_joined.clone(), + } + } + + pub(super) fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); @@ -41,7 +43,7 @@ impl Data for KeyValueDatabase { Ok(()) } - fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { + pub(super) fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); @@ -53,7 +55,7 @@ impl Data for KeyValueDatabase { }) } - fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { + pub(super) fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xFF); userroom_id.extend_from_slice(room_id.as_bytes()); @@ -65,7 +67,7 @@ impl Data for KeyValueDatabase { }) } - fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result { + pub(super) fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result { let mut key = room_id.as_bytes().to_vec(); key.push(0xFF); key.extend_from_slice(user_id.as_bytes()); @@ -81,7 +83,9 @@ impl Data for KeyValueDatabase { .unwrap_or(0)) } - fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) -> Result<()> { + pub(super) fn associate_token_shortstatehash( + &self, room_id: &RoomId, token: u64, shortstatehash: u64, + ) -> Result<()> { let shortroomid = services() .rooms .short @@ -95,7 +99,7 @@ impl Data for KeyValueDatabase { .insert(&key, &shortstatehash.to_be_bytes()) } - fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result> { + pub(super) fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result> { let shortroomid = services() .rooms .short @@ -114,7 +118,7 @@ impl Data for KeyValueDatabase { .transpose() } - fn get_shared_rooms<'a>( + pub(super) fn get_shared_rooms<'a>( &'a self, users: Vec, ) -> Result> + 'a>> { let iterators = users.into_iter().map(move |user_id| { diff --git a/src/service/rooms/user/mod.rs b/src/service/rooms/user/mod.rs index e589a444..36e78982 100644 --- a/src/service/rooms/user/mod.rs +++ b/src/service/rooms/user/mod.rs @@ -1,3 +1,6 @@ +use conduit::Server; +use database::KeyValueDatabase; + mod data; use std::sync::Arc; @@ -8,10 +11,16 @@ use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; use crate::Result; pub struct Service { - pub db: Arc, + pub db: Data, } impl Service { + pub fn build(_server: &Arc, db: &Arc) -> Result { + Ok(Self { + db: Data::new(db), + }) + } + pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { self.db.reset_notification_counts(user_id, room_id) } diff --git a/src/service/sending/data.rs b/src/service/sending/data.rs index 9057c603..13ad4d6e 100644 --- a/src/service/sending/data.rs +++ b/src/service/sending/data.rs @@ -1,31 +1,31 @@ +use std::sync::Arc; + use ruma::{ServerName, UserId}; use super::{Destination, SendingEvent}; -use crate::{services, utils, Error, KeyValueDatabase, Result}; +use crate::{services, utils, Error, KeyValueDatabase, KvTree, Result}; type OutgoingSendingIter<'a> = Box, Destination, SendingEvent)>> + 'a>; type SendingEventIter<'a> = Box, SendingEvent)>> + 'a>; -pub trait Data: Send + Sync { - fn active_requests(&self) -> OutgoingSendingIter<'_>; - fn active_requests_for(&self, destination: &Destination) -> SendingEventIter<'_>; - fn delete_active_request(&self, key: Vec) -> Result<()>; - fn delete_all_active_requests_for(&self, destination: &Destination) -> Result<()>; - - /// TODO: use this? - #[allow(dead_code)] - fn delete_all_requests_for(&self, destination: &Destination) -> Result<()>; - fn queue_requests(&self, requests: &[(&Destination, SendingEvent)]) -> Result>>; - fn queued_requests<'a>( - &'a self, destination: &Destination, - ) -> Box)>> + 'a>; - fn mark_as_active(&self, events: &[(SendingEvent, Vec)]) -> Result<()>; - fn set_latest_educount(&self, server_name: &ServerName, educount: u64) -> Result<()>; - fn get_latest_educount(&self, server_name: &ServerName) -> Result; +pub struct Data { + servercurrentevent_data: Arc, + servernameevent_data: Arc, + servername_educount: Arc, + _db: Arc, } -impl Data for KeyValueDatabase { - fn active_requests<'a>(&'a self) -> Box, Destination, SendingEvent)>> + 'a> { +impl Data { + pub(super) fn new(db: Arc) -> Self { + Self { + servercurrentevent_data: db.servercurrentevent_data.clone(), + servernameevent_data: db.servernameevent_data.clone(), + servername_educount: db.servername_educount.clone(), + _db: db, + } + } + + pub fn active_requests(&self) -> OutgoingSendingIter<'_> { Box::new( self.servercurrentevent_data .iter() @@ -33,9 +33,7 @@ impl Data for KeyValueDatabase { ) } - fn active_requests_for<'a>( - &'a self, destination: &Destination, - ) -> Box, SendingEvent)>> + 'a> { + pub fn active_requests_for<'a>(&'a self, destination: &Destination) -> SendingEventIter<'a> { let prefix = destination.get_prefix(); Box::new( self.servercurrentevent_data @@ -44,9 +42,9 @@ impl Data for KeyValueDatabase { ) } - fn delete_active_request(&self, key: Vec) -> Result<()> { self.servercurrentevent_data.remove(&key) } + pub(super) fn delete_active_request(&self, key: &[u8]) -> Result<()> { self.servercurrentevent_data.remove(key) } - fn delete_all_active_requests_for(&self, destination: &Destination) -> Result<()> { + pub(super) fn delete_all_active_requests_for(&self, destination: &Destination) -> Result<()> { let prefix = destination.get_prefix(); for (key, _) in self.servercurrentevent_data.scan_prefix(prefix) { self.servercurrentevent_data.remove(&key)?; @@ -55,7 +53,7 @@ impl Data for KeyValueDatabase { Ok(()) } - fn delete_all_requests_for(&self, destination: &Destination) -> Result<()> { + pub(super) fn delete_all_requests_for(&self, destination: &Destination) -> Result<()> { let prefix = destination.get_prefix(); for (key, _) in self.servercurrentevent_data.scan_prefix(prefix.clone()) { self.servercurrentevent_data.remove(&key).unwrap(); @@ -68,7 +66,7 @@ impl Data for KeyValueDatabase { Ok(()) } - fn queue_requests(&self, requests: &[(&Destination, SendingEvent)]) -> Result>> { + pub(super) fn queue_requests(&self, requests: &[(&Destination, SendingEvent)]) -> Result>> { let mut batch = Vec::new(); let mut keys = Vec::new(); for (destination, event) in requests { @@ -91,7 +89,7 @@ impl Data for KeyValueDatabase { Ok(keys) } - fn queued_requests<'a>( + pub fn queued_requests<'a>( &'a self, destination: &Destination, ) -> Box)>> + 'a> { let prefix = destination.get_prefix(); @@ -102,7 +100,7 @@ impl Data for KeyValueDatabase { ); } - fn mark_as_active(&self, events: &[(SendingEvent, Vec)]) -> Result<()> { + pub(super) fn mark_as_active(&self, events: &[(SendingEvent, Vec)]) -> Result<()> { for (e, key) in events { if key.is_empty() { continue; @@ -120,12 +118,12 @@ impl Data for KeyValueDatabase { Ok(()) } - fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) -> Result<()> { + pub(super) fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) -> Result<()> { self.servername_educount .insert(server_name.as_bytes(), &last_count.to_be_bytes()) } - fn get_latest_educount(&self, server_name: &ServerName) -> Result { + pub fn get_latest_educount(&self, server_name: &ServerName) -> Result { self.servername_educount .get(server_name.as_bytes())? .map_or(Ok(0), |bytes| { diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index 225d37c7..b9a0144f 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -1,3 +1,5 @@ +use conduit::Server; + mod appservice; mod data; pub mod resolve; @@ -15,10 +17,10 @@ use ruma::{ use tokio::{sync::Mutex, task::JoinHandle}; use tracing::{error, warn}; -use crate::{server_is_ours, services, Config, Error, Result}; +use crate::{server_is_ours, services, Error, KeyValueDatabase, Result}; pub struct Service { - pub db: Arc, + pub db: Data, /// The state for a given state hash. sender: loole::Sender, @@ -51,16 +53,17 @@ pub enum SendingEvent { } impl Service { - pub fn build(db: Arc, config: &Config) -> Arc { + pub fn build(server: &Arc, db: &Arc) -> Result> { + let config = &server.config; let (sender, receiver) = loole::unbounded(); - Arc::new(Self { - db, + Ok(Arc::new(Self { + db: Data::new(db.clone()), sender, receiver: Mutex::new(receiver), handler_join: Mutex::new(None), startup_netburst: config.startup_netburst, startup_netburst_keep: config.startup_netburst_keep, - }) + })) } pub async fn close(&self) { diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index 8479a869..7edb4db5 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -144,7 +144,7 @@ impl Service { if self.startup_netburst_keep >= 0 && entry.len() >= keep { warn!("Dropping unsent event {:?} {:?}", dest, String::from_utf8_lossy(&key)); self.db - .delete_active_request(key) + .delete_active_request(&key) .expect("active request deleted"); } else { entry.push(event); diff --git a/src/service/services.rs b/src/service/services.rs index 7629b207..6f07c1fe 100644 --- a/src/service/services.rs +++ b/src/service/services.rs @@ -1,12 +1,7 @@ -use std::{ - collections::{BTreeMap, HashMap}, - sync::{Arc, Mutex as StdMutex}, -}; +use std::sync::Arc; use conduit::{debug_info, Result, Server}; use database::KeyValueDatabase; -use lru_cache::LruCache; -use tokio::sync::{broadcast, Mutex, RwLock}; use tracing::{debug, info, trace}; use crate::{ @@ -15,9 +10,9 @@ use crate::{ }; pub struct Services { + pub rooms: rooms::Service, pub appservice: appservice::Service, pub pusher: pusher::Service, - pub rooms: rooms::Service, pub transaction_ids: transaction_ids::Service, pub uiaa: uiaa::Service, pub users: users::Service, @@ -34,111 +29,41 @@ pub struct Services { impl Services { pub async fn build(server: Arc, db: Arc) -> Result { - let config = &server.config; Ok(Self { - appservice: appservice::Service::build(db.clone())?, - pusher: pusher::Service { - db: db.clone(), - }, rooms: rooms::Service { - alias: rooms::alias::Service { - db: db.clone(), - }, - auth_chain: rooms::auth_chain::Service { - db: db.clone(), - }, - directory: rooms::directory::Service { - db: db.clone(), - }, - event_handler: rooms::event_handler::Service, - lazy_loading: rooms::lazy_loading::Service { - db: db.clone(), - lazy_load_waiting: Mutex::new(HashMap::new()), - }, - metadata: rooms::metadata::Service { - db: db.clone(), - }, - outlier: rooms::outlier::Service { - db: db.clone(), - }, - pdu_metadata: rooms::pdu_metadata::Service { - db: db.clone(), - }, - read_receipt: rooms::read_receipt::Service { - db: db.clone(), - }, - search: rooms::search::Service { - db: db.clone(), - }, - short: rooms::short::Service { - db: db.clone(), - }, - state: rooms::state::Service { - db: db.clone(), - }, - state_accessor: rooms::state_accessor::Service { - db: db.clone(), - server_visibility_cache: StdMutex::new(LruCache::new( - (f64::from(config.server_visibility_cache_capacity) * config.conduit_cache_capacity_modifier) - as usize, - )), - user_visibility_cache: StdMutex::new(LruCache::new( - (f64::from(config.user_visibility_cache_capacity) * config.conduit_cache_capacity_modifier) - as usize, - )), - }, - state_cache: rooms::state_cache::Service { - db: db.clone(), - }, - state_compressor: rooms::state_compressor::Service { - db: db.clone(), - stateinfo_cache: StdMutex::new(LruCache::new( - (f64::from(config.stateinfo_cache_capacity) * config.conduit_cache_capacity_modifier) as usize, - )), - }, - timeline: rooms::timeline::Service { - db: db.clone(), - lasttimelinecount_cache: Mutex::new(HashMap::new()), - }, - threads: rooms::threads::Service { - db: db.clone(), - }, - typing: rooms::typing::Service { - typing: RwLock::new(BTreeMap::new()), - last_typing_update: RwLock::new(BTreeMap::new()), - typing_update_sender: broadcast::channel(100).0, - }, - spaces: rooms::spaces::Service { - roomid_spacehierarchy_cache: Mutex::new(LruCache::new( - (f64::from(config.roomid_spacehierarchy_cache_capacity) - * config.conduit_cache_capacity_modifier) as usize, - )), - }, - user: rooms::user::Service { - db: db.clone(), - }, + alias: rooms::alias::Service::build(&server, &db)?, + auth_chain: rooms::auth_chain::Service::build(&server, &db)?, + directory: rooms::directory::Service::build(&server, &db)?, + event_handler: rooms::event_handler::Service::build(&server, &db)?, + lazy_loading: rooms::lazy_loading::Service::build(&server, &db)?, + metadata: rooms::metadata::Service::build(&server, &db)?, + outlier: rooms::outlier::Service::build(&server, &db)?, + pdu_metadata: rooms::pdu_metadata::Service::build(&server, &db)?, + read_receipt: rooms::read_receipt::Service::build(&server, &db)?, + search: rooms::search::Service::build(&server, &db)?, + short: rooms::short::Service::build(&server, &db)?, + state: rooms::state::Service::build(&server, &db)?, + state_accessor: rooms::state_accessor::Service::build(&server, &db)?, + state_cache: rooms::state_cache::Service::build(&server, &db)?, + state_compressor: rooms::state_compressor::Service::build(&server, &db)?, + timeline: rooms::timeline::Service::build(&server, &db)?, + threads: rooms::threads::Service::build(&server, &db)?, + typing: rooms::typing::Service::build(&server, &db)?, + spaces: rooms::spaces::Service::build(&server, &db)?, + user: rooms::user::Service::build(&server, &db)?, }, - transaction_ids: transaction_ids::Service { - db: db.clone(), - }, - uiaa: uiaa::Service { - db: db.clone(), - }, - users: users::Service { - db: db.clone(), - connections: StdMutex::new(BTreeMap::new()), - }, - account_data: account_data::Service { - db: db.clone(), - }, - presence: presence::Service::build(db.clone(), config), - admin: admin::Service::build(), - key_backups: key_backups::Service { - db: db.clone(), - }, - media: media::Service::build(&server, &db), - sending: sending::Service::build(db.clone(), config), - globals: globals::Service::load(db.clone(), config)?, + appservice: appservice::Service::build(&server, &db)?, + pusher: pusher::Service::build(&server, &db)?, + transaction_ids: transaction_ids::Service::build(&server, &db)?, + uiaa: uiaa::Service::build(&server, &db)?, + users: users::Service::build(&server, &db)?, + account_data: account_data::Service::build(&server, &db)?, + presence: presence::Service::build(&server, &db)?, + admin: admin::Service::build(&server, &db)?, + key_backups: key_backups::Service::build(&server, &db)?, + media: media::Service::build(&server, &db)?, + sending: sending::Service::build(&server, &db)?, + globals: globals::Service::build(&server, &db)?, server, db, }) diff --git a/src/service/transaction_ids/data.rs b/src/service/transaction_ids/data.rs index 8ea3b8fd..5b29f211 100644 --- a/src/service/transaction_ids/data.rs +++ b/src/service/transaction_ids/data.rs @@ -1,19 +1,21 @@ +use std::sync::Arc; + +use conduit::Result; +use database::{KeyValueDatabase, KvTree}; use ruma::{DeviceId, TransactionId, UserId}; -use crate::{KeyValueDatabase, Result}; - -pub(crate) trait Data: Send + Sync { - fn add_txnid( - &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8], - ) -> Result<()>; - - fn existing_txnid( - &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, - ) -> Result>>; +pub struct Data { + userdevicetxnid_response: Arc, } -impl Data for KeyValueDatabase { - fn add_txnid( +impl Data { + pub(super) fn new(db: &Arc) -> Self { + Self { + userdevicetxnid_response: db.userdevicetxnid_response.clone(), + } + } + + pub(super) fn add_txnid( &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8], ) -> Result<()> { let mut key = user_id.as_bytes().to_vec(); @@ -27,7 +29,7 @@ impl Data for KeyValueDatabase { Ok(()) } - fn existing_txnid( + pub(super) fn existing_txnid( &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, ) -> Result>> { let mut key = user_id.as_bytes().to_vec(); diff --git a/src/service/transaction_ids/mod.rs b/src/service/transaction_ids/mod.rs index e986f0ac..5fb85d4a 100644 --- a/src/service/transaction_ids/mod.rs +++ b/src/service/transaction_ids/mod.rs @@ -1,3 +1,6 @@ +use conduit::Server; +use database::KeyValueDatabase; + mod data; use std::sync::Arc; @@ -8,10 +11,16 @@ use ruma::{DeviceId, TransactionId, UserId}; use crate::Result; pub struct Service { - pub(super) db: Arc, + pub db: Data, } impl Service { + pub fn build(_server: &Arc, db: &Arc) -> Result { + Ok(Self { + db: Data::new(db), + }) + } + pub fn add_txnid( &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8], ) -> Result<()> { diff --git a/src/service/uiaa/data.rs b/src/service/uiaa/data.rs index d6d745bc..b0710b6a 100644 --- a/src/service/uiaa/data.rs +++ b/src/service/uiaa/data.rs @@ -1,29 +1,30 @@ +use std::sync::Arc; + +use conduit::{Error, Result}; +use database::{KeyValueDatabase, KvTree}; use ruma::{ api::client::{error::ErrorKind, uiaa::UiaaInfo}, CanonicalJsonValue, DeviceId, UserId, }; -use crate::{Error, KeyValueDatabase, Result}; - -pub(crate) trait Data: Send + Sync { - fn set_uiaa_request( - &self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue, - ) -> Result<()>; - - fn get_uiaa_request(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Option; - - fn update_uiaa_session( - &self, user_id: &UserId, device_id: &DeviceId, session: &str, uiaainfo: Option<&UiaaInfo>, - ) -> Result<()>; - - fn get_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Result; +pub struct Data { + userdevicesessionid_uiaainfo: Arc, + db: Arc, } -impl Data for KeyValueDatabase { - fn set_uiaa_request( +impl Data { + pub(super) fn new(db: &Arc) -> Self { + Self { + userdevicesessionid_uiaainfo: db.userdevicesessionid_uiaainfo.clone(), + db: db.clone(), + } + } + + pub(super) fn set_uiaa_request( &self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue, ) -> Result<()> { - self.userdevicesessionid_uiaarequest + self.db + .userdevicesessionid_uiaarequest .write() .unwrap() .insert( @@ -34,15 +35,18 @@ impl Data for KeyValueDatabase { Ok(()) } - fn get_uiaa_request(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Option { - self.userdevicesessionid_uiaarequest + pub(super) fn get_uiaa_request( + &self, user_id: &UserId, device_id: &DeviceId, session: &str, + ) -> Option { + self.db + .userdevicesessionid_uiaarequest .read() .unwrap() .get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned())) .map(ToOwned::to_owned) } - fn update_uiaa_session( + pub(super) fn update_uiaa_session( &self, user_id: &UserId, device_id: &DeviceId, session: &str, uiaainfo: Option<&UiaaInfo>, ) -> Result<()> { let mut userdevicesessionid = user_id.as_bytes().to_vec(); @@ -64,7 +68,7 @@ impl Data for KeyValueDatabase { Ok(()) } - fn get_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Result { + pub(super) fn get_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Result { let mut userdevicesessionid = user_id.as_bytes().to_vec(); userdevicesessionid.push(0xFF); userdevicesessionid.extend_from_slice(device_id.as_bytes()); diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs index bec6c15b..4539c762 100644 --- a/src/service/uiaa/mod.rs +++ b/src/service/uiaa/mod.rs @@ -1,3 +1,6 @@ +use conduit::Server; +use database::KeyValueDatabase; + mod data; use std::sync::Arc; @@ -18,10 +21,16 @@ use crate::services; pub const SESSION_ID_LENGTH: usize = 32; pub struct Service { - pub(super) db: Arc, + pub db: Data, } impl Service { + pub fn build(_server: &Arc, db: &Arc) -> Result { + Ok(Self { + db: Data::new(db), + }) + } + /// Creates a new Uiaa session. Make sure the session token is unique. pub fn create( &self, user_id: &UserId, device_id: &DeviceId, uiaainfo: &UiaaInfo, json_body: &CanonicalJsonValue, diff --git a/src/service/users/data.rs b/src/service/users/data.rs index 5d3eadd8..5644b3e3 100644 --- a/src/service/users/data.rs +++ b/src/service/users/data.rs @@ -1,4 +1,4 @@ -use std::{collections::BTreeMap, mem::size_of}; +use std::{collections::BTreeMap, mem::size_of, sync::Arc}; use ruma::{ api::client::{device::Device, error::ErrorKind, filter::FilterDefinition}, @@ -10,149 +10,60 @@ use ruma::{ }; use tracing::warn; -use crate::{services, users::clean_signatures, utils, Error, KeyValueDatabase, Result}; +use crate::{services, users::clean_signatures, utils, Error, KeyValueDatabase, KvTree, Result}; -pub trait Data: Send + Sync { - /// Check if a user has an account on this homeserver. - fn exists(&self, user_id: &UserId) -> Result; - - /// Check if account is deactivated - fn is_deactivated(&self, user_id: &UserId) -> Result; - - /// Returns the number of users registered on this server. - fn count(&self) -> Result; - - /// Find out which user an access token belongs to. - fn find_from_token(&self, token: &str) -> Result>; - - /// Returns an iterator over all users on this homeserver. - fn iter<'a>(&'a self) -> Box> + 'a>; - - /// Returns a list of local users as list of usernames. - /// - /// A user account is considered `local` if the length of it's password is - /// greater then zero. - fn list_local_users(&self) -> Result>; - - /// Returns the password hash for the given user. - fn password_hash(&self, user_id: &UserId) -> Result>; - - /// Hash and set the user's password to the Argon2 hash - fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()>; - - /// Returns the displayname of a user on this homeserver. - fn displayname(&self, user_id: &UserId) -> Result>; - - /// Sets a new displayname or removes it if displayname is None. You still - /// need to nofify all rooms of this change. - fn set_displayname(&self, user_id: &UserId, displayname: Option) -> Result<()>; - - /// Get the avatar_url of a user. - fn avatar_url(&self, user_id: &UserId) -> Result>; - - /// Sets a new avatar_url or removes it if avatar_url is None. - fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option) -> Result<()>; - - /// Get the blurhash of a user. - fn blurhash(&self, user_id: &UserId) -> Result>; - - /// Sets a new avatar_url or removes it if avatar_url is None. - fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()>; - - /// Adds a new device to a user. - fn create_device( - &self, user_id: &UserId, device_id: &DeviceId, token: &str, initial_device_display_name: Option, - ) -> Result<()>; - - /// Removes a device from a user. - fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()>; - - /// Returns an iterator over all device ids of this user. - fn all_device_ids<'a>(&'a self, user_id: &UserId) -> Box> + 'a>; - - /// Replaces the access token of one device. - fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()>; - - fn add_one_time_key( - &self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &DeviceKeyId, - one_time_key_value: &Raw, - ) -> Result<()>; - - fn last_one_time_keys_update(&self, user_id: &UserId) -> Result; - - fn take_one_time_key( - &self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm, - ) -> Result)>>; - - fn count_one_time_keys(&self, user_id: &UserId, device_id: &DeviceId) - -> Result>; - - fn add_device_keys(&self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw) -> Result<()>; - - fn add_cross_signing_keys( - &self, user_id: &UserId, master_key: &Raw, self_signing_key: &Option>, - user_signing_key: &Option>, notify: bool, - ) -> Result<()>; - - fn sign_key(&self, target_id: &UserId, key_id: &str, signature: (String, String), sender_id: &UserId) - -> Result<()>; - - fn keys_changed<'a>( - &'a self, user_or_room_id: &str, from: u64, to: Option, - ) -> Box> + 'a>; - - fn mark_device_key_update(&self, user_id: &UserId) -> Result<()>; - - fn get_device_keys(&self, user_id: &UserId, device_id: &DeviceId) -> Result>>; - - fn parse_master_key( - &self, user_id: &UserId, master_key: &Raw, - ) -> Result<(Vec, CrossSigningKey)>; - - fn get_key( - &self, key: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result>>; - - fn get_master_key( - &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result>>; - - fn get_self_signing_key( - &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result>>; - - fn get_user_signing_key(&self, user_id: &UserId) -> Result>>; - - fn add_to_device_event( - &self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str, - content: serde_json::Value, - ) -> Result<()>; - - fn get_to_device_events(&self, user_id: &UserId, device_id: &DeviceId) -> Result>>; - - fn remove_to_device_events(&self, user_id: &UserId, device_id: &DeviceId, until: u64) -> Result<()>; - - fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()>; - - /// Get device metadata. - fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) -> Result>; - - fn get_devicelist_version(&self, user_id: &UserId) -> Result>; - - fn all_devices_metadata<'a>(&'a self, user_id: &UserId) -> Box> + 'a>; - - /// Creates a new sync filter. Returns the filter id. - fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> Result; - - fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result>; +pub struct Data { + userid_password: Arc, + token_userdeviceid: Arc, + userid_displayname: Arc, + userid_avatarurl: Arc, + userid_blurhash: Arc, + userid_devicelistversion: Arc, + userdeviceid_token: Arc, + userdeviceid_metadata: Arc, + onetimekeyid_onetimekeys: Arc, + userid_lastonetimekeyupdate: Arc, + keyid_key: Arc, + userid_masterkeyid: Arc, + userid_selfsigningkeyid: Arc, + userid_usersigningkeyid: Arc, + keychangeid_userid: Arc, + todeviceid_events: Arc, + userfilterid_filter: Arc, + _db: Arc, } -impl Data for KeyValueDatabase { +impl Data { + pub(super) fn new(db: Arc) -> Self { + Self { + userid_password: db.userid_password.clone(), + token_userdeviceid: db.token_userdeviceid.clone(), + userid_displayname: db.userid_displayname.clone(), + userid_avatarurl: db.userid_avatarurl.clone(), + userid_blurhash: db.userid_blurhash.clone(), + userid_devicelistversion: db.userid_devicelistversion.clone(), + userdeviceid_token: db.userdeviceid_token.clone(), + userdeviceid_metadata: db.userdeviceid_metadata.clone(), + onetimekeyid_onetimekeys: db.onetimekeyid_onetimekeys.clone(), + userid_lastonetimekeyupdate: db.userid_lastonetimekeyupdate.clone(), + keyid_key: db.keyid_key.clone(), + userid_masterkeyid: db.userid_masterkeyid.clone(), + userid_selfsigningkeyid: db.userid_selfsigningkeyid.clone(), + userid_usersigningkeyid: db.userid_usersigningkeyid.clone(), + keychangeid_userid: db.keychangeid_userid.clone(), + todeviceid_events: db.todeviceid_events.clone(), + userfilterid_filter: db.userfilterid_filter.clone(), + _db: db, + } + } + /// Check if a user has an account on this homeserver. - fn exists(&self, user_id: &UserId) -> Result { Ok(self.userid_password.get(user_id.as_bytes())?.is_some()) } + pub(super) fn exists(&self, user_id: &UserId) -> Result { + Ok(self.userid_password.get(user_id.as_bytes())?.is_some()) + } /// Check if account is deactivated - fn is_deactivated(&self, user_id: &UserId) -> Result { + pub(super) fn is_deactivated(&self, user_id: &UserId) -> Result { Ok(self .userid_password .get(user_id.as_bytes())? @@ -161,10 +72,10 @@ impl Data for KeyValueDatabase { } /// Returns the number of users registered on this server. - fn count(&self) -> Result { Ok(self.userid_password.iter().count()) } + pub(super) fn count(&self) -> Result { Ok(self.userid_password.iter().count()) } /// Find out which user an access token belongs to. - fn find_from_token(&self, token: &str) -> Result> { + pub(super) fn find_from_token(&self, token: &str) -> Result> { self.token_userdeviceid .get(token.as_bytes())? .map_or(Ok(None), |bytes| { @@ -189,7 +100,7 @@ impl Data for KeyValueDatabase { } /// Returns an iterator over all users on this homeserver. - fn iter<'a>(&'a self) -> Box> + 'a> { + pub fn iter<'a>(&'a self) -> Box> + 'a> { Box::new(self.userid_password.iter().map(|(bytes, _)| { UserId::parse( utils::string_from_bytes(&bytes) @@ -203,7 +114,7 @@ impl Data for KeyValueDatabase { /// /// A user account is considered `local` if the length of it's password is /// greater then zero. - fn list_local_users(&self) -> Result> { + pub(super) fn list_local_users(&self) -> Result> { let users: Vec = self .userid_password .iter() @@ -213,7 +124,7 @@ impl Data for KeyValueDatabase { } /// Returns the password hash for the given user. - fn password_hash(&self, user_id: &UserId) -> Result> { + pub(super) fn password_hash(&self, user_id: &UserId) -> Result> { self.userid_password .get(user_id.as_bytes())? .map_or(Ok(None), |bytes| { @@ -224,7 +135,7 @@ impl Data for KeyValueDatabase { } /// Hash and set the user's password to the Argon2 hash - fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { + pub(super) fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { if let Some(password) = password { if let Ok(hash) = utils::hash::password(password) { self.userid_password @@ -243,7 +154,7 @@ impl Data for KeyValueDatabase { } /// Returns the displayname of a user on this homeserver. - fn displayname(&self, user_id: &UserId) -> Result> { + pub(super) fn displayname(&self, user_id: &UserId) -> Result> { self.userid_displayname .get(user_id.as_bytes())? .map_or(Ok(None), |bytes| { @@ -256,7 +167,7 @@ impl Data for KeyValueDatabase { /// Sets a new displayname or removes it if displayname is None. You still /// need to nofify all rooms of this change. - fn set_displayname(&self, user_id: &UserId, displayname: Option) -> Result<()> { + pub(super) fn set_displayname(&self, user_id: &UserId, displayname: Option) -> Result<()> { if let Some(displayname) = displayname { self.userid_displayname .insert(user_id.as_bytes(), displayname.as_bytes())?; @@ -268,7 +179,7 @@ impl Data for KeyValueDatabase { } /// Get the `avatar_url` of a user. - fn avatar_url(&self, user_id: &UserId) -> Result> { + pub(super) fn avatar_url(&self, user_id: &UserId) -> Result> { self.userid_avatarurl .get(user_id.as_bytes())? .map(|bytes| { @@ -283,7 +194,7 @@ impl Data for KeyValueDatabase { } /// Sets a new avatar_url or removes it if avatar_url is None. - fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option) -> Result<()> { + pub(super) fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option) -> Result<()> { if let Some(avatar_url) = avatar_url { self.userid_avatarurl .insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?; @@ -295,7 +206,7 @@ impl Data for KeyValueDatabase { } /// Get the blurhash of a user. - fn blurhash(&self, user_id: &UserId) -> Result> { + pub(super) fn blurhash(&self, user_id: &UserId) -> Result> { self.userid_blurhash .get(user_id.as_bytes())? .map(|bytes| { @@ -307,8 +218,8 @@ impl Data for KeyValueDatabase { .transpose() } - /// Sets a new blurhash or removes it if blurhash is None. - fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()> { + /// Sets a new avatar_url or removes it if avatar_url is None. + pub(super) fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()> { if let Some(blurhash) = blurhash { self.userid_blurhash .insert(user_id.as_bytes(), blurhash.as_bytes())?; @@ -320,7 +231,7 @@ impl Data for KeyValueDatabase { } /// Adds a new device to a user. - fn create_device( + pub(super) fn create_device( &self, user_id: &UserId, device_id: &DeviceId, token: &str, initial_device_display_name: Option, ) -> Result<()> { // This method should never be called for nonexistent users. We shouldn't assert @@ -354,7 +265,7 @@ impl Data for KeyValueDatabase { } /// Removes a device from a user. - fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { + pub(super) fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { let mut userdeviceid = user_id.as_bytes().to_vec(); userdeviceid.push(0xFF); userdeviceid.extend_from_slice(device_id.as_bytes()); @@ -384,7 +295,9 @@ impl Data for KeyValueDatabase { } /// Returns an iterator over all device ids of this user. - fn all_device_ids<'a>(&'a self, user_id: &UserId) -> Box> + 'a> { + pub(super) fn all_device_ids<'a>( + &'a self, user_id: &UserId, + ) -> Box> + 'a> { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xFF); // All devices have metadata @@ -405,7 +318,7 @@ impl Data for KeyValueDatabase { } /// Replaces the access token of one device. - fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { + pub(super) fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { let mut userdeviceid = user_id.as_bytes().to_vec(); userdeviceid.push(0xFF); userdeviceid.extend_from_slice(device_id.as_bytes()); @@ -436,7 +349,7 @@ impl Data for KeyValueDatabase { Ok(()) } - fn add_one_time_key( + pub(super) fn add_one_time_key( &self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &DeviceKeyId, one_time_key_value: &Raw, ) -> Result<()> { @@ -478,7 +391,7 @@ impl Data for KeyValueDatabase { Ok(()) } - fn last_one_time_keys_update(&self, user_id: &UserId) -> Result { + pub(super) fn last_one_time_keys_update(&self, user_id: &UserId) -> Result { self.userid_lastonetimekeyupdate .get(user_id.as_bytes())? .map_or(Ok(0), |bytes| { @@ -487,7 +400,7 @@ impl Data for KeyValueDatabase { }) } - fn take_one_time_key( + pub(super) fn take_one_time_key( &self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm, ) -> Result)>> { let mut prefix = user_id.as_bytes().to_vec(); @@ -521,7 +434,7 @@ impl Data for KeyValueDatabase { .transpose() } - fn count_one_time_keys( + pub(super) fn count_one_time_keys( &self, user_id: &UserId, device_id: &DeviceId, ) -> Result> { let mut userdeviceid = user_id.as_bytes().to_vec(); @@ -551,7 +464,9 @@ impl Data for KeyValueDatabase { Ok(counts) } - fn add_device_keys(&self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw) -> Result<()> { + pub(super) fn add_device_keys( + &self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw, + ) -> Result<()> { let mut userdeviceid = user_id.as_bytes().to_vec(); userdeviceid.push(0xFF); userdeviceid.extend_from_slice(device_id.as_bytes()); @@ -566,7 +481,7 @@ impl Data for KeyValueDatabase { Ok(()) } - fn add_cross_signing_keys( + pub(super) fn add_cross_signing_keys( &self, user_id: &UserId, master_key: &Raw, self_signing_key: &Option>, user_signing_key: &Option>, notify: bool, ) -> Result<()> { @@ -574,7 +489,7 @@ impl Data for KeyValueDatabase { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xFF); - let (master_key_key, _) = self.parse_master_key(user_id, master_key)?; + let (master_key_key, _) = Self::parse_master_key(user_id, master_key)?; self.keyid_key .insert(&master_key_key, master_key.json().get().as_bytes())?; @@ -647,7 +562,7 @@ impl Data for KeyValueDatabase { Ok(()) } - fn sign_key( + pub(super) fn sign_key( &self, target_id: &UserId, key_id: &str, signature: (String, String), sender_id: &UserId, ) -> Result<()> { let mut key = target_id.as_bytes().to_vec(); @@ -685,7 +600,7 @@ impl Data for KeyValueDatabase { Ok(()) } - fn keys_changed<'a>( + pub(super) fn keys_changed<'a>( &'a self, user_or_room_id: &str, from: u64, to: Option, ) -> Box> + 'a> { let mut prefix = user_or_room_id.as_bytes().to_vec(); @@ -724,7 +639,7 @@ impl Data for KeyValueDatabase { ) } - fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { + pub(super) fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { let count = services().globals.next_count()?.to_be_bytes(); for room_id in services() .rooms @@ -757,7 +672,7 @@ impl Data for KeyValueDatabase { Ok(()) } - fn get_device_keys(&self, user_id: &UserId, device_id: &DeviceId) -> Result>> { + pub(super) fn get_device_keys(&self, user_id: &UserId, device_id: &DeviceId) -> Result>> { let mut key = user_id.as_bytes().to_vec(); key.push(0xFF); key.extend_from_slice(device_id.as_bytes()); @@ -769,8 +684,8 @@ impl Data for KeyValueDatabase { }) } - fn parse_master_key( - &self, user_id: &UserId, master_key: &Raw, + pub(super) fn parse_master_key( + user_id: &UserId, master_key: &Raw, ) -> Result<(Vec, CrossSigningKey)> { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xFF); @@ -793,7 +708,7 @@ impl Data for KeyValueDatabase { Ok((master_key_key, master_key)) } - fn get_key( + pub(super) fn get_key( &self, key: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, ) -> Result>> { self.keyid_key.get(key)?.map_or(Ok(None), |bytes| { @@ -807,7 +722,7 @@ impl Data for KeyValueDatabase { }) } - fn get_master_key( + pub(super) fn get_master_key( &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, ) -> Result>> { self.userid_masterkeyid @@ -815,7 +730,7 @@ impl Data for KeyValueDatabase { .map_or(Ok(None), |key| self.get_key(&key, sender_user, user_id, allowed_signatures)) } - fn get_self_signing_key( + pub(super) fn get_self_signing_key( &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, ) -> Result>> { self.userid_selfsigningkeyid @@ -823,7 +738,7 @@ impl Data for KeyValueDatabase { .map_or(Ok(None), |key| self.get_key(&key, sender_user, user_id, allowed_signatures)) } - fn get_user_signing_key(&self, user_id: &UserId) -> Result>> { + pub(super) fn get_user_signing_key(&self, user_id: &UserId) -> Result>> { self.userid_usersigningkeyid .get(user_id.as_bytes())? .map_or(Ok(None), |key| { @@ -836,7 +751,7 @@ impl Data for KeyValueDatabase { }) } - fn add_to_device_event( + pub(super) fn add_to_device_event( &self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str, content: serde_json::Value, ) -> Result<()> { @@ -858,7 +773,9 @@ impl Data for KeyValueDatabase { Ok(()) } - fn get_to_device_events(&self, user_id: &UserId, device_id: &DeviceId) -> Result>> { + pub(super) fn get_to_device_events( + &self, user_id: &UserId, device_id: &DeviceId, + ) -> Result>> { let mut events = Vec::new(); let mut prefix = user_id.as_bytes().to_vec(); @@ -876,7 +793,7 @@ impl Data for KeyValueDatabase { Ok(events) } - fn remove_to_device_events(&self, user_id: &UserId, device_id: &DeviceId, until: u64) -> Result<()> { + pub(super) fn remove_to_device_events(&self, user_id: &UserId, device_id: &DeviceId, until: u64) -> Result<()> { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xFF); prefix.extend_from_slice(device_id.as_bytes()); @@ -905,7 +822,7 @@ impl Data for KeyValueDatabase { Ok(()) } - fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()> { + pub(super) fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()> { let mut userdeviceid = user_id.as_bytes().to_vec(); userdeviceid.push(0xFF); userdeviceid.extend_from_slice(device_id.as_bytes()); @@ -935,7 +852,7 @@ impl Data for KeyValueDatabase { } /// Get device metadata. - fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) -> Result> { + pub(super) fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) -> Result> { let mut userdeviceid = user_id.as_bytes().to_vec(); userdeviceid.push(0xFF); userdeviceid.extend_from_slice(device_id.as_bytes()); @@ -949,7 +866,7 @@ impl Data for KeyValueDatabase { }) } - fn get_devicelist_version(&self, user_id: &UserId) -> Result> { + pub(super) fn get_devicelist_version(&self, user_id: &UserId) -> Result> { self.userid_devicelistversion .get(user_id.as_bytes())? .map_or(Ok(None), |bytes| { @@ -959,7 +876,9 @@ impl Data for KeyValueDatabase { }) } - fn all_devices_metadata<'a>(&'a self, user_id: &UserId) -> Box> + 'a> { + pub(super) fn all_devices_metadata<'a>( + &'a self, user_id: &UserId, + ) -> Box> + 'a> { let mut key = user_id.as_bytes().to_vec(); key.push(0xFF); @@ -974,7 +893,7 @@ impl Data for KeyValueDatabase { } /// Creates a new sync filter. Returns the filter id. - fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> Result { + pub(super) fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> Result { let filter_id = utils::random_string(4); let mut key = user_id.as_bytes().to_vec(); @@ -987,7 +906,7 @@ impl Data for KeyValueDatabase { Ok(filter_id) } - fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result> { + pub(super) fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result> { let mut key = user_id.as_bytes().to_vec(); key.push(0xFF); key.extend_from_slice(filter_id.as_bytes()); @@ -1006,7 +925,7 @@ impl Data for KeyValueDatabase { /// username could be successfully parsed. /// If `utils::string_from_bytes`(...) returns an error that username will be /// skipped and the error will be logged. -fn get_username_with_valid_password(username: &[u8], password: &[u8]) -> Option { +pub(super) fn get_username_with_valid_password(username: &[u8], password: &[u8]) -> Option { // A valid password is not empty if password.is_empty() { None diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index c3de86a0..a5349d97 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -1,8 +1,10 @@ +use conduit::Server; + mod data; use std::{ collections::{BTreeMap, BTreeSet}, mem, - sync::{Arc, Mutex}, + sync::{Arc, Mutex, Mutex as StdMutex}, }; use data::Data; @@ -22,7 +24,7 @@ use ruma::{ UInt, UserId, }; -use crate::{service, services, Error, Result}; +use crate::{database::KeyValueDatabase, service, services, Error, Result}; pub struct SlidingSyncCache { lists: BTreeMap, @@ -34,11 +36,18 @@ pub struct SlidingSyncCache { type DbConnections = Mutex>>>; pub struct Service { - pub db: Arc, + pub db: Data, pub connections: DbConnections, } impl Service { + pub fn build(_server: &Arc, db: &Arc) -> Result { + Ok(Self { + db: Data::new(db.clone()), + connections: StdMutex::new(BTreeMap::new()), + }) + } + /// Check if a user has an account on this homeserver. pub fn exists(&self, user_id: &UserId) -> Result { self.db.exists(user_id) } @@ -381,7 +390,7 @@ impl Service { pub fn parse_master_key( &self, user_id: &UserId, master_key: &Raw, ) -> Result<(Vec, CrossSigningKey)> { - self.db.parse_master_key(user_id, master_key) + Data::parse_master_key(user_id, master_key) } pub fn get_key(