diff --git a/Cargo.lock b/Cargo.lock index d8d791f0..c074c760 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -98,9 +98,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.56" +version = "0.1.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96cf8829f67d2eab0b2dfa42c5d0ef737e0724e4a82b01b3e292456202b19716" +checksum = "76464446b8bc32758d7e88ee1a804d9914cd9b1cb264c029899680b0be29826f" dependencies = [ "proc-macro2", "quote", @@ -408,6 +408,7 @@ checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" name = "conduit" version = "0.3.0-next" dependencies = [ + "async-trait", "axum", "axum-server", "base64 0.13.0", @@ -422,6 +423,7 @@ dependencies = [ "http", "image", "jsonwebtoken", + "lazy_static", "lru-cache", "num_cpus", "opentelemetry", diff --git a/Cargo.toml b/Cargo.toml index f150c4e7..b88674dd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -90,6 +90,8 @@ figment = { version = "0.10.6", features = ["env", "toml"] } tikv-jemalloc-ctl = { version = "0.4.2", features = ["use_std"], optional = true } tikv-jemallocator = { version = "0.4.1", features = ["unprefixed_malloc_on_supported_platforms"], optional = true } +lazy_static = "1.4.0" +async-trait = "0.1.57" [features] default = ["conduit_bin", "backend_sqlite", "backend_rocksdb", "jemalloc"] diff --git a/src/api/appservice_server.rs b/src/api/appservice_server.rs index ce122dad..1f6e2c9d 100644 --- a/src/api/appservice_server.rs +++ b/src/api/appservice_server.rs @@ -1,12 +1,11 @@ -use crate::{utils, Error, Result}; +use crate::{utils, Error, Result, services}; use bytes::BytesMut; use ruma::api::{IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken}; use std::{fmt::Debug, mem, time::Duration}; use tracing::warn; -#[tracing::instrument(skip(globals, request))] +#[tracing::instrument(skip(request))] pub(crate) async fn send_request( - globals: &crate::database::globals::Globals, registration: serde_yaml::Value, request: T, ) -> Result @@ -46,7 +45,7 @@ where *reqwest_request.timeout_mut() = Some(Duration::from_secs(30)); let url = reqwest_request.url().clone(); - let mut response = globals.default_client().execute(reqwest_request).await?; + let mut response = services().globals.default_client().execute(reqwest_request).await?; // reqwest::Response -> http::Response conversion let status = response.status(); diff --git a/src/api/client_server/account.rs b/src/api/client_server/account.rs index dc0782d1..848bfaa7 100644 --- a/src/api/client_server/account.rs +++ b/src/api/client_server/account.rs @@ -2,9 +2,7 @@ use std::sync::Arc; use super::{DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH}; use crate::{ - database::{admin::make_user_admin, DatabaseGuard}, - pdu::PduBuilder, - utils, Database, Error, Result, Ruma, + utils, Error, Result, Ruma, services, }; use ruma::{ api::client::{ @@ -42,15 +40,14 @@ const RANDOM_USER_ID_LENGTH: usize = 10; /// /// Note: This will not reserve the username, so the username might become invalid when trying to register pub async fn get_register_available_route( - db: DatabaseGuard, body: Ruma, ) -> Result { // Validate user id let user_id = - UserId::parse_with_server_name(body.username.to_lowercase(), db.globals.server_name()) + UserId::parse_with_server_name(body.username.to_lowercase(), services().globals.server_name()) .ok() .filter(|user_id| { - !user_id.is_historical() && user_id.server_name() == db.globals.server_name() + !user_id.is_historical() && user_id.server_name() == services().globals.server_name() }) .ok_or(Error::BadRequest( ErrorKind::InvalidUsername, @@ -58,7 +55,7 @@ pub async fn get_register_available_route( ))?; // Check if username is creative enough - if db.users.exists(&user_id)? { + if services().users.exists(&user_id)? { return Err(Error::BadRequest( ErrorKind::UserInUse, "Desired user ID is already taken.", @@ -85,10 +82,9 @@ pub async fn get_register_available_route( /// - Creates a new account and populates it with default account data /// - If `inhibit_login` is false: Creates a device and returns device id and access_token pub async fn register_route( - db: DatabaseGuard, body: Ruma, ) -> Result { - if !db.globals.allow_registration() && !body.from_appservice { + if !services().globals.allow_registration() && !body.from_appservice { return Err(Error::BadRequest( ErrorKind::Forbidden, "Registration has been disabled.", @@ -100,17 +96,17 @@ pub async fn register_route( let user_id = match (&body.username, is_guest) { (Some(username), false) => { let proposed_user_id = - UserId::parse_with_server_name(username.to_lowercase(), db.globals.server_name()) + UserId::parse_with_server_name(username.to_lowercase(), services().globals.server_name()) .ok() .filter(|user_id| { !user_id.is_historical() - && user_id.server_name() == db.globals.server_name() + && user_id.server_name() == services().globals.server_name() }) .ok_or(Error::BadRequest( ErrorKind::InvalidUsername, "Username is invalid.", ))?; - if db.users.exists(&proposed_user_id)? { + if services().users.exists(&proposed_user_id)? { return Err(Error::BadRequest( ErrorKind::UserInUse, "Desired user ID is already taken.", @@ -121,10 +117,10 @@ pub async fn register_route( _ => loop { let proposed_user_id = UserId::parse_with_server_name( utils::random_string(RANDOM_USER_ID_LENGTH).to_lowercase(), - db.globals.server_name(), + services().globals.server_name(), ) .unwrap(); - if !db.users.exists(&proposed_user_id)? { + if !services().users.exists(&proposed_user_id)? { break proposed_user_id; } }, @@ -143,14 +139,12 @@ pub async fn register_route( if !body.from_appservice { if let Some(auth) = &body.auth { - let (worked, uiaainfo) = db.uiaa.try_auth( - &UserId::parse_with_server_name("", db.globals.server_name()) + let (worked, uiaainfo) = services().uiaa.try_auth( + &UserId::parse_with_server_name("", services().globals.server_name()) .expect("we know this is valid"), "".into(), auth, &uiaainfo, - &db.users, - &db.globals, )?; if !worked { return Err(Error::Uiaa(uiaainfo)); @@ -158,8 +152,8 @@ pub async fn register_route( // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - db.uiaa.create( - &UserId::parse_with_server_name("", db.globals.server_name()) + services().uiaa.create( + &UserId::parse_with_server_name("", services().globals.server_name()) .expect("we know this is valid"), "".into(), &uiaainfo, @@ -178,15 +172,15 @@ pub async fn register_route( }; // Create user - db.users.create(&user_id, password)?; + services().users.create(&user_id, password)?; // Default to pretty displayname let displayname = format!("{} ⚡️", user_id.localpart()); - db.users + services().users .set_displayname(&user_id, Some(displayname.clone()))?; // Initial account data - db.account_data.update( + services().account_data.update( None, &user_id, GlobalAccountDataEventType::PushRules.to_string().into(), @@ -195,7 +189,6 @@ pub async fn register_route( global: push::Ruleset::server_default(&user_id), }, }, - &db.globals, )?; // Inhibit login does not work for guests @@ -219,7 +212,7 @@ pub async fn register_route( let token = utils::random_string(TOKEN_LENGTH); // Create device for this account - db.users.create_device( + services().users.create_device( &user_id, &device_id, &token, @@ -227,7 +220,7 @@ pub async fn register_route( )?; info!("New user {} registered on this server.", user_id); - db.admin + services().admin .send_message(RoomMessageEventContent::notice_plain(format!( "New user {} registered on this server.", user_id @@ -235,14 +228,12 @@ pub async fn register_route( // If this is the first real user, grant them admin privileges // Note: the server user, @conduit:servername, is generated first - if db.users.count()? == 2 { - make_user_admin(&db, &user_id, displayname).await?; + if services().users.count()? == 2 { + services().admin.make_user_admin(&user_id, displayname).await?; warn!("Granting {} admin privileges as the first user", user_id); } - db.flush()?; - Ok(register::v3::Response { access_token: Some(token), user_id, @@ -265,7 +256,6 @@ pub async fn register_route( /// - Forgets to-device events /// - Triggers device list updates pub async fn change_password_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -282,13 +272,11 @@ pub async fn change_password_route( }; if let Some(auth) = &body.auth { - let (worked, uiaainfo) = db.uiaa.try_auth( + let (worked, uiaainfo) = services().uiaa.try_auth( sender_user, sender_device, auth, &uiaainfo, - &db.users, - &db.globals, )?; if !worked { return Err(Error::Uiaa(uiaainfo)); @@ -296,32 +284,30 @@ pub async fn change_password_route( // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - db.uiaa + services().uiaa .create(sender_user, sender_device, &uiaainfo, &json)?; return Err(Error::Uiaa(uiaainfo)); } else { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); } - db.users + services().users .set_password(sender_user, Some(&body.new_password))?; if body.logout_devices { // Logout all devices except the current one - for id in db + for id in services() .users .all_device_ids(sender_user) .filter_map(|id| id.ok()) .filter(|id| id != sender_device) { - db.users.remove_device(sender_user, &id)?; + services().users.remove_device(sender_user, &id)?; } } - db.flush()?; - info!("User {} changed their password.", sender_user); - db.admin + services().admin .send_message(RoomMessageEventContent::notice_plain(format!( "User {} changed their password.", sender_user @@ -336,7 +322,6 @@ pub async fn change_password_route( /// /// Note: Also works for Application Services pub async fn whoami_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -345,7 +330,7 @@ pub async fn whoami_route( Ok(whoami::v3::Response { user_id: sender_user.clone(), device_id, - is_guest: db.users.is_deactivated(&sender_user)?, + is_guest: services().users.is_deactivated(&sender_user)?, }) } @@ -360,7 +345,6 @@ pub async fn whoami_route( /// - Triggers device list updates /// - Removes ability to log in again pub async fn deactivate_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -377,13 +361,11 @@ pub async fn deactivate_route( }; if let Some(auth) = &body.auth { - let (worked, uiaainfo) = db.uiaa.try_auth( + let (worked, uiaainfo) = services().uiaa.try_auth( sender_user, sender_device, auth, &uiaainfo, - &db.users, - &db.globals, )?; if !worked { return Err(Error::Uiaa(uiaainfo)); @@ -391,7 +373,7 @@ pub async fn deactivate_route( // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - db.uiaa + services().uiaa .create(sender_user, sender_device, &uiaainfo, &json)?; return Err(Error::Uiaa(uiaainfo)); } else { @@ -399,20 +381,18 @@ pub async fn deactivate_route( } // Make the user leave all rooms before deactivation - db.rooms.leave_all_rooms(&sender_user, &db).await?; + services().rooms.leave_all_rooms(&sender_user).await?; // Remove devices and mark account as deactivated - db.users.deactivate_account(sender_user)?; + services().users.deactivate_account(sender_user)?; info!("User {} deactivated their account.", sender_user); - db.admin + services().admin .send_message(RoomMessageEventContent::notice_plain(format!( "User {} deactivated their account.", sender_user ))); - db.flush()?; - Ok(deactivate::v3::Response { id_server_unbind_result: ThirdPartyIdRemovalStatus::NoSupport, }) diff --git a/src/api/client_server/alias.rs b/src/api/client_server/alias.rs index 90e9d2c3..7aa5fb2c 100644 --- a/src/api/client_server/alias.rs +++ b/src/api/client_server/alias.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Database, Error, Result, Ruma}; +use crate::{Error, Result, Ruma, services}; use regex::Regex; use ruma::{ api::{ @@ -16,24 +16,21 @@ use ruma::{ /// /// Creates a new room alias on this server. pub async fn create_alias_route( - db: DatabaseGuard, body: Ruma, ) -> Result { - if body.room_alias.server_name() != db.globals.server_name() { + if body.room_alias.server_name() != services().globals.server_name() { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Alias is from another server.", )); } - if db.rooms.id_from_alias(&body.room_alias)?.is_some() { + if services().rooms.id_from_alias(&body.room_alias)?.is_some() { return Err(Error::Conflict("Alias already exists.")); } - db.rooms - .set_alias(&body.room_alias, Some(&body.room_id), &db.globals)?; - - db.flush()?; + services().rooms + .set_alias(&body.room_alias, Some(&body.room_id))?; Ok(create_alias::v3::Response::new()) } @@ -45,22 +42,19 @@ pub async fn create_alias_route( /// - TODO: additional access control checks /// - TODO: Update canonical alias event pub async fn delete_alias_route( - db: DatabaseGuard, body: Ruma, ) -> Result { - if body.room_alias.server_name() != db.globals.server_name() { + if body.room_alias.server_name() != services().globals.server_name() { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Alias is from another server.", )); } - db.rooms.set_alias(&body.room_alias, None, &db.globals)?; + services().rooms.set_alias(&body.room_alias, None)?; // TODO: update alt_aliases? - db.flush()?; - Ok(delete_alias::v3::Response::new()) } @@ -70,21 +64,18 @@ pub async fn delete_alias_route( /// /// - TODO: Suggest more servers to join via pub async fn get_alias_route( - db: DatabaseGuard, body: Ruma, ) -> Result { - get_alias_helper(&db, &body.room_alias).await + get_alias_helper(&body.room_alias).await } pub(crate) async fn get_alias_helper( - db: &Database, room_alias: &RoomAliasId, ) -> Result { - if room_alias.server_name() != db.globals.server_name() { - let response = db + if room_alias.server_name() != services().globals.server_name() { + let response = services() .sending .send_federation_request( - &db.globals, room_alias.server_name(), federation::query::get_room_information::v1::Request { room_alias }, ) @@ -97,10 +88,10 @@ pub(crate) async fn get_alias_helper( } let mut room_id = None; - match db.rooms.id_from_alias(room_alias)? { + match services().rooms.id_from_alias(room_alias)? { Some(r) => room_id = Some(r), None => { - for (_id, registration) in db.appservice.all()? { + for (_id, registration) in services().appservice.all()? { let aliases = registration .get("namespaces") .and_then(|ns| ns.get("aliases")) @@ -115,17 +106,16 @@ pub(crate) async fn get_alias_helper( if aliases .iter() .any(|aliases| aliases.is_match(room_alias.as_str())) - && db + && services() .sending .send_appservice_request( - &db.globals, registration, appservice::query::query_room_alias::v1::Request { room_alias }, ) .await .is_ok() { - room_id = Some(db.rooms.id_from_alias(room_alias)?.ok_or_else(|| { + room_id = Some(services().rooms.id_from_alias(room_alias)?.ok_or_else(|| { Error::bad_config("Appservice lied to us. Room does not exist.") })?); break; @@ -146,6 +136,6 @@ pub(crate) async fn get_alias_helper( Ok(get_alias::v3::Response::new( room_id, - vec![db.globals.server_name().to_owned()], + vec![services().globals.server_name().to_owned()], )) } diff --git a/src/api/client_server/backup.rs b/src/api/client_server/backup.rs index 067f20cd..e4138938 100644 --- a/src/api/client_server/backup.rs +++ b/src/api/client_server/backup.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Error, Result, Ruma}; +use crate::{Error, Result, Ruma, services}; use ruma::api::client::{ backup::{ add_backup_keys, add_backup_keys_for_room, add_backup_keys_for_session, @@ -14,15 +14,12 @@ use ruma::api::client::{ /// /// Creates a new backup. pub async fn create_backup_version_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let version = db + let version = services() .key_backups - .create_backup(sender_user, &body.algorithm, &db.globals)?; - - db.flush()?; + .create_backup(sender_user, &body.algorithm)?; Ok(create_backup_version::v3::Response { version }) } @@ -31,14 +28,11 @@ pub async fn create_backup_version_route( /// /// Update information about an existing backup. Only `auth_data` can be modified. pub async fn update_backup_version_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - db.key_backups - .update_backup(sender_user, &body.version, &body.algorithm, &db.globals)?; - - db.flush()?; + services().key_backups + .update_backup(sender_user, &body.version, &body.algorithm)?; Ok(update_backup_version::v3::Response {}) } @@ -47,13 +41,12 @@ pub async fn update_backup_version_route( /// /// Get information about the latest backup version. pub async fn get_latest_backup_info_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let (version, algorithm) = - db.key_backups + services().key_backups .get_latest_backup(sender_user)? .ok_or(Error::BadRequest( ErrorKind::NotFound, @@ -62,8 +55,8 @@ pub async fn get_latest_backup_info_route( Ok(get_latest_backup_info::v3::Response { algorithm, - count: (db.key_backups.count_keys(sender_user, &version)? as u32).into(), - etag: db.key_backups.get_etag(sender_user, &version)?, + count: (services().key_backups.count_keys(sender_user, &version)? as u32).into(), + etag: services().key_backups.get_etag(sender_user, &version)?, version, }) } @@ -72,11 +65,10 @@ pub async fn get_latest_backup_info_route( /// /// Get information about an existing backup. pub async fn get_backup_info_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let algorithm = db + let algorithm = services() .key_backups .get_backup(sender_user, &body.version)? .ok_or(Error::BadRequest( @@ -86,8 +78,8 @@ pub async fn get_backup_info_route( Ok(get_backup_info::v3::Response { algorithm, - count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), - etag: db.key_backups.get_etag(sender_user, &body.version)?, + count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(), + etag: services().key_backups.get_etag(sender_user, &body.version)?, version: body.version.to_owned(), }) } @@ -98,14 +90,11 @@ pub async fn get_backup_info_route( /// /// - Deletes both information about the backup, as well as all key data related to the backup pub async fn delete_backup_version_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - db.key_backups.delete_backup(sender_user, &body.version)?; - - db.flush()?; + services().key_backups.delete_backup(sender_user, &body.version)?; Ok(delete_backup_version::v3::Response {}) } @@ -118,13 +107,12 @@ pub async fn delete_backup_version_route( /// - Adds the keys to the backup /// - Returns the new number of keys in this backup and the etag pub async fn add_backup_keys_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); if Some(&body.version) - != db + != services() .key_backups .get_latest_backup_version(sender_user)? .as_ref() @@ -137,22 +125,19 @@ pub async fn add_backup_keys_route( for (room_id, room) in &body.rooms { for (session_id, key_data) in &room.sessions { - db.key_backups.add_key( + services().key_backups.add_key( sender_user, &body.version, room_id, session_id, key_data, - &db.globals, )? } } - db.flush()?; - Ok(add_backup_keys::v3::Response { - count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), - etag: db.key_backups.get_etag(sender_user, &body.version)?, + count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(), + etag: services().key_backups.get_etag(sender_user, &body.version)?, }) } @@ -164,13 +149,12 @@ pub async fn add_backup_keys_route( /// - Adds the keys to the backup /// - Returns the new number of keys in this backup and the etag pub async fn add_backup_keys_for_room_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); if Some(&body.version) - != db + != services() .key_backups .get_latest_backup_version(sender_user)? .as_ref() @@ -182,21 +166,18 @@ pub async fn add_backup_keys_for_room_route( } for (session_id, key_data) in &body.sessions { - db.key_backups.add_key( + services().key_backups.add_key( sender_user, &body.version, &body.room_id, session_id, key_data, - &db.globals, )? } - db.flush()?; - Ok(add_backup_keys_for_room::v3::Response { - count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), - etag: db.key_backups.get_etag(sender_user, &body.version)?, + count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(), + etag: services().key_backups.get_etag(sender_user, &body.version)?, }) } @@ -208,13 +189,12 @@ pub async fn add_backup_keys_for_room_route( /// - Adds the keys to the backup /// - Returns the new number of keys in this backup and the etag pub async fn add_backup_keys_for_session_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); if Some(&body.version) - != db + != services() .key_backups .get_latest_backup_version(sender_user)? .as_ref() @@ -225,20 +205,17 @@ pub async fn add_backup_keys_for_session_route( )); } - db.key_backups.add_key( + services().key_backups.add_key( sender_user, &body.version, &body.room_id, &body.session_id, &body.session_data, - &db.globals, )?; - db.flush()?; - Ok(add_backup_keys_for_session::v3::Response { - count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), - etag: db.key_backups.get_etag(sender_user, &body.version)?, + count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(), + etag: services().key_backups.get_etag(sender_user, &body.version)?, }) } @@ -246,12 +223,11 @@ pub async fn add_backup_keys_for_session_route( /// /// Retrieves all keys from the backup. pub async fn get_backup_keys_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let rooms = db.key_backups.get_all(sender_user, &body.version)?; + let rooms = services().key_backups.get_all(sender_user, &body.version)?; Ok(get_backup_keys::v3::Response { rooms }) } @@ -260,12 +236,11 @@ pub async fn get_backup_keys_route( /// /// Retrieves all keys from the backup for a given room. pub async fn get_backup_keys_for_room_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sessions = db + let sessions = services() .key_backups .get_room(sender_user, &body.version, &body.room_id)?; @@ -276,12 +251,11 @@ pub async fn get_backup_keys_for_room_route( /// /// Retrieves a key from the backup. pub async fn get_backup_keys_for_session_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let key_data = db + let key_data = services() .key_backups .get_session(sender_user, &body.version, &body.room_id, &body.session_id)? .ok_or(Error::BadRequest( @@ -296,18 +270,15 @@ pub async fn get_backup_keys_for_session_route( /// /// Delete the keys from the backup. pub async fn delete_backup_keys_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - db.key_backups.delete_all_keys(sender_user, &body.version)?; - - db.flush()?; + services().key_backups.delete_all_keys(sender_user, &body.version)?; Ok(delete_backup_keys::v3::Response { - count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), - etag: db.key_backups.get_etag(sender_user, &body.version)?, + count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(), + etag: services().key_backups.get_etag(sender_user, &body.version)?, }) } @@ -315,19 +286,16 @@ pub async fn delete_backup_keys_route( /// /// Delete the keys from the backup for a given room. pub async fn delete_backup_keys_for_room_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - db.key_backups + services().key_backups .delete_room_keys(sender_user, &body.version, &body.room_id)?; - db.flush()?; - Ok(delete_backup_keys_for_room::v3::Response { - count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), - etag: db.key_backups.get_etag(sender_user, &body.version)?, + count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(), + etag: services().key_backups.get_etag(sender_user, &body.version)?, }) } @@ -335,18 +303,15 @@ pub async fn delete_backup_keys_for_room_route( /// /// Delete a key from the backup. pub async fn delete_backup_keys_for_session_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - db.key_backups + services().key_backups .delete_room_key(sender_user, &body.version, &body.room_id, &body.session_id)?; - db.flush()?; - Ok(delete_backup_keys_for_session::v3::Response { - count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), - etag: db.key_backups.get_etag(sender_user, &body.version)?, + count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(), + etag: services().key_backups.get_etag(sender_user, &body.version)?, }) } diff --git a/src/api/client_server/capabilities.rs b/src/api/client_server/capabilities.rs index 417ad29d..e4283b72 100644 --- a/src/api/client_server/capabilities.rs +++ b/src/api/client_server/capabilities.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Result, Ruma}; +use crate::{Result, Ruma, services}; use ruma::api::client::discovery::get_capabilities::{ self, Capabilities, RoomVersionStability, RoomVersionsCapability, }; @@ -8,26 +8,25 @@ use std::collections::BTreeMap; /// /// Get information on the supported feature set and other relevent capabilities of this server. pub async fn get_capabilities_route( - db: DatabaseGuard, _body: Ruma, ) -> Result { let mut available = BTreeMap::new(); - if db.globals.allow_unstable_room_versions() { - for room_version in &db.globals.unstable_room_versions { + if services().globals.allow_unstable_room_versions() { + for room_version in &services().globals.unstable_room_versions { available.insert(room_version.clone(), RoomVersionStability::Stable); } } else { - for room_version in &db.globals.unstable_room_versions { + for room_version in &services().globals.unstable_room_versions { available.insert(room_version.clone(), RoomVersionStability::Unstable); } } - for room_version in &db.globals.stable_room_versions { + for room_version in &services().globals.stable_room_versions { available.insert(room_version.clone(), RoomVersionStability::Stable); } let mut capabilities = Capabilities::new(); capabilities.room_versions = RoomVersionsCapability { - default: db.globals.default_room_version(), + default: services().globals.default_room_version(), available, }; diff --git a/src/api/client_server/config.rs b/src/api/client_server/config.rs index 6184e0bc..36f4fcb7 100644 --- a/src/api/client_server/config.rs +++ b/src/api/client_server/config.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Error, Result, Ruma}; +use crate::{Error, Result, Ruma, services}; use ruma::{ api::client::{ config::{ @@ -17,7 +17,6 @@ use serde_json::{json, value::RawValue as RawJsonValue}; /// /// Sets some account data for the sender user. pub async fn set_global_account_data_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -27,7 +26,7 @@ pub async fn set_global_account_data_route( let event_type = body.event_type.to_string(); - db.account_data.update( + services().account_data.update( None, sender_user, event_type.clone().into(), @@ -35,11 +34,8 @@ pub async fn set_global_account_data_route( "type": event_type, "content": data, }), - &db.globals, )?; - db.flush()?; - Ok(set_global_account_data::v3::Response {}) } @@ -47,7 +43,6 @@ pub async fn set_global_account_data_route( /// /// Sets some room account data for the sender user. pub async fn set_room_account_data_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -57,7 +52,7 @@ pub async fn set_room_account_data_route( let event_type = body.event_type.to_string(); - db.account_data.update( + services().account_data.update( Some(&body.room_id), sender_user, event_type.clone().into(), @@ -65,11 +60,8 @@ pub async fn set_room_account_data_route( "type": event_type, "content": data, }), - &db.globals, )?; - db.flush()?; - Ok(set_room_account_data::v3::Response {}) } @@ -77,12 +69,11 @@ pub async fn set_room_account_data_route( /// /// Gets some account data for the sender user. pub async fn get_global_account_data_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event: Box = db + let event: Box = services() .account_data .get(None, sender_user, body.event_type.clone().into())? .ok_or(Error::BadRequest(ErrorKind::NotFound, "Data not found."))?; @@ -98,12 +89,11 @@ pub async fn get_global_account_data_route( /// /// Gets some room account data for the sender user. pub async fn get_room_account_data_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event: Box = db + let event: Box = services() .account_data .get( Some(&body.room_id), diff --git a/src/api/client_server/context.rs b/src/api/client_server/context.rs index e93f5a5b..3551dcfd 100644 --- a/src/api/client_server/context.rs +++ b/src/api/client_server/context.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Error, Result, Ruma}; +use crate::{Error, Result, Ruma, services}; use ruma::{ api::client::{context::get_context, error::ErrorKind, filter::LazyLoadOptions}, events::StateEventType, @@ -13,7 +13,6 @@ use tracing::error; /// - Only works if the user is joined (TODO: always allow, but only show events if the user was /// joined, depending on history_visibility) pub async fn get_context_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -28,7 +27,7 @@ pub async fn get_context_route( let mut lazy_loaded = HashSet::new(); - let base_pdu_id = db + let base_pdu_id = services() .rooms .get_pdu_id(&body.event_id)? .ok_or(Error::BadRequest( @@ -36,9 +35,9 @@ pub async fn get_context_route( "Base event id not found.", ))?; - let base_token = db.rooms.pdu_count(&base_pdu_id)?; + let base_token = services().rooms.pdu_count(&base_pdu_id)?; - let base_event = db + let base_event = services() .rooms .get_pdu_from_id(&base_pdu_id)? .ok_or(Error::BadRequest( @@ -48,14 +47,14 @@ pub async fn get_context_route( let room_id = base_event.room_id.clone(); - if !db.rooms.is_joined(sender_user, &room_id)? { + if !services().rooms.is_joined(sender_user, &room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "You don't have permission to view this room.", )); } - if !db.rooms.lazy_load_was_sent_before( + if !services().rooms.lazy_load_was_sent_before( sender_user, sender_device, &room_id, @@ -67,7 +66,7 @@ pub async fn get_context_route( let base_event = base_event.to_room_event(); - let events_before: Vec<_> = db + let events_before: Vec<_> = services() .rooms .pdus_until(sender_user, &room_id, base_token)? .take( @@ -80,7 +79,7 @@ pub async fn get_context_route( .collect(); for (_, event) in &events_before { - if !db.rooms.lazy_load_was_sent_before( + if !services().rooms.lazy_load_was_sent_before( sender_user, sender_device, &room_id, @@ -93,7 +92,7 @@ pub async fn get_context_route( let start_token = events_before .last() - .and_then(|(pdu_id, _)| db.rooms.pdu_count(pdu_id).ok()) + .and_then(|(pdu_id, _)| services().rooms.pdu_count(pdu_id).ok()) .map(|count| count.to_string()); let events_before: Vec<_> = events_before @@ -101,7 +100,7 @@ pub async fn get_context_route( .map(|(_, pdu)| pdu.to_room_event()) .collect(); - let events_after: Vec<_> = db + let events_after: Vec<_> = services() .rooms .pdus_after(sender_user, &room_id, base_token)? .take( @@ -114,7 +113,7 @@ pub async fn get_context_route( .collect(); for (_, event) in &events_after { - if !db.rooms.lazy_load_was_sent_before( + if !services().rooms.lazy_load_was_sent_before( sender_user, sender_device, &room_id, @@ -125,23 +124,23 @@ pub async fn get_context_route( } } - let shortstatehash = match db.rooms.pdu_shortstatehash( + let shortstatehash = match services().rooms.pdu_shortstatehash( events_after .last() .map_or(&*body.event_id, |(_, e)| &*e.event_id), )? { Some(s) => s, - None => db + None => services() .rooms .current_shortstatehash(&room_id)? .expect("All rooms have state"), }; - let state_ids = db.rooms.state_full_ids(shortstatehash).await?; + let state_ids = services().rooms.state_full_ids(shortstatehash).await?; let end_token = events_after .last() - .and_then(|(pdu_id, _)| db.rooms.pdu_count(pdu_id).ok()) + .and_then(|(pdu_id, _)| services().rooms.pdu_count(pdu_id).ok()) .map(|count| count.to_string()); let events_after: Vec<_> = events_after @@ -152,10 +151,10 @@ pub async fn get_context_route( let mut state = Vec::new(); for (shortstatekey, id) in state_ids { - let (event_type, state_key) = db.rooms.get_statekey_from_short(shortstatekey)?; + let (event_type, state_key) = services().rooms.get_statekey_from_short(shortstatekey)?; if event_type != StateEventType::RoomMember { - let pdu = match db.rooms.get_pdu(&id)? { + let pdu = match services().rooms.get_pdu(&id)? { Some(pdu) => pdu, None => { error!("Pdu in state not found: {}", id); @@ -164,7 +163,7 @@ pub async fn get_context_route( }; state.push(pdu.to_state_event()); } else if !lazy_load_enabled || lazy_loaded.contains(&state_key) { - let pdu = match db.rooms.get_pdu(&id)? { + let pdu = match services().rooms.get_pdu(&id)? { Some(pdu) => pdu, None => { error!("Pdu in state not found: {}", id); diff --git a/src/api/client_server/device.rs b/src/api/client_server/device.rs index b100bf22..2f559939 100644 --- a/src/api/client_server/device.rs +++ b/src/api/client_server/device.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, utils, Error, Result, Ruma}; +use crate::{utils, Error, Result, Ruma, services}; use ruma::api::client::{ device::{self, delete_device, delete_devices, get_device, get_devices, update_device}, error::ErrorKind, @@ -11,12 +11,11 @@ use super::SESSION_ID_LENGTH; /// /// Get metadata on all devices of the sender user. pub async fn get_devices_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let devices: Vec = db + let devices: Vec = services() .users .all_devices_metadata(sender_user) .filter_map(|r| r.ok()) // Filter out buggy devices @@ -29,12 +28,11 @@ pub async fn get_devices_route( /// /// Get metadata on a single device of the sender user. pub async fn get_device_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let device = db + let device = services() .users .get_device_metadata(sender_user, &body.body.device_id)? .ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?; @@ -46,23 +44,20 @@ pub async fn get_device_route( /// /// Updates the metadata on a given device of the sender user. pub async fn update_device_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let mut device = db + let mut device = services() .users .get_device_metadata(sender_user, &body.device_id)? .ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?; device.display_name = body.display_name.clone(); - db.users + services().users .update_device_metadata(sender_user, &body.device_id, &device)?; - db.flush()?; - Ok(update_device::v3::Response {}) } @@ -76,7 +71,6 @@ pub async fn update_device_route( /// - Forgets to-device events /// - Triggers device list updates pub async fn delete_device_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -94,13 +88,11 @@ pub async fn delete_device_route( }; if let Some(auth) = &body.auth { - let (worked, uiaainfo) = db.uiaa.try_auth( + let (worked, uiaainfo) = services().uiaa.try_auth( sender_user, sender_device, auth, &uiaainfo, - &db.users, - &db.globals, )?; if !worked { return Err(Error::Uiaa(uiaainfo)); @@ -108,16 +100,14 @@ pub async fn delete_device_route( // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - db.uiaa + services().uiaa .create(sender_user, sender_device, &uiaainfo, &json)?; return Err(Error::Uiaa(uiaainfo)); } else { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); } - db.users.remove_device(sender_user, &body.device_id)?; - - db.flush()?; + services().users.remove_device(sender_user, &body.device_id)?; Ok(delete_device::v3::Response {}) } @@ -134,7 +124,6 @@ pub async fn delete_device_route( /// - Forgets to-device events /// - Triggers device list updates pub async fn delete_devices_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -152,13 +141,11 @@ pub async fn delete_devices_route( }; if let Some(auth) = &body.auth { - let (worked, uiaainfo) = db.uiaa.try_auth( + let (worked, uiaainfo) = services().uiaa.try_auth( sender_user, sender_device, auth, &uiaainfo, - &db.users, - &db.globals, )?; if !worked { return Err(Error::Uiaa(uiaainfo)); @@ -166,7 +153,7 @@ pub async fn delete_devices_route( // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - db.uiaa + services().uiaa .create(sender_user, sender_device, &uiaainfo, &json)?; return Err(Error::Uiaa(uiaainfo)); } else { @@ -174,10 +161,8 @@ pub async fn delete_devices_route( } for device_id in &body.devices { - db.users.remove_device(sender_user, device_id)? + services().users.remove_device(sender_user, device_id)? } - db.flush()?; - Ok(delete_devices::v3::Response {}) } diff --git a/src/api/client_server/directory.rs b/src/api/client_server/directory.rs index 4e4a3225..87493fa0 100644 --- a/src/api/client_server/directory.rs +++ b/src/api/client_server/directory.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Database, Error, Result, Ruma}; +use crate::{Error, Result, Ruma, services}; use ruma::{ api::{ client::{ @@ -37,11 +37,9 @@ use tracing::{info, warn}; /// /// - Rooms are ordered by the number of joined members pub async fn get_public_rooms_filtered_route( - db: DatabaseGuard, body: Ruma, ) -> Result { get_public_rooms_filtered_helper( - &db, body.server.as_deref(), body.limit, body.since.as_deref(), @@ -57,11 +55,9 @@ pub async fn get_public_rooms_filtered_route( /// /// - Rooms are ordered by the number of joined members pub async fn get_public_rooms_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let response = get_public_rooms_filtered_helper( - &db, body.server.as_deref(), body.limit, body.since.as_deref(), @@ -84,17 +80,16 @@ pub async fn get_public_rooms_route( /// /// - TODO: Access control checks pub async fn set_room_visibility_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); match &body.visibility { room::Visibility::Public => { - db.rooms.set_public(&body.room_id, true)?; + services().rooms.set_public(&body.room_id, true)?; info!("{} made {} public", sender_user, body.room_id); } - room::Visibility::Private => db.rooms.set_public(&body.room_id, false)?, + room::Visibility::Private => services().rooms.set_public(&body.room_id, false)?, _ => { return Err(Error::BadRequest( ErrorKind::InvalidParam, @@ -103,8 +98,6 @@ pub async fn set_room_visibility_route( } } - db.flush()?; - Ok(set_room_visibility::v3::Response {}) } @@ -112,11 +105,10 @@ pub async fn set_room_visibility_route( /// /// Gets the visibility of a given room in the room directory. pub async fn get_room_visibility_route( - db: DatabaseGuard, body: Ruma, ) -> Result { Ok(get_room_visibility::v3::Response { - visibility: if db.rooms.is_public_room(&body.room_id)? { + visibility: if services().rooms.is_public_room(&body.room_id)? { room::Visibility::Public } else { room::Visibility::Private @@ -125,19 +117,17 @@ pub async fn get_room_visibility_route( } pub(crate) async fn get_public_rooms_filtered_helper( - db: &Database, server: Option<&ServerName>, limit: Option, since: Option<&str>, filter: &IncomingFilter, _network: &IncomingRoomNetwork, ) -> Result { - if let Some(other_server) = server.filter(|server| *server != db.globals.server_name().as_str()) + if let Some(other_server) = server.filter(|server| *server != services().globals.server_name().as_str()) { - let response = db + let response = services() .sending .send_federation_request( - &db.globals, other_server, federation::directory::get_public_rooms_filtered::v1::Request { limit, @@ -184,14 +174,14 @@ pub(crate) async fn get_public_rooms_filtered_helper( } } - let mut all_rooms: Vec<_> = db + let mut all_rooms: Vec<_> = services() .rooms .public_rooms() .map(|room_id| { let room_id = room_id?; let chunk = PublicRoomsChunk { - canonical_alias: db + canonical_alias: services() .rooms .room_state_get(&room_id, &StateEventType::RoomCanonicalAlias, "")? .map_or(Ok(None), |s| { @@ -201,7 +191,7 @@ pub(crate) async fn get_public_rooms_filtered_helper( Error::bad_database("Invalid canonical alias event in database.") }) })?, - name: db + name: services() .rooms .room_state_get(&room_id, &StateEventType::RoomName, "")? .map_or(Ok(None), |s| { @@ -211,7 +201,7 @@ pub(crate) async fn get_public_rooms_filtered_helper( Error::bad_database("Invalid room name event in database.") }) })?, - num_joined_members: db + num_joined_members: services() .rooms .room_joined_count(&room_id)? .unwrap_or_else(|| { @@ -220,7 +210,7 @@ pub(crate) async fn get_public_rooms_filtered_helper( }) .try_into() .expect("user count should not be that big"), - topic: db + topic: services() .rooms .room_state_get(&room_id, &StateEventType::RoomTopic, "")? .map_or(Ok(None), |s| { @@ -230,7 +220,7 @@ pub(crate) async fn get_public_rooms_filtered_helper( Error::bad_database("Invalid room topic event in database.") }) })?, - world_readable: db + world_readable: services() .rooms .room_state_get(&room_id, &StateEventType::RoomHistoryVisibility, "")? .map_or(Ok(false), |s| { @@ -244,7 +234,7 @@ pub(crate) async fn get_public_rooms_filtered_helper( ) }) })?, - guest_can_join: db + guest_can_join: services() .rooms .room_state_get(&room_id, &StateEventType::RoomGuestAccess, "")? .map_or(Ok(false), |s| { @@ -256,7 +246,7 @@ pub(crate) async fn get_public_rooms_filtered_helper( Error::bad_database("Invalid room guest access event in database.") }) })?, - avatar_url: db + avatar_url: services() .rooms .room_state_get(&room_id, &StateEventType::RoomAvatar, "")? .map(|s| { @@ -269,7 +259,7 @@ pub(crate) async fn get_public_rooms_filtered_helper( .transpose()? // url is now an Option so we must flatten .flatten(), - join_rule: db + join_rule: services() .rooms .room_state_get(&room_id, &StateEventType::RoomJoinRules, "")? .map(|s| { diff --git a/src/api/client_server/filter.rs b/src/api/client_server/filter.rs index 6522c900..e0c95066 100644 --- a/src/api/client_server/filter.rs +++ b/src/api/client_server/filter.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Error, Result, Ruma}; +use crate::{Error, Result, Ruma, services}; use ruma::api::client::{ error::ErrorKind, filter::{create_filter, get_filter}, @@ -10,11 +10,10 @@ use ruma::api::client::{ /// /// - A user can only access their own filters pub async fn get_filter_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let filter = match db.users.get_filter(sender_user, &body.filter_id)? { + let filter = match services().users.get_filter(sender_user, &body.filter_id)? { Some(filter) => filter, None => return Err(Error::BadRequest(ErrorKind::NotFound, "Filter not found.")), }; @@ -26,11 +25,10 @@ pub async fn get_filter_route( /// /// Creates a new filter to be used by other endpoints. pub async fn create_filter_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); Ok(create_filter::v3::Response::new( - db.users.create_filter(sender_user, &body.filter)?, + services().users.create_filter(sender_user, &body.filter)?, )) } diff --git a/src/api/client_server/keys.rs b/src/api/client_server/keys.rs index c4f91cb2..698bd1ec 100644 --- a/src/api/client_server/keys.rs +++ b/src/api/client_server/keys.rs @@ -1,5 +1,5 @@ use super::SESSION_ID_LENGTH; -use crate::{database::DatabaseGuard, utils, Database, Error, Result, Ruma}; +use crate::{utils, Error, Result, Ruma, services}; use futures_util::{stream::FuturesUnordered, StreamExt}; use ruma::{ api::{ @@ -26,39 +26,34 @@ use std::collections::{BTreeMap, HashMap, HashSet}; /// - Adds one time keys /// - If there are no device keys yet: Adds device keys (TODO: merge with existing keys?) pub async fn upload_keys_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); for (key_key, key_value) in &body.one_time_keys { - db.users - .add_one_time_key(sender_user, sender_device, key_key, key_value, &db.globals)?; + services().users + .add_one_time_key(sender_user, sender_device, key_key, key_value)?; } if let Some(device_keys) = &body.device_keys { // TODO: merge this and the existing event? // This check is needed to assure that signatures are kept - if db + if services() .users .get_device_keys(sender_user, sender_device)? .is_none() { - db.users.add_device_keys( + services().users.add_device_keys( sender_user, sender_device, device_keys, - &db.rooms, - &db.globals, )?; } } - db.flush()?; - Ok(upload_keys::v3::Response { - one_time_key_counts: db.users.count_one_time_keys(sender_user, sender_device)?, + one_time_key_counts: services().users.count_one_time_keys(sender_user, sender_device)?, }) } @@ -70,7 +65,6 @@ pub async fn upload_keys_route( /// - Gets master keys, self-signing keys, user signing keys and device keys. /// - The master and self-signing keys contain signatures that the user is allowed to see pub async fn get_keys_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -79,7 +73,6 @@ pub async fn get_keys_route( Some(sender_user), &body.device_keys, |u| u == sender_user, - &db, ) .await?; @@ -90,12 +83,9 @@ pub async fn get_keys_route( /// /// Claims one-time keys pub async fn claim_keys_route( - db: DatabaseGuard, body: Ruma, ) -> Result { - let response = claim_keys_helper(&body.one_time_keys, &db).await?; - - db.flush()?; + let response = claim_keys_helper(&body.one_time_keys).await?; Ok(response) } @@ -106,7 +96,6 @@ pub async fn claim_keys_route( /// /// - Requires UIAA to verify password pub async fn upload_signing_keys_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -124,13 +113,11 @@ pub async fn upload_signing_keys_route( }; if let Some(auth) = &body.auth { - let (worked, uiaainfo) = db.uiaa.try_auth( + let (worked, uiaainfo) = services().uiaa.try_auth( sender_user, sender_device, auth, &uiaainfo, - &db.users, - &db.globals, )?; if !worked { return Err(Error::Uiaa(uiaainfo)); @@ -138,7 +125,7 @@ pub async fn upload_signing_keys_route( // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - db.uiaa + services().uiaa .create(sender_user, sender_device, &uiaainfo, &json)?; return Err(Error::Uiaa(uiaainfo)); } else { @@ -146,18 +133,14 @@ pub async fn upload_signing_keys_route( } if let Some(master_key) = &body.master_key { - db.users.add_cross_signing_keys( + services().users.add_cross_signing_keys( sender_user, master_key, &body.self_signing_key, &body.user_signing_key, - &db.rooms, - &db.globals, )?; } - db.flush()?; - Ok(upload_signing_keys::v3::Response {}) } @@ -165,7 +148,6 @@ pub async fn upload_signing_keys_route( /// /// Uploads end-to-end key signatures from the sender user. pub async fn upload_signatures_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -205,20 +187,16 @@ pub async fn upload_signatures_route( ))? .to_owned(), ); - db.users.sign_key( + services().users.sign_key( user_id, key_id, signature, sender_user, - &db.rooms, - &db.globals, )?; } } } - db.flush()?; - Ok(upload_signatures::v3::Response { failures: BTreeMap::new(), // TODO: integrate }) @@ -230,7 +208,6 @@ pub async fn upload_signatures_route( /// /// - TODO: left users pub async fn get_key_changes_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -238,7 +215,7 @@ pub async fn get_key_changes_route( let mut device_list_updates = HashSet::new(); device_list_updates.extend( - db.users + services().users .keys_changed( sender_user.as_str(), body.from @@ -253,9 +230,9 @@ pub async fn get_key_changes_route( .filter_map(|r| r.ok()), ); - for room_id in db.rooms.rooms_joined(sender_user).filter_map(|r| r.ok()) { + for room_id in services().rooms.rooms_joined(sender_user).filter_map(|r| r.ok()) { device_list_updates.extend( - db.users + services().users .keys_changed( &room_id.to_string(), body.from.parse().map_err(|_| { @@ -278,7 +255,6 @@ pub(crate) async fn get_keys_helper bool>( sender_user: Option<&UserId>, device_keys_input: &BTreeMap, Vec>>, allowed_signatures: F, - db: &Database, ) -> Result { let mut master_keys = BTreeMap::new(); let mut self_signing_keys = BTreeMap::new(); @@ -290,7 +266,7 @@ pub(crate) async fn get_keys_helper bool>( for (user_id, device_ids) in device_keys_input { let user_id: &UserId = &**user_id; - if user_id.server_name() != db.globals.server_name() { + if user_id.server_name() != services().globals.server_name() { get_over_federation .entry(user_id.server_name()) .or_insert_with(Vec::new) @@ -300,10 +276,10 @@ pub(crate) async fn get_keys_helper bool>( if device_ids.is_empty() { let mut container = BTreeMap::new(); - for device_id in db.users.all_device_ids(user_id) { + for device_id in services().users.all_device_ids(user_id) { let device_id = device_id?; - if let Some(mut keys) = db.users.get_device_keys(user_id, &device_id)? { - let metadata = db + if let Some(mut keys) = services().users.get_device_keys(user_id, &device_id)? { + let metadata = services() .users .get_device_metadata(user_id, &device_id)? .ok_or_else(|| { @@ -319,8 +295,8 @@ pub(crate) async fn get_keys_helper bool>( } else { for device_id in device_ids { let mut container = BTreeMap::new(); - if let Some(mut keys) = db.users.get_device_keys(user_id, device_id)? { - let metadata = db.users.get_device_metadata(user_id, device_id)?.ok_or( + if let Some(mut keys) = services().users.get_device_keys(user_id, device_id)? { + let metadata = services().users.get_device_metadata(user_id, device_id)?.ok_or( Error::BadRequest( ErrorKind::InvalidParam, "Tried to get keys for nonexistent device.", @@ -335,17 +311,17 @@ pub(crate) async fn get_keys_helper bool>( } } - if let Some(master_key) = db.users.get_master_key(user_id, &allowed_signatures)? { + if let Some(master_key) = services().users.get_master_key(user_id, &allowed_signatures)? { master_keys.insert(user_id.to_owned(), master_key); } - if let Some(self_signing_key) = db + if let Some(self_signing_key) = services() .users .get_self_signing_key(user_id, &allowed_signatures)? { self_signing_keys.insert(user_id.to_owned(), self_signing_key); } if Some(user_id) == sender_user { - if let Some(user_signing_key) = db.users.get_user_signing_key(user_id)? { + if let Some(user_signing_key) = services().users.get_user_signing_key(user_id)? { user_signing_keys.insert(user_id.to_owned(), user_signing_key); } } @@ -362,9 +338,8 @@ pub(crate) async fn get_keys_helper bool>( } ( server, - db.sending + services().sending .send_federation_request( - &db.globals, server, federation::keys::get_keys::v1::Request { device_keys: device_keys_input_fed, @@ -417,14 +392,13 @@ fn add_unsigned_device_display_name( pub(crate) async fn claim_keys_helper( one_time_keys_input: &BTreeMap, BTreeMap, DeviceKeyAlgorithm>>, - db: &Database, ) -> Result { let mut one_time_keys = BTreeMap::new(); let mut get_over_federation = BTreeMap::new(); for (user_id, map) in one_time_keys_input { - if user_id.server_name() != db.globals.server_name() { + if user_id.server_name() != services().globals.server_name() { get_over_federation .entry(user_id.server_name()) .or_insert_with(Vec::new) @@ -434,8 +408,8 @@ pub(crate) async fn claim_keys_helper( let mut container = BTreeMap::new(); for (device_id, key_algorithm) in map { if let Some(one_time_keys) = - db.users - .take_one_time_key(user_id, device_id, key_algorithm, &db.globals)? + services().users + .take_one_time_key(user_id, device_id, key_algorithm)? { let mut c = BTreeMap::new(); c.insert(one_time_keys.0, one_time_keys.1); @@ -453,10 +427,9 @@ pub(crate) async fn claim_keys_helper( one_time_keys_input_fed.insert(user_id.clone(), keys.clone()); } // Ignore failures - if let Ok(keys) = db + if let Ok(keys) = services() .sending .send_federation_request( - &db.globals, server, federation::keys::claim_keys::v1::Request { one_time_keys: one_time_keys_input_fed, diff --git a/src/api/client_server/media.rs b/src/api/client_server/media.rs index a9a6d6cd..f0da0849 100644 --- a/src/api/client_server/media.rs +++ b/src/api/client_server/media.rs @@ -1,6 +1,5 @@ use crate::{ - database::{media::FileMeta, DatabaseGuard}, - utils, Error, Result, Ruma, + utils, Error, Result, Ruma, services, service::media::FileMeta, }; use ruma::api::client::{ error::ErrorKind, @@ -16,11 +15,10 @@ const MXC_LENGTH: usize = 32; /// /// Returns max upload size. pub async fn get_media_config_route( - db: DatabaseGuard, _body: Ruma, ) -> Result { Ok(get_media_config::v3::Response { - upload_size: db.globals.max_request_size().into(), + upload_size: services().globals.max_request_size().into(), }) } @@ -31,19 +29,17 @@ pub async fn get_media_config_route( /// - Some metadata will be saved in the database /// - Media will be saved in the media/ directory pub async fn create_content_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let mxc = format!( "mxc://{}/{}", - db.globals.server_name(), + services().globals.server_name(), utils::random_string(MXC_LENGTH) ); - db.media + services().media .create( mxc.clone(), - &db.globals, &body .filename .as_ref() @@ -54,8 +50,6 @@ pub async fn create_content_route( ) .await?; - db.flush()?; - Ok(create_content::v3::Response { content_uri: mxc.try_into().expect("Invalid mxc:// URI"), blurhash: None, @@ -63,15 +57,13 @@ pub async fn create_content_route( } pub async fn get_remote_content( - db: &DatabaseGuard, mxc: &str, server_name: &ruma::ServerName, media_id: &str, ) -> Result { - let content_response = db + let content_response = services() .sending .send_federation_request( - &db.globals, server_name, get_content::v3::Request { allow_remote: false, @@ -81,10 +73,9 @@ pub async fn get_remote_content( ) .await?; - db.media + services().media .create( mxc.to_string(), - &db.globals, &content_response.content_disposition.as_deref(), &content_response.content_type.as_deref(), &content_response.file, @@ -100,7 +91,6 @@ pub async fn get_remote_content( /// /// - Only allows federation if `allow_remote` is true pub async fn get_content_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); @@ -109,16 +99,16 @@ pub async fn get_content_route( content_disposition, content_type, file, - }) = db.media.get(&db.globals, &mxc).await? + }) = services().media.get(&mxc).await? { Ok(get_content::v3::Response { file, content_type, content_disposition, }) - } else if &*body.server_name != db.globals.server_name() && body.allow_remote { + } else if &*body.server_name != services().globals.server_name() && body.allow_remote { let remote_content_response = - get_remote_content(&db, &mxc, &body.server_name, &body.media_id).await?; + get_remote_content(&mxc, &body.server_name, &body.media_id).await?; Ok(remote_content_response) } else { Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")) @@ -131,7 +121,6 @@ pub async fn get_content_route( /// /// - Only allows federation if `allow_remote` is true pub async fn get_content_as_filename_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); @@ -140,16 +129,16 @@ pub async fn get_content_as_filename_route( content_disposition: _, content_type, file, - }) = db.media.get(&db.globals, &mxc).await? + }) = services().media.get(&mxc).await? { Ok(get_content_as_filename::v3::Response { file, content_type, content_disposition: Some(format!("inline; filename={}", body.filename)), }) - } else if &*body.server_name != db.globals.server_name() && body.allow_remote { + } else if &*body.server_name != services().globals.server_name() && body.allow_remote { let remote_content_response = - get_remote_content(&db, &mxc, &body.server_name, &body.media_id).await?; + get_remote_content(&mxc, &body.server_name, &body.media_id).await?; Ok(get_content_as_filename::v3::Response { content_disposition: Some(format!("inline: filename={}", body.filename)), @@ -167,18 +156,16 @@ pub async fn get_content_as_filename_route( /// /// - Only allows federation if `allow_remote` is true pub async fn get_content_thumbnail_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); if let Some(FileMeta { content_type, file, .. - }) = db + }) = services() .media .get_thumbnail( &mxc, - &db.globals, body.width .try_into() .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?, @@ -189,11 +176,10 @@ pub async fn get_content_thumbnail_route( .await? { Ok(get_content_thumbnail::v3::Response { file, content_type }) - } else if &*body.server_name != db.globals.server_name() && body.allow_remote { - let get_thumbnail_response = db + } else if &*body.server_name != services().globals.server_name() && body.allow_remote { + let get_thumbnail_response = services() .sending .send_federation_request( - &db.globals, &body.server_name, get_content_thumbnail::v3::Request { allow_remote: false, @@ -206,10 +192,9 @@ pub async fn get_content_thumbnail_route( ) .await?; - db.media + services().media .upload_thumbnail( mxc, - &db.globals, &None, &get_thumbnail_response.content_type, body.width.try_into().expect("all UInts are valid u32s"), diff --git a/src/api/client_server/membership.rs b/src/api/client_server/membership.rs index ecd26d1a..b000ec1b 100644 --- a/src/api/client_server/membership.rs +++ b/src/api/client_server/membership.rs @@ -1,9 +1,3 @@ -use crate::{ - client_server, - database::DatabaseGuard, - pdu::{EventHash, PduBuilder, PduEvent}, - server_server, utils, Database, Error, Result, Ruma, -}; use ruma::{ api::{ client::{ @@ -29,13 +23,17 @@ use ruma::{ }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use std::{ - collections::{hash_map::Entry, BTreeMap, HashMap}, + collections::{hash_map::Entry, BTreeMap, HashMap, HashSet}, iter, sync::{Arc, RwLock}, time::{Duration, Instant}, }; use tracing::{debug, error, warn}; +use crate::{services, PduEvent, service::pdu::{gen_event_id_canonical_json, PduBuilder}, Error, api::{server_server}, utils, Ruma}; + +use super::get_alias_helper; + /// # `POST /_matrix/client/r0/rooms/{roomId}/join` /// /// Tries to join the sender user into a room. @@ -43,14 +41,13 @@ use tracing::{debug, error, warn}; /// - If the server knowns about this room: creates the join event and does auth rules locally /// - If the server does not know about the room: asks other servers over federation pub async fn join_room_by_id_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let mut servers = Vec::new(); // There is no body.server_name for /roomId/join servers.extend( - db.rooms + services().rooms .invite_state(sender_user, &body.room_id)? .unwrap_or_default() .iter() @@ -64,7 +61,6 @@ pub async fn join_room_by_id_route( servers.push(body.room_id.server_name().to_owned()); let ret = join_room_by_id_helper( - &db, body.sender_user.as_deref(), &body.room_id, &servers, @@ -72,8 +68,6 @@ pub async fn join_room_by_id_route( ) .await; - db.flush()?; - ret } @@ -84,7 +78,6 @@ pub async fn join_room_by_id_route( /// - If the server knowns about this room: creates the join event and does auth rules locally /// - If the server does not know about the room: asks other servers over federation pub async fn join_room_by_id_or_alias_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_deref().expect("user is authenticated"); @@ -94,7 +87,7 @@ pub async fn join_room_by_id_or_alias_route( Ok(room_id) => { let mut servers = body.server_name.clone(); servers.extend( - db.rooms + services().rooms .invite_state(sender_user, &room_id)? .unwrap_or_default() .iter() @@ -109,14 +102,13 @@ pub async fn join_room_by_id_or_alias_route( (servers, room_id) } Err(room_alias) => { - let response = client_server::get_alias_helper(&db, &room_alias).await?; + let response = get_alias_helper(&room_alias).await?; (response.servers.into_iter().collect(), response.room_id) } }; let join_room_response = join_room_by_id_helper( - &db, Some(sender_user), &room_id, &servers, @@ -124,8 +116,6 @@ pub async fn join_room_by_id_or_alias_route( ) .await?; - db.flush()?; - Ok(join_room_by_id_or_alias::v3::Response { room_id: join_room_response.room_id, }) @@ -137,14 +127,11 @@ pub async fn join_room_by_id_or_alias_route( /// /// - This should always work if the user is currently joined. pub async fn leave_room_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - db.rooms.leave_room(sender_user, &body.room_id, &db).await?; - - db.flush()?; + services().rooms.leave_room(sender_user, &body.room_id).await?; Ok(leave_room::v3::Response::new()) } @@ -153,14 +140,12 @@ pub async fn leave_room_route( /// /// Tries to send an invite event into the room. pub async fn invite_user_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); if let invite_user::v3::IncomingInvitationRecipient::UserId { user_id } = &body.recipient { - invite_helper(sender_user, user_id, &body.room_id, &db, false).await?; - db.flush()?; + invite_helper(sender_user, user_id, &body.room_id, false).await?; Ok(invite_user::v3::Response {}) } else { Err(Error::BadRequest(ErrorKind::NotFound, "User not found.")) @@ -171,13 +156,12 @@ pub async fn invite_user_route( /// /// Tries to send a kick event into the room. pub async fn kick_user_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let mut event: RoomMemberEventContent = serde_json::from_str( - db.rooms + services().rooms .room_state_get( &body.room_id, &StateEventType::RoomMember, @@ -196,7 +180,7 @@ pub async fn kick_user_route( // TODO: reason let mutex_state = Arc::clone( - db.globals + services().globals .roomid_mutex_state .write() .unwrap() @@ -205,7 +189,7 @@ pub async fn kick_user_route( ); let state_lock = mutex_state.lock().await; - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomMember, content: to_raw_value(&event).expect("event is valid, we just created it"), @@ -215,14 +199,11 @@ pub async fn kick_user_route( }, sender_user, &body.room_id, - &db, &state_lock, )?; drop(state_lock); - db.flush()?; - Ok(kick_user::v3::Response::new()) } @@ -230,14 +211,13 @@ pub async fn kick_user_route( /// /// Tries to send a ban event into the room. pub async fn ban_user_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); // TODO: reason - let event = db + let event = services() .rooms .room_state_get( &body.room_id, @@ -247,11 +227,11 @@ pub async fn ban_user_route( .map_or( Ok(RoomMemberEventContent { membership: MembershipState::Ban, - displayname: db.users.displayname(&body.user_id)?, - avatar_url: db.users.avatar_url(&body.user_id)?, + displayname: services().users.displayname(&body.user_id)?, + avatar_url: services().users.avatar_url(&body.user_id)?, is_direct: None, third_party_invite: None, - blurhash: db.users.blurhash(&body.user_id)?, + blurhash: services().users.blurhash(&body.user_id)?, reason: None, join_authorized_via_users_server: None, }), @@ -266,7 +246,7 @@ pub async fn ban_user_route( )?; let mutex_state = Arc::clone( - db.globals + services().globals .roomid_mutex_state .write() .unwrap() @@ -275,7 +255,7 @@ pub async fn ban_user_route( ); let state_lock = mutex_state.lock().await; - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomMember, content: to_raw_value(&event).expect("event is valid, we just created it"), @@ -285,14 +265,11 @@ pub async fn ban_user_route( }, sender_user, &body.room_id, - &db, &state_lock, )?; drop(state_lock); - db.flush()?; - Ok(ban_user::v3::Response::new()) } @@ -300,13 +277,12 @@ pub async fn ban_user_route( /// /// Tries to send an unban event into the room. pub async fn unban_user_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let mut event: RoomMemberEventContent = serde_json::from_str( - db.rooms + services().rooms .room_state_get( &body.room_id, &StateEventType::RoomMember, @@ -324,7 +300,7 @@ pub async fn unban_user_route( event.membership = MembershipState::Leave; let mutex_state = Arc::clone( - db.globals + services().globals .roomid_mutex_state .write() .unwrap() @@ -333,7 +309,7 @@ pub async fn unban_user_route( ); let state_lock = mutex_state.lock().await; - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomMember, content: to_raw_value(&event).expect("event is valid, we just created it"), @@ -343,14 +319,11 @@ pub async fn unban_user_route( }, sender_user, &body.room_id, - &db, &state_lock, )?; drop(state_lock); - db.flush()?; - Ok(unban_user::v3::Response::new()) } @@ -363,14 +336,11 @@ pub async fn unban_user_route( /// Note: Other devices of the user have no way of knowing the room was forgotten, so this has to /// be called from every device pub async fn forget_room_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - db.rooms.forget(&body.room_id, sender_user)?; - - db.flush()?; + services().rooms.forget(&body.room_id, sender_user)?; Ok(forget_room::v3::Response::new()) } @@ -379,13 +349,12 @@ pub async fn forget_room_route( /// /// Lists all rooms the user has joined. pub async fn joined_rooms_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); Ok(joined_rooms::v3::Response { - joined_rooms: db + joined_rooms: services() .rooms .rooms_joined(sender_user) .filter_map(|r| r.ok()) @@ -399,13 +368,12 @@ pub async fn joined_rooms_route( /// /// - Only works if the user is currently joined pub async fn get_member_events_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); // TODO: check history visibility? - if !db.rooms.is_joined(sender_user, &body.room_id)? { + if !services().rooms.is_joined(sender_user, &body.room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "You don't have permission to view this room.", @@ -413,7 +381,7 @@ pub async fn get_member_events_route( } Ok(get_member_events::v3::Response { - chunk: db + chunk: services() .rooms .room_state_full(&body.room_id) .await? @@ -431,12 +399,11 @@ pub async fn get_member_events_route( /// - The sender user must be in the room /// - TODO: An appservice just needs a puppet joined pub async fn joined_members_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !db.rooms.is_joined(sender_user, &body.room_id)? { + if !services().rooms.is_joined(sender_user, &body.room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "You aren't a member of the room.", @@ -444,9 +411,9 @@ pub async fn joined_members_route( } let mut joined = BTreeMap::new(); - for user_id in db.rooms.room_members(&body.room_id).filter_map(|r| r.ok()) { - let display_name = db.users.displayname(&user_id)?; - let avatar_url = db.users.avatar_url(&user_id)?; + for user_id in services().rooms.room_members(&body.room_id).filter_map(|r| r.ok()) { + let display_name = services().users.displayname(&user_id)?; + let avatar_url = services().users.avatar_url(&user_id)?; joined.insert( user_id, @@ -460,9 +427,7 @@ pub async fn joined_members_route( Ok(joined_members::v3::Response { joined }) } -#[tracing::instrument(skip(db))] async fn join_room_by_id_helper( - db: &Database, sender_user: Option<&UserId>, room_id: &RoomId, servers: &[Box], @@ -471,7 +436,7 @@ async fn join_room_by_id_helper( let sender_user = sender_user.expect("user is authenticated"); let mutex_state = Arc::clone( - db.globals + services().globals .roomid_mutex_state .write() .unwrap() @@ -481,21 +446,20 @@ async fn join_room_by_id_helper( let state_lock = mutex_state.lock().await; // Ask a remote server if we don't have this room - if !db.rooms.exists(room_id)? { + if !services().rooms.exists(room_id)? { let mut make_join_response_and_server = Err(Error::BadServerResponse( "No server available to assist in joining.", )); for remote_server in servers { - let make_join_response = db + let make_join_response = services() .sending .send_federation_request( - &db.globals, remote_server, federation::membership::prepare_join_event::v1::Request { room_id, user_id: sender_user, - ver: &db.globals.supported_room_versions(), + ver: &services().globals.supported_room_versions(), }, ) .await; @@ -510,7 +474,7 @@ async fn join_room_by_id_helper( let (make_join_response, remote_server) = make_join_response_and_server?; let room_version = match make_join_response.room_version { - Some(room_version) if db.rooms.is_supported_version(&db, &room_version) => room_version, + Some(room_version) if services().rooms.is_supported_version(&room_version) => room_version, _ => return Err(Error::BadServerResponse("Room version is not supported")), }; @@ -522,7 +486,7 @@ async fn join_room_by_id_helper( // TODO: Is origin needed? join_event_stub.insert( "origin".to_owned(), - CanonicalJsonValue::String(db.globals.server_name().as_str().to_owned()), + CanonicalJsonValue::String(services().globals.server_name().as_str().to_owned()), ); join_event_stub.insert( "origin_server_ts".to_owned(), @@ -536,11 +500,11 @@ async fn join_room_by_id_helper( "content".to_owned(), to_canonical_value(RoomMemberEventContent { membership: MembershipState::Join, - displayname: db.users.displayname(sender_user)?, - avatar_url: db.users.avatar_url(sender_user)?, + displayname: services().users.displayname(sender_user)?, + avatar_url: services().users.avatar_url(sender_user)?, is_direct: None, third_party_invite: None, - blurhash: db.users.blurhash(sender_user)?, + blurhash: services().users.blurhash(sender_user)?, reason: None, join_authorized_via_users_server: None, }) @@ -552,8 +516,8 @@ async fn join_room_by_id_helper( // In order to create a compatible ref hash (EventID) the `hashes` field needs to be present ruma::signatures::hash_and_sign_event( - db.globals.server_name().as_str(), - db.globals.keypair(), + services().globals.server_name().as_str(), + services().globals.keypair(), &mut join_event_stub, &room_version, ) @@ -577,10 +541,9 @@ async fn join_room_by_id_helper( // It has enough fields to be called a proper event now let join_event = join_event_stub; - let send_join_response = db + let send_join_response = services() .sending .send_federation_request( - &db.globals, remote_server, federation::membership::create_join_event::v2::Request { room_id, @@ -590,7 +553,7 @@ async fn join_room_by_id_helper( ) .await?; - db.rooms.get_or_create_shortroomid(room_id, &db.globals)?; + services().rooms.get_or_create_shortroomid(room_id, &services().globals)?; let parsed_pdu = PduEvent::from_id_val(event_id, join_event.clone()) .map_err(|_| Error::BadServerResponse("Invalid join event PDU."))?; @@ -602,7 +565,6 @@ async fn join_room_by_id_helper( &send_join_response, &room_version, &pub_key_map, - db, ) .await?; @@ -610,7 +572,7 @@ async fn join_room_by_id_helper( .room_state .state .iter() - .map(|pdu| validate_and_add_event_id(pdu, &room_version, &pub_key_map, db)) + .map(|pdu| validate_and_add_event_id(pdu, &room_version, &pub_key_map)) { let (event_id, value) = match result { Ok(t) => t, @@ -622,29 +584,27 @@ async fn join_room_by_id_helper( Error::BadServerResponse("Invalid PDU in send_join response.") })?; - db.rooms.add_pdu_outlier(&event_id, &value)?; + services().rooms.add_pdu_outlier(&event_id, &value)?; if let Some(state_key) = &pdu.state_key { - let shortstatekey = db.rooms.get_or_create_shortstatekey( + let shortstatekey = services().rooms.get_or_create_shortstatekey( &pdu.kind.to_string().into(), state_key, - &db.globals, )?; state.insert(shortstatekey, pdu.event_id.clone()); } } - let incoming_shortstatekey = db.rooms.get_or_create_shortstatekey( + let incoming_shortstatekey = services().rooms.get_or_create_shortstatekey( &parsed_pdu.kind.to_string().into(), parsed_pdu .state_key .as_ref() .expect("Pdu is a membership state event"), - &db.globals, )?; state.insert(incoming_shortstatekey, parsed_pdu.event_id.clone()); - let create_shortstatekey = db + let create_shortstatekey = services() .rooms .get_shortstatekey(&StateEventType::RoomCreate, "")? .expect("Room exists"); @@ -653,56 +613,54 @@ async fn join_room_by_id_helper( return Err(Error::BadServerResponse("State contained no create event.")); } - db.rooms.force_state( + services().rooms.force_state( room_id, state .into_iter() - .map(|(k, id)| db.rooms.compress_state_event(k, &id, &db.globals)) + .map(|(k, id)| services().rooms.compress_state_event(k, &id)) .collect::>()?, - db, )?; for result in send_join_response .room_state .auth_chain .iter() - .map(|pdu| validate_and_add_event_id(pdu, &room_version, &pub_key_map, db)) + .map(|pdu| validate_and_add_event_id(pdu, &room_version, &pub_key_map)) { let (event_id, value) = match result { Ok(t) => t, Err(_) => continue, }; - db.rooms.add_pdu_outlier(&event_id, &value)?; + services().rooms.add_pdu_outlier(&event_id, &value)?; } // We append to state before appending the pdu, so we don't have a moment in time with the // pdu without it's state. This is okay because append_pdu can't fail. - let statehashid = db.rooms.append_to_state(&parsed_pdu, &db.globals)?; + let statehashid = services().rooms.append_to_state(&parsed_pdu)?; - db.rooms.append_pdu( + services().rooms.append_pdu( &parsed_pdu, join_event, iter::once(&*parsed_pdu.event_id), - db, )?; // We set the room state after inserting the pdu, so that we never have a moment in time // where events in the current room state do not exist - db.rooms.set_room_state(room_id, statehashid)?; + services().rooms.set_room_state(room_id, statehashid)?; } else { let event = RoomMemberEventContent { membership: MembershipState::Join, - displayname: db.users.displayname(sender_user)?, - avatar_url: db.users.avatar_url(sender_user)?, + displayname: services().users.displayname(sender_user)?, + avatar_url: services().users.avatar_url(sender_user)?, is_direct: None, third_party_invite: None, - blurhash: db.users.blurhash(sender_user)?, + blurhash: services().users.blurhash(sender_user)?, reason: None, join_authorized_via_users_server: None, }; - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomMember, content: to_raw_value(&event).expect("event is valid, we just created it"), @@ -712,15 +670,13 @@ async fn join_room_by_id_helper( }, sender_user, room_id, - db, + services(), &state_lock, )?; } drop(state_lock); - db.flush()?; - Ok(join_room_by_id::v3::Response::new(room_id.to_owned())) } @@ -728,7 +684,6 @@ fn validate_and_add_event_id( pdu: &RawJsonValue, room_version: &RoomVersionId, pub_key_map: &RwLock>>, - db: &Database, ) -> Result<(Box, CanonicalJsonObject)> { let mut value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { error!("Invalid PDU in server response: {:?}: {:?}", pdu, e); @@ -741,14 +696,14 @@ fn validate_and_add_event_id( )) .expect("ruma's reference hashes are valid event ids"); - let back_off = |id| match db.globals.bad_event_ratelimiter.write().unwrap().entry(id) { + let back_off = |id| match services().globals.bad_event_ratelimiter.write().unwrap().entry(id) { Entry::Vacant(e) => { e.insert((Instant::now(), 1)); } Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), }; - if let Some((time, tries)) = db + if let Some((time, tries)) = services() .globals .bad_event_ratelimiter .read() @@ -791,13 +746,12 @@ pub(crate) async fn invite_helper<'a>( sender_user: &UserId, user_id: &UserId, room_id: &RoomId, - db: &Database, is_direct: bool, ) -> Result<()> { - if user_id.server_name() != db.globals.server_name() { - let (room_version_id, pdu_json, invite_room_state) = { + if user_id.server_name() != services().globals.server_name() { + let (pdu_json, invite_room_state) = { let mutex_state = Arc::clone( - db.globals + services().globals .roomid_mutex_state .write() .unwrap() @@ -818,36 +772,38 @@ pub(crate) async fn invite_helper<'a>( }) .expect("member event is valid value"); - let state_key = user_id.to_string(); - let kind = StateEventType::RoomMember; + let (pdu, pdu_json) = services().rooms.timeline.create_hash_and_sign_event(PduBuilder { + event_type: RoomEventType::RoomMember, + content, + unsigned: None, + state_key: Some(user_id.to_string()), + redacts: None, + }, sender_user, room_id, &state_lock); - let (pdu, pdu_json) = create_hash_and_sign_event(); - - let invite_room_state = db.rooms.calculate_invite_state(&pdu)?; + let invite_room_state = services().rooms.calculate_invite_state(&pdu)?; drop(state_lock); - (room_version_id, pdu_json, invite_room_state) + (pdu_json, invite_room_state) }; // Generate event id let expected_event_id = format!( "${}", - ruma::signatures::reference_hash(&pdu_json, &room_version_id) + ruma::signatures::reference_hash(&pdu_json, &services().rooms.state.get_room_version(&room_id)?) .expect("ruma can calculate reference hashes") ); let expected_event_id = <&EventId>::try_from(expected_event_id.as_str()) .expect("ruma's reference hashes are valid event ids"); - let response = db + let response = services() .sending .send_federation_request( - &db.globals, user_id.server_name(), create_invite::v2::Request { room_id, event_id: expected_event_id, - room_version: &room_version_id, + room_version: &services().state.get_room_version(&room_id)?, event: &PduEvent::convert_to_outgoing_federation_event(pdu_json.clone()), invite_room_state: &invite_room_state, }, @@ -857,7 +813,7 @@ pub(crate) async fn invite_helper<'a>( let pub_key_map = RwLock::new(BTreeMap::new()); // We do not add the event_id field to the pdu here because of signature and hashes checks - let (event_id, value) = match crate::pdu::gen_event_id_canonical_json(&response.event, &db) + let (event_id, value) = match gen_event_id_canonical_json(&response.event) { Ok(t) => t, Err(_) => { @@ -882,13 +838,12 @@ pub(crate) async fn invite_helper<'a>( ) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid."))?; - let pdu_id = server_server::handle_incoming_pdu( + let pdu_id = services().rooms.event_handler.handle_incoming_pdu( &origin, &event_id, room_id, value, true, - db, &pub_key_map, ) .await @@ -903,18 +858,18 @@ pub(crate) async fn invite_helper<'a>( "Could not accept incoming PDU as timeline event.", ))?; - let servers = db + let servers = services() .rooms .room_servers(room_id) .filter_map(|r| r.ok()) - .filter(|server| &**server != db.globals.server_name()); + .filter(|server| &**server != services().globals.server_name()); - db.sending.send_pdu(servers, &pdu_id)?; + services().sending.send_pdu(servers, &pdu_id)?; return Ok(()); } - if !db.rooms.is_joined(sender_user, &room_id)? { + if !services().rooms.is_joined(sender_user, &room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "You don't have permission to view this room.", @@ -922,7 +877,7 @@ pub(crate) async fn invite_helper<'a>( } let mutex_state = Arc::clone( - db.globals + services().globals .roomid_mutex_state .write() .unwrap() @@ -931,16 +886,16 @@ pub(crate) async fn invite_helper<'a>( ); let state_lock = mutex_state.lock().await; - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomMember, content: to_raw_value(&RoomMemberEventContent { membership: MembershipState::Invite, - displayname: db.users.displayname(user_id)?, - avatar_url: db.users.avatar_url(user_id)?, + displayname: services().users.displayname(user_id)?, + avatar_url: services().users.avatar_url(user_id)?, is_direct: Some(is_direct), third_party_invite: None, - blurhash: db.users.blurhash(user_id)?, + blurhash: services().users.blurhash(user_id)?, reason: None, join_authorized_via_users_server: None, }) @@ -951,7 +906,6 @@ pub(crate) async fn invite_helper<'a>( }, sender_user, room_id, - db, &state_lock, )?; @@ -960,208 +914,196 @@ pub(crate) async fn invite_helper<'a>( Ok(()) } - // Make a user leave all their joined rooms - #[tracing::instrument(skip(self, db))] - pub async fn leave_all_rooms(&self, user_id: &UserId, db: &Database) -> Result<()> { - let all_rooms = db - .rooms - .rooms_joined(user_id) - .chain(db.rooms.rooms_invited(user_id).map(|t| t.map(|(r, _)| r))) - .collect::>(); +// Make a user leave all their joined rooms +pub async fn leave_all_rooms(user_id: &UserId) -> Result<()> { + let all_rooms = services() + .rooms + .rooms_joined(user_id) + .chain(services().rooms.rooms_invited(user_id).map(|t| t.map(|(r, _)| r))) + .collect::>(); - for room_id in all_rooms { - let room_id = match room_id { - Ok(room_id) => room_id, - Err(_) => continue, - }; - - let _ = self.leave_room(user_id, &room_id, db).await; - } - - Ok(()) - } - - #[tracing::instrument(skip(self, db))] - pub async fn leave_room( - &self, - user_id: &UserId, - room_id: &RoomId, - db: &Database, - ) -> Result<()> { - // Ask a remote server if we don't have this room - if !self.exists(room_id)? && room_id.server_name() != db.globals.server_name() { - if let Err(e) = self.remote_leave_room(user_id, room_id, db).await { - warn!("Failed to leave room {} remotely: {}", user_id, e); - // Don't tell the client about this error - } - - let last_state = self - .invite_state(user_id, room_id)? - .map_or_else(|| self.left_state(user_id, room_id), |s| Ok(Some(s)))?; - - // We always drop the invite, we can't rely on other servers - self.update_membership( - room_id, - user_id, - MembershipState::Leave, - user_id, - last_state, - db, - true, - )?; - } else { - let mutex_state = Arc::clone( - db.globals - .roomid_mutex_state - .write() - .unwrap() - .entry(room_id.to_owned()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; - - let mut event: RoomMemberEventContent = serde_json::from_str( - self.room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())? - .ok_or(Error::BadRequest( - ErrorKind::BadState, - "Cannot leave a room you are not a member of.", - ))? - .content - .get(), - ) - .map_err(|_| Error::bad_database("Invalid member event in database."))?; - - event.membership = MembershipState::Leave; - - self.build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomMember, - content: to_raw_value(&event).expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - }, - user_id, - room_id, - db, - &state_lock, - )?; - } - - Ok(()) - } - - #[tracing::instrument(skip(self, db))] - async fn remote_leave_room( - &self, - user_id: &UserId, - room_id: &RoomId, - db: &Database, - ) -> Result<()> { - let mut make_leave_response_and_server = Err(Error::BadServerResponse( - "No server available to assist in leaving.", - )); - - let invite_state = db - .rooms - .invite_state(user_id, room_id)? - .ok_or(Error::BadRequest( - ErrorKind::BadState, - "User is not invited.", - ))?; - - let servers: HashSet<_> = invite_state - .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(|s| s.to_owned())) - .filter_map(|sender| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()) - .collect(); - - for remote_server in servers { - let make_leave_response = db - .sending - .send_federation_request( - &db.globals, - &remote_server, - federation::membership::prepare_leave_event::v1::Request { room_id, user_id }, - ) - .await; - - make_leave_response_and_server = make_leave_response.map(|r| (r, remote_server)); - - if make_leave_response_and_server.is_ok() { - break; - } - } - - let (make_leave_response, remote_server) = make_leave_response_and_server?; - - let room_version_id = match make_leave_response.room_version { - Some(version) if self.is_supported_version(&db, &version) => version, - _ => return Err(Error::BadServerResponse("Room version is not supported")), + for room_id in all_rooms { + let room_id = match room_id { + Ok(room_id) => room_id, + Err(_) => continue, }; - let mut leave_event_stub = - serde_json::from_str::(make_leave_response.event.get()).map_err( - |_| Error::BadServerResponse("Invalid make_leave event json received from server."), - )?; - - // TODO: Is origin needed? - leave_event_stub.insert( - "origin".to_owned(), - CanonicalJsonValue::String(db.globals.server_name().as_str().to_owned()), - ); - leave_event_stub.insert( - "origin_server_ts".to_owned(), - CanonicalJsonValue::Integer( - utils::millis_since_unix_epoch() - .try_into() - .expect("Timestamp is valid js_int value"), - ), - ); - // We don't leave the event id in the pdu because that's only allowed in v1 or v2 rooms - leave_event_stub.remove("event_id"); - - // In order to create a compatible ref hash (EventID) the `hashes` field needs to be present - ruma::signatures::hash_and_sign_event( - db.globals.server_name().as_str(), - db.globals.keypair(), - &mut leave_event_stub, - &room_version_id, - ) - .expect("event is valid, we just created it"); - - // Generate event id - let event_id = EventId::parse(format!( - "${}", - ruma::signatures::reference_hash(&leave_event_stub, &room_version_id) - .expect("ruma can calculate reference hashes") - )) - .expect("ruma's reference hashes are valid event ids"); - - // Add event_id back - leave_event_stub.insert( - "event_id".to_owned(), - CanonicalJsonValue::String(event_id.as_str().to_owned()), - ); - - // It has enough fields to be called a proper event now - let leave_event = leave_event_stub; - - db.sending - .send_federation_request( - &db.globals, - &remote_server, - federation::membership::create_leave_event::v2::Request { - room_id, - event_id: &event_id, - pdu: &PduEvent::convert_to_outgoing_federation_event(leave_event.clone()), - }, - ) - .await?; - - Ok(()) + let _ = leave_room(user_id, &room_id).await; } + Ok(()) +} + +pub async fn leave_room( + user_id: &UserId, + room_id: &RoomId, +) -> Result<()> { + // Ask a remote server if we don't have this room + if !services().rooms.metadata.exists(room_id)? && room_id.server_name() != services().globals.server_name() { + if let Err(e) = remote_leave_room(user_id, room_id).await { + warn!("Failed to leave room {} remotely: {}", user_id, e); + // Don't tell the client about this error + } + + let last_state = services().rooms.state_cache + .invite_state(user_id, room_id)? + .map_or_else(|| services().rooms.left_state(user_id, room_id), |s| Ok(Some(s)))?; + + // We always drop the invite, we can't rely on other servers + services().rooms.state_cache.update_membership( + room_id, + user_id, + MembershipState::Leave, + user_id, + last_state, + true, + )?; + } else { + let mutex_state = Arc::clone( + services().globals + .roomid_mutex_state + .write() + .unwrap() + .entry(room_id.to_owned()) + .or_default(), + ); + let state_lock = mutex_state.lock().await; + + let mut event: RoomMemberEventContent = serde_json::from_str( + services().rooms.state.room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())? + .ok_or(Error::BadRequest( + ErrorKind::BadState, + "Cannot leave a room you are not a member of.", + ))? + .content + .get(), + ) + .map_err(|_| Error::bad_database("Invalid member event in database."))?; + + event.membership = MembershipState::Leave; + + services().rooms.timeline.build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomMember, + content: to_raw_value(&event).expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(user_id.to_string()), + redacts: None, + }, + user_id, + room_id, + &state_lock, + )?; + } + + Ok(()) +} + +async fn remote_leave_room( + user_id: &UserId, + room_id: &RoomId, +) -> Result<()> { + let mut make_leave_response_and_server = Err(Error::BadServerResponse( + "No server available to assist in leaving.", + )); + + let invite_state = services() + .rooms + .invite_state(user_id, room_id)? + .ok_or(Error::BadRequest( + ErrorKind::BadState, + "User is not invited.", + ))?; + + let servers: HashSet<_> = invite_state + .iter() + .filter_map(|event| serde_json::from_str(event.json().get()).ok()) + .filter_map(|event: serde_json::Value| event.get("sender").cloned()) + .filter_map(|sender| sender.as_str().map(|s| s.to_owned())) + .filter_map(|sender| UserId::parse(sender).ok()) + .map(|user| user.server_name().to_owned()) + .collect(); + + for remote_server in servers { + let make_leave_response = services() + .sending + .send_federation_request( + &remote_server, + federation::membership::prepare_leave_event::v1::Request { room_id, user_id }, + ) + .await; + + make_leave_response_and_server = make_leave_response.map(|r| (r, remote_server)); + + if make_leave_response_and_server.is_ok() { + break; + } + } + + let (make_leave_response, remote_server) = make_leave_response_and_server?; + + let room_version_id = match make_leave_response.room_version { + Some(version) if services().rooms.is_supported_version(&version) => version, + _ => return Err(Error::BadServerResponse("Room version is not supported")), + }; + + let mut leave_event_stub = + serde_json::from_str::(make_leave_response.event.get()).map_err( + |_| Error::BadServerResponse("Invalid make_leave event json received from server."), + )?; + + // TODO: Is origin needed? + leave_event_stub.insert( + "origin".to_owned(), + CanonicalJsonValue::String(services().globals.server_name().as_str().to_owned()), + ); + leave_event_stub.insert( + "origin_server_ts".to_owned(), + CanonicalJsonValue::Integer( + utils::millis_since_unix_epoch() + .try_into() + .expect("Timestamp is valid js_int value"), + ), + ); + // We don't leave the event id in the pdu because that's only allowed in v1 or v2 rooms + leave_event_stub.remove("event_id"); + + // In order to create a compatible ref hash (EventID) the `hashes` field needs to be present + ruma::signatures::hash_and_sign_event( + services().globals.server_name().as_str(), + services().globals.keypair(), + &mut leave_event_stub, + &room_version_id, + ) + .expect("event is valid, we just created it"); + + // Generate event id + let event_id = EventId::parse(format!( + "${}", + ruma::signatures::reference_hash(&leave_event_stub, &room_version_id) + .expect("ruma can calculate reference hashes") + )) + .expect("ruma's reference hashes are valid event ids"); + + // Add event_id back + leave_event_stub.insert( + "event_id".to_owned(), + CanonicalJsonValue::String(event_id.as_str().to_owned()), + ); + + // It has enough fields to be called a proper event now + let leave_event = leave_event_stub; + + services().sending + .send_federation_request( + &remote_server, + federation::membership::create_leave_event::v2::Request { + room_id, + event_id: &event_id, + pdu: &PduEvent::convert_to_outgoing_federation_event(leave_event.clone()), + }, + ) + .await?; + + Ok(()) +} diff --git a/src/api/client_server/message.rs b/src/api/client_server/message.rs index 1348132f..861f9c13 100644 --- a/src/api/client_server/message.rs +++ b/src/api/client_server/message.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, pdu::PduBuilder, utils, Error, Result, Ruma}; +use crate::{utils, Error, Result, Ruma, services, service::pdu::PduBuilder}; use ruma::{ api::client::{ error::ErrorKind, @@ -19,14 +19,13 @@ use std::{ /// - The only requirement for the content is that it has to be valid json /// - Tries to send the event into the room, auth rules will determine if it is allowed pub async fn send_message_event_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_deref(); let mutex_state = Arc::clone( - db.globals + services().globals .roomid_mutex_state .write() .unwrap() @@ -37,7 +36,7 @@ pub async fn send_message_event_route( // Forbid m.room.encrypted if encryption is disabled if RoomEventType::RoomEncrypted == body.event_type.to_string().into() - && !db.globals.allow_encryption() + && !services().globals.allow_encryption() { return Err(Error::BadRequest( ErrorKind::Forbidden, @@ -47,7 +46,7 @@ pub async fn send_message_event_route( // Check if this is a new transaction id if let Some(response) = - db.transaction_ids + services().transaction_ids .existing_txnid(sender_user, sender_device, &body.txn_id)? { // The client might have sent a txnid of the /sendToDevice endpoint @@ -69,7 +68,7 @@ pub async fn send_message_event_route( let mut unsigned = BTreeMap::new(); unsigned.insert("transaction_id".to_owned(), body.txn_id.to_string().into()); - let event_id = db.rooms.build_and_append_pdu( + let event_id = services().rooms.build_and_append_pdu( PduBuilder { event_type: body.event_type.to_string().into(), content: serde_json::from_str(body.body.body.json().get()) @@ -80,11 +79,10 @@ pub async fn send_message_event_route( }, sender_user, &body.room_id, - &db, &state_lock, )?; - db.transaction_ids.add_txnid( + services().transaction_ids.add_txnid( sender_user, sender_device, &body.txn_id, @@ -93,8 +91,6 @@ pub async fn send_message_event_route( drop(state_lock); - db.flush()?; - Ok(send_message_event::v3::Response::new( (*event_id).to_owned(), )) @@ -107,13 +103,12 @@ pub async fn send_message_event_route( /// - Only works if the user is joined (TODO: always allow, but only show events where the user was /// joined, depending on history_visibility) pub async fn get_message_events_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); - if !db.rooms.is_joined(sender_user, &body.room_id)? { + if !services().rooms.is_joined(sender_user, &body.room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "You don't have permission to view this room.", @@ -133,7 +128,7 @@ pub async fn get_message_events_route( let to = body.to.as_ref().map(|t| t.parse()); - db.rooms + services().rooms .lazy_load_confirm_delivery(sender_user, sender_device, &body.room_id, from)?; // Use limit or else 10 @@ -147,13 +142,13 @@ pub async fn get_message_events_route( match body.dir { get_message_events::v3::Direction::Forward => { - let events_after: Vec<_> = db + let events_after: Vec<_> = services() .rooms .pdus_after(sender_user, &body.room_id, from)? .take(limit) .filter_map(|r| r.ok()) // Filter out buggy events .filter_map(|(pdu_id, pdu)| { - db.rooms + services().rooms .pdu_count(&pdu_id) .map(|pdu_count| (pdu_count, pdu)) .ok() @@ -162,7 +157,7 @@ pub async fn get_message_events_route( .collect(); for (_, event) in &events_after { - if !db.rooms.lazy_load_was_sent_before( + if !services().rooms.lazy_load_was_sent_before( sender_user, sender_device, &body.room_id, @@ -184,13 +179,13 @@ pub async fn get_message_events_route( resp.chunk = events_after; } get_message_events::v3::Direction::Backward => { - let events_before: Vec<_> = db + let events_before: Vec<_> = services() .rooms .pdus_until(sender_user, &body.room_id, from)? .take(limit) .filter_map(|r| r.ok()) // Filter out buggy events .filter_map(|(pdu_id, pdu)| { - db.rooms + services().rooms .pdu_count(&pdu_id) .map(|pdu_count| (pdu_count, pdu)) .ok() @@ -199,7 +194,7 @@ pub async fn get_message_events_route( .collect(); for (_, event) in &events_before { - if !db.rooms.lazy_load_was_sent_before( + if !services().rooms.lazy_load_was_sent_before( sender_user, sender_device, &body.room_id, @@ -225,7 +220,7 @@ pub async fn get_message_events_route( resp.state = Vec::new(); for ll_id in &lazy_loaded { if let Some(member_event) = - db.rooms + services().rooms .room_state_get(&body.room_id, &StateEventType::RoomMember, ll_id.as_str())? { resp.state.push(member_event.to_state_event()); @@ -233,7 +228,7 @@ pub async fn get_message_events_route( } if let Some(next_token) = next_token { - db.rooms.lazy_load_mark_sent( + services().rooms.lazy_load_mark_sent( sender_user, sender_device, &body.room_id, diff --git a/src/api/client_server/presence.rs b/src/api/client_server/presence.rs index 773fef47..bc220b80 100644 --- a/src/api/client_server/presence.rs +++ b/src/api/client_server/presence.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, utils, Result, Ruma}; +use crate::{utils, Result, Ruma, services}; use ruma::api::client::presence::{get_presence, set_presence}; use std::time::Duration; @@ -6,22 +6,21 @@ use std::time::Duration; /// /// Sets the presence state of the sender user. pub async fn set_presence_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - for room_id in db.rooms.rooms_joined(sender_user) { + for room_id in services().rooms.rooms_joined(sender_user) { let room_id = room_id?; - db.rooms.edus.update_presence( + services().rooms.edus.update_presence( sender_user, &room_id, ruma::events::presence::PresenceEvent { content: ruma::events::presence::PresenceEventContent { - avatar_url: db.users.avatar_url(sender_user)?, + avatar_url: services().users.avatar_url(sender_user)?, currently_active: None, - displayname: db.users.displayname(sender_user)?, + displayname: services().users.displayname(sender_user)?, last_active_ago: Some( utils::millis_since_unix_epoch() .try_into() @@ -32,12 +31,9 @@ pub async fn set_presence_route( }, sender: sender_user.clone(), }, - &db.globals, )?; } - db.flush()?; - Ok(set_presence::v3::Response {}) } @@ -47,20 +43,19 @@ pub async fn set_presence_route( /// /// - Only works if you share a room with the user pub async fn get_presence_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let mut presence_event = None; - for room_id in db + for room_id in services() .rooms .get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])? { let room_id = room_id?; - if let Some(presence) = db + if let Some(presence) = services() .rooms .edus .get_last_presence_event(sender_user, &room_id)? diff --git a/src/api/client_server/profile.rs b/src/api/client_server/profile.rs index acea19f0..7a87bcd1 100644 --- a/src/api/client_server/profile.rs +++ b/src/api/client_server/profile.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, pdu::PduBuilder, utils, Error, Result, Ruma}; +use crate::{utils, Error, Result, Ruma, services, service::pdu::PduBuilder}; use ruma::{ api::{ client::{ @@ -20,16 +20,15 @@ use std::sync::Arc; /// /// - Also makes sure other users receive the update using presence EDUs pub async fn set_displayname_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - db.users + services().users .set_displayname(sender_user, body.displayname.clone())?; // Send a new membership event and presence update into all joined rooms - let all_rooms_joined: Vec<_> = db + let all_rooms_joined: Vec<_> = services() .rooms .rooms_joined(sender_user) .filter_map(|r| r.ok()) @@ -40,7 +39,7 @@ pub async fn set_displayname_route( content: to_raw_value(&RoomMemberEventContent { displayname: body.displayname.clone(), ..serde_json::from_str( - db.rooms + services().rooms .room_state_get( &room_id, &StateEventType::RoomMember, @@ -70,7 +69,7 @@ pub async fn set_displayname_route( for (pdu_builder, room_id) in all_rooms_joined { let mutex_state = Arc::clone( - db.globals + services().globals .roomid_mutex_state .write() .unwrap() @@ -79,19 +78,19 @@ pub async fn set_displayname_route( ); let state_lock = mutex_state.lock().await; - let _ = db + let _ = services() .rooms - .build_and_append_pdu(pdu_builder, sender_user, &room_id, &db, &state_lock); + .build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock); // Presence update - db.rooms.edus.update_presence( + services().rooms.edus.update_presence( sender_user, &room_id, ruma::events::presence::PresenceEvent { content: ruma::events::presence::PresenceEventContent { - avatar_url: db.users.avatar_url(sender_user)?, + avatar_url: services().users.avatar_url(sender_user)?, currently_active: None, - displayname: db.users.displayname(sender_user)?, + displayname: services().users.displayname(sender_user)?, last_active_ago: Some( utils::millis_since_unix_epoch() .try_into() @@ -102,12 +101,9 @@ pub async fn set_displayname_route( }, sender: sender_user.clone(), }, - &db.globals, )?; } - db.flush()?; - Ok(set_display_name::v3::Response {}) } @@ -117,14 +113,12 @@ pub async fn set_displayname_route( /// /// - If user is on another server: Fetches displayname over federation pub async fn get_displayname_route( - db: DatabaseGuard, body: Ruma, ) -> Result { - if body.user_id.server_name() != db.globals.server_name() { - let response = db + if body.user_id.server_name() != services().globals.server_name() { + let response = services() .sending .send_federation_request( - &db.globals, body.user_id.server_name(), federation::query::get_profile_information::v1::Request { user_id: &body.user_id, @@ -139,7 +133,7 @@ pub async fn get_displayname_route( } Ok(get_display_name::v3::Response { - displayname: db.users.displayname(&body.user_id)?, + displayname: services().users.displayname(&body.user_id)?, }) } @@ -149,18 +143,17 @@ pub async fn get_displayname_route( /// /// - Also makes sure other users receive the update using presence EDUs pub async fn set_avatar_url_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - db.users + services().users .set_avatar_url(sender_user, body.avatar_url.clone())?; - db.users.set_blurhash(sender_user, body.blurhash.clone())?; + services().users.set_blurhash(sender_user, body.blurhash.clone())?; // Send a new membership event and presence update into all joined rooms - let all_joined_rooms: Vec<_> = db + let all_joined_rooms: Vec<_> = services() .rooms .rooms_joined(sender_user) .filter_map(|r| r.ok()) @@ -171,7 +164,7 @@ pub async fn set_avatar_url_route( content: to_raw_value(&RoomMemberEventContent { avatar_url: body.avatar_url.clone(), ..serde_json::from_str( - db.rooms + services().rooms .room_state_get( &room_id, &StateEventType::RoomMember, @@ -201,7 +194,7 @@ pub async fn set_avatar_url_route( for (pdu_builder, room_id) in all_joined_rooms { let mutex_state = Arc::clone( - db.globals + services().globals .roomid_mutex_state .write() .unwrap() @@ -210,19 +203,19 @@ pub async fn set_avatar_url_route( ); let state_lock = mutex_state.lock().await; - let _ = db + let _ = services() .rooms - .build_and_append_pdu(pdu_builder, sender_user, &room_id, &db, &state_lock); + .build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock); // Presence update - db.rooms.edus.update_presence( + services().rooms.edus.update_presence( sender_user, &room_id, ruma::events::presence::PresenceEvent { content: ruma::events::presence::PresenceEventContent { - avatar_url: db.users.avatar_url(sender_user)?, + avatar_url: services().users.avatar_url(sender_user)?, currently_active: None, - displayname: db.users.displayname(sender_user)?, + displayname: services().users.displayname(sender_user)?, last_active_ago: Some( utils::millis_since_unix_epoch() .try_into() @@ -233,12 +226,10 @@ pub async fn set_avatar_url_route( }, sender: sender_user.clone(), }, - &db.globals, + &services().globals, )?; } - db.flush()?; - Ok(set_avatar_url::v3::Response {}) } @@ -248,14 +239,12 @@ pub async fn set_avatar_url_route( /// /// - If user is on another server: Fetches avatar_url and blurhash over federation pub async fn get_avatar_url_route( - db: DatabaseGuard, body: Ruma, ) -> Result { - if body.user_id.server_name() != db.globals.server_name() { - let response = db + if body.user_id.server_name() != services().globals.server_name() { + let response = services() .sending .send_federation_request( - &db.globals, body.user_id.server_name(), federation::query::get_profile_information::v1::Request { user_id: &body.user_id, @@ -271,8 +260,8 @@ pub async fn get_avatar_url_route( } Ok(get_avatar_url::v3::Response { - avatar_url: db.users.avatar_url(&body.user_id)?, - blurhash: db.users.blurhash(&body.user_id)?, + avatar_url: services().users.avatar_url(&body.user_id)?, + blurhash: services().users.blurhash(&body.user_id)?, }) } @@ -282,14 +271,12 @@ pub async fn get_avatar_url_route( /// /// - If user is on another server: Fetches profile over federation pub async fn get_profile_route( - db: DatabaseGuard, body: Ruma, ) -> Result { - if body.user_id.server_name() != db.globals.server_name() { - let response = db + if body.user_id.server_name() != services().globals.server_name() { + let response = services() .sending .send_federation_request( - &db.globals, body.user_id.server_name(), federation::query::get_profile_information::v1::Request { user_id: &body.user_id, @@ -305,7 +292,7 @@ pub async fn get_profile_route( }); } - if !db.users.exists(&body.user_id)? { + if !services().users.exists(&body.user_id)? { // Return 404 if this user doesn't exist return Err(Error::BadRequest( ErrorKind::NotFound, @@ -314,8 +301,8 @@ pub async fn get_profile_route( } Ok(get_profile::v3::Response { - avatar_url: db.users.avatar_url(&body.user_id)?, - blurhash: db.users.blurhash(&body.user_id)?, - displayname: db.users.displayname(&body.user_id)?, + avatar_url: services().users.avatar_url(&body.user_id)?, + blurhash: services().users.blurhash(&body.user_id)?, + displayname: services().users.displayname(&body.user_id)?, }) } diff --git a/src/api/client_server/push.rs b/src/api/client_server/push.rs index dc45ea0b..112fa002 100644 --- a/src/api/client_server/push.rs +++ b/src/api/client_server/push.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Error, Result, Ruma}; +use crate::{Error, Result, Ruma, services}; use ruma::{ api::client::{ error::ErrorKind, @@ -16,12 +16,11 @@ use ruma::{ /// /// Retrieves the push rules event for this user. pub async fn get_pushrules_all_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event: PushRulesEvent = db + let event: PushRulesEvent = services() .account_data .get( None, @@ -42,12 +41,11 @@ pub async fn get_pushrules_all_route( /// /// Retrieves a single specified push rule for this user. pub async fn get_pushrule_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event: PushRulesEvent = db + let event: PushRulesEvent = services() .account_data .get( None, @@ -98,7 +96,6 @@ pub async fn get_pushrule_route( /// /// Creates a single specified push rule for this user. pub async fn set_pushrule_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -111,7 +108,7 @@ pub async fn set_pushrule_route( )); } - let mut event: PushRulesEvent = db + let mut event: PushRulesEvent = services() .account_data .get( None, @@ -186,16 +183,13 @@ pub async fn set_pushrule_route( _ => {} } - db.account_data.update( + services().account_data.update( None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into(), &event, - &db.globals, )?; - db.flush()?; - Ok(set_pushrule::v3::Response {}) } @@ -203,7 +197,6 @@ pub async fn set_pushrule_route( /// /// Gets the actions of a single specified push rule for this user. pub async fn get_pushrule_actions_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -215,7 +208,7 @@ pub async fn get_pushrule_actions_route( )); } - let mut event: PushRulesEvent = db + let mut event: PushRulesEvent = services() .account_data .get( None, @@ -252,8 +245,6 @@ pub async fn get_pushrule_actions_route( _ => None, }; - db.flush()?; - Ok(get_pushrule_actions::v3::Response { actions: actions.unwrap_or_default(), }) @@ -263,7 +254,6 @@ pub async fn get_pushrule_actions_route( /// /// Sets the actions of a single specified push rule for this user. pub async fn set_pushrule_actions_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -275,7 +265,7 @@ pub async fn set_pushrule_actions_route( )); } - let mut event: PushRulesEvent = db + let mut event: PushRulesEvent = services() .account_data .get( None, @@ -322,16 +312,13 @@ pub async fn set_pushrule_actions_route( _ => {} }; - db.account_data.update( + services().account_data.update( None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into(), &event, - &db.globals, )?; - db.flush()?; - Ok(set_pushrule_actions::v3::Response {}) } @@ -339,7 +326,6 @@ pub async fn set_pushrule_actions_route( /// /// Gets the enabled status of a single specified push rule for this user. pub async fn get_pushrule_enabled_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -351,7 +337,7 @@ pub async fn get_pushrule_enabled_route( )); } - let mut event: PushRulesEvent = db + let mut event: PushRulesEvent = services() .account_data .get( None, @@ -393,8 +379,6 @@ pub async fn get_pushrule_enabled_route( _ => false, }; - db.flush()?; - Ok(get_pushrule_enabled::v3::Response { enabled }) } @@ -402,7 +386,6 @@ pub async fn get_pushrule_enabled_route( /// /// Sets the enabled status of a single specified push rule for this user. pub async fn set_pushrule_enabled_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -414,7 +397,7 @@ pub async fn set_pushrule_enabled_route( )); } - let mut event: PushRulesEvent = db + let mut event: PushRulesEvent = services() .account_data .get( None, @@ -466,16 +449,13 @@ pub async fn set_pushrule_enabled_route( _ => {} } - db.account_data.update( + services().account_data.update( None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into(), &event, - &db.globals, )?; - db.flush()?; - Ok(set_pushrule_enabled::v3::Response {}) } @@ -483,7 +463,6 @@ pub async fn set_pushrule_enabled_route( /// /// Deletes a single specified push rule for this user. pub async fn delete_pushrule_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -495,7 +474,7 @@ pub async fn delete_pushrule_route( )); } - let mut event: PushRulesEvent = db + let mut event: PushRulesEvent = services() .account_data .get( None, @@ -537,16 +516,13 @@ pub async fn delete_pushrule_route( _ => {} } - db.account_data.update( + services().account_data.update( None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into(), &event, - &db.globals, )?; - db.flush()?; - Ok(delete_pushrule::v3::Response {}) } @@ -554,13 +530,12 @@ pub async fn delete_pushrule_route( /// /// Gets all currently active pushers for the sender user. pub async fn get_pushers_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); Ok(get_pushers::v3::Response { - pushers: db.pusher.get_pushers(sender_user)?, + pushers: services().pusher.get_pushers(sender_user)?, }) } @@ -570,15 +545,12 @@ pub async fn get_pushers_route( /// /// - TODO: Handle `append` pub async fn set_pushers_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let pusher = body.pusher.clone(); - db.pusher.set_pusher(sender_user, pusher)?; - - db.flush()?; + services().pusher.set_pusher(sender_user, pusher)?; Ok(set_pusher::v3::Response::default()) } diff --git a/src/api/client_server/read_marker.rs b/src/api/client_server/read_marker.rs index 91988a47..284ae65e 100644 --- a/src/api/client_server/read_marker.rs +++ b/src/api/client_server/read_marker.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Error, Result, Ruma}; +use crate::{Error, Result, Ruma, services}; use ruma::{ api::client::{error::ErrorKind, read_marker::set_read_marker, receipt::create_receipt}, events::RoomAccountDataEventType, @@ -14,7 +14,6 @@ use std::collections::BTreeMap; /// - Updates fully-read account data event to `fully_read` /// - If `read_receipt` is set: Update private marker and public read receipt EDU pub async fn set_read_marker_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -24,25 +23,23 @@ pub async fn set_read_marker_route( event_id: body.fully_read.clone(), }, }; - db.account_data.update( + services().account_data.update( Some(&body.room_id), sender_user, RoomAccountDataEventType::FullyRead, &fully_read_event, - &db.globals, )?; if let Some(event) = &body.read_receipt { - db.rooms.edus.private_read_set( + services().rooms.edus.private_read_set( &body.room_id, sender_user, - db.rooms.get_pdu_count(event)?.ok_or(Error::BadRequest( + services().rooms.get_pdu_count(event)?.ok_or(Error::BadRequest( ErrorKind::InvalidParam, "Event does not exist.", ))?, - &db.globals, )?; - db.rooms + services().rooms .reset_notification_counts(sender_user, &body.room_id)?; let mut user_receipts = BTreeMap::new(); @@ -59,19 +56,16 @@ pub async fn set_read_marker_route( let mut receipt_content = BTreeMap::new(); receipt_content.insert(event.to_owned(), receipts); - db.rooms.edus.readreceipt_update( + services().rooms.edus.readreceipt_update( sender_user, &body.room_id, ruma::events::receipt::ReceiptEvent { content: ruma::events::receipt::ReceiptEventContent(receipt_content), room_id: body.room_id.clone(), }, - &db.globals, )?; } - db.flush()?; - Ok(set_read_marker::v3::Response {}) } @@ -79,23 +73,21 @@ pub async fn set_read_marker_route( /// /// Sets private read marker and public read receipt EDU. pub async fn create_receipt_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - db.rooms.edus.private_read_set( + services().rooms.edus.private_read_set( &body.room_id, sender_user, - db.rooms + services().rooms .get_pdu_count(&body.event_id)? .ok_or(Error::BadRequest( ErrorKind::InvalidParam, "Event does not exist.", ))?, - &db.globals, )?; - db.rooms + services().rooms .reset_notification_counts(sender_user, &body.room_id)?; let mut user_receipts = BTreeMap::new(); @@ -111,17 +103,16 @@ pub async fn create_receipt_route( let mut receipt_content = BTreeMap::new(); receipt_content.insert(body.event_id.to_owned(), receipts); - db.rooms.edus.readreceipt_update( + services().rooms.edus.readreceipt_update( sender_user, &body.room_id, ruma::events::receipt::ReceiptEvent { content: ruma::events::receipt::ReceiptEventContent(receipt_content), room_id: body.room_id.clone(), }, - &db.globals, )?; - db.flush()?; + services().flush()?; Ok(create_receipt::v3::Response {}) } diff --git a/src/api/client_server/redact.rs b/src/api/client_server/redact.rs index 059e0f52..d6699bcf 100644 --- a/src/api/client_server/redact.rs +++ b/src/api/client_server/redact.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use crate::{database::DatabaseGuard, pdu::PduBuilder, Result, Ruma}; +use crate::{Result, Ruma, services, service::pdu::PduBuilder}; use ruma::{ api::client::redact::redact_event, events::{room::redaction::RoomRedactionEventContent, RoomEventType}, @@ -14,14 +14,13 @@ use serde_json::value::to_raw_value; /// /// - TODO: Handle txn id pub async fn redact_event_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let body = body.body; let mutex_state = Arc::clone( - db.globals + services().globals .roomid_mutex_state .write() .unwrap() @@ -30,7 +29,7 @@ pub async fn redact_event_route( ); let state_lock = mutex_state.lock().await; - let event_id = db.rooms.build_and_append_pdu( + let event_id = services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomRedaction, content: to_raw_value(&RoomRedactionEventContent { @@ -43,14 +42,11 @@ pub async fn redact_event_route( }, sender_user, &body.room_id, - &db, &state_lock, )?; drop(state_lock); - db.flush()?; - let event_id = (*event_id).to_owned(); Ok(redact_event::v3::Response { event_id }) } diff --git a/src/api/client_server/report.rs b/src/api/client_server/report.rs index 14768e1c..2c2a5493 100644 --- a/src/api/client_server/report.rs +++ b/src/api/client_server/report.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, utils::HtmlEscape, Error, Result, Ruma}; +use crate::{utils::HtmlEscape, Error, Result, Ruma, services}; use ruma::{ api::client::{error::ErrorKind, room::report_content}, events::room::message, @@ -10,12 +10,11 @@ use ruma::{ /// Reports an inappropriate event to homeserver admins /// pub async fn report_event_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let pdu = match db.rooms.get_pdu(&body.event_id)? { + let pdu = match services().rooms.get_pdu(&body.event_id)? { Some(pdu) => pdu, _ => { return Err(Error::BadRequest( @@ -39,7 +38,7 @@ pub async fn report_event_route( )); }; - db.admin + services().admin .send_message(message::RoomMessageEventContent::text_html( format!( "Report received from: {}\n\n\ @@ -66,7 +65,5 @@ pub async fn report_event_route( ), )); - db.flush()?; - Ok(report_content::v3::Response {}) } diff --git a/src/api/client_server/room.rs b/src/api/client_server/room.rs index 5ae7224c..14affc65 100644 --- a/src/api/client_server/room.rs +++ b/src/api/client_server/room.rs @@ -1,5 +1,5 @@ use crate::{ - client_server::invite_helper, database::DatabaseGuard, pdu::PduBuilder, Error, Result, Ruma, + Error, Result, Ruma, service::pdu::PduBuilder, services, api::client_server::invite_helper, }; use ruma::{ api::client::{ @@ -46,19 +46,18 @@ use tracing::{info, warn}; /// - Send events implied by `name` and `topic` /// - Send invite events pub async fn create_room_route( - db: DatabaseGuard, body: Ruma, ) -> Result { use create_room::v3::RoomPreset; let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let room_id = RoomId::new(db.globals.server_name()); + let room_id = RoomId::new(services().globals.server_name()); - db.rooms.get_or_create_shortroomid(&room_id, &db.globals)?; + services().rooms.get_or_create_shortroomid(&room_id)?; let mutex_state = Arc::clone( - db.globals + services().globals .roomid_mutex_state .write() .unwrap() @@ -67,9 +66,9 @@ pub async fn create_room_route( ); let state_lock = mutex_state.lock().await; - if !db.globals.allow_room_creation() + if !services().globals.allow_room_creation() && !body.from_appservice - && !db.users.is_admin(sender_user, &db.rooms, &db.globals)? + && !services().users.is_admin(sender_user)? { return Err(Error::BadRequest( ErrorKind::Forbidden, @@ -83,12 +82,12 @@ pub async fn create_room_route( .map_or(Ok(None), |localpart| { // TODO: Check for invalid characters and maximum length let alias = - RoomAliasId::parse(format!("#{}:{}", localpart, db.globals.server_name())) + RoomAliasId::parse(format!("#{}:{}", localpart, services().globals.server_name())) .map_err(|_| { Error::BadRequest(ErrorKind::InvalidParam, "Invalid alias.") })?; - if db.rooms.id_from_alias(&alias)?.is_some() { + if services().rooms.id_from_alias(&alias)?.is_some() { Err(Error::BadRequest( ErrorKind::RoomInUse, "Room alias already exists.", @@ -100,7 +99,7 @@ pub async fn create_room_route( let room_version = match body.room_version.clone() { Some(room_version) => { - if db.rooms.is_supported_version(&db, &room_version) { + if services().rooms.is_supported_version(&services(), &room_version) { room_version } else { return Err(Error::BadRequest( @@ -109,7 +108,7 @@ pub async fn create_room_route( )); } } - None => db.globals.default_room_version(), + None => services().globals.default_room_version(), }; let content = match &body.creation_content { @@ -163,7 +162,7 @@ pub async fn create_room_route( } // 1. The room create event - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomCreate, content: to_raw_value(&content).expect("event is valid, we just created it"), @@ -173,21 +172,20 @@ pub async fn create_room_route( }, sender_user, &room_id, - &db, &state_lock, )?; // 2. Let the room creator join - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomMember, content: to_raw_value(&RoomMemberEventContent { membership: MembershipState::Join, - displayname: db.users.displayname(sender_user)?, - avatar_url: db.users.avatar_url(sender_user)?, + displayname: services().users.displayname(sender_user)?, + avatar_url: services().users.avatar_url(sender_user)?, is_direct: Some(body.is_direct), third_party_invite: None, - blurhash: db.users.blurhash(sender_user)?, + blurhash: services().users.blurhash(sender_user)?, reason: None, join_authorized_via_users_server: None, }) @@ -198,7 +196,6 @@ pub async fn create_room_route( }, sender_user, &room_id, - &db, &state_lock, )?; @@ -240,7 +237,7 @@ pub async fn create_room_route( } } - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomPowerLevels, content: to_raw_value(&power_levels_content) @@ -251,13 +248,12 @@ pub async fn create_room_route( }, sender_user, &room_id, - &db, &state_lock, )?; // 4. Canonical room alias if let Some(room_alias_id) = &alias { - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomCanonicalAlias, content: to_raw_value(&RoomCanonicalAliasEventContent { @@ -271,7 +267,6 @@ pub async fn create_room_route( }, sender_user, &room_id, - &db, &state_lock, )?; } @@ -279,7 +274,7 @@ pub async fn create_room_route( // 5. Events set by preset // 5.1 Join Rules - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomJoinRules, content: to_raw_value(&RoomJoinRulesEventContent::new(match preset { @@ -294,12 +289,11 @@ pub async fn create_room_route( }, sender_user, &room_id, - &db, &state_lock, )?; // 5.2 History Visibility - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomHistoryVisibility, content: to_raw_value(&RoomHistoryVisibilityEventContent::new( @@ -312,12 +306,11 @@ pub async fn create_room_route( }, sender_user, &room_id, - &db, &state_lock, )?; // 5.3 Guest Access - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomGuestAccess, content: to_raw_value(&RoomGuestAccessEventContent::new(match preset { @@ -331,7 +324,6 @@ pub async fn create_room_route( }, sender_user, &room_id, - &db, &state_lock, )?; @@ -346,18 +338,18 @@ pub async fn create_room_route( pdu_builder.state_key.get_or_insert_with(|| "".to_owned()); // Silently skip encryption events if they are not allowed - if pdu_builder.event_type == RoomEventType::RoomEncryption && !db.globals.allow_encryption() + if pdu_builder.event_type == RoomEventType::RoomEncryption && !services().globals.allow_encryption() { continue; } - db.rooms - .build_and_append_pdu(pdu_builder, sender_user, &room_id, &db, &state_lock)?; + services().rooms + .build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock)?; } // 7. Events implied by name and topic if let Some(name) = &body.name { - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomName, content: to_raw_value(&RoomNameEventContent::new(Some(name.clone()))) @@ -368,13 +360,12 @@ pub async fn create_room_route( }, sender_user, &room_id, - &db, &state_lock, )?; } if let Some(topic) = &body.topic { - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomTopic, content: to_raw_value(&RoomTopicEventContent { @@ -387,7 +378,6 @@ pub async fn create_room_route( }, sender_user, &room_id, - &db, &state_lock, )?; } @@ -395,22 +385,20 @@ pub async fn create_room_route( // 8. Events implied by invite (and TODO: invite_3pid) drop(state_lock); for user_id in &body.invite { - let _ = invite_helper(sender_user, user_id, &room_id, &db, body.is_direct).await; + let _ = invite_helper(sender_user, user_id, &room_id, body.is_direct).await; } // Homeserver specific stuff if let Some(alias) = alias { - db.rooms.set_alias(&alias, Some(&room_id), &db.globals)?; + services().rooms.set_alias(&alias, Some(&room_id))?; } if body.visibility == room::Visibility::Public { - db.rooms.set_public(&room_id, true)?; + services().rooms.set_public(&room_id, true)?; } info!("{} created a room", sender_user); - db.flush()?; - Ok(create_room::v3::Response::new(room_id)) } @@ -420,12 +408,11 @@ pub async fn create_room_route( /// /// - You have to currently be joined to the room (TODO: Respect history visibility) pub async fn get_room_event_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !db.rooms.is_joined(sender_user, &body.room_id)? { + if !services().rooms.is_joined(sender_user, &body.room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "You don't have permission to view this room.", @@ -433,7 +420,7 @@ pub async fn get_room_event_route( } Ok(get_room_event::v3::Response { - event: db + event: services() .rooms .get_pdu(&body.event_id)? .ok_or(Error::BadRequest(ErrorKind::NotFound, "Event not found."))? @@ -447,12 +434,11 @@ pub async fn get_room_event_route( /// /// - Only users joined to the room are allowed to call this TODO: Allow any user to call it if history_visibility is world readable pub async fn get_room_aliases_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !db.rooms.is_joined(sender_user, &body.room_id)? { + if !services().rooms.is_joined(sender_user, &body.room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "You don't have permission to view this room.", @@ -460,7 +446,7 @@ pub async fn get_room_aliases_route( } Ok(aliases::v3::Response { - aliases: db + aliases: services() .rooms .room_aliases(&body.room_id) .filter_map(|a| a.ok()) @@ -479,12 +465,11 @@ pub async fn get_room_aliases_route( /// - Moves local aliases /// - Modifies old room power levels to prevent users from speaking pub async fn upgrade_room_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !db.rooms.is_supported_version(&db, &body.new_version) { + if !services().rooms.is_supported_version(&body.new_version) { return Err(Error::BadRequest( ErrorKind::UnsupportedRoomVersion, "This server does not support that room version.", @@ -492,12 +477,12 @@ pub async fn upgrade_room_route( } // Create a replacement room - let replacement_room = RoomId::new(db.globals.server_name()); - db.rooms - .get_or_create_shortroomid(&replacement_room, &db.globals)?; + let replacement_room = RoomId::new(services().globals.server_name()); + services().rooms + .get_or_create_shortroomid(&replacement_room)?; let mutex_state = Arc::clone( - db.globals + services().globals .roomid_mutex_state .write() .unwrap() @@ -508,7 +493,7 @@ pub async fn upgrade_room_route( // Send a m.room.tombstone event to the old room to indicate that it is not intended to be used any further // Fail if the sender does not have the required permissions - let tombstone_event_id = db.rooms.build_and_append_pdu( + let tombstone_event_id = services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomTombstone, content: to_raw_value(&RoomTombstoneEventContent { @@ -522,14 +507,13 @@ pub async fn upgrade_room_route( }, sender_user, &body.room_id, - &db, &state_lock, )?; // Change lock to replacement room drop(state_lock); let mutex_state = Arc::clone( - db.globals + services().globals .roomid_mutex_state .write() .unwrap() @@ -540,7 +524,7 @@ pub async fn upgrade_room_route( // Get the old room creation event let mut create_event_content = serde_json::from_str::( - db.rooms + services().rooms .room_state_get(&body.room_id, &StateEventType::RoomCreate, "")? .ok_or_else(|| Error::bad_database("Found room without m.room.create event."))? .content @@ -588,7 +572,7 @@ pub async fn upgrade_room_route( )); } - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomCreate, content: to_raw_value(&create_event_content) @@ -599,21 +583,20 @@ pub async fn upgrade_room_route( }, sender_user, &replacement_room, - &db, &state_lock, )?; // Join the new room - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomMember, content: to_raw_value(&RoomMemberEventContent { membership: MembershipState::Join, - displayname: db.users.displayname(sender_user)?, - avatar_url: db.users.avatar_url(sender_user)?, + displayname: services().users.displayname(sender_user)?, + avatar_url: services().users.avatar_url(sender_user)?, is_direct: None, third_party_invite: None, - blurhash: db.users.blurhash(sender_user)?, + blurhash: services().users.blurhash(sender_user)?, reason: None, join_authorized_via_users_server: None, }) @@ -624,7 +607,6 @@ pub async fn upgrade_room_route( }, sender_user, &replacement_room, - &db, &state_lock, )?; @@ -643,12 +625,12 @@ pub async fn upgrade_room_route( // Replicate transferable state events to the new room for event_type in transferable_state_events { - let event_content = match db.rooms.room_state_get(&body.room_id, &event_type, "")? { + let event_content = match services().rooms.room_state_get(&body.room_id, &event_type, "")? { Some(v) => v.content.clone(), None => continue, // Skipping missing events. }; - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: event_type.to_string().into(), content: event_content, @@ -658,20 +640,19 @@ pub async fn upgrade_room_route( }, sender_user, &replacement_room, - &db, &state_lock, )?; } // Moves any local aliases to the new room - for alias in db.rooms.room_aliases(&body.room_id).filter_map(|r| r.ok()) { - db.rooms - .set_alias(&alias, Some(&replacement_room), &db.globals)?; + for alias in services().rooms.room_aliases(&body.room_id).filter_map(|r| r.ok()) { + services().rooms + .set_alias(&alias, Some(&replacement_room))?; } // Get the old room power levels let mut power_levels_event_content: RoomPowerLevelsEventContent = serde_json::from_str( - db.rooms + services().rooms .room_state_get(&body.room_id, &StateEventType::RoomPowerLevels, "")? .ok_or_else(|| Error::bad_database("Found room without m.room.create event."))? .content @@ -685,7 +666,7 @@ pub async fn upgrade_room_route( power_levels_event_content.invite = new_level; // Modify the power levels in the old room to prevent sending of events and inviting new users - let _ = db.rooms.build_and_append_pdu( + let _ = services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomPowerLevels, content: to_raw_value(&power_levels_event_content) @@ -696,35 +677,12 @@ pub async fn upgrade_room_route( }, sender_user, &body.room_id, - &db, &state_lock, )?; drop(state_lock); - db.flush()?; - // Return the replacement room id Ok(upgrade_room::v3::Response { replacement_room }) } - /// Returns the room's version. - #[tracing::instrument(skip(self))] - pub fn get_room_version(&self, room_id: &RoomId) -> Result { - let create_event = self.room_state_get(room_id, &StateEventType::RoomCreate, "")?; - - let create_event_content: Option = create_event - .as_ref() - .map(|create_event| { - serde_json::from_str(create_event.content.get()).map_err(|e| { - warn!("Invalid create event: {}", e); - Error::bad_database("Invalid create event in db.") - }) - }) - .transpose()?; - let room_version = create_event_content - .map(|create_event| create_event.room_version) - .ok_or_else(|| Error::BadDatabase("Invalid room version"))?; - Ok(room_version) - } - diff --git a/src/api/client_server/search.rs b/src/api/client_server/search.rs index 686e3b5e..b7eecd5a 100644 --- a/src/api/client_server/search.rs +++ b/src/api/client_server/search.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Error, Result, Ruma}; +use crate::{Error, Result, Ruma, services}; use ruma::api::client::{ error::ErrorKind, search::search_events::{ @@ -15,7 +15,6 @@ use std::collections::BTreeMap; /// /// - Only works if the user is currently joined to the room (TODO: Respect history visibility) pub async fn search_events_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -24,7 +23,7 @@ pub async fn search_events_route( let filter = &search_criteria.filter; let room_ids = filter.rooms.clone().unwrap_or_else(|| { - db.rooms + services().rooms .rooms_joined(sender_user) .filter_map(|r| r.ok()) .collect() @@ -35,14 +34,14 @@ pub async fn search_events_route( let mut searches = Vec::new(); for room_id in room_ids { - if !db.rooms.is_joined(sender_user, &room_id)? { + if !services().rooms.is_joined(sender_user, &room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "You don't have permission to view this room.", )); } - if let Some(search) = db + if let Some(search) = services() .rooms .search_pdus(&room_id, &search_criteria.search_term)? { @@ -85,7 +84,7 @@ pub async fn search_events_route( start: None, }, rank: None, - result: db + result: services() .rooms .get_pdu_from_id(result)? .map(|pdu| pdu.to_room_event()), diff --git a/src/api/client_server/session.rs b/src/api/client_server/session.rs index c2a79ca6..7feeb66c 100644 --- a/src/api/client_server/session.rs +++ b/src/api/client_server/session.rs @@ -1,5 +1,5 @@ use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH}; -use crate::{database::DatabaseGuard, utils, Error, Result, Ruma}; +use crate::{utils, Error, Result, Ruma, services}; use ruma::{ api::client::{ error::ErrorKind, @@ -41,7 +41,6 @@ pub async fn get_login_types_route( /// Note: You can use [`GET /_matrix/client/r0/login`](fn.get_supported_versions_route.html) to see /// supported login types. pub async fn login_route( - db: DatabaseGuard, body: Ruma, ) -> Result { // Validate login method @@ -57,11 +56,11 @@ pub async fn login_route( return Err(Error::BadRequest(ErrorKind::Forbidden, "Bad login type.")); }; let user_id = - UserId::parse_with_server_name(username.to_owned(), db.globals.server_name()) + UserId::parse_with_server_name(username.to_owned(), services().globals.server_name()) .map_err(|_| { Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.") })?; - let hash = db.users.password_hash(&user_id)?.ok_or(Error::BadRequest( + let hash = services().users.password_hash(&user_id)?.ok_or(Error::BadRequest( ErrorKind::Forbidden, "Wrong username or password.", ))?; @@ -85,7 +84,7 @@ pub async fn login_route( user_id } login::v3::IncomingLoginInfo::Token(login::v3::IncomingToken { token }) => { - if let Some(jwt_decoding_key) = db.globals.jwt_decoding_key() { + if let Some(jwt_decoding_key) = services().globals.jwt_decoding_key() { let token = jsonwebtoken::decode::( token, jwt_decoding_key, @@ -93,7 +92,7 @@ pub async fn login_route( ) .map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Token is invalid."))?; let username = token.claims.sub; - UserId::parse_with_server_name(username, db.globals.server_name()).map_err( + UserId::parse_with_server_name(username, services().globals.server_name()).map_err( |_| Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."), )? } else { @@ -122,15 +121,15 @@ pub async fn login_route( // Determine if device_id was provided and exists in the db for this user let device_exists = body.device_id.as_ref().map_or(false, |device_id| { - db.users + services().users .all_device_ids(&user_id) .any(|x| x.as_ref().map_or(false, |v| v == device_id)) }); if device_exists { - db.users.set_token(&user_id, &device_id, &token)?; + services().users.set_token(&user_id, &device_id, &token)?; } else { - db.users.create_device( + services().users.create_device( &user_id, &device_id, &token, @@ -140,12 +139,10 @@ pub async fn login_route( info!("{} logged in", user_id); - db.flush()?; - Ok(login::v3::Response { user_id, access_token: token, - home_server: Some(db.globals.server_name().to_owned()), + home_server: Some(services().globals.server_name().to_owned()), device_id, well_known: None, }) @@ -160,15 +157,12 @@ pub async fn login_route( /// - Forgets to-device events /// - Triggers device list updates pub async fn logout_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); - db.users.remove_device(sender_user, sender_device)?; - - db.flush()?; + services().users.remove_device(sender_user, sender_device)?; Ok(logout::v3::Response::new()) } @@ -185,16 +179,13 @@ pub async fn logout_route( /// Note: This is equivalent to calling [`GET /_matrix/client/r0/logout`](fn.logout_route.html) /// from each device of this user. pub async fn logout_all_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - for device_id in db.users.all_device_ids(sender_user).flatten() { - db.users.remove_device(sender_user, &device_id)?; + for device_id in services().users.all_device_ids(sender_user).flatten() { + services().users.remove_device(sender_user, &device_id)?; } - db.flush()?; - Ok(logout_all::v3::Response::new()) } diff --git a/src/api/client_server/state.rs b/src/api/client_server/state.rs index 4df953cf..4e8d594e 100644 --- a/src/api/client_server/state.rs +++ b/src/api/client_server/state.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use crate::{ - database::DatabaseGuard, pdu::PduBuilder, Database, Error, Result, Ruma, RumaResponse, + Error, Result, Ruma, RumaResponse, services, service::pdu::PduBuilder, }; use ruma::{ api::client::{ @@ -27,13 +27,11 @@ use ruma::{ /// - Tries to send the event into the room, auth rules will determine if it is allowed /// - If event is new canonical_alias: Rejects if alias is incorrect pub async fn send_state_event_for_key_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let event_id = send_state_event_for_key_helper( - &db, sender_user, &body.room_id, &body.event_type, @@ -42,8 +40,6 @@ pub async fn send_state_event_for_key_route( ) .await?; - db.flush()?; - let event_id = (*event_id).to_owned(); Ok(send_state_event::v3::Response { event_id }) } @@ -56,13 +52,12 @@ pub async fn send_state_event_for_key_route( /// - Tries to send the event into the room, auth rules will determine if it is allowed /// - If event is new canonical_alias: Rejects if alias is incorrect pub async fn send_state_event_for_empty_key_route( - db: DatabaseGuard, body: Ruma, ) -> Result> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); // Forbid m.room.encryption if encryption is disabled - if body.event_type == StateEventType::RoomEncryption && !db.globals.allow_encryption() { + if body.event_type == StateEventType::RoomEncryption && !services().globals.allow_encryption() { return Err(Error::BadRequest( ErrorKind::Forbidden, "Encryption has been disabled", @@ -70,7 +65,6 @@ pub async fn send_state_event_for_empty_key_route( } let event_id = send_state_event_for_key_helper( - &db, sender_user, &body.room_id, &body.event_type.to_string().into(), @@ -79,8 +73,6 @@ pub async fn send_state_event_for_empty_key_route( ) .await?; - db.flush()?; - let event_id = (*event_id).to_owned(); Ok(send_state_event::v3::Response { event_id }.into()) } @@ -91,7 +83,6 @@ pub async fn send_state_event_for_empty_key_route( /// /// - If not joined: Only works if current room history visibility is world readable pub async fn get_state_events_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -99,9 +90,9 @@ pub async fn get_state_events_route( #[allow(clippy::blocks_in_if_conditions)] // Users not in the room should not be able to access the state unless history_visibility is // WorldReadable - if !db.rooms.is_joined(sender_user, &body.room_id)? + if !services().rooms.is_joined(sender_user, &body.room_id)? && !matches!( - db.rooms + services().rooms .room_state_get(&body.room_id, &StateEventType::RoomHistoryVisibility, "")? .map(|event| { serde_json::from_str(event.content.get()) @@ -122,7 +113,7 @@ pub async fn get_state_events_route( } Ok(get_state_events::v3::Response { - room_state: db + room_state: services() .rooms .room_state_full(&body.room_id) .await? @@ -138,7 +129,6 @@ pub async fn get_state_events_route( /// /// - If not joined: Only works if current room history visibility is world readable pub async fn get_state_events_for_key_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -146,9 +136,9 @@ pub async fn get_state_events_for_key_route( #[allow(clippy::blocks_in_if_conditions)] // Users not in the room should not be able to access the state unless history_visibility is // WorldReadable - if !db.rooms.is_joined(sender_user, &body.room_id)? + if !services().rooms.is_joined(sender_user, &body.room_id)? && !matches!( - db.rooms + services().rooms .room_state_get(&body.room_id, &StateEventType::RoomHistoryVisibility, "")? .map(|event| { serde_json::from_str(event.content.get()) @@ -168,7 +158,7 @@ pub async fn get_state_events_for_key_route( )); } - let event = db + let event = services() .rooms .room_state_get(&body.room_id, &body.event_type, &body.state_key)? .ok_or(Error::BadRequest( @@ -188,7 +178,6 @@ pub async fn get_state_events_for_key_route( /// /// - If not joined: Only works if current room history visibility is world readable pub async fn get_state_events_for_empty_key_route( - db: DatabaseGuard, body: Ruma, ) -> Result> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -196,9 +185,9 @@ pub async fn get_state_events_for_empty_key_route( #[allow(clippy::blocks_in_if_conditions)] // Users not in the room should not be able to access the state unless history_visibility is // WorldReadable - if !db.rooms.is_joined(sender_user, &body.room_id)? + if !services().rooms.is_joined(sender_user, &body.room_id)? && !matches!( - db.rooms + services().rooms .room_state_get(&body.room_id, &StateEventType::RoomHistoryVisibility, "")? .map(|event| { serde_json::from_str(event.content.get()) @@ -218,7 +207,7 @@ pub async fn get_state_events_for_empty_key_route( )); } - let event = db + let event = services() .rooms .room_state_get(&body.room_id, &body.event_type, "")? .ok_or(Error::BadRequest( @@ -234,7 +223,6 @@ pub async fn get_state_events_for_empty_key_route( } async fn send_state_event_for_key_helper( - db: &Database, sender: &UserId, room_id: &RoomId, event_type: &StateEventType, @@ -255,8 +243,8 @@ async fn send_state_event_for_key_helper( } for alias in aliases { - if alias.server_name() != db.globals.server_name() - || db + if alias.server_name() != services().globals.server_name() + || services() .rooms .id_from_alias(&alias)? .filter(|room| room == room_id) // Make sure it's the right room @@ -272,7 +260,7 @@ async fn send_state_event_for_key_helper( } let mutex_state = Arc::clone( - db.globals + services().globals .roomid_mutex_state .write() .unwrap() @@ -281,7 +269,7 @@ async fn send_state_event_for_key_helper( ); let state_lock = mutex_state.lock().await; - let event_id = db.rooms.build_and_append_pdu( + let event_id = services().rooms.build_and_append_pdu( PduBuilder { event_type: event_type.to_string().into(), content: serde_json::from_str(json.json().get()).expect("content is valid json"), @@ -291,7 +279,6 @@ async fn send_state_event_for_key_helper( }, sender_user, room_id, - db, &state_lock, )?; diff --git a/src/api/client_server/sync.rs b/src/api/client_server/sync.rs index 0c294b7e..cc4ebf6e 100644 --- a/src/api/client_server/sync.rs +++ b/src/api/client_server/sync.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Database, Error, Result, Ruma, RumaResponse}; +use crate::{Error, Result, Ruma, RumaResponse, services}; use ruma::{ api::client::{ filter::{IncomingFilterDefinition, LazyLoadOptions}, @@ -55,16 +55,13 @@ use tracing::error; /// - Sync is handled in an async task, multiple requests from the same device with the same /// `since` will be cached pub async fn sync_events_route( - db: DatabaseGuard, body: Ruma, ) -> Result> { let sender_user = body.sender_user.expect("user is authenticated"); let sender_device = body.sender_device.expect("user is authenticated"); let body = body.body; - let arc_db = Arc::new(db); - - let mut rx = match arc_db + let mut rx = match services() .globals .sync_receivers .write() @@ -77,7 +74,6 @@ pub async fn sync_events_route( v.insert((body.since.to_owned(), rx.clone())); tokio::spawn(sync_helper_wrapper( - Arc::clone(&arc_db), sender_user.clone(), sender_device.clone(), body, @@ -93,7 +89,6 @@ pub async fn sync_events_route( o.insert((body.since.clone(), rx.clone())); tokio::spawn(sync_helper_wrapper( - Arc::clone(&arc_db), sender_user.clone(), sender_device.clone(), body, @@ -127,7 +122,6 @@ pub async fn sync_events_route( } async fn sync_helper_wrapper( - db: Arc, sender_user: Box, sender_device: Box, body: sync_events::v3::IncomingRequest, @@ -136,7 +130,6 @@ async fn sync_helper_wrapper( let since = body.since.clone(); let r = sync_helper( - Arc::clone(&db), sender_user.clone(), sender_device.clone(), body, @@ -145,7 +138,7 @@ async fn sync_helper_wrapper( if let Ok((_, caching_allowed)) = r { if !caching_allowed { - match db + match services() .globals .sync_receivers .write() @@ -163,13 +156,10 @@ async fn sync_helper_wrapper( } } - drop(db); - let _ = tx.send(Some(r.map(|(r, _)| r))); } async fn sync_helper( - db: Arc, sender_user: Box, sender_device: Box, body: sync_events::v3::IncomingRequest, @@ -182,19 +172,19 @@ async fn sync_helper( }; // TODO: match body.set_presence { - db.rooms.edus.ping_presence(&sender_user)?; + services().rooms.edus.ping_presence(&sender_user)?; // Setup watchers, so if there's no response, we can wait for them - let watcher = db.watch(&sender_user, &sender_device); + let watcher = services().watch(&sender_user, &sender_device); - let next_batch = db.globals.current_count()?; + let next_batch = services().globals.current_count()?; let next_batch_string = next_batch.to_string(); // Load filter let filter = match body.filter { None => IncomingFilterDefinition::default(), Some(IncomingFilter::FilterDefinition(filter)) => filter, - Some(IncomingFilter::FilterId(filter_id)) => db + Some(IncomingFilter::FilterId(filter_id)) => services() .users .get_filter(&sender_user, &filter_id)? .unwrap_or_default(), @@ -221,12 +211,12 @@ async fn sync_helper( // Look for device list updates of this account device_list_updates.extend( - db.users + services().users .keys_changed(&sender_user.to_string(), since, None) .filter_map(|r| r.ok()), ); - let all_joined_rooms = db.rooms.rooms_joined(&sender_user).collect::>(); + let all_joined_rooms = services().rooms.rooms_joined(&sender_user).collect::>(); for room_id in all_joined_rooms { let room_id = room_id?; @@ -234,7 +224,7 @@ async fn sync_helper( // Get and drop the lock to wait for remaining operations to finish // This will make sure the we have all events until next_batch let mutex_insert = Arc::clone( - db.globals + services().globals .roomid_mutex_insert .write() .unwrap() @@ -247,8 +237,8 @@ async fn sync_helper( let timeline_pdus; let limited; - if db.rooms.last_timeline_count(&sender_user, &room_id)? > since { - let mut non_timeline_pdus = db + if services().rooms.last_timeline_count(&sender_user, &room_id)? > since { + let mut non_timeline_pdus = services() .rooms .pdus_until(&sender_user, &room_id, u64::MAX)? .filter_map(|r| { @@ -259,7 +249,7 @@ async fn sync_helper( r.ok() }) .take_while(|(pduid, _)| { - db.rooms + services().rooms .pdu_count(pduid) .map_or(false, |count| count > since) }); @@ -282,7 +272,7 @@ async fn sync_helper( } let send_notification_counts = !timeline_pdus.is_empty() - || db + || services() .rooms .edus .last_privateread_update(&sender_user, &room_id)? @@ -293,24 +283,24 @@ async fn sync_helper( timeline_users.insert(event.sender.as_str().to_owned()); } - db.rooms + services().rooms .lazy_load_confirm_delivery(&sender_user, &sender_device, &room_id, since)?; // Database queries: - let current_shortstatehash = if let Some(s) = db.rooms.current_shortstatehash(&room_id)? { + let current_shortstatehash = if let Some(s) = services().rooms.current_shortstatehash(&room_id)? { s } else { error!("Room {} has no state", room_id); continue; }; - let since_shortstatehash = db.rooms.get_token_shortstatehash(&room_id, since)?; + let since_shortstatehash = services().rooms.get_token_shortstatehash(&room_id, since)?; // Calculates joined_member_count, invited_member_count and heroes let calculate_counts = || { - let joined_member_count = db.rooms.room_joined_count(&room_id)?.unwrap_or(0); - let invited_member_count = db.rooms.room_invited_count(&room_id)?.unwrap_or(0); + let joined_member_count = services().rooms.room_joined_count(&room_id)?.unwrap_or(0); + let invited_member_count = services().rooms.room_invited_count(&room_id)?.unwrap_or(0); // Recalculate heroes (first 5 members) let mut heroes = Vec::new(); @@ -319,7 +309,7 @@ async fn sync_helper( // Go through all PDUs and for each member event, check if the user is still joined or // invited until we have 5 or we reach the end - for hero in db + for hero in services() .rooms .all_pdus(&sender_user, &room_id)? .filter_map(|pdu| pdu.ok()) // Ignore all broken pdus @@ -339,8 +329,8 @@ async fn sync_helper( if matches!( content.membership, MembershipState::Join | MembershipState::Invite - ) && (db.rooms.is_joined(&user_id, &room_id)? - || db.rooms.is_invited(&user_id, &room_id)?) + ) && (services().rooms.is_joined(&user_id, &room_id)? + || services().rooms.is_invited(&user_id, &room_id)?) { Ok::<_, Error>(Some(state_key.clone())) } else { @@ -381,17 +371,17 @@ async fn sync_helper( let (joined_member_count, invited_member_count, heroes) = calculate_counts()?; - let current_state_ids = db.rooms.state_full_ids(current_shortstatehash).await?; + let current_state_ids = services().rooms.state_full_ids(current_shortstatehash).await?; let mut state_events = Vec::new(); let mut lazy_loaded = HashSet::new(); let mut i = 0; for (shortstatekey, id) in current_state_ids { - let (event_type, state_key) = db.rooms.get_statekey_from_short(shortstatekey)?; + let (event_type, state_key) = services().rooms.get_statekey_from_short(shortstatekey)?; if event_type != StateEventType::RoomMember { - let pdu = match db.rooms.get_pdu(&id)? { + let pdu = match services().rooms.get_pdu(&id)? { Some(pdu) => pdu, None => { error!("Pdu in state not found: {}", id); @@ -408,7 +398,7 @@ async fn sync_helper( || body.full_state || timeline_users.contains(&state_key) { - let pdu = match db.rooms.get_pdu(&id)? { + let pdu = match services().rooms.get_pdu(&id)? { Some(pdu) => pdu, None => { error!("Pdu in state not found: {}", id); @@ -430,12 +420,12 @@ async fn sync_helper( } // Reset lazy loading because this is an initial sync - db.rooms + services().rooms .lazy_load_reset(&sender_user, &sender_device, &room_id)?; // The state_events above should contain all timeline_users, let's mark them as lazy // loaded. - db.rooms.lazy_load_mark_sent( + services().rooms.lazy_load_mark_sent( &sender_user, &sender_device, &room_id, @@ -457,7 +447,7 @@ async fn sync_helper( // Incremental /sync let since_shortstatehash = since_shortstatehash.unwrap(); - let since_sender_member: Option = db + let since_sender_member: Option = services() .rooms .state_get( since_shortstatehash, @@ -477,12 +467,12 @@ async fn sync_helper( let mut lazy_loaded = HashSet::new(); if since_shortstatehash != current_shortstatehash { - let current_state_ids = db.rooms.state_full_ids(current_shortstatehash).await?; - let since_state_ids = db.rooms.state_full_ids(since_shortstatehash).await?; + let current_state_ids = services().rooms.state_full_ids(current_shortstatehash).await?; + let since_state_ids = services().rooms.state_full_ids(since_shortstatehash).await?; for (key, id) in current_state_ids { if body.full_state || since_state_ids.get(&key) != Some(&id) { - let pdu = match db.rooms.get_pdu(&id)? { + let pdu = match services().rooms.get_pdu(&id)? { Some(pdu) => pdu, None => { error!("Pdu in state not found: {}", id); @@ -515,14 +505,14 @@ async fn sync_helper( continue; } - if !db.rooms.lazy_load_was_sent_before( + if !services().rooms.lazy_load_was_sent_before( &sender_user, &sender_device, &room_id, &event.sender, )? || lazy_load_send_redundant { - if let Some(member_event) = db.rooms.room_state_get( + if let Some(member_event) = services().rooms.room_state_get( &room_id, &StateEventType::RoomMember, event.sender.as_str(), @@ -533,7 +523,7 @@ async fn sync_helper( } } - db.rooms.lazy_load_mark_sent( + services().rooms.lazy_load_mark_sent( &sender_user, &sender_device, &room_id, @@ -541,13 +531,13 @@ async fn sync_helper( next_batch, ); - let encrypted_room = db + let encrypted_room = services() .rooms .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "")? .is_some(); let since_encryption = - db.rooms + services().rooms .state_get(since_shortstatehash, &StateEventType::RoomEncryption, "")?; // Calculations: @@ -580,7 +570,7 @@ async fn sync_helper( match new_membership { MembershipState::Join => { // A new user joined an encrypted room - if !share_encrypted_room(&db, &sender_user, &user_id, &room_id)? { + if !share_encrypted_room(&sender_user, &user_id, &room_id)? { device_list_updates.insert(user_id); } } @@ -597,7 +587,7 @@ async fn sync_helper( if joined_since_last_sync && encrypted_room || new_encrypted_room { // If the user is in a new encrypted room, give them all joined users device_list_updates.extend( - db.rooms + services().rooms .room_members(&room_id) .flatten() .filter(|user_id| { @@ -606,7 +596,7 @@ async fn sync_helper( }) .filter(|user_id| { // Only send keys if the sender doesn't share an encrypted room with the target already - !share_encrypted_room(&db, &sender_user, user_id, &room_id) + !share_encrypted_room(&sender_user, user_id, &room_id) .unwrap_or(false) }), ); @@ -629,14 +619,14 @@ async fn sync_helper( // Look for device list updates in this room device_list_updates.extend( - db.users + services().users .keys_changed(&room_id.to_string(), since, None) .filter_map(|r| r.ok()), ); let notification_count = if send_notification_counts { Some( - db.rooms + services().rooms .notification_count(&sender_user, &room_id)? .try_into() .expect("notification count can't go that high"), @@ -647,7 +637,7 @@ async fn sync_helper( let highlight_count = if send_notification_counts { Some( - db.rooms + services().rooms .highlight_count(&sender_user, &room_id)? .try_into() .expect("highlight count can't go that high"), @@ -659,7 +649,7 @@ async fn sync_helper( let prev_batch = timeline_pdus .first() .map_or(Ok::<_, Error>(None), |(pdu_id, _)| { - Ok(Some(db.rooms.pdu_count(pdu_id)?.to_string())) + Ok(Some(services().rooms.pdu_count(pdu_id)?.to_string())) })?; let room_events: Vec<_> = timeline_pdus @@ -667,7 +657,7 @@ async fn sync_helper( .map(|(_, pdu)| pdu.to_sync_room_event()) .collect(); - let mut edus: Vec<_> = db + let mut edus: Vec<_> = services() .rooms .edus .readreceipts_since(&room_id, since) @@ -675,10 +665,10 @@ async fn sync_helper( .map(|(_, _, v)| v) .collect(); - if db.rooms.edus.last_typing_update(&room_id, &db.globals)? > since { + if services().rooms.edus.last_typing_update(&room_id, &services().globals)? > since { edus.push( serde_json::from_str( - &serde_json::to_string(&db.rooms.edus.typings_all(&room_id)?) + &serde_json::to_string(&services().rooms.edus.typings_all(&room_id)?) .expect("event is valid, we just created it"), ) .expect("event is valid, we just created it"), @@ -686,12 +676,12 @@ async fn sync_helper( } // Save the state after this sync so we can send the correct state diff next sync - db.rooms + services().rooms .associate_token_shortstatehash(&room_id, next_batch, current_shortstatehash)?; let joined_room = JoinedRoom { account_data: RoomAccountData { - events: db + events: services() .account_data .changes_since(Some(&room_id), &sender_user, since)? .into_iter() @@ -731,9 +721,9 @@ async fn sync_helper( // Take presence updates from this room for (user_id, presence) in - db.rooms + services().rooms .edus - .presence_since(&room_id, since, &db.rooms, &db.globals)? + .presence_since(&room_id, since)? { match presence_updates.entry(user_id) { Entry::Vacant(v) => { @@ -765,14 +755,14 @@ async fn sync_helper( } let mut left_rooms = BTreeMap::new(); - let all_left_rooms: Vec<_> = db.rooms.rooms_left(&sender_user).collect(); + let all_left_rooms: Vec<_> = services().rooms.rooms_left(&sender_user).collect(); for result in all_left_rooms { let (room_id, left_state_events) = result?; { // Get and drop the lock to wait for remaining operations to finish let mutex_insert = Arc::clone( - db.globals + services().globals .roomid_mutex_insert .write() .unwrap() @@ -783,7 +773,7 @@ async fn sync_helper( drop(insert_lock); } - let left_count = db.rooms.get_left_count(&room_id, &sender_user)?; + let left_count = services().rooms.get_left_count(&room_id, &sender_user)?; // Left before last sync if Some(since) >= left_count { @@ -807,14 +797,14 @@ async fn sync_helper( } let mut invited_rooms = BTreeMap::new(); - let all_invited_rooms: Vec<_> = db.rooms.rooms_invited(&sender_user).collect(); + let all_invited_rooms: Vec<_> = services().rooms.rooms_invited(&sender_user).collect(); for result in all_invited_rooms { let (room_id, invite_state_events) = result?; { // Get and drop the lock to wait for remaining operations to finish let mutex_insert = Arc::clone( - db.globals + services().globals .roomid_mutex_insert .write() .unwrap() @@ -825,7 +815,7 @@ async fn sync_helper( drop(insert_lock); } - let invite_count = db.rooms.get_invite_count(&room_id, &sender_user)?; + let invite_count = services().rooms.get_invite_count(&room_id, &sender_user)?; // Invited before last sync if Some(since) >= invite_count { @@ -843,13 +833,13 @@ async fn sync_helper( } for user_id in left_encrypted_users { - let still_share_encrypted_room = db + let still_share_encrypted_room = services() .rooms .get_shared_rooms(vec![sender_user.clone(), user_id.clone()])? .filter_map(|r| r.ok()) .filter_map(|other_room_id| { Some( - db.rooms + services().rooms .room_state_get(&other_room_id, &StateEventType::RoomEncryption, "") .ok()? .is_some(), @@ -864,7 +854,7 @@ async fn sync_helper( } // Remove all to-device events the device received *last time* - db.users + services().users .remove_to_device_events(&sender_user, &sender_device, since)?; let response = sync_events::v3::Response { @@ -882,7 +872,7 @@ async fn sync_helper( .collect(), }, account_data: GlobalAccountData { - events: db + events: services() .account_data .changes_since(None, &sender_user, since)? .into_iter() @@ -897,9 +887,9 @@ async fn sync_helper( changed: device_list_updates.into_iter().collect(), left: device_list_left.into_iter().collect(), }, - device_one_time_keys_count: db.users.count_one_time_keys(&sender_user, &sender_device)?, + device_one_time_keys_count: services().users.count_one_time_keys(&sender_user, &sender_device)?, to_device: ToDevice { - events: db + events: services() .users .get_to_device_events(&sender_user, &sender_device)?, }, @@ -928,21 +918,19 @@ async fn sync_helper( } } -#[tracing::instrument(skip(db))] fn share_encrypted_room( - db: &Database, sender_user: &UserId, user_id: &UserId, ignore_room: &RoomId, ) -> Result { - Ok(db + Ok(services() .rooms .get_shared_rooms(vec![sender_user.to_owned(), user_id.to_owned()])? .filter_map(|r| r.ok()) .filter(|room_id| room_id != ignore_room) .filter_map(|other_room_id| { Some( - db.rooms + services().rooms .room_state_get(&other_room_id, &StateEventType::RoomEncryption, "") .ok()? .is_some(), diff --git a/src/api/client_server/tag.rs b/src/api/client_server/tag.rs index 98d895cd..bbea2d58 100644 --- a/src/api/client_server/tag.rs +++ b/src/api/client_server/tag.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Result, Ruma}; +use crate::{Result, Ruma, services}; use ruma::{ api::client::tag::{create_tag, delete_tag, get_tags}, events::{ @@ -14,12 +14,11 @@ use std::collections::BTreeMap; /// /// - Inserts the tag into the tag event of the room account data. pub async fn update_tag_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let mut tags_event = db + let mut tags_event = services() .account_data .get( Some(&body.room_id), @@ -36,16 +35,13 @@ pub async fn update_tag_route( .tags .insert(body.tag.clone().into(), body.tag_info.clone()); - db.account_data.update( + services().account_data.update( Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag, &tags_event, - &db.globals, )?; - db.flush()?; - Ok(create_tag::v3::Response {}) } @@ -55,12 +51,11 @@ pub async fn update_tag_route( /// /// - Removes the tag from the tag event of the room account data. pub async fn delete_tag_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let mut tags_event = db + let mut tags_event = services() .account_data .get( Some(&body.room_id), @@ -74,16 +69,13 @@ pub async fn delete_tag_route( }); tags_event.content.tags.remove(&body.tag.clone().into()); - db.account_data.update( + services().account_data.update( Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag, &tags_event, - &db.globals, )?; - db.flush()?; - Ok(delete_tag::v3::Response {}) } @@ -93,13 +85,12 @@ pub async fn delete_tag_route( /// /// - Gets the tag event of the room account data. pub async fn get_tags_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); Ok(get_tags::v3::Response { - tags: db + tags: services() .account_data .get( Some(&body.room_id), diff --git a/src/api/client_server/to_device.rs b/src/api/client_server/to_device.rs index 51441dd4..3a2f6c09 100644 --- a/src/api/client_server/to_device.rs +++ b/src/api/client_server/to_device.rs @@ -1,7 +1,7 @@ use ruma::events::ToDeviceEventType; use std::collections::BTreeMap; -use crate::{database::DatabaseGuard, Error, Result, Ruma}; +use crate::{Error, Result, Ruma, services}; use ruma::{ api::{ client::{error::ErrorKind, to_device::send_event_to_device}, @@ -14,14 +14,13 @@ use ruma::{ /// /// Send a to-device event to a set of client devices. pub async fn send_event_to_device_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_deref(); // Check if this is a new transaction id - if db + if services() .transaction_ids .existing_txnid(sender_user, sender_device, &body.txn_id)? .is_some() @@ -31,13 +30,13 @@ pub async fn send_event_to_device_route( for (target_user_id, map) in &body.messages { for (target_device_id_maybe, event) in map { - if target_user_id.server_name() != db.globals.server_name() { + if target_user_id.server_name() != services().globals.server_name() { let mut map = BTreeMap::new(); map.insert(target_device_id_maybe.clone(), event.clone()); let mut messages = BTreeMap::new(); messages.insert(target_user_id.clone(), map); - db.sending.send_reliable_edu( + services().sending.send_reliable_edu( target_user_id.server_name(), serde_json::to_vec(&federation::transactions::edu::Edu::DirectToDevice( DirectDeviceContent { @@ -48,14 +47,14 @@ pub async fn send_event_to_device_route( }, )) .expect("DirectToDevice EDU can be serialized"), - db.globals.next_count()?, + services().globals.next_count()?, )?; continue; } match target_device_id_maybe { - DeviceIdOrAllDevices::DeviceId(target_device_id) => db.users.add_to_device_event( + DeviceIdOrAllDevices::DeviceId(target_device_id) => services().users.add_to_device_event( sender_user, target_user_id, &target_device_id, @@ -63,12 +62,11 @@ pub async fn send_event_to_device_route( event.deserialize_as().map_err(|_| { Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid") })?, - &db.globals, )?, DeviceIdOrAllDevices::AllDevices => { - for target_device_id in db.users.all_device_ids(target_user_id) { - db.users.add_to_device_event( + for target_device_id in services().users.all_device_ids(target_user_id) { + services().users.add_to_device_event( sender_user, target_user_id, &target_device_id?, @@ -76,7 +74,6 @@ pub async fn send_event_to_device_route( event.deserialize_as().map_err(|_| { Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid") })?, - &db.globals, )?; } } @@ -85,10 +82,8 @@ pub async fn send_event_to_device_route( } // Save transaction id with empty data - db.transaction_ids + services().transaction_ids .add_txnid(sender_user, sender_device, &body.txn_id, &[])?; - db.flush()?; - Ok(send_event_to_device::v3::Response {}) } diff --git a/src/api/client_server/typing.rs b/src/api/client_server/typing.rs index cac5a5fd..afd5d6b3 100644 --- a/src/api/client_server/typing.rs +++ b/src/api/client_server/typing.rs @@ -1,18 +1,17 @@ -use crate::{database::DatabaseGuard, utils, Error, Result, Ruma}; +use crate::{utils, Error, Result, Ruma, services}; use ruma::api::client::{error::ErrorKind, typing::create_typing_event}; /// # `PUT /_matrix/client/r0/rooms/{roomId}/typing/{userId}` /// /// Sets the typing state of the sender user. pub async fn create_typing_event_route( - db: DatabaseGuard, body: Ruma, ) -> Result { use create_typing_event::v3::Typing; let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !db.rooms.is_joined(sender_user, &body.room_id)? { + if !services().rooms.is_joined(sender_user, &body.room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "You are not in this room.", @@ -20,16 +19,15 @@ pub async fn create_typing_event_route( } if let Typing::Yes(duration) = body.state { - db.rooms.edus.typing_add( + services().rooms.edus.typing_add( sender_user, &body.room_id, duration.as_millis() as u64 + utils::millis_since_unix_epoch(), - &db.globals, )?; } else { - db.rooms + services().rooms .edus - .typing_remove(sender_user, &body.room_id, &db.globals)?; + .typing_remove(sender_user, &body.room_id)?; } Ok(create_typing_event::v3::Response {}) diff --git a/src/api/client_server/user_directory.rs b/src/api/client_server/user_directory.rs index 349c1399..60b4e2fa 100644 --- a/src/api/client_server/user_directory.rs +++ b/src/api/client_server/user_directory.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Result, Ruma}; +use crate::{Result, Ruma, services}; use ruma::{ api::client::user_directory::search_users, events::{ @@ -14,20 +14,19 @@ use ruma::{ /// - Hides any local users that aren't in any public rooms (i.e. those that have the join rule set to public) /// and don't share a room with the sender pub async fn search_users_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let limit = u64::from(body.limit) as usize; - let mut users = db.users.iter().filter_map(|user_id| { + let mut users = services().users.iter().filter_map(|user_id| { // Filter out buggy users (they should not exist, but you never know...) let user_id = user_id.ok()?; let user = search_users::v3::User { user_id: user_id.clone(), - display_name: db.users.displayname(&user_id).ok()?, - avatar_url: db.users.avatar_url(&user_id).ok()?, + display_name: services().users.displayname(&user_id).ok()?, + avatar_url: services().users.avatar_url(&user_id).ok()?, }; let user_id_matches = user @@ -50,11 +49,11 @@ pub async fn search_users_route( } let user_is_in_public_rooms = - db.rooms + services().rooms .rooms_joined(&user_id) .filter_map(|r| r.ok()) .any(|room| { - db.rooms + services().rooms .room_state_get(&room, &StateEventType::RoomJoinRules, "") .map_or(false, |event| { event.map_or(false, |event| { @@ -70,7 +69,7 @@ pub async fn search_users_route( return Some(user); } - let user_is_in_shared_rooms = db + let user_is_in_shared_rooms = services() .rooms .get_shared_rooms(vec![sender_user.clone(), user_id.clone()]) .ok()? diff --git a/src/api/client_server/voip.rs b/src/api/client_server/voip.rs index 7e9de31e..2a804f97 100644 --- a/src/api/client_server/voip.rs +++ b/src/api/client_server/voip.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Result, Ruma}; +use crate::{Result, Ruma, services}; use hmac::{Hmac, Mac, NewMac}; use ruma::{api::client::voip::get_turn_server_info, SecondsSinceUnixEpoch}; use sha1::Sha1; @@ -10,16 +10,15 @@ type HmacSha1 = Hmac; /// /// TODO: Returns information about the recommended turn server. pub async fn turn_server_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let turn_secret = db.globals.turn_secret(); + let turn_secret = services().globals.turn_secret(); let (username, password) = if !turn_secret.is_empty() { let expiry = SecondsSinceUnixEpoch::from_system_time( - SystemTime::now() + Duration::from_secs(db.globals.turn_ttl()), + SystemTime::now() + Duration::from_secs(services().globals.turn_ttl()), ) .expect("time is valid"); @@ -34,15 +33,15 @@ pub async fn turn_server_route( (username, password) } else { ( - db.globals.turn_username().clone(), - db.globals.turn_password().clone(), + services().globals.turn_username().clone(), + services().globals.turn_password().clone(), ) }; Ok(get_turn_server_info::v3::Response { username, password, - uris: db.globals.turn_uris().to_vec(), - ttl: Duration::from_secs(db.globals.turn_ttl()), + uris: services().globals.turn_uris().to_vec(), + ttl: Duration::from_secs(services().globals.turn_ttl()), }) } diff --git a/src/api/mod.rs b/src/api/mod.rs new file mode 100644 index 00000000..68589be7 --- /dev/null +++ b/src/api/mod.rs @@ -0,0 +1,4 @@ +pub mod client_server; +pub mod server_server; +pub mod appservice_server; +pub mod ruma_wrapper; diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs index 45e9d9a8..babf2a74 100644 --- a/src/api/ruma_wrapper/axum.rs +++ b/src/api/ruma_wrapper/axum.rs @@ -24,7 +24,7 @@ use serde::Deserialize; use tracing::{debug, error, warn}; use super::{Ruma, RumaResponse}; -use crate::{database::DatabaseGuard, server_server, Error, Result}; +use crate::{Error, Result, api::server_server, services}; #[async_trait] impl FromRequest for Ruma @@ -44,7 +44,6 @@ where } let metadata = T::METADATA; - let db = DatabaseGuard::from_request(req).await?; let auth_header = Option::>>::from_request(req).await?; let path_params = Path::>::from_request(req).await?; @@ -71,7 +70,7 @@ where let mut json_body = serde_json::from_slice::(&body).ok(); - let appservices = db.appservice.all().unwrap(); + let appservices = services().appservice.all().unwrap(); let appservice_registration = appservices.iter().find(|(_id, registration)| { registration .get("as_token") @@ -91,14 +90,14 @@ where .unwrap() .as_str() .unwrap(), - db.globals.server_name(), + services().globals.server_name(), ) .unwrap() }, |s| UserId::parse(s).unwrap(), ); - if !db.users.exists(&user_id).unwrap() { + if !services().users.exists(&user_id).unwrap() { return Err(Error::BadRequest( ErrorKind::Forbidden, "User does not exist.", @@ -124,7 +123,7 @@ where } }; - match db.users.find_from_token(token).unwrap() { + match services().users.find_from_token(token).unwrap() { None => { return Err(Error::BadRequest( ErrorKind::UnknownToken { soft_logout: false }, @@ -185,7 +184,7 @@ where ( "destination".to_owned(), CanonicalJsonValue::String( - db.globals.server_name().as_str().to_owned(), + services().globals.server_name().as_str().to_owned(), ), ), ( @@ -199,7 +198,6 @@ where }; let keys_result = server_server::fetch_signing_keys( - &db, &x_matrix.origin, vec![x_matrix.key.to_owned()], ) @@ -251,7 +249,7 @@ where if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body { let user_id = sender_user.clone().unwrap_or_else(|| { - UserId::parse_with_server_name("", db.globals.server_name()) + UserId::parse_with_server_name("", services().globals.server_name()) .expect("we know this is valid") }); @@ -261,7 +259,7 @@ where .and_then(|auth| auth.get("session")) .and_then(|session| session.as_str()) .and_then(|session| { - db.uiaa.get_uiaa_request( + services().uiaa.get_uiaa_request( &user_id, &sender_device.clone().unwrap_or_else(|| "".into()), session, diff --git a/src/api/server_server.rs b/src/api/server_server.rs index f60f735a..776777d1 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -1,8 +1,6 @@ use crate::{ - client_server::{self, claim_keys_helper, get_keys_helper}, - database::{rooms::CompressedStateEvent, DatabaseGuard}, - pdu::EventHash, - utils, Database, Error, PduEvent, Result, Ruma, + api::client_server::{self, claim_keys_helper, get_keys_helper}, + utils, Error, PduEvent, Result, Ruma, services, service::pdu::{gen_event_id_canonical_json, PduBuilder}, }; use axum::{response::IntoResponse, Json}; use futures_util::{stream::FuturesUnordered, StreamExt}; @@ -126,22 +124,21 @@ impl FedDest { } } -#[tracing::instrument(skip(globals, request))] +#[tracing::instrument(skip(request))] pub(crate) async fn send_request( - globals: &crate::database::globals::Globals, destination: &ServerName, request: T, ) -> Result where T: Debug, { - if !globals.allow_federation() { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } let mut write_destination_to_cache = false; - let cached_result = globals + let cached_result = services().globals .actual_destination_cache .read() .unwrap() @@ -153,7 +150,7 @@ where } else { write_destination_to_cache = true; - let result = find_actual_destination(globals, destination).await; + let result = find_actual_destination(destination).await; (result.0, result.1.into_uri_string()) }; @@ -194,15 +191,15 @@ where .to_string() .into(), ); - request_map.insert("origin".to_owned(), globals.server_name().as_str().into()); + request_map.insert("origin".to_owned(), services().globals.server_name().as_str().into()); request_map.insert("destination".to_owned(), destination.as_str().into()); let mut request_json = serde_json::from_value(request_map.into()).expect("valid JSON is valid BTreeMap"); ruma::signatures::sign_json( - globals.server_name().as_str(), - globals.keypair(), + services().globals.server_name().as_str(), + services().globals.keypair(), &mut request_json, ) .expect("our request json is what ruma expects"); @@ -227,7 +224,7 @@ where AUTHORIZATION, HeaderValue::from_str(&format!( "X-Matrix origin={},key=\"{}\",sig=\"{}\"", - globals.server_name(), + services().globals.server_name(), s.0, s.1 )) @@ -241,7 +238,7 @@ where let url = reqwest_request.url().clone(); - let response = globals.federation_client().execute(reqwest_request).await; + let response = services().globals.federation_client().execute(reqwest_request).await; match response { Ok(mut response) => { @@ -281,7 +278,7 @@ where if status == 200 { let response = T::IncomingResponse::try_from_http_response(http_response); if response.is_ok() && write_destination_to_cache { - globals.actual_destination_cache.write().unwrap().insert( + services().globals.actual_destination_cache.write().unwrap().insert( Box::::from(destination), (actual_destination, host), ); @@ -332,9 +329,7 @@ fn add_port_to_hostname(destination_str: &str) -> FedDest { /// Returns: actual_destination, host header /// Implemented according to the specification at https://matrix.org/docs/spec/server_server/r0.1.4#resolving-server-names /// Numbers in comments below refer to bullet points in linked section of specification -#[tracing::instrument(skip(globals))] async fn find_actual_destination( - globals: &crate::database::globals::Globals, destination: &'_ ServerName, ) -> (FedDest, FedDest) { let destination_str = destination.as_str().to_owned(); @@ -350,7 +345,7 @@ async fn find_actual_destination( let (host, port) = destination_str.split_at(pos); FedDest::Named(host.to_owned(), port.to_owned()) } else { - match request_well_known(globals, destination.as_str()).await { + match request_well_known(destination.as_str()).await { // 3: A .well-known file is available Some(delegated_hostname) => { hostname = add_port_to_hostname(&delegated_hostname).into_uri_string(); @@ -364,17 +359,17 @@ async fn find_actual_destination( } else { // Delegated hostname has no port in this branch if let Some(hostname_override) = - query_srv_record(globals, &delegated_hostname).await + query_srv_record(&delegated_hostname).await { // 3.3: SRV lookup successful let force_port = hostname_override.port(); - if let Ok(override_ip) = globals + if let Ok(override_ip) = services().globals .dns_resolver() .lookup_ip(hostname_override.hostname()) .await { - globals.tls_name_override.write().unwrap().insert( + services().globals.tls_name_override.write().unwrap().insert( delegated_hostname.clone(), ( override_ip.iter().collect(), @@ -400,17 +395,17 @@ async fn find_actual_destination( } // 4: No .well-known or an error occured None => { - match query_srv_record(globals, &destination_str).await { + match query_srv_record(&destination_str).await { // 4: SRV record found Some(hostname_override) => { let force_port = hostname_override.port(); - if let Ok(override_ip) = globals + if let Ok(override_ip) = services().globals .dns_resolver() .lookup_ip(hostname_override.hostname()) .await { - globals.tls_name_override.write().unwrap().insert( + services().globals.tls_name_override.write().unwrap().insert( hostname.clone(), (override_ip.iter().collect(), force_port.unwrap_or(8448)), ); @@ -448,12 +443,10 @@ async fn find_actual_destination( (actual_destination, hostname) } -#[tracing::instrument(skip(globals))] async fn query_srv_record( - globals: &crate::database::globals::Globals, hostname: &'_ str, ) -> Option { - if let Ok(Some(host_port)) = globals + if let Ok(Some(host_port)) = services().globals .dns_resolver() .srv_lookup(format!("_matrix._tcp.{}", hostname)) .await @@ -472,13 +465,11 @@ async fn query_srv_record( } } -#[tracing::instrument(skip(globals))] async fn request_well_known( - globals: &crate::database::globals::Globals, destination: &str, ) -> Option { let body: serde_json::Value = serde_json::from_str( - &globals + &services().globals .default_client() .get(&format!( "https://{}/.well-known/matrix/server", @@ -499,10 +490,9 @@ async fn request_well_known( /// /// Get version information on this server. pub async fn get_server_version_route( - db: DatabaseGuard, _body: Ruma, ) -> Result { - if !db.globals.allow_federation() { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } @@ -521,24 +511,24 @@ pub async fn get_server_version_route( /// - Matrix does not support invalidating public keys, so the key returned by this will be valid /// forever. // Response type for this endpoint is Json because we need to calculate a signature for the response -pub async fn get_server_keys_route(db: DatabaseGuard) -> Result { - if !db.globals.allow_federation() { +pub async fn get_server_keys_route() -> Result { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } let mut verify_keys: BTreeMap, VerifyKey> = BTreeMap::new(); verify_keys.insert( - format!("ed25519:{}", db.globals.keypair().version()) + format!("ed25519:{}", services().globals.keypair().version()) .try_into() .expect("found invalid server signing keys in DB"), VerifyKey { - key: Base64::new(db.globals.keypair().public_key().to_vec()), + key: Base64::new(services().globals.keypair().public_key().to_vec()), }, ); let mut response = serde_json::from_slice( get_server_keys::v2::Response { server_key: Raw::new(&ServerSigningKeys { - server_name: db.globals.server_name().to_owned(), + server_name: services().globals.server_name().to_owned(), verify_keys, old_verify_keys: BTreeMap::new(), signatures: BTreeMap::new(), @@ -556,8 +546,8 @@ pub async fn get_server_keys_route(db: DatabaseGuard) -> Result Result impl IntoResponse { - get_server_keys_route(db).await +pub async fn get_server_keys_deprecated_route() -> impl IntoResponse { + get_server_keys_route().await } /// # `POST /_matrix/federation/v1/publicRooms` /// /// Lists the public rooms on this server. pub async fn get_public_rooms_filtered_route( - db: DatabaseGuard, body: Ruma, ) -> Result { - if !db.globals.allow_federation() { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } let response = client_server::get_public_rooms_filtered_helper( - &db, None, body.limit, body.since.as_deref(), @@ -608,15 +596,13 @@ pub async fn get_public_rooms_filtered_route( /// /// Lists the public rooms on this server. pub async fn get_public_rooms_route( - db: DatabaseGuard, body: Ruma, ) -> Result { - if !db.globals.allow_federation() { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } let response = client_server::get_public_rooms_filtered_helper( - &db, None, body.limit, body.since.as_deref(), @@ -637,10 +623,9 @@ pub async fn get_public_rooms_route( /// /// Push EDUs and PDUs to this server. pub async fn send_transaction_message_route( - db: DatabaseGuard, body: Ruma, ) -> Result { - if !db.globals.allow_federation() { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } @@ -663,7 +648,7 @@ pub async fn send_transaction_message_route( for pdu in &body.pdus { // We do not add the event_id field to the pdu here because of signature and hashes checks - let (event_id, value) = match crate::pdu::gen_event_id_canonical_json(pdu, &db) { + let (event_id, value) = match gen_event_id_canonical_json(pdu) { Ok(t) => t, Err(_) => { // Event could not be converted to canonical json @@ -684,10 +669,10 @@ pub async fn send_transaction_message_route( } }; - acl_check(&sender_servername, &room_id, &db)?; + acl_check(&sender_servername, &room_id)?; let mutex = Arc::clone( - db.globals + services().globals .roomid_mutex_federation .write() .unwrap() @@ -698,13 +683,12 @@ pub async fn send_transaction_message_route( let start_time = Instant::now(); resolved_map.insert( event_id.clone(), - handle_incoming_pdu( + services().rooms.event_handler.handle_incoming_pdu( &sender_servername, &event_id, &room_id, value, true, - &db, &pub_key_map, ) .await @@ -743,7 +727,7 @@ pub async fn send_transaction_message_route( .event_ids .iter() .filter_map(|id| { - db.rooms.get_pdu_count(id).ok().flatten().map(|r| (id, r)) + services().rooms.get_pdu_count(id).ok().flatten().map(|r| (id, r)) }) .max_by_key(|(_, count)| *count) { @@ -760,11 +744,10 @@ pub async fn send_transaction_message_route( content: ReceiptEventContent(receipt_content), room_id: room_id.clone(), }; - db.rooms.edus.readreceipt_update( + services().rooms.edus.readreceipt_update( &user_id, &room_id, event, - &db.globals, )?; } else { // TODO fetch missing events @@ -774,26 +757,24 @@ pub async fn send_transaction_message_route( } } Edu::Typing(typing) => { - if db.rooms.is_joined(&typing.user_id, &typing.room_id)? { + if services().rooms.is_joined(&typing.user_id, &typing.room_id)? { if typing.typing { - db.rooms.edus.typing_add( + services().rooms.edus.typing_add( &typing.user_id, &typing.room_id, 3000 + utils::millis_since_unix_epoch(), - &db.globals, )?; } else { - db.rooms.edus.typing_remove( + services().rooms.edus.typing_remove( &typing.user_id, &typing.room_id, - &db.globals, )?; } } } Edu::DeviceListUpdate(DeviceListUpdateContent { user_id, .. }) => { - db.users - .mark_device_key_update(&user_id, &db.rooms, &db.globals)?; + services().users + .mark_device_key_update(&user_id)?; } Edu::DirectToDevice(DirectDeviceContent { sender, @@ -802,7 +783,7 @@ pub async fn send_transaction_message_route( messages, }) => { // Check if this is a new transaction id - if db + if services() .transaction_ids .existing_txnid(&sender, None, &message_id)? .is_some() @@ -814,7 +795,7 @@ pub async fn send_transaction_message_route( for (target_device_id_maybe, event) in map { match target_device_id_maybe { DeviceIdOrAllDevices::DeviceId(target_device_id) => { - db.users.add_to_device_event( + services().users.add_to_device_event( &sender, target_user_id, target_device_id, @@ -825,13 +806,12 @@ pub async fn send_transaction_message_route( "Event is invalid", ) })?, - &db.globals, )? } DeviceIdOrAllDevices::AllDevices => { - for target_device_id in db.users.all_device_ids(target_user_id) { - db.users.add_to_device_event( + for target_device_id in services().users.all_device_ids(target_user_id) { + services().users.add_to_device_event( &sender, target_user_id, &target_device_id?, @@ -842,7 +822,6 @@ pub async fn send_transaction_message_route( "Event is invalid", ) })?, - &db.globals, )?; } } @@ -851,7 +830,7 @@ pub async fn send_transaction_message_route( } // Save transaction id with empty data - db.transaction_ids + services().transaction_ids .add_txnid(&sender, None, &message_id, &[])?; } Edu::SigningKeyUpdate(SigningKeyUpdateContent { @@ -863,13 +842,11 @@ pub async fn send_transaction_message_route( continue; } if let Some(master_key) = master_key { - db.users.add_cross_signing_keys( + services().users.add_cross_signing_keys( &user_id, &master_key, &self_signing_key, &None, - &db.rooms, - &db.globals, )?; } } @@ -877,8 +854,6 @@ pub async fn send_transaction_message_route( } } - db.flush()?; - Ok(send_transaction_message::v1::Response { pdus: resolved_map }) } @@ -886,14 +861,13 @@ pub async fn send_transaction_message_route( /// fetch them from the server and save to our DB. #[tracing::instrument(skip_all)] pub(crate) async fn fetch_signing_keys( - db: &Database, origin: &ServerName, signature_ids: Vec, ) -> Result> { let contains_all_ids = |keys: &BTreeMap| signature_ids.iter().all(|id| keys.contains_key(id)); - let permit = db + let permit = services() .globals .servername_ratelimiter .read() @@ -904,7 +878,7 @@ pub(crate) async fn fetch_signing_keys( let permit = match permit { Some(p) => p, None => { - let mut write = db.globals.servername_ratelimiter.write().unwrap(); + let mut write = services().globals.servername_ratelimiter.write().unwrap(); let s = Arc::clone( write .entry(origin.to_owned()) @@ -916,7 +890,7 @@ pub(crate) async fn fetch_signing_keys( } .await; - let back_off = |id| match db + let back_off = |id| match services() .globals .bad_signature_ratelimiter .write() @@ -929,7 +903,7 @@ pub(crate) async fn fetch_signing_keys( hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), }; - if let Some((time, tries)) = db + if let Some((time, tries)) = services() .globals .bad_signature_ratelimiter .read() @@ -950,7 +924,7 @@ pub(crate) async fn fetch_signing_keys( trace!("Loading signing keys for {}", origin); - let mut result: BTreeMap<_, _> = db + let mut result: BTreeMap<_, _> = services() .globals .signing_keys_for(origin)? .into_iter() @@ -963,14 +937,14 @@ pub(crate) async fn fetch_signing_keys( debug!("Fetching signing keys for {} over federation", origin); - if let Some(server_key) = db + if let Some(server_key) = services() .sending - .send_federation_request(&db.globals, origin, get_server_keys::v2::Request::new()) + .send_federation_request(origin, get_server_keys::v2::Request::new()) .await .ok() .and_then(|resp| resp.server_key.deserialize().ok()) { - db.globals.add_signing_key(origin, server_key.clone())?; + services().globals.add_signing_key(origin, server_key.clone())?; result.extend( server_key @@ -990,12 +964,11 @@ pub(crate) async fn fetch_signing_keys( } } - for server in db.globals.trusted_servers() { + for server in services().globals.trusted_servers() { debug!("Asking {} for {}'s signing key", server, origin); - if let Some(server_keys) = db + if let Some(server_keys) = services() .sending .send_federation_request( - &db.globals, server, get_remote_server_keys::v2::Request::new( origin, @@ -1018,7 +991,7 @@ pub(crate) async fn fetch_signing_keys( { trace!("Got signing keys: {:?}", server_keys); for k in server_keys { - db.globals.add_signing_key(origin, k.clone())?; + services().globals.add_signing_key(origin, k.clone())?; result.extend( k.verify_keys .into_iter() @@ -1047,11 +1020,10 @@ pub(crate) async fn fetch_signing_keys( )) } -#[tracing::instrument(skip(starting_events, db))] +#[tracing::instrument(skip(starting_events))] pub(crate) async fn get_auth_chain<'a>( room_id: &RoomId, starting_events: Vec>, - db: &'a Database, ) -> Result> + 'a> { const NUM_BUCKETS: usize = 50; @@ -1059,7 +1031,7 @@ pub(crate) async fn get_auth_chain<'a>( let mut i = 0; for id in starting_events { - let short = db.rooms.get_or_create_shorteventid(&id, &db.globals)?; + let short = services().rooms.get_or_create_shorteventid(&id)?; let bucket_id = (short % NUM_BUCKETS as u64) as usize; buckets[bucket_id].insert((short, id.clone())); i += 1; @@ -1078,7 +1050,7 @@ pub(crate) async fn get_auth_chain<'a>( } let chunk_key: Vec = chunk.iter().map(|(short, _)| short).copied().collect(); - if let Some(cached) = db.rooms.get_auth_chain_from_cache(&chunk_key)? { + if let Some(cached) = services().rooms.get_auth_chain_from_cache(&chunk_key)? { hits += 1; full_auth_chain.extend(cached.iter().copied()); continue; @@ -1090,13 +1062,13 @@ pub(crate) async fn get_auth_chain<'a>( let mut misses2 = 0; let mut i = 0; for (sevent_id, event_id) in chunk { - if let Some(cached) = db.rooms.get_auth_chain_from_cache(&[sevent_id])? { + if let Some(cached) = services().rooms.get_auth_chain_from_cache(&[sevent_id])? { hits2 += 1; chunk_cache.extend(cached.iter().copied()); } else { misses2 += 1; - let auth_chain = Arc::new(get_auth_chain_inner(room_id, &event_id, db)?); - db.rooms + let auth_chain = Arc::new(get_auth_chain_inner(room_id, &event_id)?); + services().rooms .cache_auth_chain(vec![sevent_id], Arc::clone(&auth_chain))?; println!( "cache missed event {} with auth chain len {}", @@ -1118,7 +1090,7 @@ pub(crate) async fn get_auth_chain<'a>( misses2 ); let chunk_cache = Arc::new(chunk_cache); - db.rooms + services().rooms .cache_auth_chain(chunk_key, Arc::clone(&chunk_cache))?; full_auth_chain.extend(chunk_cache.iter()); } @@ -1132,28 +1104,27 @@ pub(crate) async fn get_auth_chain<'a>( Ok(full_auth_chain .into_iter() - .filter_map(move |sid| db.rooms.get_eventid_from_short(sid).ok())) + .filter_map(move |sid| services().rooms.get_eventid_from_short(sid).ok())) } -#[tracing::instrument(skip(event_id, db))] +#[tracing::instrument(skip(event_id))] fn get_auth_chain_inner( room_id: &RoomId, event_id: &EventId, - db: &Database, ) -> Result> { let mut todo = vec![Arc::from(event_id)]; let mut found = HashSet::new(); while let Some(event_id) = todo.pop() { - match db.rooms.get_pdu(&event_id) { + match services().rooms.get_pdu(&event_id) { Ok(Some(pdu)) => { if pdu.room_id != room_id { return Err(Error::BadRequest(ErrorKind::Forbidden, "Evil event in db")); } for auth_event in &pdu.auth_events { - let sauthevent = db + let sauthevent = services() .rooms - .get_or_create_shorteventid(auth_event, &db.globals)?; + .get_or_create_shorteventid(auth_event)?; if !found.contains(&sauthevent) { found.insert(sauthevent); @@ -1179,10 +1150,9 @@ fn get_auth_chain_inner( /// /// - Only works if a user of this server is currently invited or joined the room pub async fn get_event_route( - db: DatabaseGuard, body: Ruma, ) -> Result { - if !db.globals.allow_federation() { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } @@ -1191,7 +1161,7 @@ pub async fn get_event_route( .as_ref() .expect("server is authenticated"); - let event = db + let event = services() .rooms .get_pdu_json(&body.event_id)? .ok_or(Error::BadRequest(ErrorKind::NotFound, "Event not found."))?; @@ -1204,7 +1174,7 @@ pub async fn get_event_route( let room_id = <&RoomId>::try_from(room_id_str) .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; - if !db.rooms.server_in_room(sender_servername, room_id)? { + if !services().rooms.server_in_room(sender_servername, room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "Server is not in room", @@ -1212,7 +1182,7 @@ pub async fn get_event_route( } Ok(get_event::v1::Response { - origin: db.globals.server_name().to_owned(), + origin: services().globals.server_name().to_owned(), origin_server_ts: MilliSecondsSinceUnixEpoch::now(), pdu: PduEvent::convert_to_outgoing_federation_event(event), }) @@ -1222,10 +1192,9 @@ pub async fn get_event_route( /// /// Retrieves events that the sender is missing. pub async fn get_missing_events_route( - db: DatabaseGuard, body: Ruma, ) -> Result { - if !db.globals.allow_federation() { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } @@ -1234,21 +1203,21 @@ pub async fn get_missing_events_route( .as_ref() .expect("server is authenticated"); - if !db.rooms.server_in_room(sender_servername, &body.room_id)? { + if !services().rooms.server_in_room(sender_servername, &body.room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "Server is not in room", )); } - acl_check(sender_servername, &body.room_id, &db)?; + acl_check(sender_servername, &body.room_id)?; let mut queued_events = body.latest_events.clone(); let mut events = Vec::new(); let mut i = 0; while i < queued_events.len() && events.len() < u64::from(body.limit) as usize { - if let Some(pdu) = db.rooms.get_pdu_json(&queued_events[i])? { + if let Some(pdu) = services().rooms.get_pdu_json(&queued_events[i])? { let room_id_str = pdu .get("room_id") .and_then(|val| val.as_str()) @@ -1295,10 +1264,9 @@ pub async fn get_missing_events_route( /// /// - This does not include the event itself pub async fn get_event_authorization_route( - db: DatabaseGuard, body: Ruma, ) -> Result { - if !db.globals.allow_federation() { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } @@ -1307,16 +1275,16 @@ pub async fn get_event_authorization_route( .as_ref() .expect("server is authenticated"); - if !db.rooms.server_in_room(sender_servername, &body.room_id)? { + if !services().rooms.server_in_room(sender_servername, &body.room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "Server is not in room.", )); } - acl_check(sender_servername, &body.room_id, &db)?; + acl_check(sender_servername, &body.room_id)?; - let event = db + let event = services() .rooms .get_pdu_json(&body.event_id)? .ok_or(Error::BadRequest(ErrorKind::NotFound, "Event not found."))?; @@ -1329,11 +1297,11 @@ pub async fn get_event_authorization_route( let room_id = <&RoomId>::try_from(room_id_str) .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; - let auth_chain_ids = get_auth_chain(room_id, vec![Arc::from(&*body.event_id)], &db).await?; + let auth_chain_ids = get_auth_chain(room_id, vec![Arc::from(&*body.event_id)]).await?; Ok(get_event_authorization::v1::Response { auth_chain: auth_chain_ids - .filter_map(|id| db.rooms.get_pdu_json(&id).ok()?) + .filter_map(|id| services().rooms.get_pdu_json(&id).ok()?) .map(PduEvent::convert_to_outgoing_federation_event) .collect(), }) @@ -1343,10 +1311,9 @@ pub async fn get_event_authorization_route( /// /// Retrieves the current state of the room. pub async fn get_room_state_route( - db: DatabaseGuard, body: Ruma, ) -> Result { - if !db.globals.allow_federation() { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } @@ -1355,16 +1322,16 @@ pub async fn get_room_state_route( .as_ref() .expect("server is authenticated"); - if !db.rooms.server_in_room(sender_servername, &body.room_id)? { + if !services().rooms.server_in_room(sender_servername, &body.room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "Server is not in room.", )); } - acl_check(sender_servername, &body.room_id, &db)?; + acl_check(sender_servername, &body.room_id)?; - let shortstatehash = db + let shortstatehash = services() .rooms .pdu_shortstatehash(&body.event_id)? .ok_or(Error::BadRequest( @@ -1372,25 +1339,25 @@ pub async fn get_room_state_route( "Pdu state not found.", ))?; - let pdus = db + let pdus = services() .rooms .state_full_ids(shortstatehash) .await? .into_iter() .map(|(_, id)| { PduEvent::convert_to_outgoing_federation_event( - db.rooms.get_pdu_json(&id).unwrap().unwrap(), + services().rooms.get_pdu_json(&id).unwrap().unwrap(), ) }) .collect(); let auth_chain_ids = - get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)], &db).await?; + get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)]).await?; Ok(get_room_state::v1::Response { auth_chain: auth_chain_ids .map(|id| { - db.rooms.get_pdu_json(&id).map(|maybe_json| { + services().rooms.get_pdu_json(&id).map(|maybe_json| { PduEvent::convert_to_outgoing_federation_event(maybe_json.unwrap()) }) }) @@ -1404,10 +1371,9 @@ pub async fn get_room_state_route( /// /// Retrieves the current state of the room. pub async fn get_room_state_ids_route( - db: DatabaseGuard, body: Ruma, ) -> Result { - if !db.globals.allow_federation() { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } @@ -1416,16 +1382,16 @@ pub async fn get_room_state_ids_route( .as_ref() .expect("server is authenticated"); - if !db.rooms.server_in_room(sender_servername, &body.room_id)? { + if !services().rooms.server_in_room(sender_servername, &body.room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "Server is not in room.", )); } - acl_check(sender_servername, &body.room_id, &db)?; + acl_check(sender_servername, &body.room_id)?; - let shortstatehash = db + let shortstatehash = services() .rooms .pdu_shortstatehash(&body.event_id)? .ok_or(Error::BadRequest( @@ -1433,7 +1399,7 @@ pub async fn get_room_state_ids_route( "Pdu state not found.", ))?; - let pdu_ids = db + let pdu_ids = services() .rooms .state_full_ids(shortstatehash) .await? @@ -1442,7 +1408,7 @@ pub async fn get_room_state_ids_route( .collect(); let auth_chain_ids = - get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)], &db).await?; + get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)]).await?; Ok(get_room_state_ids::v1::Response { auth_chain_ids: auth_chain_ids.map(|id| (*id).to_owned()).collect(), @@ -1454,14 +1420,13 @@ pub async fn get_room_state_ids_route( /// /// Creates a join template. pub async fn create_join_event_template_route( - db: DatabaseGuard, body: Ruma, ) -> Result { - if !db.globals.allow_federation() { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } - if !db.rooms.exists(&body.room_id)? { + if !services().rooms.exists(&body.room_id)? { return Err(Error::BadRequest( ErrorKind::NotFound, "Room is unknown to this server.", @@ -1473,11 +1438,21 @@ pub async fn create_join_event_template_route( .as_ref() .expect("server is authenticated"); - acl_check(sender_servername, &body.room_id, &db)?; + acl_check(sender_servername, &body.room_id)?; + + let mutex_state = Arc::clone( + services().globals + .roomid_mutex_state + .write() + .unwrap() + .entry(body.room_id.to_owned()) + .or_default(), + ); + let state_lock = mutex_state.lock().await; // TODO: Conduit does not implement restricted join rules yet, we always reject let join_rules_event = - db.rooms + services().rooms .room_state_get(&body.room_id, &StateEventType::RoomJoinRules, "")?; let join_rules_event_content: Option = join_rules_event @@ -1502,7 +1477,8 @@ pub async fn create_join_event_template_route( } } - if !body.ver.contains(&room_version_id) { + let room_version_id = services().rooms.state.get_room_version(&body.room_id); + if !body.ver.contains(room_version_id) { return Err(Error::BadRequest( ErrorKind::IncompatibleRoomVersion { room_version: room_version_id, @@ -1523,10 +1499,15 @@ pub async fn create_join_event_template_route( }) .expect("member event is valid value"); - let state_key = body.user_id.to_string(); - let kind = StateEventType::RoomMember; + let (pdu, pdu_json) = services().rooms.timeline.create_hash_and_sign_event(PduBuilder { + event_type: RoomEventType::RoomMember, + content, + unsigned: None, + state_key: Some(body.user_id.to_string()), + redacts: None, + }, &body.user_id, &body.room_id, &state_lock); - let (pdu, pdu_json) = create_hash_and_sign_event(); + drop(state_lock); Ok(prepare_join_event::v1::Response { room_version: Some(room_version_id), @@ -1535,26 +1516,25 @@ pub async fn create_join_event_template_route( } async fn create_join_event( - db: &DatabaseGuard, sender_servername: &ServerName, room_id: &RoomId, pdu: &RawJsonValue, ) -> Result { - if !db.globals.allow_federation() { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } - if !db.rooms.exists(room_id)? { + if !services().rooms.exists(room_id)? { return Err(Error::BadRequest( ErrorKind::NotFound, "Room is unknown to this server.", )); } - acl_check(sender_servername, room_id, db)?; + acl_check(sender_servername, room_id)?; // TODO: Conduit does not implement restricted join rules yet, we always reject - let join_rules_event = db + let join_rules_event = services() .rooms .room_state_get(room_id, &StateEventType::RoomJoinRules, "")?; @@ -1581,7 +1561,7 @@ async fn create_join_event( } // We need to return the state prior to joining, let's keep a reference to that here - let shortstatehash = db + let shortstatehash = services() .rooms .current_shortstatehash(room_id)? .ok_or(Error::BadRequest( @@ -1593,7 +1573,7 @@ async fn create_join_event( // let mut auth_cache = EventMap::new(); // We do not add the event_id field to the pdu here because of signature and hashes checks - let (event_id, value) = match crate::pdu::gen_event_id_canonical_json(pdu, &db) { + let (event_id, value) = match gen_event_id_canonical_json(pdu) { Ok(t) => t, Err(_) => { // Event could not be converted to canonical json @@ -1614,7 +1594,7 @@ async fn create_join_event( .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid."))?; let mutex = Arc::clone( - db.globals + services().globals .roomid_mutex_federation .write() .unwrap() @@ -1622,7 +1602,7 @@ async fn create_join_event( .or_default(), ); let mutex_lock = mutex.lock().await; - let pdu_id = handle_incoming_pdu(&origin, &event_id, room_id, value, true, db, &pub_key_map) + let pdu_id = services().rooms.event_handler.handle_incoming_pdu(&origin, &event_id, room_id, value, true, &pub_key_map) .await .map_err(|e| { warn!("Error while handling incoming send join PDU: {}", e); @@ -1637,32 +1617,29 @@ async fn create_join_event( ))?; drop(mutex_lock); - let state_ids = db.rooms.state_full_ids(shortstatehash).await?; + let state_ids = services().rooms.state_full_ids(shortstatehash).await?; let auth_chain_ids = get_auth_chain( room_id, state_ids.iter().map(|(_, id)| id.clone()).collect(), - db, ) .await?; - let servers = db + let servers = services() .rooms .room_servers(room_id) .filter_map(|r| r.ok()) - .filter(|server| &**server != db.globals.server_name()); + .filter(|server| &**server != services().globals.server_name()); - db.sending.send_pdu(servers, &pdu_id)?; - - db.flush()?; + services().sending.send_pdu(servers, &pdu_id)?; Ok(RoomState { auth_chain: auth_chain_ids - .filter_map(|id| db.rooms.get_pdu_json(&id).ok().flatten()) + .filter_map(|id| services().rooms.get_pdu_json(&id).ok().flatten()) .map(PduEvent::convert_to_outgoing_federation_event) .collect(), state: state_ids .iter() - .filter_map(|(_, id)| db.rooms.get_pdu_json(id).ok().flatten()) + .filter_map(|(_, id)| services().rooms.get_pdu_json(id).ok().flatten()) .map(PduEvent::convert_to_outgoing_federation_event) .collect(), }) @@ -1672,7 +1649,6 @@ async fn create_join_event( /// /// Submits a signed join event. pub async fn create_join_event_v1_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_servername = body @@ -1680,7 +1656,7 @@ pub async fn create_join_event_v1_route( .as_ref() .expect("server is authenticated"); - let room_state = create_join_event(&db, sender_servername, &body.room_id, &body.pdu).await?; + let room_state = create_join_event(sender_servername, &body.room_id, &body.pdu).await?; Ok(create_join_event::v1::Response { room_state }) } @@ -1689,7 +1665,6 @@ pub async fn create_join_event_v1_route( /// /// Submits a signed join event. pub async fn create_join_event_v2_route( - db: DatabaseGuard, body: Ruma, ) -> Result { let sender_servername = body @@ -1697,7 +1672,7 @@ pub async fn create_join_event_v2_route( .as_ref() .expect("server is authenticated"); - let room_state = create_join_event(&db, sender_servername, &body.room_id, &body.pdu).await?; + let room_state = create_join_event(sender_servername, &body.room_id, &body.pdu).await?; Ok(create_join_event::v2::Response { room_state }) } @@ -1706,10 +1681,9 @@ pub async fn create_join_event_v2_route( /// /// Invites a remote user to a room. pub async fn create_invite_route( - db: DatabaseGuard, body: Ruma, ) -> Result { - if !db.globals.allow_federation() { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } @@ -1718,9 +1692,9 @@ pub async fn create_invite_route( .as_ref() .expect("server is authenticated"); - acl_check(sender_servername, &body.room_id, &db)?; + acl_check(sender_servername, &body.room_id)?; - if !db.rooms.is_supported_version(&db, &body.room_version) { + if !services().rooms.is_supported_version(&body.room_version) { return Err(Error::BadRequest( ErrorKind::IncompatibleRoomVersion { room_version: body.room_version.clone(), @@ -1733,8 +1707,8 @@ pub async fn create_invite_route( .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invite event is invalid."))?; ruma::signatures::hash_and_sign_event( - db.globals.server_name().as_str(), - db.globals.keypair(), + services().globals.server_name().as_str(), + services().globals.keypair(), &mut signed_event, &body.room_version, ) @@ -1793,20 +1767,17 @@ pub async fn create_invite_route( invite_state.push(pdu.to_stripped_state_event()); // If the room already exists, the remote server will notify us about the join via /send - if !db.rooms.exists(&pdu.room_id)? { - db.rooms.update_membership( + if !services().rooms.exists(&pdu.room_id)? { + services().rooms.update_membership( &body.room_id, &invited_user, MembershipState::Invite, &sender, Some(invite_state), - &db, true, )?; } - db.flush()?; - Ok(create_invite::v2::Response { event: PduEvent::convert_to_outgoing_federation_event(signed_event), }) @@ -1816,10 +1787,9 @@ pub async fn create_invite_route( /// /// Gets information on all devices of the user. pub async fn get_devices_route( - db: DatabaseGuard, body: Ruma, ) -> Result { - if !db.globals.allow_federation() { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } @@ -1830,19 +1800,19 @@ pub async fn get_devices_route( Ok(get_devices::v1::Response { user_id: body.user_id.clone(), - stream_id: db + stream_id: services() .users .get_devicelist_version(&body.user_id)? .unwrap_or(0) .try_into() .expect("version will not grow that large"), - devices: db + devices: services() .users .all_devices_metadata(&body.user_id) .filter_map(|r| r.ok()) .filter_map(|metadata| { Some(UserDevice { - keys: db + keys: services() .users .get_device_keys(&body.user_id, &metadata.device_id) .ok()??, @@ -1851,10 +1821,10 @@ pub async fn get_devices_route( }) }) .collect(), - master_key: db + master_key: services() .users .get_master_key(&body.user_id, |u| u.server_name() == sender_servername)?, - self_signing_key: db + self_signing_key: services() .users .get_self_signing_key(&body.user_id, |u| u.server_name() == sender_servername)?, }) @@ -1864,14 +1834,13 @@ pub async fn get_devices_route( /// /// Resolve a room alias to a room id. pub async fn get_room_information_route( - db: DatabaseGuard, body: Ruma, ) -> Result { - if !db.globals.allow_federation() { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } - let room_id = db + let room_id = services() .rooms .id_from_alias(&body.room_alias)? .ok_or(Error::BadRequest( @@ -1881,7 +1850,7 @@ pub async fn get_room_information_route( Ok(get_room_information::v1::Response { room_id, - servers: vec![db.globals.server_name().to_owned()], + servers: vec![services().globals.server_name().to_owned()], }) } @@ -1889,10 +1858,9 @@ pub async fn get_room_information_route( /// /// Gets information on a profile. pub async fn get_profile_information_route( - db: DatabaseGuard, body: Ruma, ) -> Result { - if !db.globals.allow_federation() { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } @@ -1901,17 +1869,17 @@ pub async fn get_profile_information_route( let mut blurhash = None; match &body.field { - Some(ProfileField::DisplayName) => displayname = db.users.displayname(&body.user_id)?, + Some(ProfileField::DisplayName) => displayname = services().users.displayname(&body.user_id)?, Some(ProfileField::AvatarUrl) => { - avatar_url = db.users.avatar_url(&body.user_id)?; - blurhash = db.users.blurhash(&body.user_id)? + avatar_url = services().users.avatar_url(&body.user_id)?; + blurhash = services().users.blurhash(&body.user_id)? } // TODO: what to do with custom Some(_) => {} None => { - displayname = db.users.displayname(&body.user_id)?; - avatar_url = db.users.avatar_url(&body.user_id)?; - blurhash = db.users.blurhash(&body.user_id)?; + displayname = services().users.displayname(&body.user_id)?; + avatar_url = services().users.avatar_url(&body.user_id)?; + blurhash = services().users.blurhash(&body.user_id)?; } } @@ -1926,10 +1894,9 @@ pub async fn get_profile_information_route( /// /// Gets devices and identity keys for the given users. pub async fn get_keys_route( - db: DatabaseGuard, body: Ruma, ) -> Result { - if !db.globals.allow_federation() { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } @@ -1937,12 +1904,9 @@ pub async fn get_keys_route( None, &body.device_keys, |u| Some(u.server_name()) == body.sender_servername.as_deref(), - &db, ) .await?; - db.flush()?; - Ok(get_keys::v1::Response { device_keys: result.device_keys, master_keys: result.master_keys, @@ -1954,16 +1918,13 @@ pub async fn get_keys_route( /// /// Claims one-time keys. pub async fn claim_keys_route( - db: DatabaseGuard, body: Ruma, ) -> Result { - if !db.globals.allow_federation() { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } - let result = claim_keys_helper(&body.one_time_keys, &db).await?; - - db.flush()?; + let result = claim_keys_helper(&body.one_time_keys).await?; Ok(claim_keys::v1::Response { one_time_keys: result.one_time_keys, @@ -1974,7 +1935,6 @@ pub async fn claim_keys_route( pub(crate) async fn fetch_required_signing_keys( event: &BTreeMap, pub_key_map: &RwLock>>, - db: &Database, ) -> Result<()> { let signatures = event .get("signatures") @@ -1996,7 +1956,6 @@ pub(crate) async fn fetch_required_signing_keys( let signature_ids = signature_object.keys().cloned().collect::>(); let fetch_res = fetch_signing_keys( - db, signature_server.as_str().try_into().map_err(|_| { Error::BadServerResponse("Invalid servername in signatures of server response pdu.") })?, @@ -2028,7 +1987,6 @@ fn get_server_keys_from_cache( servers: &mut BTreeMap, BTreeMap, QueryCriteria>>, room_version: &RoomVersionId, pub_key_map: &mut RwLockWriteGuard<'_, BTreeMap>>, - db: &Database, ) -> Result<()> { let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { error!("Invalid PDU in server response: {:?}: {:?}", pdu, e); @@ -2043,7 +2001,7 @@ fn get_server_keys_from_cache( let event_id = <&EventId>::try_from(event_id.as_str()) .expect("ruma's reference hashes are valid event ids"); - if let Some((time, tries)) = db + if let Some((time, tries)) = services() .globals .bad_event_ratelimiter .read() @@ -2092,7 +2050,7 @@ fn get_server_keys_from_cache( trace!("Loading signing keys for {}", origin); - let result: BTreeMap<_, _> = db + let result: BTreeMap<_, _> = services() .globals .signing_keys_for(origin)? .into_iter() @@ -2114,7 +2072,6 @@ pub(crate) async fn fetch_join_signing_keys( event: &create_join_event::v2::Response, room_version: &RoomVersionId, pub_key_map: &RwLock>>, - db: &Database, ) -> Result<()> { let mut servers: BTreeMap, BTreeMap, QueryCriteria>> = BTreeMap::new(); @@ -2127,10 +2084,10 @@ pub(crate) async fn fetch_join_signing_keys( // Try to fetch keys, failure is okay // Servers we couldn't find in the cache will be added to `servers` for pdu in &event.room_state.state { - let _ = get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm, db); + let _ = get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm); } for pdu in &event.room_state.auth_chain { - let _ = get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm, db); + let _ = get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm); } drop(pkm); @@ -2141,12 +2098,11 @@ pub(crate) async fn fetch_join_signing_keys( return Ok(()); } - for server in db.globals.trusted_servers() { + for server in services().globals.trusted_servers() { trace!("Asking batch signing keys from trusted server {}", server); - if let Ok(keys) = db + if let Ok(keys) = services() .sending .send_federation_request( - &db.globals, server, get_remote_server_keys_batch::v2::Request { server_keys: servers.clone(), @@ -2164,7 +2120,7 @@ pub(crate) async fn fetch_join_signing_keys( // TODO: Check signature from trusted server? servers.remove(&k.server_name); - let result = db + let result = services() .globals .add_signing_key(&k.server_name, k.clone())? .into_iter() @@ -2184,9 +2140,8 @@ pub(crate) async fn fetch_join_signing_keys( .into_iter() .map(|(server, _)| async move { ( - db.sending + services().sending .send_federation_request( - &db.globals, &server, get_server_keys::v2::Request::new(), ) @@ -2198,7 +2153,7 @@ pub(crate) async fn fetch_join_signing_keys( while let Some(result) = futures.next().await { if let (Ok(get_keys_response), origin) = result { - let result: BTreeMap<_, _> = db + let result: BTreeMap<_, _> = services() .globals .add_signing_key(&origin, get_keys_response.server_key.deserialize().unwrap())? .into_iter() @@ -2216,8 +2171,8 @@ pub(crate) async fn fetch_join_signing_keys( } /// Returns Ok if the acl allows the server -fn acl_check(server_name: &ServerName, room_id: &RoomId, db: &Database) -> Result<()> { - let acl_event = match db +fn acl_check(server_name: &ServerName, room_id: &RoomId) -> Result<()> { + let acl_event = match services() .rooms .room_state_get(room_id, &StateEventType::RoomServerAcl, "")? { diff --git a/src/database/abstraction.rs b/src/database/abstraction.rs index 29325bd6..93660f9f 100644 --- a/src/database/abstraction.rs +++ b/src/database/abstraction.rs @@ -30,7 +30,7 @@ pub trait KeyValueDatabaseEngine: Send + Sync { fn open(config: &Config) -> Result where Self: Sized; - fn open_tree(&self, name: &'static str) -> Result>; + fn open_tree(&self, name: &'static str) -> Result>; fn flush(&self) -> Result<()>; fn cleanup(&self) -> Result<()> { Ok(()) @@ -40,7 +40,7 @@ pub trait KeyValueDatabaseEngine: Send + Sync { } } -pub trait KeyValueTree: Send + Sync { +pub trait KvTree: Send + Sync { fn get(&self, key: &[u8]) -> Result>>; fn insert(&self, key: &[u8], value: &[u8]) -> Result<()>; diff --git a/src/database/abstraction/rocksdb.rs b/src/database/abstraction/rocksdb.rs index 2cf9d5ee..1388dc38 100644 --- a/src/database/abstraction/rocksdb.rs +++ b/src/database/abstraction/rocksdb.rs @@ -1,4 +1,4 @@ -use super::{super::Config, watchers::Watchers, DatabaseEngine, Tree}; +use super::{super::Config, watchers::Watchers, KvTree, KeyValueDatabaseEngine}; use crate::{utils, Result}; use std::{ future::Future, @@ -51,7 +51,7 @@ fn db_options(max_open_files: i32, rocksdb_cache: &rocksdb::Cache) -> rocksdb::O db_opts } -impl DatabaseEngine for Arc { +impl KeyValueDatabaseEngine for Arc { fn open(config: &Config) -> Result { let cache_capacity_bytes = (config.db_cache_capacity_mb * 1024.0 * 1024.0) as usize; let rocksdb_cache = rocksdb::Cache::new_lru_cache(cache_capacity_bytes).unwrap(); @@ -83,7 +83,7 @@ impl DatabaseEngine for Arc { })) } - fn open_tree(&self, name: &'static str) -> Result> { + fn open_tree(&self, name: &'static str) -> Result> { if !self.old_cfs.contains(&name.to_owned()) { // Create if it didn't exist let _ = self @@ -129,7 +129,7 @@ impl RocksDbEngineTree<'_> { } } -impl Tree for RocksDbEngineTree<'_> { +impl KvTree for RocksDbEngineTree<'_> { fn get(&self, key: &[u8]) -> Result>> { Ok(self.db.rocks.get_cf(&self.cf(), key)?) } diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index 7cfa81af..02d4dbd6 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -1,4 +1,4 @@ -use super::{watchers::Watchers, DatabaseEngine, Tree}; +use super::{watchers::Watchers, KeyValueDatabaseEngine, KvTree}; use crate::{database::Config, Result}; use parking_lot::{Mutex, MutexGuard}; use rusqlite::{Connection, DatabaseName::Main, OptionalExtension}; @@ -80,7 +80,7 @@ impl Engine { } } -impl DatabaseEngine for Arc { +impl KeyValueDatabaseEngine for Arc { fn open(config: &Config) -> Result { let path = Path::new(&config.database_path).join("conduit.db"); @@ -105,7 +105,7 @@ impl DatabaseEngine for Arc { Ok(arc) } - fn open_tree(&self, name: &str) -> Result> { + fn open_tree(&self, name: &str) -> Result> { self.write_lock().execute(&format!("CREATE TABLE IF NOT EXISTS {} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )", name), [])?; Ok(Arc::new(SqliteTable { @@ -189,7 +189,7 @@ impl SqliteTable { } } -impl Tree for SqliteTable { +impl KvTree for SqliteTable { fn get(&self, key: &[u8]) -> Result>> { self.get_with_guard(self.engine.read_lock(), key) } diff --git a/src/database/key_value/appservice.rs b/src/database/key_value/appservice.rs index 66a2a5c8..eae2cfbc 100644 --- a/src/database/key_value/appservice.rs +++ b/src/database/key_value/appservice.rs @@ -1,6 +1,8 @@ +use crate::{database::KeyValueDatabase, service, utils, Error}; + impl service::appservice::Data for KeyValueDatabase { /// Registers an appservice and returns the ID to the caller - pub fn register_appservice(&self, yaml: serde_yaml::Value) -> Result { + fn register_appservice(&self, yaml: serde_yaml::Value) -> Result { // TODO: Rumaify let id = yaml.get("id").unwrap().as_str().unwrap(); self.id_appserviceregistrations.insert( diff --git a/src/database/key_value/mod.rs b/src/database/key_value/mod.rs index 0c09c17e..189571f6 100644 --- a/src/database/key_value/mod.rs +++ b/src/database/key_value/mod.rs @@ -1,13 +1,13 @@ -mod account_data; -mod admin; +//mod account_data; +//mod admin; mod appservice; -mod globals; -mod key_backups; -mod media; -mod pdu; +//mod globals; +//mod key_backups; +//mod media; +//mod pdu; mod pusher; mod rooms; -mod sending; +//mod sending; mod transaction_ids; mod uiaa; mod users; diff --git a/src/database/key_value/pusher.rs b/src/database/key_value/pusher.rs index 94374ab2..b77170db 100644 --- a/src/database/key_value/pusher.rs +++ b/src/database/key_value/pusher.rs @@ -1,3 +1,7 @@ +use ruma::{UserId, api::client::push::{set_pusher, get_pushers}}; + +use crate::{service, database::KeyValueDatabase, Error}; + impl service::pusher::Data for KeyValueDatabase { fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()> { let mut key = sender.as_bytes().to_vec(); diff --git a/src/database/key_value/rooms/alias.rs b/src/database/key_value/rooms/alias.rs index b00eb3b1..a9236a75 100644 --- a/src/database/key_value/rooms/alias.rs +++ b/src/database/key_value/rooms/alias.rs @@ -1,4 +1,8 @@ -impl service::room::alias::Data for KeyValueDatabase { +use ruma::{RoomId, RoomAliasId, api::client::error::ErrorKind}; + +use crate::{service, database::KeyValueDatabase, utils, Error, services}; + +impl service::rooms::alias::Data for KeyValueDatabase { fn set_alias( &self, alias: &RoomAliasId, @@ -8,7 +12,7 @@ impl service::room::alias::Data for KeyValueDatabase { .insert(alias.alias().as_bytes(), room_id.as_bytes())?; let mut aliasid = room_id.as_bytes().to_vec(); aliasid.push(0xff); - aliasid.extend_from_slice(&globals.next_count()?.to_be_bytes()); + aliasid.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); self.aliasid_alias.insert(&aliasid, &*alias.as_bytes())?; Ok(()) } diff --git a/src/database/key_value/rooms/directory.rs b/src/database/key_value/rooms/directory.rs index f42de45e..44a580c3 100644 --- a/src/database/key_value/rooms/directory.rs +++ b/src/database/key_value/rooms/directory.rs @@ -1,10 +1,14 @@ -impl service::room::directory::Data for KeyValueDatabase { +use ruma::RoomId; + +use crate::{service, database::KeyValueDatabase, utils, Error}; + +impl service::rooms::directory::Data for KeyValueDatabase { fn set_public(&self, room_id: &RoomId) -> Result<()> { - self.publicroomids.insert(room_id.as_bytes(), &[])?; + self.publicroomids.insert(room_id.as_bytes(), &[]) } fn set_not_public(&self, room_id: &RoomId) -> Result<()> { - self.publicroomids.remove(room_id.as_bytes())?; + self.publicroomids.remove(room_id.as_bytes()) } fn is_public_room(&self, room_id: &RoomId) -> Result { diff --git a/src/database/key_value/rooms/edus/mod.rs b/src/database/key_value/rooms/edus/mod.rs new file mode 100644 index 00000000..9ffd33da --- /dev/null +++ b/src/database/key_value/rooms/edus/mod.rs @@ -0,0 +1,3 @@ +mod presence; +mod typing; +mod read_receipt; diff --git a/src/database/key_value/rooms/edus/presence.rs b/src/database/key_value/rooms/edus/presence.rs index 1978ce7b..9f3977db 100644 --- a/src/database/key_value/rooms/edus/presence.rs +++ b/src/database/key_value/rooms/edus/presence.rs @@ -1,4 +1,10 @@ -impl service::room::edus::presence::Data for KeyValueDatabase { +use std::collections::HashMap; + +use ruma::{UserId, RoomId, events::presence::PresenceEvent, presence::PresenceState, UInt}; + +use crate::{service, database::KeyValueDatabase, utils, Error, services}; + +impl service::rooms::edus::presence::Data for KeyValueDatabase { fn update_presence( &self, user_id: &UserId, @@ -7,7 +13,7 @@ impl service::room::edus::presence::Data for KeyValueDatabase { ) -> Result<()> { // TODO: Remove old entry? Or maybe just wipe completely from time to time? - let count = globals.next_count()?.to_be_bytes(); + let count = services().globals.next_count()?.to_be_bytes(); let mut presence_id = room_id.as_bytes().to_vec(); presence_id.push(0xff); @@ -101,6 +107,7 @@ impl service::room::edus::presence::Data for KeyValueDatabase { Ok(hashmap) } + /* fn presence_maintain(&self, db: Arc>) { // TODO @M0dEx: move this to a timed tasks module tokio::spawn(async move { @@ -117,6 +124,7 @@ impl service::room::edus::presence::Data for KeyValueDatabase { } }); } + */ } fn parse_presence_event(bytes: &[u8]) -> Result { diff --git a/src/database/key_value/rooms/edus/read_receipt.rs b/src/database/key_value/rooms/edus/read_receipt.rs index 556e697f..68aea165 100644 --- a/src/database/key_value/rooms/edus/read_receipt.rs +++ b/src/database/key_value/rooms/edus/read_receipt.rs @@ -1,4 +1,10 @@ -impl service::room::edus::read_receipt::Data for KeyValueDatabase { +use std::mem; + +use ruma::{UserId, RoomId, events::receipt::ReceiptEvent, serde::Raw, signatures::CanonicalJsonObject}; + +use crate::{database::KeyValueDatabase, service, utils, Error, services}; + +impl service::rooms::edus::read_receipt::Data for KeyValueDatabase { fn readreceipt_update( &self, user_id: &UserId, @@ -28,7 +34,7 @@ impl service::room::edus::read_receipt::Data for KeyValueDatabase { } let mut room_latest_id = prefix; - room_latest_id.extend_from_slice(&globals.next_count()?.to_be_bytes()); + room_latest_id.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); room_latest_id.push(0xff); room_latest_id.extend_from_slice(user_id.as_bytes()); @@ -40,7 +46,7 @@ impl service::room::edus::read_receipt::Data for KeyValueDatabase { Ok(()) } - pub fn readreceipts_since<'a>( + fn readreceipts_since<'a>( &'a self, room_id: &RoomId, since: u64, @@ -102,7 +108,7 @@ impl service::room::edus::read_receipt::Data for KeyValueDatabase { .insert(&key, &count.to_be_bytes())?; self.roomuserid_lastprivatereadupdate - .insert(&key, &globals.next_count()?.to_be_bytes())?; + .insert(&key, &services().globals.next_count()?.to_be_bytes()) } fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result> { diff --git a/src/database/key_value/rooms/edus/typing.rs b/src/database/key_value/rooms/edus/typing.rs index 8cfb432d..905bffc8 100644 --- a/src/database/key_value/rooms/edus/typing.rs +++ b/src/database/key_value/rooms/edus/typing.rs @@ -1,15 +1,20 @@ -impl service::room::edus::typing::Data for KeyValueDatabase { +use std::collections::HashSet; + +use ruma::{UserId, RoomId}; + +use crate::{database::KeyValueDatabase, service, utils, Error, services}; + +impl service::rooms::edus::typing::Data for KeyValueDatabase { fn typing_add( &self, user_id: &UserId, room_id: &RoomId, timeout: u64, - globals: &super::super::globals::Globals, ) -> Result<()> { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); - let count = globals.next_count()?.to_be_bytes(); + let count = services().globals.next_count()?.to_be_bytes(); let mut room_typing_id = prefix; room_typing_id.extend_from_slice(&timeout.to_be_bytes()); @@ -49,7 +54,7 @@ impl service::room::edus::typing::Data for KeyValueDatabase { if found_outdated { self.roomid_lasttypingupdate - .insert(room_id.as_bytes(), &globals.next_count()?.to_be_bytes())?; + .insert(room_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?; } Ok(()) diff --git a/src/database/key_value/rooms/lazy_load.rs b/src/database/key_value/rooms/lazy_load.rs index 8abdce49..c230cbf7 100644 --- a/src/database/key_value/rooms/lazy_load.rs +++ b/src/database/key_value/rooms/lazy_load.rs @@ -1,4 +1,8 @@ -impl service::room::lazy_load::Data for KeyValueDatabase { +use ruma::{UserId, DeviceId, RoomId}; + +use crate::{service, database::KeyValueDatabase}; + +impl service::rooms::lazy_loading::Data for KeyValueDatabase { fn lazy_load_was_sent_before( &self, user_id: &UserId, diff --git a/src/database/key_value/rooms/metadata.rs b/src/database/key_value/rooms/metadata.rs index 37dd7173..b4cba2c6 100644 --- a/src/database/key_value/rooms/metadata.rs +++ b/src/database/key_value/rooms/metadata.rs @@ -1,4 +1,8 @@ -impl service::room::metadata::Data for KeyValueDatabase { +use ruma::RoomId; + +use crate::{service, database::KeyValueDatabase}; + +impl service::rooms::metadata::Data for KeyValueDatabase { fn exists(&self, room_id: &RoomId) -> Result { let prefix = match self.get_shortroomid(room_id)? { Some(b) => b.to_be_bytes().to_vec(), diff --git a/src/database/key_value/rooms/mod.rs b/src/database/key_value/rooms/mod.rs index 2a3f81d8..adb810ba 100644 --- a/src/database/key_value/rooms/mod.rs +++ b/src/database/key_value/rooms/mod.rs @@ -1,14 +1,13 @@ -mod state; mod alias; mod directory; mod edus; -mod event_handler; -mod lazy_loading; +//mod event_handler; +mod lazy_load; mod metadata; mod outlier; mod pdu_metadata; mod search; -mod short; +//mod short; mod state; mod state_accessor; mod state_cache; diff --git a/src/database/key_value/rooms/outlier.rs b/src/database/key_value/rooms/outlier.rs index c979d253..08299a0c 100644 --- a/src/database/key_value/rooms/outlier.rs +++ b/src/database/key_value/rooms/outlier.rs @@ -1,4 +1,8 @@ -impl service::room::outlier::Data for KeyValueDatabase { +use ruma::{EventId, signatures::CanonicalJsonObject}; + +use crate::{service, database::KeyValueDatabase, PduEvent, Error}; + +impl service::rooms::outlier::Data for KeyValueDatabase { fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result> { self.eventid_outlierpdu .get(event_id.as_bytes())? diff --git a/src/database/key_value/rooms/pdu_metadata.rs b/src/database/key_value/rooms/pdu_metadata.rs index 6b2171ca..602f3f6c 100644 --- a/src/database/key_value/rooms/pdu_metadata.rs +++ b/src/database/key_value/rooms/pdu_metadata.rs @@ -1,4 +1,10 @@ -impl service::room::pdu_metadata::Data for KeyValueDatabase { +use std::sync::Arc; + +use ruma::{RoomId, EventId}; + +use crate::{service, database::KeyValueDatabase}; + +impl service::rooms::pdu_metadata::Data for KeyValueDatabase { fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) -> Result<()> { for prev in event_ids { let mut key = room_id.as_bytes().to_vec(); diff --git a/src/database/key_value/rooms/search.rs b/src/database/key_value/rooms/search.rs index 1ffffe56..44663ff3 100644 --- a/src/database/key_value/rooms/search.rs +++ b/src/database/key_value/rooms/search.rs @@ -1,7 +1,12 @@ -impl service::room::search::Data for KeyValueDatabase { +use std::mem::size_of; +use ruma::RoomId; + +use crate::{service, database::KeyValueDatabase, utils}; + +impl service::rooms::search::Data for KeyValueDatabase { fn index_pdu<'a>(&self, room_id: &RoomId, pdu_id: u64, message_body: String) -> Result<()> { - let mut batch = body + let mut batch = message_body .split_terminator(|c: char| !c.is_alphanumeric()) .filter(|s| !s.is_empty()) .filter(|word| word.len() <= 50) @@ -14,7 +19,7 @@ impl service::room::search::Data for KeyValueDatabase { (key, Vec::new()) }); - self.tokenids.insert_batch(&mut batch)?; + self.tokenids.insert_batch(&mut batch) } fn search_pdus<'a>( @@ -64,3 +69,4 @@ impl service::room::search::Data for KeyValueDatabase { ) })) } +} diff --git a/src/database/key_value/rooms/state.rs b/src/database/key_value/rooms/state.rs index 5daf6c6a..192dbb83 100644 --- a/src/database/key_value/rooms/state.rs +++ b/src/database/key_value/rooms/state.rs @@ -1,4 +1,11 @@ -impl service::room::state::Data for KeyValueDatabase { +use ruma::{RoomId, EventId}; +use std::sync::Arc; +use std::{sync::MutexGuard, collections::HashSet}; +use std::fmt::Debug; + +use crate::{service, database::KeyValueDatabase, utils, Error}; + +impl service::rooms::state::Data for KeyValueDatabase { fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result> { self.roomid_shortstatehash .get(room_id.as_bytes())? @@ -9,21 +16,21 @@ impl service::room::state::Data for KeyValueDatabase { }) } - fn set_room_state(&self, room_id: &RoomId, new_shortstatehash: u64 - _mutex_lock: &MutexGuard<'_, StateLock>, // Take mutex guard to make sure users get the room state mutex + fn set_room_state(&self, room_id: &RoomId, new_shortstatehash: u64, + _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex ) -> Result<()> { self.roomid_shortstatehash .insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?; Ok(()) } - fn set_event_state(&self) -> Result<()> { - db.shorteventid_shortstatehash + fn set_event_state(&self, shorteventid: Vec, shortstatehash: Vec) -> Result<()> { + self.shorteventid_shortstatehash .insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?; Ok(()) } - fn get_pdu_leaves(&self, room_id: &RoomId) -> Result>> { + fn get_forward_extremities(&self, room_id: &RoomId) -> Result>> { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); @@ -38,11 +45,11 @@ impl service::room::state::Data for KeyValueDatabase { .collect() } - fn set_forward_extremities( + fn set_forward_extremities<'a>( &self, room_id: &RoomId, event_ids: impl IntoIterator + Debug, - _mutex_lock: &MutexGuard<'_, StateLock>, // Take mutex guard to make sure users get the room state mutex + _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex ) -> Result<()> { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); diff --git a/src/database/key_value/rooms/state_accessor.rs b/src/database/key_value/rooms/state_accessor.rs index db81967d..ea15afc0 100644 --- a/src/database/key_value/rooms/state_accessor.rs +++ b/src/database/key_value/rooms/state_accessor.rs @@ -1,4 +1,11 @@ -impl service::room::state_accessor::Data for KeyValueDatabase { +use std::{collections::{BTreeMap, HashMap}, sync::Arc}; + +use crate::{database::KeyValueDatabase, service, PduEvent, Error, utils}; +use async_trait::async_trait; +use ruma::{EventId, events::StateEventType, RoomId}; + +#[async_trait] +impl service::rooms::state_accessor::Data for KeyValueDatabase { async fn state_full_ids(&self, shortstatehash: u64) -> Result>> { let full_state = self .load_shortstatehash_info(shortstatehash)? @@ -149,3 +156,4 @@ impl service::room::state_accessor::Data for KeyValueDatabase { Ok(None) } } +} diff --git a/src/database/key_value/rooms/state_cache.rs b/src/database/key_value/rooms/state_cache.rs index 37814020..567dc809 100644 --- a/src/database/key_value/rooms/state_cache.rs +++ b/src/database/key_value/rooms/state_cache.rs @@ -1,8 +1,12 @@ -impl service::room::state_cache::Data for KeyValueDatabase { - fn mark_as_once_joined(user_id: &UserId, room_id: &RoomId) -> Result<()> { +use ruma::{UserId, RoomId}; + +use crate::{service, database::KeyValueDatabase}; + +impl service::rooms::state_cache::Data for KeyValueDatabase { + fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xff); userroom_id.extend_from_slice(room_id.as_bytes()); - self.roomuseroncejoinedids.insert(&userroom_id, &[])?; + self.roomuseroncejoinedids.insert(&userroom_id, &[]) } } diff --git a/src/database/key_value/rooms/state_compressor.rs b/src/database/key_value/rooms/state_compressor.rs index 71a2f3a0..09e35660 100644 --- a/src/database/key_value/rooms/state_compressor.rs +++ b/src/database/key_value/rooms/state_compressor.rs @@ -1,11 +1,20 @@ -impl service::room::state_compressor::Data for KeyValueDatabase { - fn get_statediff(shortstatehash: u64) -> Result { +use std::{collections::HashSet, mem::size_of}; + +use crate::{service::{self, rooms::state_compressor::data::StateDiff}, database::KeyValueDatabase, Error, utils}; + +impl service::rooms::state_compressor::Data for KeyValueDatabase { + fn get_statediff(&self, shortstatehash: u64) -> Result { let value = self .shortstatehash_statediff .get(&shortstatehash.to_be_bytes())? .ok_or_else(|| Error::bad_database("State hash does not exist"))?; let parent = utils::u64_from_bytes(&value[0..size_of::()]).expect("bytes have right length"); + let parent = if parent != 0 { + Some(parent) + } else { + None + }; let mut add_mode = true; let mut added = HashSet::new(); @@ -26,10 +35,10 @@ impl service::room::state_compressor::Data for KeyValueDatabase { i += 2 * size_of::(); } - StateDiff { parent, added, removed } + Ok(StateDiff { parent, added, removed }) } - fn save_statediff(shortstatehash: u64, diff: StateDiff) -> Result<()> { + fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()> { let mut value = diff.parent.to_be_bytes().to_vec(); for new in &diff.new { value.extend_from_slice(&new[..]); @@ -43,6 +52,6 @@ impl service::room::state_compressor::Data for KeyValueDatabase { } self.shortstatehash_statediff - .insert(&shortstatehash.to_be_bytes(), &value)?; + .insert(&shortstatehash.to_be_bytes(), &value) } } diff --git a/src/database/key_value/rooms/timeline.rs b/src/database/key_value/rooms/timeline.rs index 58884ec3..cf93df12 100644 --- a/src/database/key_value/rooms/timeline.rs +++ b/src/database/key_value/rooms/timeline.rs @@ -1,4 +1,11 @@ -impl service::room::timeline::Data for KeyValueDatabase { +use std::{collections::hash_map, mem::size_of, sync::Arc}; + +use ruma::{UserId, RoomId, api::client::error::ErrorKind, EventId, signatures::CanonicalJsonObject}; +use tracing::error; + +use crate::{service, database::KeyValueDatabase, utils, Error, PduEvent}; + +impl service::rooms::timeline::Data for KeyValueDatabase { fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { match self .lasttimelinecount_cache @@ -37,7 +44,7 @@ impl service::room::timeline::Data for KeyValueDatabase { } /// Returns the json of a pdu. - pub fn get_pdu_json(&self, event_id: &EventId) -> Result> { + fn get_pdu_json(&self, event_id: &EventId) -> Result> { self.eventid_pduid .get(event_id.as_bytes())? .map_or_else( @@ -55,7 +62,7 @@ impl service::room::timeline::Data for KeyValueDatabase { } /// Returns the json of a pdu. - pub fn get_non_outlier_pdu_json( + fn get_non_outlier_pdu_json( &self, event_id: &EventId, ) -> Result> { @@ -74,14 +81,14 @@ impl service::room::timeline::Data for KeyValueDatabase { } /// Returns the pdu's id. - pub fn get_pdu_id(&self, event_id: &EventId) -> Result>> { + fn get_pdu_id(&self, event_id: &EventId) -> Result>> { self.eventid_pduid.get(event_id.as_bytes()) } /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. - pub fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result> { + fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result> { self.eventid_pduid .get(event_id.as_bytes())? .map(|pduid| { @@ -99,7 +106,7 @@ impl service::room::timeline::Data for KeyValueDatabase { /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. - pub fn get_pdu(&self, event_id: &EventId) -> Result>> { + fn get_pdu(&self, event_id: &EventId) -> Result>> { if let Some(p) = self.pdu_cache.lock().unwrap().get_mut(event_id) { return Ok(Some(Arc::clone(p))); } @@ -135,7 +142,7 @@ impl service::room::timeline::Data for KeyValueDatabase { /// Returns the pdu. /// /// This does __NOT__ check the outliers `Tree`. - pub fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result> { + fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result> { self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { Ok(Some( serde_json::from_slice(&pdu) @@ -145,7 +152,7 @@ impl service::room::timeline::Data for KeyValueDatabase { } /// Returns the pdu as a `BTreeMap`. - pub fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result> { + fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result> { self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { Ok(Some( serde_json::from_slice(&pdu) @@ -155,7 +162,7 @@ impl service::room::timeline::Data for KeyValueDatabase { } /// Returns the `count` of this pdu's id. - pub fn pdu_count(&self, pdu_id: &[u8]) -> Result { + fn pdu_count(&self, pdu_id: &[u8]) -> Result { utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::()..]) .map_err(|_| Error::bad_database("PDU has invalid count bytes.")) } @@ -178,7 +185,7 @@ impl service::room::timeline::Data for KeyValueDatabase { /// Returns an iterator over all events in a room that happened after the event with id `since` /// in chronological order. - pub fn pdus_since<'a>( + fn pdus_since<'a>( &'a self, user_id: &UserId, room_id: &RoomId, @@ -212,7 +219,7 @@ impl service::room::timeline::Data for KeyValueDatabase { /// Returns an iterator over all events and their tokens in a room that happened before the /// event with id `until` in reverse-chronological order. - pub fn pdus_until<'a>( + fn pdus_until<'a>( &'a self, user_id: &UserId, room_id: &RoomId, @@ -246,7 +253,7 @@ impl service::room::timeline::Data for KeyValueDatabase { })) } - pub fn pdus_after<'a>( + fn pdus_after<'a>( &'a self, user_id: &UserId, room_id: &RoomId, diff --git a/src/database/key_value/rooms/user.rs b/src/database/key_value/rooms/user.rs index 52145ced..2fc3b9f4 100644 --- a/src/database/key_value/rooms/user.rs +++ b/src/database/key_value/rooms/user.rs @@ -1,4 +1,8 @@ -impl service::room::user::Data for KeyValueDatabase { +use ruma::{UserId, RoomId}; + +use crate::{service, database::KeyValueDatabase, utils, Error}; + +impl service::rooms::user::Data for KeyValueDatabase { fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xff); diff --git a/src/database/key_value/transaction_ids.rs b/src/database/key_value/transaction_ids.rs index 81c1197d..6652a627 100644 --- a/src/database/key_value/transaction_ids.rs +++ b/src/database/key_value/transaction_ids.rs @@ -1,5 +1,9 @@ -impl service::pusher::Data for KeyValueDatabase { - pub fn add_txnid( +use ruma::{UserId, DeviceId, TransactionId}; + +use crate::{service, database::KeyValueDatabase}; + +impl service::transaction_ids::Data for KeyValueDatabase { + fn add_txnid( &self, user_id: &UserId, device_id: Option<&DeviceId>, @@ -17,7 +21,7 @@ impl service::pusher::Data for KeyValueDatabase { Ok(()) } - pub fn existing_txnid( + fn existing_txnid( &self, user_id: &UserId, device_id: Option<&DeviceId>, diff --git a/src/database/key_value/uiaa.rs b/src/database/key_value/uiaa.rs index 4d1dac57..b1960bd5 100644 --- a/src/database/key_value/uiaa.rs +++ b/src/database/key_value/uiaa.rs @@ -1,3 +1,9 @@ +use std::io::ErrorKind; + +use ruma::{UserId, DeviceId, signatures::CanonicalJsonValue, api::client::uiaa::UiaaInfo}; + +use crate::{database::KeyValueDatabase, service, Error}; + impl service::uiaa::Data for KeyValueDatabase { fn set_uiaa_request( &self, diff --git a/src/database/key_value/users.rs b/src/database/key_value/users.rs index 5ef058f3..ea844903 100644 --- a/src/database/key_value/users.rs +++ b/src/database/key_value/users.rs @@ -1,11 +1,18 @@ +use std::{mem::size_of, collections::BTreeMap}; + +use ruma::{api::client::{filter::IncomingFilterDefinition, error::ErrorKind, device::Device}, UserId, RoomAliasId, MxcUri, DeviceId, MilliSecondsSinceUnixEpoch, DeviceKeyId, encryption::{OneTimeKey, CrossSigningKey, DeviceKeys}, serde::Raw, events::{AnyToDeviceEvent, StateEventType}, DeviceKeyAlgorithm, UInt}; +use tracing::warn; + +use crate::{service::{self, users::clean_signatures}, database::KeyValueDatabase, Error, utils, services}; + impl service::users::Data for KeyValueDatabase { /// Check if a user has an account on this homeserver. - pub fn exists(&self, user_id: &UserId) -> Result { + fn exists(&self, user_id: &UserId) -> Result { Ok(self.userid_password.get(user_id.as_bytes())?.is_some()) } /// Check if account is deactivated - pub fn is_deactivated(&self, user_id: &UserId) -> Result { + fn is_deactivated(&self, user_id: &UserId) -> Result { Ok(self .userid_password .get(user_id.as_bytes())? @@ -16,33 +23,13 @@ impl service::users::Data for KeyValueDatabase { .is_empty()) } - /// Check if a user is an admin - pub fn is_admin( - &self, - user_id: &UserId, - rooms: &super::rooms::Rooms, - globals: &super::globals::Globals, - ) -> Result { - let admin_room_alias_id = RoomAliasId::parse(format!("#admins:{}", globals.server_name())) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid alias."))?; - let admin_room_id = rooms.id_from_alias(&admin_room_alias_id)?.unwrap(); - - rooms.is_joined(user_id, &admin_room_id) - } - - /// Create a new user account on this homeserver. - pub fn create(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { - self.set_password(user_id, password)?; - Ok(()) - } - /// Returns the number of users registered on this server. - pub fn count(&self) -> Result { + fn count(&self) -> Result { Ok(self.userid_password.iter().count()) } /// Find out which user an access token belongs to. - pub fn find_from_token(&self, token: &str) -> Result, String)>> { + fn find_from_token(&self, token: &str) -> Result, String)>> { self.token_userdeviceid .get(token.as_bytes())? .map_or(Ok(None), |bytes| { @@ -69,7 +56,7 @@ impl service::users::Data for KeyValueDatabase { } /// Returns an iterator over all users on this homeserver. - pub fn iter(&self) -> impl Iterator>> + '_ { + fn iter(&self) -> impl Iterator>> + '_ { self.userid_password.iter().map(|(bytes, _)| { UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| { Error::bad_database("User ID in userid_password is invalid unicode.") @@ -81,7 +68,7 @@ impl service::users::Data for KeyValueDatabase { /// Returns a list of local users as list of usernames. /// /// A user account is considered `local` if the length of it's password is greater then zero. - pub fn list_local_users(&self) -> Result> { + fn list_local_users(&self) -> Result> { let users: Vec = self .userid_password .iter() @@ -113,7 +100,7 @@ impl service::users::Data for KeyValueDatabase { } /// Returns the password hash for the given user. - pub fn password_hash(&self, user_id: &UserId) -> Result> { + fn password_hash(&self, user_id: &UserId) -> Result> { self.userid_password .get(user_id.as_bytes())? .map_or(Ok(None), |bytes| { @@ -124,7 +111,7 @@ impl service::users::Data for KeyValueDatabase { } /// Hash and set the user's password to the Argon2 hash - pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { + fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { if let Some(password) = password { if let Ok(hash) = utils::calculate_hash(password) { self.userid_password @@ -143,7 +130,7 @@ impl service::users::Data for KeyValueDatabase { } /// Returns the displayname of a user on this homeserver. - pub fn displayname(&self, user_id: &UserId) -> Result> { + fn displayname(&self, user_id: &UserId) -> Result> { self.userid_displayname .get(user_id.as_bytes())? .map_or(Ok(None), |bytes| { @@ -154,7 +141,7 @@ impl service::users::Data for KeyValueDatabase { } /// Sets a new displayname or removes it if displayname is None. You still need to nofify all rooms of this change. - pub fn set_displayname(&self, user_id: &UserId, displayname: Option) -> Result<()> { + fn set_displayname(&self, user_id: &UserId, displayname: Option) -> Result<()> { if let Some(displayname) = displayname { self.userid_displayname .insert(user_id.as_bytes(), displayname.as_bytes())?; @@ -166,7 +153,7 @@ impl service::users::Data for KeyValueDatabase { } /// Get the avatar_url of a user. - pub fn avatar_url(&self, user_id: &UserId) -> Result>> { + fn avatar_url(&self, user_id: &UserId) -> Result>> { self.userid_avatarurl .get(user_id.as_bytes())? .map(|bytes| { @@ -179,7 +166,7 @@ impl service::users::Data for KeyValueDatabase { } /// Sets a new avatar_url or removes it if avatar_url is None. - pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option>) -> Result<()> { + fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option>) -> Result<()> { if let Some(avatar_url) = avatar_url { self.userid_avatarurl .insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?; @@ -191,7 +178,7 @@ impl service::users::Data for KeyValueDatabase { } /// Get the blurhash of a user. - pub fn blurhash(&self, user_id: &UserId) -> Result> { + fn blurhash(&self, user_id: &UserId) -> Result> { self.userid_blurhash .get(user_id.as_bytes())? .map(|bytes| { @@ -204,7 +191,7 @@ impl service::users::Data for KeyValueDatabase { } /// Sets a new avatar_url or removes it if avatar_url is None. - pub fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()> { + fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()> { if let Some(blurhash) = blurhash { self.userid_blurhash .insert(user_id.as_bytes(), blurhash.as_bytes())?; @@ -216,7 +203,7 @@ impl service::users::Data for KeyValueDatabase { } /// Adds a new device to a user. - pub fn create_device( + fn create_device( &self, user_id: &UserId, device_id: &DeviceId, @@ -250,7 +237,7 @@ impl service::users::Data for KeyValueDatabase { } /// Removes a device from a user. - pub fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { + fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { let mut userdeviceid = user_id.as_bytes().to_vec(); userdeviceid.push(0xff); userdeviceid.extend_from_slice(device_id.as_bytes()); @@ -280,7 +267,7 @@ impl service::users::Data for KeyValueDatabase { } /// Returns an iterator over all device ids of this user. - pub fn all_device_ids<'a>( + fn all_device_ids<'a>( &'a self, user_id: &UserId, ) -> impl Iterator>> + 'a { @@ -302,7 +289,7 @@ impl service::users::Data for KeyValueDatabase { } /// Replaces the access token of one device. - pub fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { + fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { let mut userdeviceid = user_id.as_bytes().to_vec(); userdeviceid.push(0xff); userdeviceid.extend_from_slice(device_id.as_bytes()); @@ -325,13 +312,12 @@ impl service::users::Data for KeyValueDatabase { Ok(()) } - pub fn add_one_time_key( + fn add_one_time_key( &self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &DeviceKeyId, one_time_key_value: &Raw, - globals: &super::globals::Globals, ) -> Result<()> { let mut key = user_id.as_bytes().to_vec(); key.push(0xff); @@ -356,12 +342,12 @@ impl service::users::Data for KeyValueDatabase { )?; self.userid_lastonetimekeyupdate - .insert(user_id.as_bytes(), &globals.next_count()?.to_be_bytes())?; + .insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?; Ok(()) } - pub fn last_one_time_keys_update(&self, user_id: &UserId) -> Result { + fn last_one_time_keys_update(&self, user_id: &UserId) -> Result { self.userid_lastonetimekeyupdate .get(user_id.as_bytes())? .map(|bytes| { @@ -372,12 +358,11 @@ impl service::users::Data for KeyValueDatabase { .unwrap_or(Ok(0)) } - pub fn take_one_time_key( + fn take_one_time_key( &self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm, - globals: &super::globals::Globals, ) -> Result, Raw)>> { let mut prefix = user_id.as_bytes().to_vec(); prefix.push(0xff); @@ -388,7 +373,7 @@ impl service::users::Data for KeyValueDatabase { prefix.push(b':'); self.userid_lastonetimekeyupdate - .insert(user_id.as_bytes(), &globals.next_count()?.to_be_bytes())?; + .insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?; self.onetimekeyid_onetimekeys .scan_prefix(prefix) @@ -411,7 +396,7 @@ impl service::users::Data for KeyValueDatabase { .transpose() } - pub fn count_one_time_keys( + fn count_one_time_keys( &self, user_id: &UserId, device_id: &DeviceId, @@ -443,13 +428,11 @@ impl service::users::Data for KeyValueDatabase { Ok(counts) } - pub fn add_device_keys( + fn add_device_keys( &self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw, - rooms: &super::rooms::Rooms, - globals: &super::globals::Globals, ) -> Result<()> { let mut userdeviceid = user_id.as_bytes().to_vec(); userdeviceid.push(0xff); @@ -460,19 +443,17 @@ impl service::users::Data for KeyValueDatabase { &serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"), )?; - self.mark_device_key_update(user_id, rooms, globals)?; + self.mark_device_key_update(user_id)?; Ok(()) } - pub fn add_cross_signing_keys( + fn add_cross_signing_keys( &self, user_id: &UserId, master_key: &Raw, self_signing_key: &Option>, user_signing_key: &Option>, - rooms: &super::rooms::Rooms, - globals: &super::globals::Globals, ) -> Result<()> { // TODO: Check signatures @@ -575,19 +556,17 @@ impl service::users::Data for KeyValueDatabase { .insert(user_id.as_bytes(), &user_signing_key_key)?; } - self.mark_device_key_update(user_id, rooms, globals)?; + self.mark_device_key_update(user_id)?; Ok(()) } - pub fn sign_key( + fn sign_key( &self, target_id: &UserId, key_id: &str, signature: (String, String), sender_id: &UserId, - rooms: &super::rooms::Rooms, - globals: &super::globals::Globals, ) -> Result<()> { let mut key = target_id.as_bytes().to_vec(); key.push(0xff); @@ -619,12 +598,12 @@ impl service::users::Data for KeyValueDatabase { )?; // TODO: Should we notify about this change? - self.mark_device_key_update(target_id, rooms, globals)?; + self.mark_device_key_update(target_id)?; Ok(()) } - pub fn keys_changed<'a>( + fn keys_changed<'a>( &'a self, user_or_room_id: &str, from: u64, @@ -662,16 +641,14 @@ impl service::users::Data for KeyValueDatabase { }) } - pub fn mark_device_key_update( + fn mark_device_key_update( &self, user_id: &UserId, - rooms: &super::rooms::Rooms, - globals: &super::globals::Globals, ) -> Result<()> { - let count = globals.next_count()?.to_be_bytes(); - for room_id in rooms.rooms_joined(user_id).filter_map(|r| r.ok()) { + let count = services().globals.next_count()?.to_be_bytes(); + for room_id in services().rooms.rooms_joined(user_id).filter_map(|r| r.ok()) { // Don't send key updates to unencrypted rooms - if rooms + if services().rooms .room_state_get(&room_id, &StateEventType::RoomEncryption, "")? .is_none() { @@ -693,7 +670,7 @@ impl service::users::Data for KeyValueDatabase { Ok(()) } - pub fn get_device_keys( + fn get_device_keys( &self, user_id: &UserId, device_id: &DeviceId, @@ -709,7 +686,7 @@ impl service::users::Data for KeyValueDatabase { }) } - pub fn get_master_key bool>( + fn get_master_key bool>( &self, user_id: &UserId, allowed_signatures: F, @@ -730,7 +707,7 @@ impl service::users::Data for KeyValueDatabase { }) } - pub fn get_self_signing_key bool>( + fn get_self_signing_key bool>( &self, user_id: &UserId, allowed_signatures: F, @@ -751,7 +728,7 @@ impl service::users::Data for KeyValueDatabase { }) } - pub fn get_user_signing_key(&self, user_id: &UserId) -> Result>> { + fn get_user_signing_key(&self, user_id: &UserId) -> Result>> { self.userid_usersigningkeyid .get(user_id.as_bytes())? .map_or(Ok(None), |key| { @@ -763,20 +740,19 @@ impl service::users::Data for KeyValueDatabase { }) } - pub fn add_to_device_event( + fn add_to_device_event( &self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str, content: serde_json::Value, - globals: &super::globals::Globals, ) -> Result<()> { let mut key = target_user_id.as_bytes().to_vec(); key.push(0xff); key.extend_from_slice(target_device_id.as_bytes()); key.push(0xff); - key.extend_from_slice(&globals.next_count()?.to_be_bytes()); + key.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); let mut json = serde_json::Map::new(); json.insert("type".to_owned(), event_type.to_owned().into()); @@ -790,7 +766,7 @@ impl service::users::Data for KeyValueDatabase { Ok(()) } - pub fn get_to_device_events( + fn get_to_device_events( &self, user_id: &UserId, device_id: &DeviceId, @@ -812,7 +788,7 @@ impl service::users::Data for KeyValueDatabase { Ok(events) } - pub fn remove_to_device_events( + fn remove_to_device_events( &self, user_id: &UserId, device_id: &DeviceId, @@ -833,7 +809,7 @@ impl service::users::Data for KeyValueDatabase { .map(|(key, _)| { Ok::<_, Error>(( key.clone(), - utils::u64_from_bytes(&key[key.len() - mem::size_of::()..key.len()]) + utils::u64_from_bytes(&key[key.len() - size_of::()..key.len()]) .map_err(|_| Error::bad_database("ToDeviceId has invalid count bytes."))?, )) }) @@ -846,7 +822,7 @@ impl service::users::Data for KeyValueDatabase { Ok(()) } - pub fn update_device_metadata( + fn update_device_metadata( &self, user_id: &UserId, device_id: &DeviceId, @@ -871,7 +847,7 @@ impl service::users::Data for KeyValueDatabase { } /// Get device metadata. - pub fn get_device_metadata( + fn get_device_metadata( &self, user_id: &UserId, device_id: &DeviceId, @@ -889,7 +865,7 @@ impl service::users::Data for KeyValueDatabase { }) } - pub fn get_devicelist_version(&self, user_id: &UserId) -> Result> { + fn get_devicelist_version(&self, user_id: &UserId) -> Result> { self.userid_devicelistversion .get(user_id.as_bytes())? .map_or(Ok(None), |bytes| { @@ -899,7 +875,7 @@ impl service::users::Data for KeyValueDatabase { }) } - pub fn all_devices_metadata<'a>( + fn all_devices_metadata<'a>( &'a self, user_id: &UserId, ) -> impl Iterator> + 'a { @@ -915,7 +891,7 @@ impl service::users::Data for KeyValueDatabase { } /// Creates a new sync filter. Returns the filter id. - pub fn create_filter( + fn create_filter( &self, user_id: &UserId, filter: &IncomingFilterDefinition, @@ -934,7 +910,7 @@ impl service::users::Data for KeyValueDatabase { Ok(filter_id) } - pub fn get_filter( + fn get_filter( &self, user_id: &UserId, filter_id: &str, diff --git a/src/database/mod.rs b/src/database/mod.rs index a35228aa..12758af2 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1,20 +1,7 @@ pub mod abstraction; +pub mod key_value; -pub mod account_data; -pub mod admin; -pub mod appservice; -pub mod globals; -pub mod key_backups; -pub mod media; -pub mod pusher; -pub mod rooms; -pub mod sending; -pub mod transaction_ids; -pub mod uiaa; -pub mod users; - -use self::admin::create_admin_room; -use crate::{utils, Config, Error, Result}; +use crate::{utils, Config, Error, Result, service::{users, globals, uiaa, rooms, account_data, media, key_backups, transaction_ids, sending, admin::{self, create_admin_room}, appservice, pusher}}; use abstraction::KeyValueDatabaseEngine; use directories::ProjectDirs; use futures_util::{stream::FuturesUnordered, StreamExt}; @@ -25,7 +12,7 @@ use ruma::{ GlobalAccountDataEvent, GlobalAccountDataEventType, }, push::Ruleset, - DeviceId, EventId, RoomId, UserId, + DeviceId, EventId, RoomId, UserId, signatures::CanonicalJsonValue, }; use std::{ collections::{BTreeMap, HashMap, HashSet}, @@ -38,21 +25,132 @@ use std::{ }; use tokio::sync::{mpsc, OwnedRwLockReadGuard, RwLock as TokioRwLock, Semaphore}; use tracing::{debug, error, info, warn}; +use abstraction::KvTree; pub struct KeyValueDatabase { _db: Arc, - pub globals: globals::Globals, - pub users: users::Users, - pub uiaa: uiaa::Uiaa, - pub rooms: rooms::Rooms, - pub account_data: account_data::AccountData, - pub media: media::Media, - pub key_backups: key_backups::KeyBackups, - pub transaction_ids: transaction_ids::TransactionIds, - pub sending: sending::Sending, - pub admin: admin::Admin, - pub appservice: appservice::Appservice, - pub pusher: pusher::PushData, + + //pub globals: globals::Globals, + pub(super) global: Arc, + pub(super) server_signingkeys: Arc, + + //pub users: users::Users, + pub(super) userid_password: Arc, + pub(super) userid_displayname: Arc, + pub(super) userid_avatarurl: Arc, + pub(super) userid_blurhash: Arc, + pub(super) userdeviceid_token: Arc, + pub(super) userdeviceid_metadata: Arc, // This is also used to check if a device exists + pub(super) userid_devicelistversion: Arc, // DevicelistVersion = u64 + pub(super) token_userdeviceid: Arc, + + pub(super) onetimekeyid_onetimekeys: Arc, // OneTimeKeyId = UserId + DeviceKeyId + pub(super) userid_lastonetimekeyupdate: Arc, // LastOneTimeKeyUpdate = Count + pub(super) keychangeid_userid: Arc, // KeyChangeId = UserId/RoomId + Count + pub(super) keyid_key: Arc, // KeyId = UserId + KeyId (depends on key type) + pub(super) userid_masterkeyid: Arc, + pub(super) userid_selfsigningkeyid: Arc, + pub(super) userid_usersigningkeyid: Arc, + + pub(super) userfilterid_filter: Arc, // UserFilterId = UserId + FilterId + + pub(super) todeviceid_events: Arc, // ToDeviceId = UserId + DeviceId + Count + + //pub uiaa: uiaa::Uiaa, + pub(super) userdevicesessionid_uiaainfo: Arc, // User-interactive authentication + pub(super) userdevicesessionid_uiaarequest: + RwLock, Box, String), CanonicalJsonValue>>, + + //pub edus: RoomEdus, + pub(super) readreceiptid_readreceipt: Arc, // ReadReceiptId = RoomId + Count + UserId + pub(super) roomuserid_privateread: Arc, // RoomUserId = Room + User, PrivateRead = Count + pub(super) roomuserid_lastprivatereadupdate: Arc, // LastPrivateReadUpdate = Count + pub(super) typingid_userid: Arc, // TypingId = RoomId + TimeoutTime + Count + pub(super) roomid_lasttypingupdate: Arc, // LastRoomTypingUpdate = Count + pub(super) presenceid_presence: Arc, // PresenceId = RoomId + Count + UserId + pub(super) userid_lastpresenceupdate: Arc, // LastPresenceUpdate = Count + + //pub rooms: rooms::Rooms, + pub(super) pduid_pdu: Arc, // PduId = ShortRoomId + Count + pub(super) eventid_pduid: Arc, + pub(super) roomid_pduleaves: Arc, + pub(super) alias_roomid: Arc, + pub(super) aliasid_alias: Arc, // AliasId = RoomId + Count + pub(super) publicroomids: Arc, + + pub(super) tokenids: Arc, // TokenId = ShortRoomId + Token + PduIdCount + + /// Participating servers in a room. + pub(super) roomserverids: Arc, // RoomServerId = RoomId + ServerName + pub(super) serverroomids: Arc, // ServerRoomId = ServerName + RoomId + + pub(super) userroomid_joined: Arc, + pub(super) roomuserid_joined: Arc, + pub(super) roomid_joinedcount: Arc, + pub(super) roomid_invitedcount: Arc, + pub(super) roomuseroncejoinedids: Arc, + pub(super) userroomid_invitestate: Arc, // InviteState = Vec> + pub(super) roomuserid_invitecount: Arc, // InviteCount = Count + pub(super) userroomid_leftstate: Arc, + pub(super) roomuserid_leftcount: Arc, + + pub(super) disabledroomids: Arc, // Rooms where incoming federation handling is disabled + + pub(super) lazyloadedids: Arc, // LazyLoadedIds = UserId + DeviceId + RoomId + LazyLoadedUserId + + pub(super) userroomid_notificationcount: Arc, // NotifyCount = u64 + pub(super) userroomid_highlightcount: Arc, // HightlightCount = u64 + + /// Remember the current state hash of a room. + pub(super) roomid_shortstatehash: Arc, + pub(super) roomsynctoken_shortstatehash: Arc, + /// Remember the state hash at events in the past. + pub(super) shorteventid_shortstatehash: Arc, + /// StateKey = EventType + StateKey, ShortStateKey = Count + pub(super) statekey_shortstatekey: Arc, + pub(super) shortstatekey_statekey: Arc, + + pub(super) roomid_shortroomid: Arc, + + pub(super) shorteventid_eventid: Arc, + pub(super) eventid_shorteventid: Arc, + + pub(super) statehash_shortstatehash: Arc, + pub(super) shortstatehash_statediff: Arc, // StateDiff = parent (or 0) + (shortstatekey+shorteventid++) + 0_u64 + (shortstatekey+shorteventid--) + + pub(super) shorteventid_authchain: Arc, + + /// RoomId + EventId -> outlier PDU. + /// Any pdu that has passed the steps 1-8 in the incoming event /federation/send/txn. + pub(super) eventid_outlierpdu: Arc, + pub(super) softfailedeventids: Arc, + + /// RoomId + EventId -> Parent PDU EventId. + pub(super) referencedevents: Arc, + + //pub account_data: account_data::AccountData, + pub(super) roomuserdataid_accountdata: Arc, // RoomUserDataId = Room + User + Count + Type + pub(super) roomusertype_roomuserdataid: Arc, // RoomUserType = Room + User + Type + + //pub media: media::Media, + pub(super) mediaid_file: Arc, // MediaId = MXC + WidthHeight + ContentDisposition + ContentType + //pub key_backups: key_backups::KeyBackups, + pub(super) backupid_algorithm: Arc, // BackupId = UserId + Version(Count) + pub(super) backupid_etag: Arc, // BackupId = UserId + Version(Count) + pub(super) backupkeyid_backup: Arc, // BackupKeyId = UserId + Version + RoomId + SessionId + + //pub transaction_ids: transaction_ids::TransactionIds, + pub(super) userdevicetxnid_response: Arc, // Response can be empty (/sendToDevice) or the event id (/send) + //pub sending: sending::Sending, + pub(super) servername_educount: Arc, // EduCount: Count of last EDU sync + pub(super) servernameevent_data: Arc, // ServernameEvent = (+ / $)SenderKey / ServerName / UserId + PduId / Id (for edus), Data = EDU content + pub(super) servercurrentevent_data: Arc, // ServerCurrentEvents = (+ / $)ServerName / UserId + PduId / Id (for edus), Data = EDU content + + //pub appservice: appservice::Appservice, + pub(super) id_appserviceregistrations: Arc, + + //pub pusher: pusher::PushData, + pub(super) senderkey_pusher: Arc, } impl KeyValueDatabase { @@ -157,7 +255,6 @@ impl KeyValueDatabase { let db = Arc::new(TokioRwLock::from(Self { _db: builder.clone(), - users: users::Users { userid_password: builder.open_tree("userid_password")?, userid_displayname: builder.open_tree("userid_displayname")?, userid_avatarurl: builder.open_tree("userid_avatarurl")?, @@ -175,13 +272,9 @@ impl KeyValueDatabase { userid_usersigningkeyid: builder.open_tree("userid_usersigningkeyid")?, userfilterid_filter: builder.open_tree("userfilterid_filter")?, todeviceid_events: builder.open_tree("todeviceid_events")?, - }, - uiaa: uiaa::Uiaa { + userdevicesessionid_uiaainfo: builder.open_tree("userdevicesessionid_uiaainfo")?, userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()), - }, - rooms: rooms::Rooms { - edus: rooms::RoomEdus { readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt")?, roomuserid_privateread: builder.open_tree("roomuserid_privateread")?, // "Private" read receipt roomuserid_lastprivatereadupdate: builder @@ -190,7 +283,6 @@ impl KeyValueDatabase { roomid_lasttypingupdate: builder.open_tree("roomid_lasttypingupdate")?, presenceid_presence: builder.open_tree("presenceid_presence")?, userid_lastpresenceupdate: builder.open_tree("userid_lastpresenceupdate")?, - }, pduid_pdu: builder.open_tree("pduid_pdu")?, eventid_pduid: builder.open_tree("eventid_pduid")?, roomid_pduleaves: builder.open_tree("roomid_pduleaves")?, @@ -239,74 +331,23 @@ impl KeyValueDatabase { softfailedeventids: builder.open_tree("softfailedeventids")?, referencedevents: builder.open_tree("referencedevents")?, - pdu_cache: Mutex::new(LruCache::new( - config - .pdu_cache_capacity - .try_into() - .expect("pdu cache capacity fits into usize"), - )), - auth_chain_cache: Mutex::new(LruCache::new( - (100_000.0 * config.conduit_cache_capacity_modifier) as usize, - )), - shorteventid_cache: Mutex::new(LruCache::new( - (100_000.0 * config.conduit_cache_capacity_modifier) as usize, - )), - eventidshort_cache: Mutex::new(LruCache::new( - (100_000.0 * config.conduit_cache_capacity_modifier) as usize, - )), - shortstatekey_cache: Mutex::new(LruCache::new( - (100_000.0 * config.conduit_cache_capacity_modifier) as usize, - )), - statekeyshort_cache: Mutex::new(LruCache::new( - (100_000.0 * config.conduit_cache_capacity_modifier) as usize, - )), - our_real_users_cache: RwLock::new(HashMap::new()), - appservice_in_room_cache: RwLock::new(HashMap::new()), - lazy_load_waiting: Mutex::new(HashMap::new()), - stateinfo_cache: Mutex::new(LruCache::new( - (100.0 * config.conduit_cache_capacity_modifier) as usize, - )), - lasttimelinecount_cache: Mutex::new(HashMap::new()), - }, - account_data: account_data::AccountData { roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?, roomusertype_roomuserdataid: builder.open_tree("roomusertype_roomuserdataid")?, - }, - media: media::Media { mediaid_file: builder.open_tree("mediaid_file")?, - }, - key_backups: key_backups::KeyBackups { backupid_algorithm: builder.open_tree("backupid_algorithm")?, backupid_etag: builder.open_tree("backupid_etag")?, backupkeyid_backup: builder.open_tree("backupkeyid_backup")?, - }, - transaction_ids: transaction_ids::TransactionIds { userdevicetxnid_response: builder.open_tree("userdevicetxnid_response")?, - }, - sending: sending::Sending { servername_educount: builder.open_tree("servername_educount")?, servernameevent_data: builder.open_tree("servernameevent_data")?, servercurrentevent_data: builder.open_tree("servercurrentevent_data")?, - maximum_requests: Arc::new(Semaphore::new(config.max_concurrent_requests as usize)), - sender: sending_sender, - }, - admin: admin::Admin { - sender: admin_sender, - }, - appservice: appservice::Appservice { - cached_registrations: Arc::new(RwLock::new(HashMap::new())), id_appserviceregistrations: builder.open_tree("id_appserviceregistrations")?, - }, - pusher: pusher::PushData { senderkey_pusher: builder.open_tree("senderkey_pusher")?, - }, - globals: globals::Globals::load( - builder.open_tree("global")?, - builder.open_tree("server_signingkeys")?, - config.clone(), - )?, + global: builder.open_tree("global")?, + server_signingkeys: builder.open_tree("server_signingkeys")?, })); + // TODO: do this after constructing the db let guard = db.read().await; // Matrix resource ownership is based on the server name; changing it @@ -744,7 +785,7 @@ impl KeyValueDatabase { .bump_database_version(latest_database_version)?; // Create the admin room and server user on first run - create_admin_room(&guard).await?; + create_admin_room().await?; warn!( "Created new {} database with version {}", diff --git a/src/lib.rs b/src/lib.rs index c35a1293..0d058df3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,17 +9,26 @@ mod config; mod database; -mod error; -mod pdu; -mod ruma_wrapper; +mod service; +pub mod api; mod utils; -pub mod appservice_server; -pub mod client_server; -pub mod server_server; +use std::cell::Cell; pub use config::Config; -pub use database::Database; -pub use error::{Error, Result}; -pub use pdu::PduEvent; -pub use ruma_wrapper::{Ruma, RumaResponse}; +pub use utils::error::{Error, Result}; +pub use service::{Services, pdu::PduEvent}; +pub use api::ruma_wrapper::{Ruma, RumaResponse}; + +use crate::database::KeyValueDatabase; + +pub static SERVICES: Cell> = Cell::new(None); + +enum ServicesEnum { + Rocksdb(Services) +} + +pub fn services() -> Services { + SERVICES.get().unwrap() +} + diff --git a/src/main.rs b/src/main.rs index a1af9761..543b953e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -46,47 +46,44 @@ use tikv_jemallocator::Jemalloc; #[global_allocator] static GLOBAL: Jemalloc = Jemalloc; -lazy_static! { - static ref DB: Database = { - let raw_config = - Figment::new() - .merge( - Toml::file(Env::var("CONDUIT_CONFIG").expect( - "The CONDUIT_CONFIG env var needs to be set. Example: /etc/conduit.toml", - )) - .nested(), - ) - .merge(Env::prefixed("CONDUIT_").global()); - - let config = match raw_config.extract::() { - Ok(s) => s, - Err(e) => { - eprintln!("It looks like your config is invalid. The following error occured while parsing it: {}", e); - std::process::exit(1); - } - }; - - config.warn_deprecated(); - - let db = match Database::load_or_create(&config).await { - Ok(db) => db, - Err(e) => { - eprintln!( - "The database couldn't be loaded or created. The following error occured: {}", - e - ); - std::process::exit(1); - } - }; - }; -} - #[tokio::main] async fn main() { - lazy_static::initialize(&DB); + // Initialize DB + let raw_config = + Figment::new() + .merge( + Toml::file(Env::var("CONDUIT_CONFIG").expect( + "The CONDUIT_CONFIG env var needs to be set. Example: /etc/conduit.toml", + )) + .nested(), + ) + .merge(Env::prefixed("CONDUIT_").global()); + + let config = match raw_config.extract::() { + Ok(s) => s, + Err(e) => { + eprintln!("It looks like your config is invalid. The following error occured while parsing it: {}", e); + std::process::exit(1); + } + }; + + config.warn_deprecated(); + + let db = match KeyValueDatabase::load_or_create(&config).await { + Ok(db) => db, + Err(e) => { + eprintln!( + "The database couldn't be loaded or created. The following error occured: {}", + e + ); + std::process::exit(1); + } + }; + + SERVICES.set(db).expect("this is the first and only time we initialize the SERVICE static"); let start = async { - run_server(&config).await.unwrap(); + run_server().await.unwrap(); }; if config.allow_jaeger { diff --git a/src/service/account_data.rs b/src/service/account_data.rs index d85918f6..70ad9f2a 100644 --- a/src/service/account_data.rs +++ b/src/service/account_data.rs @@ -8,23 +8,15 @@ use ruma::{ use serde::{de::DeserializeOwned, Serialize}; use std::{collections::HashMap, sync::Arc}; -use super::abstraction::Tree; - -pub struct AccountData { - pub(super) roomuserdataid_accountdata: Arc, // RoomUserDataId = Room + User + Count + Type - pub(super) roomusertype_roomuserdataid: Arc, // RoomUserType = Room + User + Type -} - impl AccountData { /// Places one event in the account data of the user and removes the previous entry. - #[tracing::instrument(skip(self, room_id, user_id, event_type, data, globals))] + #[tracing::instrument(skip(self, room_id, user_id, event_type, data))] pub fn update( &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, data: &T, - globals: &super::globals::Globals, ) -> Result<()> { let mut prefix = room_id .map(|r| r.to_string()) @@ -36,7 +28,7 @@ impl AccountData { prefix.push(0xff); let mut roomuserdataid = prefix.clone(); - roomuserdataid.extend_from_slice(&globals.next_count()?.to_be_bytes()); + roomuserdataid.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); roomuserdataid.push(0xff); roomuserdataid.extend_from_slice(event_type.to_string().as_bytes()); diff --git a/src/service/admin.rs b/src/service/admin.rs index 6f418ea8..ded0adb9 100644 --- a/src/service/admin.rs +++ b/src/service/admin.rs @@ -5,14 +5,6 @@ use std::{ time::Instant, }; -use crate::{ - client_server::AUTO_GEN_PASSWORD_LENGTH, - error::{Error, Result}, - pdu::PduBuilder, - server_server, utils, - utils::HtmlEscape, - Database, PduEvent, -}; use clap::Parser; use regex::Regex; use ruma::{ @@ -36,6 +28,10 @@ use ruma::{ use serde_json::value::to_raw_value; use tokio::sync::{mpsc, MutexGuard, RwLock, RwLockReadGuard}; +use crate::{services, Error, api::{server_server, client_server::AUTO_GEN_PASSWORD_LENGTH}, PduEvent, utils::{HtmlEscape, self}}; + +use super::pdu::PduBuilder; + #[derive(Debug)] pub enum AdminRoomEvent { ProcessMessage(String), @@ -50,22 +46,19 @@ pub struct Admin { impl Admin { pub fn start_handler( &self, - db: Arc>, mut receiver: mpsc::UnboundedReceiver, ) { tokio::spawn(async move { // TODO: Use futures when we have long admin commands //let mut futures = FuturesUnordered::new(); - let guard = db.read().await; - - let conduit_user = UserId::parse(format!("@conduit:{}", guard.globals.server_name())) + let conduit_user = UserId::parse(format!("@conduit:{}", services().globals.server_name())) .expect("@conduit:server_name is valid"); - let conduit_room = guard + let conduit_room = services() .rooms .id_from_alias( - format!("#admins:{}", guard.globals.server_name()) + format!("#admins:{}", services().globals.server_name()) .as_str() .try_into() .expect("#admins:server_name is a valid room alias"), @@ -73,12 +66,9 @@ impl Admin { .expect("Database data for admin room alias must be valid") .expect("Admin room must exist"); - drop(guard); - let send_message = |message: RoomMessageEventContent, - guard: RwLockReadGuard<'_, Database>, mutex_lock: &MutexGuard<'_, ()>| { - guard + services() .rooms .build_and_append_pdu( PduBuilder { @@ -91,7 +81,6 @@ impl Admin { }, &conduit_user, &conduit_room, - &guard, mutex_lock, ) .unwrap(); @@ -100,15 +89,13 @@ impl Admin { loop { tokio::select! { Some(event) = receiver.recv() => { - let guard = db.read().await; - let message_content = match event { AdminRoomEvent::SendMessage(content) => content, - AdminRoomEvent::ProcessMessage(room_message) => process_admin_message(&*guard, room_message).await + AdminRoomEvent::ProcessMessage(room_message) => process_admin_message(room_message).await }; let mutex_state = Arc::clone( - guard.globals + services().globals .roomid_mutex_state .write() .unwrap() @@ -118,7 +105,7 @@ impl Admin { let state_lock = mutex_state.lock().await; - send_message(message_content, guard, &state_lock); + send_message(message_content, &state_lock); drop(state_lock); } @@ -141,7 +128,7 @@ impl Admin { } // Parse and process a message from the admin room -async fn process_admin_message(db: &Database, room_message: String) -> RoomMessageEventContent { +async fn process_admin_message(room_message: String) -> RoomMessageEventContent { let mut lines = room_message.lines(); let command_line = lines.next().expect("each string has at least one line"); let body: Vec<_> = lines.collect(); @@ -149,7 +136,7 @@ async fn process_admin_message(db: &Database, room_message: String) -> RoomMessa let admin_command = match parse_admin_command(&command_line) { Ok(command) => command, Err(error) => { - let server_name = db.globals.server_name(); + let server_name = services().globals.server_name(); let message = error .to_string() .replace("server.name", server_name.as_str()); @@ -159,7 +146,7 @@ async fn process_admin_message(db: &Database, room_message: String) -> RoomMessa } }; - match process_admin_command(db, admin_command, body).await { + match process_admin_command(admin_command, body).await { Ok(reply_message) => reply_message, Err(error) => { let markdown_message = format!( @@ -322,7 +309,6 @@ enum AdminCommand { } async fn process_admin_command( - db: &Database, command: AdminCommand, body: Vec<&str>, ) -> Result { @@ -332,7 +318,7 @@ async fn process_admin_command( let appservice_config = body[1..body.len() - 1].join("\n"); let parsed_config = serde_yaml::from_str::(&appservice_config); match parsed_config { - Ok(yaml) => match db.appservice.register_appservice(yaml) { + Ok(yaml) => match services().appservice.register_appservice(yaml) { Ok(id) => RoomMessageEventContent::text_plain(format!( "Appservice registered with ID: {}.", id @@ -355,7 +341,7 @@ async fn process_admin_command( } AdminCommand::UnregisterAppservice { appservice_identifier, - } => match db.appservice.unregister_appservice(&appservice_identifier) { + } => match services().appservice.unregister_appservice(&appservice_identifier) { Ok(()) => RoomMessageEventContent::text_plain("Appservice unregistered."), Err(e) => RoomMessageEventContent::text_plain(format!( "Failed to unregister appservice: {}", @@ -363,7 +349,7 @@ async fn process_admin_command( )), }, AdminCommand::ListAppservices => { - if let Ok(appservices) = db.appservice.iter_ids().map(|ids| ids.collect::>()) { + if let Ok(appservices) = services().appservice.iter_ids().map(|ids| ids.collect::>()) { let count = appservices.len(); let output = format!( "Appservices ({}): {}", @@ -380,14 +366,14 @@ async fn process_admin_command( } } AdminCommand::ListRooms => { - let room_ids = db.rooms.iter_ids(); + let room_ids = services().rooms.iter_ids(); let output = format!( "Rooms:\n{}", room_ids .filter_map(|r| r.ok()) .map(|id| id.to_string() + "\tMembers: " - + &db + + &services() .rooms .room_joined_count(&id) .ok() @@ -399,7 +385,7 @@ async fn process_admin_command( ); RoomMessageEventContent::text_plain(output) } - AdminCommand::ListLocalUsers => match db.users.list_local_users() { + AdminCommand::ListLocalUsers => match services().users.list_local_users() { Ok(users) => { let mut msg: String = format!("Found {} local user account(s):\n", users.len()); msg += &users.join("\n"); @@ -408,7 +394,7 @@ async fn process_admin_command( Err(e) => RoomMessageEventContent::text_plain(e.to_string()), }, AdminCommand::IncomingFederation => { - let map = db.globals.roomid_federationhandletime.read().unwrap(); + let map = services().globals.roomid_federationhandletime.read().unwrap(); let mut msg: String = format!("Handling {} incoming pdus:\n", map.len()); for (r, (e, i)) in map.iter() { @@ -425,7 +411,7 @@ async fn process_admin_command( } AdminCommand::GetAuthChain { event_id } => { let event_id = Arc::::from(event_id); - if let Some(event) = db.rooms.get_pdu_json(&event_id)? { + if let Some(event) = services().rooms.get_pdu_json(&event_id)? { let room_id_str = event .get("room_id") .and_then(|val| val.as_str()) @@ -435,7 +421,7 @@ async fn process_admin_command( Error::bad_database("Invalid room id field in event in database") })?; let start = Instant::now(); - let count = server_server::get_auth_chain(room_id, vec![event_id], db) + let count = server_server::get_auth_chain(room_id, vec![event_id]) .await? .count(); let elapsed = start.elapsed(); @@ -486,10 +472,10 @@ async fn process_admin_command( } AdminCommand::GetPdu { event_id } => { let mut outlier = false; - let mut pdu_json = db.rooms.get_non_outlier_pdu_json(&event_id)?; + let mut pdu_json = services().rooms.get_non_outlier_pdu_json(&event_id)?; if pdu_json.is_none() { outlier = true; - pdu_json = db.rooms.get_pdu_json(&event_id)?; + pdu_json = services().rooms.get_pdu_json(&event_id)?; } match pdu_json { Some(json) => { @@ -519,7 +505,7 @@ async fn process_admin_command( None => RoomMessageEventContent::text_plain("PDU not found."), } } - AdminCommand::DatabaseMemoryUsage => match db._db.memory_usage() { + AdminCommand::DatabaseMemoryUsage => match services()._db.memory_usage() { Ok(response) => RoomMessageEventContent::text_plain(response), Err(e) => RoomMessageEventContent::text_plain(format!( "Failed to get database memory usage: {}", @@ -528,12 +514,12 @@ async fn process_admin_command( }, AdminCommand::ShowConfig => { // Construct and send the response - RoomMessageEventContent::text_plain(format!("{}", db.globals.config)) + RoomMessageEventContent::text_plain(format!("{}", services().globals.config)) } AdminCommand::ResetPassword { username } => { let user_id = match UserId::parse_with_server_name( username.as_str().to_lowercase(), - db.globals.server_name(), + services().globals.server_name(), ) { Ok(id) => id, Err(e) => { @@ -545,10 +531,10 @@ async fn process_admin_command( }; // Check if the specified user is valid - if !db.users.exists(&user_id)? - || db.users.is_deactivated(&user_id)? + if !services().users.exists(&user_id)? + || services().users.is_deactivated(&user_id)? || user_id - == UserId::parse_with_server_name("conduit", db.globals.server_name()) + == UserId::parse_with_server_name("conduit", services().globals.server_name()) .expect("conduit user exists") { return Ok(RoomMessageEventContent::text_plain( @@ -558,7 +544,7 @@ async fn process_admin_command( let new_password = utils::random_string(AUTO_GEN_PASSWORD_LENGTH); - match db.users.set_password(&user_id, Some(new_password.as_str())) { + match services().users.set_password(&user_id, Some(new_password.as_str())) { Ok(()) => RoomMessageEventContent::text_plain(format!( "Successfully reset the password for user {}: {}", user_id, new_password @@ -574,7 +560,7 @@ async fn process_admin_command( // Validate user id let user_id = match UserId::parse_with_server_name( username.as_str().to_lowercase(), - db.globals.server_name(), + services().globals.server_name(), ) { Ok(id) => id, Err(e) => { @@ -589,21 +575,21 @@ async fn process_admin_command( "userid {user_id} is not allowed due to historical" ))); } - if db.users.exists(&user_id)? { + if services().users.exists(&user_id)? { return Ok(RoomMessageEventContent::text_plain(format!( "userid {user_id} already exists" ))); } // Create user - db.users.create(&user_id, Some(password.as_str()))?; + services().users.create(&user_id, Some(password.as_str()))?; // Default to pretty displayname let displayname = format!("{} ⚡️", user_id.localpart()); - db.users + services().users .set_displayname(&user_id, Some(displayname.clone()))?; // Initial account data - db.account_data.update( + services().account_data.update( None, &user_id, ruma::events::GlobalAccountDataEventType::PushRules @@ -614,24 +600,21 @@ async fn process_admin_command( global: ruma::push::Ruleset::server_default(&user_id), }, }, - &db.globals, )?; // we dont add a device since we're not the user, just the creator - db.flush()?; - // Inhibit login does not work for guests RoomMessageEventContent::text_plain(format!( "Created user with user_id: {user_id} and password: {password}" )) } AdminCommand::DisableRoom { room_id } => { - db.rooms.disabledroomids.insert(room_id.as_bytes(), &[])?; + services().rooms.disabledroomids.insert(room_id.as_bytes(), &[])?; RoomMessageEventContent::text_plain("Room disabled.") } AdminCommand::EnableRoom { room_id } => { - db.rooms.disabledroomids.remove(room_id.as_bytes())?; + services().rooms.disabledroomids.remove(room_id.as_bytes())?; RoomMessageEventContent::text_plain("Room enabled.") } AdminCommand::DeactivateUser { @@ -639,16 +622,16 @@ async fn process_admin_command( user_id, } => { let user_id = Arc::::from(user_id); - if db.users.exists(&user_id)? { + if services().users.exists(&user_id)? { RoomMessageEventContent::text_plain(format!( "Making {} leave all rooms before deactivation...", user_id )); - db.users.deactivate_account(&user_id)?; + services().users.deactivate_account(&user_id)?; if leave_rooms { - db.rooms.leave_all_rooms(&user_id, &db).await?; + services().rooms.leave_all_rooms(&user_id).await?; } RoomMessageEventContent::text_plain(format!( @@ -685,7 +668,7 @@ async fn process_admin_command( if !force { user_ids.retain(|&user_id| { - match db.users.is_admin(user_id, &db.rooms, &db.globals) { + match services().users.is_admin(user_id) { Ok(is_admin) => match is_admin { true => { admins.push(user_id.localpart()); @@ -699,7 +682,7 @@ async fn process_admin_command( } for &user_id in &user_ids { - match db.users.deactivate_account(user_id) { + match services().users.deactivate_account(user_id) { Ok(_) => deactivation_count += 1, Err(_) => {} } @@ -707,7 +690,7 @@ async fn process_admin_command( if leave_rooms { for &user_id in &user_ids { - let _ = db.rooms.leave_all_rooms(user_id, &db).await; + let _ = services().rooms.leave_all_rooms(user_id).await; } } @@ -814,13 +797,13 @@ fn usage_to_html(text: &str, server_name: &ServerName) -> String { /// /// Users in this room are considered admins by conduit, and the room can be /// used to issue admin commands by talking to the server user inside it. -pub(crate) async fn create_admin_room(db: &Database) -> Result<()> { - let room_id = RoomId::new(db.globals.server_name()); +pub(crate) async fn create_admin_room() -> Result<()> { + let room_id = RoomId::new(services().globals.server_name()); - db.rooms.get_or_create_shortroomid(&room_id, &db.globals)?; + services().rooms.get_or_create_shortroomid(&room_id)?; let mutex_state = Arc::clone( - db.globals + services().globals .roomid_mutex_state .write() .unwrap() @@ -830,10 +813,10 @@ pub(crate) async fn create_admin_room(db: &Database) -> Result<()> { let state_lock = mutex_state.lock().await; // Create a user for the server - let conduit_user = UserId::parse_with_server_name("conduit", db.globals.server_name()) + let conduit_user = UserId::parse_with_server_name("conduit", services().globals.server_name()) .expect("@conduit:server_name is valid"); - db.users.create(&conduit_user, None)?; + services().users.create(&conduit_user, None)?; let mut content = RoomCreateEventContent::new(conduit_user.clone()); content.federate = true; @@ -841,7 +824,7 @@ pub(crate) async fn create_admin_room(db: &Database) -> Result<()> { content.room_version = RoomVersionId::V6; // 1. The room create event - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomCreate, content: to_raw_value(&content).expect("event is valid, we just created it"), @@ -851,12 +834,11 @@ pub(crate) async fn create_admin_room(db: &Database) -> Result<()> { }, &conduit_user, &room_id, - &db, &state_lock, )?; // 2. Make conduit bot join - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomMember, content: to_raw_value(&RoomMemberEventContent { @@ -876,7 +858,6 @@ pub(crate) async fn create_admin_room(db: &Database) -> Result<()> { }, &conduit_user, &room_id, - &db, &state_lock, )?; @@ -884,7 +865,7 @@ pub(crate) async fn create_admin_room(db: &Database) -> Result<()> { let mut users = BTreeMap::new(); users.insert(conduit_user.clone(), 100.into()); - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomPowerLevels, content: to_raw_value(&RoomPowerLevelsEventContent { @@ -898,12 +879,11 @@ pub(crate) async fn create_admin_room(db: &Database) -> Result<()> { }, &conduit_user, &room_id, - &db, &state_lock, )?; // 4.1 Join Rules - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomJoinRules, content: to_raw_value(&RoomJoinRulesEventContent::new(JoinRule::Invite)) @@ -914,12 +894,11 @@ pub(crate) async fn create_admin_room(db: &Database) -> Result<()> { }, &conduit_user, &room_id, - &db, &state_lock, )?; // 4.2 History Visibility - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomHistoryVisibility, content: to_raw_value(&RoomHistoryVisibilityEventContent::new( @@ -932,12 +911,11 @@ pub(crate) async fn create_admin_room(db: &Database) -> Result<()> { }, &conduit_user, &room_id, - &db, &state_lock, )?; // 4.3 Guest Access - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomGuestAccess, content: to_raw_value(&RoomGuestAccessEventContent::new(GuestAccess::Forbidden)) @@ -948,14 +926,13 @@ pub(crate) async fn create_admin_room(db: &Database) -> Result<()> { }, &conduit_user, &room_id, - &db, &state_lock, )?; // 5. Events implied by name and topic - let room_name = RoomName::parse(format!("{} Admin Room", db.globals.server_name())) + let room_name = RoomName::parse(format!("{} Admin Room", services().globals.server_name())) .expect("Room name is valid"); - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomName, content: to_raw_value(&RoomNameEventContent::new(Some(room_name))) @@ -966,15 +943,14 @@ pub(crate) async fn create_admin_room(db: &Database) -> Result<()> { }, &conduit_user, &room_id, - &db, &state_lock, )?; - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomTopic, content: to_raw_value(&RoomTopicEventContent { - topic: format!("Manage {}", db.globals.server_name()), + topic: format!("Manage {}", services().globals.server_name()), }) .expect("event is valid, we just created it"), unsigned: None, @@ -983,16 +959,15 @@ pub(crate) async fn create_admin_room(db: &Database) -> Result<()> { }, &conduit_user, &room_id, - &db, &state_lock, )?; // 6. Room alias - let alias: Box = format!("#admins:{}", db.globals.server_name()) + let alias: Box = format!("#admins:{}", services().globals.server_name()) .try_into() .expect("#admins:server_name is a valid alias name"); - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomCanonicalAlias, content: to_raw_value(&RoomCanonicalAliasEventContent { @@ -1006,11 +981,10 @@ pub(crate) async fn create_admin_room(db: &Database) -> Result<()> { }, &conduit_user, &room_id, - &db, &state_lock, )?; - db.rooms.set_alias(&alias, Some(&room_id), &db.globals)?; + services().rooms.set_alias(&alias, Some(&room_id))?; Ok(()) } @@ -1019,20 +993,19 @@ pub(crate) async fn create_admin_room(db: &Database) -> Result<()> { /// /// In conduit, this is equivalent to granting admin privileges. pub(crate) async fn make_user_admin( - db: &Database, user_id: &UserId, displayname: String, ) -> Result<()> { - let admin_room_alias: Box = format!("#admins:{}", db.globals.server_name()) + let admin_room_alias: Box = format!("#admins:{}", services().globals.server_name()) .try_into() .expect("#admins:server_name is a valid alias name"); - let room_id = db + let room_id = services() .rooms .id_from_alias(&admin_room_alias)? .expect("Admin room must exist"); let mutex_state = Arc::clone( - db.globals + services().globals .roomid_mutex_state .write() .unwrap() @@ -1042,11 +1015,11 @@ pub(crate) async fn make_user_admin( let state_lock = mutex_state.lock().await; // Use the server user to grant the new admin's power level - let conduit_user = UserId::parse_with_server_name("conduit", db.globals.server_name()) + let conduit_user = UserId::parse_with_server_name("conduit", services().globals.server_name()) .expect("@conduit:server_name is valid"); // Invite and join the real user - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomMember, content: to_raw_value(&RoomMemberEventContent { @@ -1066,10 +1039,9 @@ pub(crate) async fn make_user_admin( }, &conduit_user, &room_id, - &db, &state_lock, )?; - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomMember, content: to_raw_value(&RoomMemberEventContent { @@ -1089,7 +1061,6 @@ pub(crate) async fn make_user_admin( }, &user_id, &room_id, - &db, &state_lock, )?; @@ -1098,7 +1069,7 @@ pub(crate) async fn make_user_admin( users.insert(conduit_user.to_owned(), 100.into()); users.insert(user_id.to_owned(), 100.into()); - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomPowerLevels, content: to_raw_value(&RoomPowerLevelsEventContent { @@ -1112,17 +1083,16 @@ pub(crate) async fn make_user_admin( }, &conduit_user, &room_id, - &db, &state_lock, )?; // Send welcome message - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomMessage, content: to_raw_value(&RoomMessageEventContent::text_html( - format!("## Thank you for trying out Conduit!\n\nConduit is currently in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.\n\nHelpful links:\n> Website: https://conduit.rs\n> Git and Documentation: https://gitlab.com/famedly/conduit\n> Report issues: https://gitlab.com/famedly/conduit/-/issues\n\nFor a list of available commands, send the following message in this room: `@conduit:{}: --help`\n\nHere are some rooms you can join (by typing the command):\n\nConduit room (Ask questions and get notified on updates):\n`/join #conduit:fachschaften.org`\n\nConduit lounge (Off-topic, only Conduit users are allowed to join)\n`/join #conduit-lounge:conduit.rs`", db.globals.server_name()).to_owned(), - format!("

Thank you for trying out Conduit!

\n

Conduit is currently in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.

\n

Helpful links:

\n
\n

Website: https://conduit.rs
Git and Documentation: https://gitlab.com/famedly/conduit
Report issues: https://gitlab.com/famedly/conduit/-/issues

\n
\n

For a list of available commands, send the following message in this room: @conduit:{}: --help

\n

Here are some rooms you can join (by typing the command):

\n

Conduit room (Ask questions and get notified on updates):
/join #conduit:fachschaften.org

\n

Conduit lounge (Off-topic, only Conduit users are allowed to join)
/join #conduit-lounge:conduit.rs

\n", db.globals.server_name()).to_owned(), + format!("## Thank you for trying out Conduit!\n\nConduit is currently in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.\n\nHelpful links:\n> Website: https://conduit.rs\n> Git and Documentation: https://gitlab.com/famedly/conduit\n> Report issues: https://gitlab.com/famedly/conduit/-/issues\n\nFor a list of available commands, send the following message in this room: `@conduit:{}: --help`\n\nHere are some rooms you can join (by typing the command):\n\nConduit room (Ask questions and get notified on updates):\n`/join #conduit:fachschaften.org`\n\nConduit lounge (Off-topic, only Conduit users are allowed to join)\n`/join #conduit-lounge:conduit.rs`", services().globals.server_name()).to_owned(), + format!("

Thank you for trying out Conduit!

\n

Conduit is currently in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.

\n

Helpful links:

\n
\n

Website: https://conduit.rs
Git and Documentation: https://gitlab.com/famedly/conduit
Report issues: https://gitlab.com/famedly/conduit/-/issues

\n
\n

For a list of available commands, send the following message in this room: @conduit:{}: --help

\n

Here are some rooms you can join (by typing the command):

\n

Conduit room (Ask questions and get notified on updates):
/join #conduit:fachschaften.org

\n

Conduit lounge (Off-topic, only Conduit users are allowed to join)
/join #conduit-lounge:conduit.rs

\n", services().globals.server_name()).to_owned(), )) .expect("event is valid, we just created it"), unsigned: None, @@ -1131,7 +1101,6 @@ pub(crate) async fn make_user_admin( }, &conduit_user, &room_id, - &db, &state_lock, )?; diff --git a/src/service/appservice/data.rs b/src/service/appservice/data.rs index fe57451f..eed84d59 100644 --- a/src/service/appservice/data.rs +++ b/src/service/appservice/data.rs @@ -1,17 +1,18 @@ pub trait Data { + type Iter: Iterator; /// Registers an appservice and returns the ID to the caller - pub fn register_appservice(&self, yaml: serde_yaml::Value) -> Result; + fn register_appservice(&self, yaml: serde_yaml::Value) -> Result; /// Remove an appservice registration /// /// # Arguments /// /// * `service_name` - the name you send to register the service previously - pub fn unregister_appservice(&self, service_name: &str) -> Result<()>; + fn unregister_appservice(&self, service_name: &str) -> Result<()>; - pub fn get_registration(&self, id: &str) -> Result>; + fn get_registration(&self, id: &str) -> Result>; - pub fn iter_ids(&self) -> Result> + '_>; + fn iter_ids(&self) -> Result>>; - pub fn all(&self) -> Result>; + fn all(&self) -> Result>; } diff --git a/src/service/key_backups.rs b/src/service/key_backups.rs index 10443f6b..be1d6b18 100644 --- a/src/service/key_backups.rs +++ b/src/service/key_backups.rs @@ -1,4 +1,4 @@ -use crate::{utils, Error, Result}; +use crate::{utils, Error, Result, services}; use ruma::{ api::client::{ backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, @@ -9,22 +9,13 @@ use ruma::{ }; use std::{collections::BTreeMap, sync::Arc}; -use super::abstraction::Tree; - -pub struct KeyBackups { - pub(super) backupid_algorithm: Arc, // BackupId = UserId + Version(Count) - pub(super) backupid_etag: Arc, // BackupId = UserId + Version(Count) - pub(super) backupkeyid_backup: Arc, // BackupKeyId = UserId + Version + RoomId + SessionId -} - impl KeyBackups { pub fn create_backup( &self, user_id: &UserId, backup_metadata: &Raw, - globals: &super::globals::Globals, ) -> Result { - let version = globals.next_count()?.to_string(); + let version = services().globals.next_count()?.to_string(); let mut key = user_id.as_bytes().to_vec(); key.push(0xff); @@ -35,7 +26,7 @@ impl KeyBackups { &serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"), )?; self.backupid_etag - .insert(&key, &globals.next_count()?.to_be_bytes())?; + .insert(&key, &services().globals.next_count()?.to_be_bytes())?; Ok(version) } @@ -61,7 +52,6 @@ impl KeyBackups { user_id: &UserId, version: &str, backup_metadata: &Raw, - globals: &super::globals::Globals, ) -> Result { let mut key = user_id.as_bytes().to_vec(); key.push(0xff); @@ -77,7 +67,7 @@ impl KeyBackups { self.backupid_algorithm .insert(&key, backup_metadata.json().get().as_bytes())?; self.backupid_etag - .insert(&key, &globals.next_count()?.to_be_bytes())?; + .insert(&key, &services().globals.next_count()?.to_be_bytes())?; Ok(version.to_owned()) } @@ -157,7 +147,6 @@ impl KeyBackups { room_id: &RoomId, session_id: &str, key_data: &Raw, - globals: &super::globals::Globals, ) -> Result<()> { let mut key = user_id.as_bytes().to_vec(); key.push(0xff); @@ -171,7 +160,7 @@ impl KeyBackups { } self.backupid_etag - .insert(&key, &globals.next_count()?.to_be_bytes())?; + .insert(&key, &services().globals.next_count()?.to_be_bytes())?; key.push(0xff); key.extend_from_slice(room_id.as_bytes()); diff --git a/src/service/media.rs b/src/service/media.rs index a4bb4025..1bdf6d47 100644 --- a/src/service/media.rs +++ b/src/service/media.rs @@ -1,4 +1,3 @@ -use crate::database::globals::Globals; use image::{imageops::FilterType, GenericImageView}; use super::abstraction::Tree; diff --git a/src/service/mod.rs b/src/service/mod.rs new file mode 100644 index 00000000..80239cbf --- /dev/null +++ b/src/service/mod.rs @@ -0,0 +1,28 @@ +pub mod pdu; +pub mod appservice; +pub mod pusher; +pub mod rooms; +pub mod transaction_ids; +pub mod uiaa; +pub mod users; +pub mod account_data; +pub mod admin; +pub mod globals; +pub mod key_backups; +pub mod media; +pub mod sending; + +pub struct Services { + pub appservice: appservice::Service, + pub pusher: pusher::Service, + pub rooms: rooms::Service, + pub transaction_ids: transaction_ids::Service, + pub uiaa: uiaa::Service, + pub users: users::Service, + //pub account_data: account_data::Service, + //pub admin: admin::Service, + pub globals: globals::Service, + //pub key_backups: key_backups::Service, + //pub media: media::Service, + //pub sending: sending::Service, +} diff --git a/src/service/pdu.rs b/src/service/pdu.rs index 20ec01ea..47e21a60 100644 --- a/src/service/pdu.rs +++ b/src/service/pdu.rs @@ -1,4 +1,4 @@ -use crate::{Database, Error}; +use crate::{Database, Error, services}; use ruma::{ events::{ room::member::RoomMemberEventContent, AnyEphemeralRoomEvent, AnyRoomEvent, AnyStateEvent, @@ -332,7 +332,6 @@ impl Ord for PduEvent { /// Returns a tuple of the new `EventId` and the PDU as a `BTreeMap`. pub(crate) fn gen_event_id_canonical_json( pdu: &RawJsonValue, - db: &Database, ) -> crate::Result<(Box, CanonicalJsonObject)> { let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { warn!("Error parsing incoming event {:?}: {:?}", pdu, e); @@ -344,7 +343,7 @@ pub(crate) fn gen_event_id_canonical_json( .and_then(|id| RoomId::parse(id.as_str()?).ok()) .ok_or_else(|| Error::bad_database("PDU in db has invalid room_id."))?; - let room_version_id = db.rooms.get_room_version(&room_id); + let room_version_id = services().rooms.get_room_version(&room_id); let event_id = format!( "${}", diff --git a/src/service/pusher/data.rs b/src/service/pusher/data.rs index 468ad8b4..ef2b8193 100644 --- a/src/service/pusher/data.rs +++ b/src/service/pusher/data.rs @@ -1,11 +1,13 @@ +use ruma::{UserId, api::client::push::{set_pusher, get_pushers}}; + pub trait Data { fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()>; - pub fn get_pusher(&self, senderkey: &[u8]) -> Result>; + fn get_pusher(&self, senderkey: &[u8]) -> Result>; - pub fn get_pushers(&self, sender: &UserId) -> Result>; + fn get_pushers(&self, sender: &UserId) -> Result>; - pub fn get_pusher_senderkeys<'a>( + fn get_pusher_senderkeys<'a>( &'a self, sender: &UserId, ) -> impl Iterator> + 'a; diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs index 342763e8..87e91a14 100644 --- a/src/service/pusher/mod.rs +++ b/src/service/pusher/mod.rs @@ -1,7 +1,27 @@ mod data; pub use data::Data; -use crate::service::*; +use crate::{services, Error, PduEvent}; +use bytes::BytesMut; +use ruma::{ + api::{ + client::push::{get_pushers, set_pusher, PusherKind}, + push_gateway::send_event_notification::{ + self, + v1::{Device, Notification, NotificationCounts, NotificationPriority}, + }, + MatrixVersion, OutgoingRequest, SendAccessToken, + }, + events::{ + room::{name::RoomNameEventContent, power_levels::RoomPowerLevelsEventContent}, + AnySyncRoomEvent, RoomEventType, StateEventType, + }, + push::{Action, PushConditionRoomCtx, PushFormat, Ruleset, Tweak}, + serde::Raw, + uint, RoomId, UInt, UserId, +}; +use std::{fmt::Debug, mem}; +use tracing::{error, info, warn}; pub struct Service { db: D, @@ -27,9 +47,8 @@ impl Service<_> { self.db.get_pusher_senderkeys(sender) } - #[tracing::instrument(skip(globals, destination, request))] + #[tracing::instrument(skip(destination, request))] pub async fn send_request( - globals: &crate::database::globals::Globals, destination: &str, request: T, ) -> Result @@ -57,7 +76,7 @@ impl Service<_> { //*reqwest_request.timeout_mut() = Some(Duration::from_secs(5)); let url = reqwest_request.url().clone(); - let response = globals.default_client().execute(reqwest_request).await; + let response = services().globals.default_client().execute(reqwest_request).await; match response { Ok(mut response) => { @@ -105,19 +124,19 @@ impl Service<_> { } } - #[tracing::instrument(skip(user, unread, pusher, ruleset, pdu, db))] + #[tracing::instrument(skip(user, unread, pusher, ruleset, pdu))] pub async fn send_push_notice( + &self, user: &UserId, unread: UInt, pusher: &get_pushers::v3::Pusher, ruleset: Ruleset, pdu: &PduEvent, - db: &Database, ) -> Result<()> { let mut notify = None; let mut tweaks = Vec::new(); - let power_levels: RoomPowerLevelsEventContent = db + let power_levels: RoomPowerLevelsEventContent = services() .rooms .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? .map(|ev| { @@ -127,13 +146,12 @@ impl Service<_> { .transpose()? .unwrap_or_default(); - for action in get_actions( + for action in self.get_actions( user, &ruleset, &power_levels, &pdu.to_sync_room_event(), &pdu.room_id, - db, )? { let n = match action { Action::DontNotify => false, @@ -155,27 +173,26 @@ impl Service<_> { } if notify == Some(true) { - send_notice(unread, pusher, tweaks, pdu, db).await?; + self.send_notice(unread, pusher, tweaks, pdu).await?; } // Else the event triggered no actions Ok(()) } - #[tracing::instrument(skip(user, ruleset, pdu, db))] + #[tracing::instrument(skip(user, ruleset, pdu))] pub fn get_actions<'a>( + &self, user: &UserId, ruleset: &'a Ruleset, power_levels: &RoomPowerLevelsEventContent, pdu: &Raw, room_id: &RoomId, - db: &Database, ) -> Result<&'a [Action]> { let ctx = PushConditionRoomCtx { room_id: room_id.to_owned(), member_count: 10_u32.into(), // TODO: get member count efficiently - user_display_name: db - .users + user_display_name: services().users .displayname(user)? .unwrap_or_else(|| user.localpart().to_owned()), users_power_levels: power_levels.users.clone(), @@ -186,13 +203,13 @@ impl Service<_> { Ok(ruleset.get_actions(pdu, &ctx)) } - #[tracing::instrument(skip(unread, pusher, tweaks, event, db))] + #[tracing::instrument(skip(unread, pusher, tweaks, event))] async fn send_notice( + &self, unread: UInt, pusher: &get_pushers::v3::Pusher, tweaks: Vec, event: &PduEvent, - db: &Database, ) -> Result<()> { // TODO: email if pusher.kind == PusherKind::Email { @@ -240,12 +257,8 @@ impl Service<_> { } if event_id_only { - send_request( - &db.globals, - url, - send_event_notification::v1::Request::new(notifi), - ) - .await?; + self.send_request(url, send_event_notification::v1::Request::new(notifi)) + .await?; } else { notifi.sender = Some(&event.sender); notifi.event_type = Some(&event.kind); @@ -256,11 +269,11 @@ impl Service<_> { notifi.user_is_target = event.state_key.as_deref() == Some(event.sender.as_str()); } - let user_name = db.users.displayname(&event.sender)?; + let user_name = services().users.displayname(&event.sender)?; notifi.sender_display_name = user_name.as_deref(); let room_name = if let Some(room_name_pdu) = - db.rooms + services().rooms .room_state_get(&event.room_id, &StateEventType::RoomName, "")? { serde_json::from_str::(room_name_pdu.content.get()) @@ -272,8 +285,7 @@ impl Service<_> { notifi.room_name = room_name.as_deref(); - send_request( - &db.globals, + self.send_request( url, send_event_notification::v1::Request::new(notifi), ) diff --git a/src/service/rooms/alias/data.rs b/src/service/rooms/alias/data.rs index 9dbfc7b5..655f32aa 100644 --- a/src/service/rooms/alias/data.rs +++ b/src/service/rooms/alias/data.rs @@ -1,22 +1,24 @@ +use ruma::{RoomId, RoomAliasId}; + pub trait Data { /// Creates or updates the alias to the given room id. - pub fn set_alias( + fn set_alias( alias: &RoomAliasId, room_id: &RoomId ) -> Result<()>; /// Forgets about an alias. Returns an error if the alias did not exist. - pub fn remove_alias( + fn remove_alias( alias: &RoomAliasId, ) -> Result<()>; /// Looks up the roomid for the given alias. - pub fn resolve_local_alias( + fn resolve_local_alias( alias: &RoomAliasId, ) -> Result<()>; /// Returns all local aliases that point to the given room - pub fn local_aliases_for_room( + fn local_aliases_for_room( alias: &RoomAliasId, ) -> Result<()>; } diff --git a/src/service/rooms/alias/mod.rs b/src/service/rooms/alias/mod.rs index cfe05396..f46609aa 100644 --- a/src/service/rooms/alias/mod.rs +++ b/src/service/rooms/alias/mod.rs @@ -1,14 +1,13 @@ mod data; pub use data::Data; - -use crate::service::*; +use ruma::{RoomAliasId, RoomId}; pub struct Service { db: D, } impl Service<_> { - #[tracing::instrument(skip(self, globals))] + #[tracing::instrument(skip(self))] pub fn set_alias( &self, alias: &RoomAliasId, @@ -17,7 +16,7 @@ impl Service<_> { self.db.set_alias(alias, room_id) } - #[tracing::instrument(skip(self, globals))] + #[tracing::instrument(skip(self))] pub fn remove_alias( &self, alias: &RoomAliasId, diff --git a/src/service/rooms/auth_chain/data.rs b/src/service/rooms/auth_chain/data.rs index d8fde958..88c86fad 100644 --- a/src/service/rooms/auth_chain/data.rs +++ b/src/service/rooms/auth_chain/data.rs @@ -1,3 +1,5 @@ +use std::collections::HashSet; + pub trait Data { fn get_cached_eventid_authchain<'a>() -> Result>; fn cache_eventid_authchain<'a>(shorteventid: u64, auth_chain: &HashSet) -> Result>; diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index dfc289f3..e17c10a1 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -1,4 +1,6 @@ mod data; +use std::{sync::Arc, collections::HashSet}; + pub use data::Data; use crate::service::*; diff --git a/src/service/rooms/directory/data.rs b/src/service/rooms/directory/data.rs index 83d78853..e28cdd12 100644 --- a/src/service/rooms/directory/data.rs +++ b/src/service/rooms/directory/data.rs @@ -1,3 +1,5 @@ +use ruma::RoomId; + pub trait Data { /// Adds the room to the public room directory fn set_public(room_id: &RoomId) -> Result<()>; diff --git a/src/service/rooms/directory/mod.rs b/src/service/rooms/directory/mod.rs index b92933f4..cb9cda86 100644 --- a/src/service/rooms/directory/mod.rs +++ b/src/service/rooms/directory/mod.rs @@ -1,5 +1,6 @@ mod data; pub use data::Data; +use ruma::RoomId; use crate::service::*; @@ -10,21 +11,21 @@ pub struct Service { impl Service<_> { #[tracing::instrument(skip(self))] pub fn set_public(&self, room_id: &RoomId) -> Result<()> { - self.db.set_public(&self, room_id) + self.db.set_public(room_id) } #[tracing::instrument(skip(self))] pub fn set_not_public(&self, room_id: &RoomId) -> Result<()> { - self.db.set_not_public(&self, room_id) + self.db.set_not_public(room_id) } #[tracing::instrument(skip(self))] pub fn is_public_room(&self, room_id: &RoomId) -> Result { - self.db.is_public_room(&self, room_id) + self.db.is_public_room(room_id) } #[tracing::instrument(skip(self))] pub fn public_rooms(&self) -> impl Iterator>> + '_ { - self.db.public_rooms(&self, room_id) + self.db.public_rooms() } } diff --git a/src/service/rooms/edus/mod.rs b/src/service/rooms/edus/mod.rs index d8ce5300..5566fb2c 100644 --- a/src/service/rooms/edus/mod.rs +++ b/src/service/rooms/edus/mod.rs @@ -1,3 +1,9 @@ pub mod presence; pub mod read_receipt; pub mod typing; + +pub struct Service { + presence: presence::Service, + read_receipt: read_receipt::Service, + typing: typing::Service, +} diff --git a/src/service/rooms/edus/presence/data.rs b/src/service/rooms/edus/presence/data.rs index de72e219..8e3c672f 100644 --- a/src/service/rooms/edus/presence/data.rs +++ b/src/service/rooms/edus/presence/data.rs @@ -1,3 +1,7 @@ +use std::collections::HashMap; + +use ruma::{UserId, RoomId, events::presence::PresenceEvent}; + pub trait Data { /// Adds a presence event which will be saved until a new event replaces it. /// diff --git a/src/service/rooms/edus/presence/mod.rs b/src/service/rooms/edus/presence/mod.rs index 5793a799..5a988d4f 100644 --- a/src/service/rooms/edus/presence/mod.rs +++ b/src/service/rooms/edus/presence/mod.rs @@ -1,5 +1,8 @@ mod data; +use std::collections::HashMap; + pub use data::Data; +use ruma::{RoomId, UserId, events::presence::PresenceEvent}; use crate::service::*; @@ -108,7 +111,7 @@ impl Service<_> { }*/ /// Returns the most recent presence updates that happened after the event with id `since`. - #[tracing::instrument(skip(self, since, _rooms, _globals))] + #[tracing::instrument(skip(self, since, room_id))] pub fn presence_since( &self, room_id: &RoomId, diff --git a/src/service/rooms/edus/read_receipt/data.rs b/src/service/rooms/edus/read_receipt/data.rs index 4befcf2c..32b091f2 100644 --- a/src/service/rooms/edus/read_receipt/data.rs +++ b/src/service/rooms/edus/read_receipt/data.rs @@ -1,3 +1,5 @@ +use ruma::{RoomId, events::receipt::ReceiptEvent, UserId, serde::Raw}; + pub trait Data { /// Replaces the previous read receipt. fn readreceipt_update( diff --git a/src/service/rooms/edus/read_receipt/mod.rs b/src/service/rooms/edus/read_receipt/mod.rs index 9cd474fb..744fece1 100644 --- a/src/service/rooms/edus/read_receipt/mod.rs +++ b/src/service/rooms/edus/read_receipt/mod.rs @@ -1,7 +1,6 @@ mod data; pub use data::Data; - -use crate::service::*; +use ruma::{RoomId, UserId, events::receipt::ReceiptEvent, serde::Raw}; pub struct Service { db: D, @@ -15,7 +14,7 @@ impl Service<_> { room_id: &RoomId, event: ReceiptEvent, ) -> Result<()> { - self.db.readreceipt_update(user_id, room_id, event); + self.db.readreceipt_update(user_id, room_id, event) } /// Returns an iterator over the most recent read_receipts in a room that happened after the event with id `since`. @@ -35,7 +34,7 @@ impl Service<_> { } /// Sets a private read marker at `count`. - #[tracing::instrument(skip(self, globals))] + #[tracing::instrument(skip(self))] pub fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { self.db.private_read_set(room_id, user_id, count) } diff --git a/src/service/rooms/edus/typing/data.rs b/src/service/rooms/edus/typing/data.rs index 83ff90ea..0c773135 100644 --- a/src/service/rooms/edus/typing/data.rs +++ b/src/service/rooms/edus/typing/data.rs @@ -1,3 +1,7 @@ +use std::collections::HashSet; + +use ruma::{UserId, RoomId}; + pub trait Data { /// Sets a user as typing until the timeout timestamp is reached or roomtyping_remove is /// called. diff --git a/src/service/rooms/edus/typing/mod.rs b/src/service/rooms/edus/typing/mod.rs index b29c7888..68b9fd83 100644 --- a/src/service/rooms/edus/typing/mod.rs +++ b/src/service/rooms/edus/typing/mod.rs @@ -1,5 +1,6 @@ mod data; pub use data::Data; +use ruma::{UserId, RoomId}; use crate::service::*; @@ -66,7 +67,6 @@ impl Service<_> { */ /// Returns the count of the last typing update in this room. - #[tracing::instrument(skip(self, globals))] pub fn last_typing_update(&self, room_id: &RoomId) -> Result { self.db.last_typing_update(room_id) } diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index 5b77586a..71529570 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -1,8 +1,29 @@ - /// An async function that can recursively call itself. type AsyncRecursiveType<'a, T> = Pin + 'a + Send>>; -use crate::service::*; +use std::{ + collections::{btree_map, hash_map, BTreeMap, HashMap, HashSet}, + pin::Pin, + sync::{Arc, RwLock}, + time::{Duration, Instant}, +}; + +use futures_util::Future; +use ruma::{ + api::{ + client::error::ErrorKind, + federation::event::{get_event, get_room_state_ids}, + }, + events::{room::create::RoomCreateEventContent, StateEventType}, + int, + serde::Base64, + signatures::CanonicalJsonValue, + state_res::{self, RoomVersion, StateMap}, + uint, EventId, MilliSecondsSinceUnixEpoch, RoomId, ServerName, +}; +use tracing::{error, info, trace, warn}; + +use crate::{service::*, services, Error, PduEvent}; pub struct Service; @@ -31,45 +52,47 @@ impl Service { /// it /// 14. Use state resolution to find new room state // We use some AsyncRecursiveType hacks here so we can call this async funtion recursively - #[tracing::instrument(skip(value, is_timeline_event, db, pub_key_map))] + #[tracing::instrument(skip(value, is_timeline_event, pub_key_map))] pub(crate) async fn handle_incoming_pdu<'a>( + &self, origin: &'a ServerName, event_id: &'a EventId, room_id: &'a RoomId, value: BTreeMap, is_timeline_event: bool, - db: &'a Database, pub_key_map: &'a RwLock>>, ) -> Result>> { - db.rooms.exists(room_id)?.ok_or(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server"))?; + services().rooms.exists(room_id)?.ok_or(Error::BadRequest( + ErrorKind::NotFound, + "Room is unknown to this server", + ))?; + + services() + .rooms + .is_disabled(room_id)? + .ok_or(Error::BadRequest( + ErrorKind::Forbidden, + "Federation of this room is currently disabled on this server.", + ))?; - db.rooms.is_disabled(room_id)?.ok_or(Error::BadRequest(ErrorKind::Forbidden, "Federation of this room is currently disabled on this server."))?; - // 1. Skip the PDU if we already have it as a timeline event - if let Some(pdu_id) = db.rooms.get_pdu_id(event_id)? { - return Some(pdu_id.to_vec()); + if let Some(pdu_id) = services().rooms.get_pdu_id(event_id)? { + return Ok(Some(pdu_id.to_vec())); } - let create_event = db + let create_event = services() .rooms .room_state_get(room_id, &StateEventType::RoomCreate, "")? .ok_or_else(|| Error::bad_database("Failed to find create event in db."))?; - let first_pdu_in_room = db + let first_pdu_in_room = services() .rooms .first_pdu_in_room(room_id)? .ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?; - let (incoming_pdu, val) = handle_outlier_pdu( - origin, - &create_event, - event_id, - room_id, - value, - db, - pub_key_map, - ) - .await?; + let (incoming_pdu, val) = self + .handle_outlier_pdu(origin, &create_event, event_id, room_id, value, pub_key_map) + .await?; // 8. if not timeline event: stop if !is_timeline_event { @@ -82,15 +105,27 @@ impl Service { } // 9. Fetch any missing prev events doing all checks listed here starting at 1. These are timeline events - let sorted_prev_events = fetch_unknown_prev_events(incoming_pdu.prev_events.clone()); + let (sorted_prev_events, eventid_info) = self.fetch_unknown_prev_events( + origin, + &create_event, + room_id, + pub_key_map, + incoming_pdu.prev_events.clone(), + ); let mut errors = 0; - for prev_id in dbg!(sorted) { + for prev_id in dbg!(sorted_prev_events) { // Check for disabled again because it might have changed - db.rooms.is_disabled(room_id)?.ok_or(Error::BadRequest(ErrorKind::Forbidden, "Federation of - this room is currently disabled on this server."))?; + services() + .rooms + .is_disabled(room_id)? + .ok_or(Error::BadRequest( + ErrorKind::Forbidden, + "Federation of + this room is currently disabled on this server.", + ))?; - if let Some((time, tries)) = db + if let Some((time, tries)) = services() .globals .bad_event_ratelimiter .read() @@ -120,26 +155,27 @@ impl Service { } let start_time = Instant::now(); - db.globals + services() + .globals .roomid_federationhandletime .write() .unwrap() .insert(room_id.to_owned(), ((*prev_id).to_owned(), start_time)); - if let Err(e) = upgrade_outlier_to_timeline_pdu( - pdu, - json, - &create_event, - origin, - db, - room_id, - pub_key_map, - ) - .await + if let Err(e) = self + .upgrade_outlier_to_timeline_pdu( + pdu, + json, + &create_event, + origin, + room_id, + pub_key_map, + ) + .await { errors += 1; warn!("Prev event {} failed: {}", prev_id, e); - match db + match services() .globals .bad_event_ratelimiter .write() @@ -155,7 +191,8 @@ impl Service { } } let elapsed = start_time.elapsed(); - db.globals + services() + .globals .roomid_federationhandletime .write() .unwrap() @@ -172,22 +209,23 @@ impl Service { // Done with prev events, now handling the incoming event let start_time = Instant::now(); - db.globals + services() + .globals .roomid_federationhandletime .write() .unwrap() .insert(room_id.to_owned(), (event_id.to_owned(), start_time)); - let r = upgrade_outlier_to_timeline_pdu( + let r = services().rooms.event_handler.upgrade_outlier_to_timeline_pdu( incoming_pdu, val, &create_event, origin, - db, room_id, pub_key_map, ) .await; - db.globals + services() + .globals .roomid_federationhandletime .write() .unwrap() @@ -196,22 +234,23 @@ impl Service { r } - #[tracing::instrument(skip(create_event, value, db, pub_key_map))] + #[tracing::instrument(skip(create_event, value, pub_key_map))] fn handle_outlier_pdu<'a>( + &self, origin: &'a ServerName, create_event: &'a PduEvent, event_id: &'a EventId, room_id: &'a RoomId, value: BTreeMap, - db: &'a Database, pub_key_map: &'a RwLock>>, - ) -> AsyncRecursiveType<'a, Result<(Arc, BTreeMap), String>> { + ) -> AsyncRecursiveType<'a, Result<(Arc, BTreeMap), String>> + { Box::pin(async move { // TODO: For RoomVersion6 we must check that Raw<..> is canonical do we anywhere?: https://matrix.org/docs/spec/rooms/v6#canonical-json // We go through all the signatures we see on the value and fetch the corresponding signing // keys - fetch_required_signing_keys(&value, pub_key_map, db) + self.fetch_required_signing_keys(&value, pub_key_map, db) .await?; // 2. Check signatures, otherwise drop @@ -223,7 +262,8 @@ impl Service { })?; let room_version_id = &create_event_content.room_version; - let room_version = RoomVersion::new(room_version_id).expect("room version is supported"); + let room_version = + RoomVersion::new(room_version_id).expect("room version is supported"); let mut val = match ruma::signatures::verify_event( &*pub_key_map.read().map_err(|_| "RwLock is poisoned.")?, @@ -261,8 +301,7 @@ impl Service { // 5. Reject "due to auth events" if can't get all the auth events or some of the auth events are also rejected "due to auth events" // NOTE: Step 5 is not applied anymore because it failed too often warn!("Fetching auth events for {}", incoming_pdu.event_id); - fetch_and_handle_outliers( - db, + self.fetch_and_handle_outliers( origin, &incoming_pdu .auth_events @@ -284,7 +323,7 @@ impl Service { // Build map of auth events let mut auth_events = HashMap::new(); for id in &incoming_pdu.auth_events { - let auth_event = match db.rooms.get_pdu(id)? { + let auth_event = match services().rooms.get_pdu(id)? { Some(e) => e, None => { warn!("Could not find auth event {}", id); @@ -303,8 +342,9 @@ impl Service { v.insert(auth_event); } hash_map::Entry::Occupied(_) => { - return Err(Error::BadRequest(ErrorKind::InvalidParam, - "Auth event's type and state_key combination exists multiple times." + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Auth event's type and state_key combination exists multiple times.", )); } } @@ -316,7 +356,10 @@ impl Service { .map(|a| a.as_ref()) != Some(create_event) { - return Err(Error::BadRequest(ErrorKind::InvalidParam("Incoming event refers to wrong create event."))); + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Incoming event refers to wrong create event.", + )); } if !state_res::event_auth::auth_check( @@ -325,15 +368,21 @@ impl Service { None::, // TODO: third party invite |k, s| auth_events.get(&(k.to_string().into(), s.to_owned())), ) - .map_err(|e| {error!(e); Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed")})? - { - return Err(Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed")); + .map_err(|e| { + error!(e); + Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed") + })? { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Auth check failed", + )); } info!("Validation successful."); // 7. Persist the event as an outlier. - db.rooms + services() + .rooms .add_pdu_outlier(&incoming_pdu.event_id, &val)?; info!("Added pdu as outlier."); @@ -342,22 +391,22 @@ impl Service { }) } - #[tracing::instrument(skip(incoming_pdu, val, create_event, db, pub_key_map))] - async fn upgrade_outlier_to_timeline_pdu( + #[tracing::instrument(skip(incoming_pdu, val, create_event, pub_key_map))] + pub async fn upgrade_outlier_to_timeline_pdu( + &self, incoming_pdu: Arc, val: BTreeMap, create_event: &PduEvent, origin: &ServerName, - db: &Database, room_id: &RoomId, pub_key_map: &RwLock>>, ) -> Result>, String> { // Skip the PDU if we already have it as a timeline event - if let Ok(Some(pduid)) = db.rooms.get_pdu_id(&incoming_pdu.event_id) { + if let Ok(Some(pduid)) = services().rooms.get_pdu_id(&incoming_pdu.event_id) { return Ok(Some(pduid)); } - if db + if services() .rooms .is_event_soft_failed(&incoming_pdu.event_id) .map_err(|_| "Failed to ask db for soft fail".to_owned())? @@ -387,32 +436,32 @@ impl Service { if incoming_pdu.prev_events.len() == 1 { let prev_event = &*incoming_pdu.prev_events[0]; - let prev_event_sstatehash = db + let prev_event_sstatehash = services() .rooms .pdu_shortstatehash(prev_event) .map_err(|_| "Failed talking to db".to_owned())?; let state = if let Some(shortstatehash) = prev_event_sstatehash { - Some(db.rooms.state_full_ids(shortstatehash).await) + Some(services().rooms.state_full_ids(shortstatehash).await) } else { None }; if let Some(Ok(mut state)) = state { info!("Using cached state"); - let prev_pdu = - db.rooms.get_pdu(prev_event).ok().flatten().ok_or_else(|| { + let prev_pdu = services() + .rooms + .get_pdu(prev_event) + .ok() + .flatten() + .ok_or_else(|| { "Could not find prev event, but we know the state.".to_owned() })?; if let Some(state_key) = &prev_pdu.state_key { - let shortstatekey = db + let shortstatekey = services() .rooms - .get_or_create_shortstatekey( - &prev_pdu.kind.to_string().into(), - state_key, - &db.globals, - ) + .get_or_create_shortstatekey(&prev_pdu.kind.to_string().into(), state_key) .map_err(|_| "Failed to create shortstatekey.".to_owned())?; state.insert(shortstatekey, Arc::from(prev_event)); @@ -427,19 +476,20 @@ impl Service { let mut okay = true; for prev_eventid in &incoming_pdu.prev_events { - let prev_event = if let Ok(Some(pdu)) = db.rooms.get_pdu(prev_eventid) { + let prev_event = if let Ok(Some(pdu)) = services().rooms.get_pdu(prev_eventid) { pdu } else { okay = false; break; }; - let sstatehash = if let Ok(Some(s)) = db.rooms.pdu_shortstatehash(prev_eventid) { - s - } else { - okay = false; - break; - }; + let sstatehash = + if let Ok(Some(s)) = services().rooms.pdu_shortstatehash(prev_eventid) { + s + } else { + okay = false; + break; + }; extremity_sstatehashes.insert(sstatehash, prev_event); } @@ -449,19 +499,18 @@ impl Service { let mut auth_chain_sets = Vec::with_capacity(extremity_sstatehashes.len()); for (sstatehash, prev_event) in extremity_sstatehashes { - let mut leaf_state: BTreeMap<_, _> = db + let mut leaf_state: BTreeMap<_, _> = services() .rooms .state_full_ids(sstatehash) .await .map_err(|_| "Failed to ask db for room state.".to_owned())?; if let Some(state_key) = &prev_event.state_key { - let shortstatekey = db + let shortstatekey = services() .rooms .get_or_create_shortstatekey( &prev_event.kind.to_string().into(), state_key, - &db.globals, ) .map_err(|_| "Failed to create shortstatekey.".to_owned())?; leaf_state.insert(shortstatekey, Arc::from(&*prev_event.event_id)); @@ -472,7 +521,7 @@ impl Service { let mut starting_events = Vec::with_capacity(leaf_state.len()); for (k, id) in leaf_state { - if let Ok((ty, st_key)) = db.rooms.get_statekey_from_short(k) { + if let Ok((ty, st_key)) = services().rooms.get_statekey_from_short(k) { // FIXME: Undo .to_string().into() when StateMap // is updated to use StateEventType state.insert((ty.to_string().into(), st_key), id.clone()); @@ -483,7 +532,10 @@ impl Service { } auth_chain_sets.push( - get_auth_chain(room_id, starting_events, db) + services() + .rooms + .auth_chain + .get_auth_chain(room_id, starting_events, services()) .await .map_err(|_| "Failed to load auth chain.".to_owned())? .collect(), @@ -492,15 +544,16 @@ impl Service { fork_states.push(state); } - let lock = db.globals.stateres_mutex.lock(); + let lock = services().globals.stateres_mutex.lock(); - let result = state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| { - let res = db.rooms.get_pdu(id); - if let Err(e) = &res { - error!("LOOK AT ME Failed to fetch event: {}", e); - } - res.ok().flatten() - }); + let result = + state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| { + let res = services().rooms.get_pdu(id); + if let Err(e) = &res { + error!("LOOK AT ME Failed to fetch event: {}", e); + } + res.ok().flatten() + }); drop(lock); state_at_incoming_event = match result { @@ -508,14 +561,15 @@ impl Service { new_state .into_iter() .map(|((event_type, state_key), event_id)| { - let shortstatekey = db + let shortstatekey = services() .rooms .get_or_create_shortstatekey( &event_type.to_string().into(), &state_key, - &db.globals, ) - .map_err(|_| "Failed to get_or_create_shortstatekey".to_owned())?; + .map_err(|_| { + "Failed to get_or_create_shortstatekey".to_owned() + })?; Ok((shortstatekey, event_id)) }) .collect::>()?, @@ -532,10 +586,9 @@ impl Service { info!("Calling /state_ids"); // Call /state_ids to find out what the state at this pdu is. We trust the server's // response to some extend, but we still do a lot of checks on the events - match db + match services() .sending .send_federation_request( - &db.globals, origin, get_room_state_ids::v1::Request { room_id, @@ -546,18 +599,18 @@ impl Service { { Ok(res) => { info!("Fetching state events at event."); - let state_vec = fetch_and_handle_outliers( - db, - origin, - &res.pdu_ids - .iter() - .map(|x| Arc::from(&**x)) - .collect::>(), - create_event, - room_id, - pub_key_map, - ) - .await; + let state_vec = self + .fetch_and_handle_outliers( + origin, + &res.pdu_ids + .iter() + .map(|x| Arc::from(&**x)) + .collect::>(), + create_event, + room_id, + pub_key_map, + ) + .await; let mut state: BTreeMap<_, Arc> = BTreeMap::new(); for (pdu, _) in state_vec { @@ -566,13 +619,9 @@ impl Service { .clone() .ok_or_else(|| "Found non-state pdu in state events.".to_owned())?; - let shortstatekey = db + let shortstatekey = services() .rooms - .get_or_create_shortstatekey( - &pdu.kind.to_string().into(), - &state_key, - &db.globals, - ) + .get_or_create_shortstatekey(&pdu.kind.to_string().into(), &state_key) .map_err(|_| "Failed to create shortstatekey.".to_owned())?; match state.entry(shortstatekey) { @@ -587,7 +636,7 @@ impl Service { } // The original create event must still be in the state - let create_shortstatekey = db + let create_shortstatekey = services() .rooms .get_shortstatekey(&StateEventType::RoomCreate, "") .map_err(|_| "Failed to talk to db.")? @@ -618,12 +667,13 @@ impl Service { &incoming_pdu, None::, // TODO: third party invite |k, s| { - db.rooms + services() + .rooms .get_shortstatekey(&k.to_string().into(), s) .ok() .flatten() .and_then(|shortstatekey| state_at_incoming_event.get(&shortstatekey)) - .and_then(|event_id| db.rooms.get_pdu(event_id).ok().flatten()) + .and_then(|event_id| services().rooms.get_pdu(event_id).ok().flatten()) }, ) .map_err(|_e| "Auth check failed.".to_owned())?; @@ -636,7 +686,8 @@ impl Service { // We start looking at current room state now, so lets lock the room let mutex_state = Arc::clone( - db.globals + services() + .globals .roomid_mutex_state .write() .unwrap() @@ -648,7 +699,7 @@ impl Service { // Now we calculate the set of extremities this room has after the incoming event has been // applied. We start with the previous extremities (aka leaves) info!("Calculating extremities"); - let mut extremities = db + let mut extremities = services() .rooms .get_pdu_leaves(room_id) .map_err(|_| "Failed to load room leaves".to_owned())?; @@ -661,14 +712,16 @@ impl Service { } // Only keep those extremities were not referenced yet - extremities.retain(|id| !matches!(db.rooms.is_event_referenced(room_id, id), Ok(true))); + extremities + .retain(|id| !matches!(services().rooms.is_event_referenced(room_id, id), Ok(true))); info!("Compressing state at event"); let state_ids_compressed = state_at_incoming_event .iter() .map(|(shortstatekey, id)| { - db.rooms - .compress_state_event(*shortstatekey, id, &db.globals) + services() + .rooms + .compress_state_event(*shortstatekey, id) .map_err(|_| "Failed to compress_state_event".to_owned()) }) .collect::>()?; @@ -676,7 +729,7 @@ impl Service { // 13. Check if the event passes auth based on the "current state" of the room, if not "soft fail" it info!("Starting soft fail auth check"); - let auth_events = db + let auth_events = services() .rooms .get_auth_events( room_id, @@ -696,11 +749,10 @@ impl Service { .map_err(|_e| "Auth check failed.".to_owned())?; if soft_fail { - append_incoming_pdu( - db, + self.append_incoming_pdu( &incoming_pdu, val, - extremities.iter().map(Deref::deref), + extremities.iter().map(std::ops::Deref::deref), state_ids_compressed, soft_fail, &state_lock, @@ -712,7 +764,8 @@ impl Service { // Soft fail, we keep the event as an outlier but don't add it to the timeline warn!("Event was soft failed: {:?}", incoming_pdu); - db.rooms + services() + .rooms .mark_event_soft_failed(&incoming_pdu.event_id) .map_err(|_| "Failed to set soft failed flag".to_owned())?; return Err("Event has been soft failed".into()); @@ -720,13 +773,13 @@ impl Service { if incoming_pdu.state_key.is_some() { info!("Loading current room state ids"); - let current_sstatehash = db + let current_sstatehash = services() .rooms .current_shortstatehash(room_id) .map_err(|_| "Failed to load current state hash.".to_owned())? .expect("every room has state"); - let current_state_ids = db + let current_state_ids = services() .rooms .state_full_ids(current_sstatehash) .await @@ -737,14 +790,14 @@ impl Service { info!("Loading extremities"); for id in dbg!(&extremities) { - match db + match services() .rooms .get_pdu(id) .map_err(|_| "Failed to ask db for pdu.".to_owned())? { Some(leaf_pdu) => { extremity_sstatehashes.insert( - db.rooms + services() .pdu_shortstatehash(&leaf_pdu.event_id) .map_err(|_| "Failed to ask db for pdu state hash.".to_owned())? .ok_or_else(|| { @@ -777,13 +830,9 @@ impl Service { // We also add state after incoming event to the fork states let mut state_after = state_at_incoming_event.clone(); if let Some(state_key) = &incoming_pdu.state_key { - let shortstatekey = db + let shortstatekey = services() .rooms - .get_or_create_shortstatekey( - &incoming_pdu.kind.to_string().into(), - state_key, - &db.globals, - ) + .get_or_create_shortstatekey(&incoming_pdu.kind.to_string().into(), state_key) .map_err(|_| "Failed to create shortstatekey.".to_owned())?; state_after.insert(shortstatekey, Arc::from(&*incoming_pdu.event_id)); @@ -801,8 +850,9 @@ impl Service { fork_states[0] .iter() .map(|(k, id)| { - db.rooms - .compress_state_event(*k, id, &db.globals) + services() + .rooms + .compress_state_event(*k, id) .map_err(|_| "Failed to compress_state_event.".to_owned()) }) .collect::>()? @@ -814,14 +864,16 @@ impl Service { let mut auth_chain_sets = Vec::new(); for state in &fork_states { auth_chain_sets.push( - get_auth_chain( - room_id, - state.iter().map(|(_, id)| id.clone()).collect(), - db, - ) - .await - .map_err(|_| "Failed to load auth chain.".to_owned())? - .collect(), + services() + .rooms + .auth_chain + .get_auth_chain( + room_id, + state.iter().map(|(_, id)| id.clone()).collect(), + ) + .await + .map_err(|_| "Failed to load auth chain.".to_owned())? + .collect(), ); } @@ -832,7 +884,8 @@ impl Service { .map(|map| { map.into_iter() .filter_map(|(k, id)| { - db.rooms + services() + .rooms .get_statekey_from_short(k) // FIXME: Undo .to_string().into() when StateMap // is updated to use StateEventType @@ -846,13 +899,13 @@ impl Service { info!("Resolving state"); - let lock = db.globals.stateres_mutex.lock(); + let lock = services().globals.stateres_mutex.lock(); let state = match state_res::resolve( room_version_id, &fork_states, auth_chain_sets, |id| { - let res = db.rooms.get_pdu(id); + let res = services().rooms.get_pdu(id); if let Err(e) = &res { error!("LOOK AT ME Failed to fetch event: {}", e); } @@ -872,16 +925,13 @@ impl Service { state .into_iter() .map(|((event_type, state_key), event_id)| { - let shortstatekey = db + let shortstatekey = services() .rooms - .get_or_create_shortstatekey( - &event_type.to_string().into(), - &state_key, - &db.globals, - ) + .get_or_create_shortstatekey(&event_type.to_string().into(), &state_key) .map_err(|_| "Failed to get_or_create_shortstatekey".to_owned())?; - db.rooms - .compress_state_event(shortstatekey, &event_id, &db.globals) + services() + .rooms + .compress_state_event(shortstatekey, &event_id) .map_err(|_| "Failed to compress state event".to_owned()) }) .collect::>()? @@ -890,8 +940,9 @@ impl Service { // Set the new room state to the resolved state if update_state { info!("Forcing new room state"); - db.rooms - .force_state(room_id, new_room_state, db) + services() + .rooms + .force_state(room_id, new_room_state) .map_err(|_| "Failed to set new room state.".to_owned())?; } } @@ -903,19 +954,19 @@ impl Service { // We use the `state_at_event` instead of `state_after` so we accurately // represent the state for this event. - let pdu_id = append_incoming_pdu( - db, - &incoming_pdu, - val, - extremities.iter().map(Deref::deref), - state_ids_compressed, - soft_fail, - &state_lock, - ) - .map_err(|e| { - warn!("Failed to add pdu to db: {}", e); - "Failed to add pdu to db.".to_owned() - })?; + let pdu_id = self + .append_incoming_pdu( + &incoming_pdu, + val, + extremities.iter().map(std::ops::Deref::deref), + state_ids_compressed, + soft_fail, + &state_lock, + ) + .map_err(|e| { + warn!("Failed to add pdu to db: {}", e); + "Failed to add pdu to db.".to_owned() + })?; info!("Appended incoming pdu"); @@ -935,15 +986,22 @@ impl Service { /// d. TODO: Ask other servers over federation? #[tracing::instrument(skip_all)] pub(crate) fn fetch_and_handle_outliers<'a>( - db: &'a Database, + &self, origin: &'a ServerName, events: &'a [Arc], create_event: &'a PduEvent, room_id: &'a RoomId, pub_key_map: &'a RwLock>>, - ) -> AsyncRecursiveType<'a, Vec<(Arc, Option>)>> { + ) -> AsyncRecursiveType<'a, Vec<(Arc, Option>)>> + { Box::pin(async move { - let back_off = |id| match db.globals.bad_event_ratelimiter.write().unwrap().entry(id) { + let back_off = |id| match services() + .globals + .bad_event_ratelimiter + .write() + .unwrap() + .entry(id) + { hash_map::Entry::Vacant(e) => { e.insert((Instant::now(), 1)); } @@ -952,10 +1010,16 @@ impl Service { let mut pdus = vec![]; for id in events { - if let Some((time, tries)) = db.globals.bad_event_ratelimiter.read().unwrap().get(&**id) + if let Some((time, tries)) = services() + .globals + .bad_event_ratelimiter + .read() + .unwrap() + .get(&**id) { // Exponential backoff - let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); + let mut min_elapsed_duration = + Duration::from_secs(5 * 60) * (*tries) * (*tries); if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { min_elapsed_duration = Duration::from_secs(60 * 60 * 24); } @@ -969,7 +1033,7 @@ impl Service { // a. Look in the main timeline (pduid_pdu tree) // b. Look at outlier pdu tree // (get_pdu_json checks both) - if let Ok(Some(local_pdu)) = db.rooms.get_pdu(id) { + if let Ok(Some(local_pdu)) = services().rooms.get_pdu(id) { trace!("Found {} in db", id); pdus.push((local_pdu, None)); continue; @@ -992,16 +1056,15 @@ impl Service { tokio::task::yield_now().await; } - if let Ok(Some(_)) = db.rooms.get_pdu(&next_id) { + if let Ok(Some(_)) = services().rooms.get_pdu(&next_id) { trace!("Found {} in db", id); continue; } info!("Fetching {} over federation.", next_id); - match db + match services() .sending .send_federation_request( - &db.globals, origin, get_event::v1::Request { event_id: &next_id }, ) @@ -1010,7 +1073,7 @@ impl Service { Ok(res) => { info!("Got {} over federation", next_id); let (calculated_event_id, value) = - match crate::pdu::gen_event_id_canonical_json(&res.pdu, &db) { + match pdu::gen_event_id_canonical_json(&res.pdu) { Ok(t) => t, Err(_) => { back_off((*next_id).to_owned()); @@ -1051,16 +1114,16 @@ impl Service { } for (next_id, value) in events_in_reverse_order.iter().rev() { - match handle_outlier_pdu( - origin, - create_event, - next_id, - room_id, - value.clone(), - db, - pub_key_map, - ) - .await + match self + .handle_outlier_pdu( + origin, + create_event, + next_id, + room_id, + value.clone(), + pub_key_map, + ) + .await { Ok((pdu, json)) => { if next_id == id { @@ -1078,9 +1141,14 @@ impl Service { }) } - - - fn fetch_unknown_prev_events(initial_set: Vec>) -> Vec> { + async fn fetch_unknown_prev_events( + &self, + origin: &ServerName, + create_event: &PduEvent, + room_id: &RoomId, + pub_key_map: &RwLock>>, + initial_set: Vec>, + ) -> Vec<(Arc, HashMap, (Arc, BTreeMap)>)> { let mut graph: HashMap, _> = HashMap::new(); let mut eventid_info = HashMap::new(); let mut todo_outlier_stack: Vec> = initial_set; @@ -1088,16 +1156,16 @@ impl Service { let mut amount = 0; while let Some(prev_event_id) = todo_outlier_stack.pop() { - if let Some((pdu, json_opt)) = fetch_and_handle_outliers( - db, - origin, - &[prev_event_id.clone()], - &create_event, - room_id, - pub_key_map, - ) - .await - .pop() + if let Some((pdu, json_opt)) = self + .fetch_and_handle_outliers( + origin, + &[prev_event_id.clone()], + &create_event, + room_id, + pub_key_map, + ) + .await + .pop() { if amount > 100 { // Max limit reached @@ -1106,9 +1174,13 @@ impl Service { continue; } - if let Some(json) = - json_opt.or_else(|| db.rooms.get_outlier_pdu_json(&prev_event_id).ok().flatten()) - { + if let Some(json) = json_opt.or_else(|| { + services() + .rooms + .get_outlier_pdu_json(&prev_event_id) + .ok() + .flatten() + }) { if pdu.origin_server_ts > first_pdu_in_room.origin_server_ts { amount += 1; for prev_prev in &pdu.prev_events { @@ -1153,6 +1225,6 @@ impl Service { }) .map_err(|_| "Error sorting prev events".to_owned())?; - sorted + (sorted, eventid_info) } } diff --git a/src/service/rooms/lazy_loading/data.rs b/src/service/rooms/lazy_loading/data.rs index 9cf2d8bc..52a683d3 100644 --- a/src/service/rooms/lazy_loading/data.rs +++ b/src/service/rooms/lazy_loading/data.rs @@ -1,3 +1,5 @@ +use ruma::{RoomId, DeviceId, UserId}; + pub trait Data { fn lazy_load_was_sent_before( &self, diff --git a/src/service/rooms/lazy_loading/mod.rs b/src/service/rooms/lazy_loading/mod.rs index cf00174b..bdc083a0 100644 --- a/src/service/rooms/lazy_loading/mod.rs +++ b/src/service/rooms/lazy_loading/mod.rs @@ -1,5 +1,8 @@ mod data; +use std::collections::HashSet; + pub use data::Data; +use ruma::{DeviceId, UserId, RoomId}; use crate::service::*; @@ -47,7 +50,7 @@ impl Service<_> { room_id: &RoomId, since: u64, ) -> Result<()> { - self.db.lazy_load_confirm_delivery(user_d, device_id, room_id, since) + self.db.lazy_load_confirm_delivery(user_id, device_id, room_id, since) } #[tracing::instrument(skip(self))] @@ -57,6 +60,6 @@ impl Service<_> { device_id: &DeviceId, room_id: &RoomId, ) -> Result<()> { - self.db.lazy_load_reset(user_id, device_id, room_id); + self.db.lazy_load_reset(user_id, device_id, room_id) } } diff --git a/src/service/rooms/metadata/data.rs b/src/service/rooms/metadata/data.rs index 58bd3510..2d718b2d 100644 --- a/src/service/rooms/metadata/data.rs +++ b/src/service/rooms/metadata/data.rs @@ -1,3 +1,5 @@ +use ruma::RoomId; + pub trait Data { fn exists(&self, room_id: &RoomId) -> Result; } diff --git a/src/service/rooms/metadata/mod.rs b/src/service/rooms/metadata/mod.rs index 644cd18f..8417e28e 100644 --- a/src/service/rooms/metadata/mod.rs +++ b/src/service/rooms/metadata/mod.rs @@ -1,5 +1,6 @@ mod data; pub use data::Data; +use ruma::RoomId; use crate::service::*; diff --git a/src/service/rooms/mod.rs b/src/service/rooms/mod.rs index 89598afe..47250340 100644 --- a/src/service/rooms/mod.rs +++ b/src/service/rooms/mod.rs @@ -1,216 +1,37 @@ -mod edus; - -pub use edus::RoomEdus; - -use crate::{ - pdu::{EventHash, PduBuilder}, - utils, Database, Error, PduEvent, Result, -}; -use lru_cache::LruCache; -use regex::Regex; -use ring::digest; -use ruma::{ - api::{client::error::ErrorKind, federation}, - events::{ - direct::DirectEvent, - ignored_user_list::IgnoredUserListEvent, - push_rules::PushRulesEvent, - room::{ - create::RoomCreateEventContent, - member::{MembershipState, RoomMemberEventContent}, - power_levels::RoomPowerLevelsEventContent, - }, - tag::TagEvent, - AnyStrippedStateEvent, AnySyncStateEvent, GlobalAccountDataEventType, - RoomAccountDataEventType, RoomEventType, StateEventType, - }, - push::{Action, Ruleset, Tweak}, - serde::{CanonicalJsonObject, CanonicalJsonValue, Raw}, - state_res::{self, RoomVersion, StateMap}, - uint, DeviceId, EventId, RoomAliasId, RoomId, RoomVersionId, ServerName, UserId, -}; -use serde::Deserialize; -use serde_json::value::to_raw_value; -use std::{ - borrow::Cow, - collections::{hash_map, BTreeMap, HashMap, HashSet}, - fmt::Debug, - iter, - mem::size_of, - sync::{Arc, Mutex, RwLock}, -}; -use tokio::sync::MutexGuard; -use tracing::{error, warn}; - -use super::{abstraction::Tree, pusher}; - -/// The unique identifier of each state group. -/// -/// This is created when a state group is added to the database by -/// hashing the entire state. -pub type StateHashId = Vec; -pub type CompressedStateEvent = [u8; 2 * size_of::()]; - -pub struct Rooms { - pub edus: RoomEdus, - pub(super) pduid_pdu: Arc, // PduId = ShortRoomId + Count - pub(super) eventid_pduid: Arc, - pub(super) roomid_pduleaves: Arc, - pub(super) alias_roomid: Arc, - pub(super) aliasid_alias: Arc, // AliasId = RoomId + Count - pub(super) publicroomids: Arc, - - pub(super) tokenids: Arc, // TokenId = ShortRoomId + Token + PduIdCount - - /// Participating servers in a room. - pub(super) roomserverids: Arc, // RoomServerId = RoomId + ServerName - pub(super) serverroomids: Arc, // ServerRoomId = ServerName + RoomId - - pub(super) userroomid_joined: Arc, - pub(super) roomuserid_joined: Arc, - pub(super) roomid_joinedcount: Arc, - pub(super) roomid_invitedcount: Arc, - pub(super) roomuseroncejoinedids: Arc, - pub(super) userroomid_invitestate: Arc, // InviteState = Vec> - pub(super) roomuserid_invitecount: Arc, // InviteCount = Count - pub(super) userroomid_leftstate: Arc, - pub(super) roomuserid_leftcount: Arc, - - pub(super) disabledroomids: Arc, // Rooms where incoming federation handling is disabled - - pub(super) lazyloadedids: Arc, // LazyLoadedIds = UserId + DeviceId + RoomId + LazyLoadedUserId - - pub(super) userroomid_notificationcount: Arc, // NotifyCount = u64 - pub(super) userroomid_highlightcount: Arc, // HightlightCount = u64 - - /// Remember the current state hash of a room. - pub(super) roomid_shortstatehash: Arc, - pub(super) roomsynctoken_shortstatehash: Arc, - /// Remember the state hash at events in the past. - pub(super) shorteventid_shortstatehash: Arc, - /// StateKey = EventType + StateKey, ShortStateKey = Count - pub(super) statekey_shortstatekey: Arc, - pub(super) shortstatekey_statekey: Arc, - - pub(super) roomid_shortroomid: Arc, - - pub(super) shorteventid_eventid: Arc, - pub(super) eventid_shorteventid: Arc, - - pub(super) statehash_shortstatehash: Arc, - pub(super) shortstatehash_statediff: Arc, // StateDiff = parent (or 0) + (shortstatekey+shorteventid++) + 0_u64 + (shortstatekey+shorteventid--) - - pub(super) shorteventid_authchain: Arc, - - /// RoomId + EventId -> outlier PDU. - /// Any pdu that has passed the steps 1-8 in the incoming event /federation/send/txn. - pub(super) eventid_outlierpdu: Arc, - pub(super) softfailedeventids: Arc, - - /// RoomId + EventId -> Parent PDU EventId. - pub(super) referencedevents: Arc, - - pub(super) pdu_cache: Mutex, Arc>>, - pub(super) shorteventid_cache: Mutex>>, - pub(super) auth_chain_cache: Mutex, Arc>>>, - pub(super) eventidshort_cache: Mutex, u64>>, - pub(super) statekeyshort_cache: Mutex>, - pub(super) shortstatekey_cache: Mutex>, - pub(super) our_real_users_cache: RwLock, Arc>>>>, - pub(super) appservice_in_room_cache: RwLock, HashMap>>, - pub(super) lazy_load_waiting: - Mutex, Box, Box, u64), HashSet>>>, - pub(super) stateinfo_cache: Mutex< - LruCache< - u64, - Vec<( - u64, // sstatehash - HashSet, // full state - HashSet, // added - HashSet, // removed - )>, - >, - >, - pub(super) lasttimelinecount_cache: Mutex, u64>>, -} - -impl Rooms { - /// Returns true if a given room version is supported - #[tracing::instrument(skip(self, db))] - pub fn is_supported_version(&self, db: &Database, room_version: &RoomVersionId) -> bool { - db.globals.supported_room_versions().contains(room_version) - } - - /// This fetches auth events from the current state. - #[tracing::instrument(skip(self))] - pub fn get_auth_events( - &self, - room_id: &RoomId, - kind: &RoomEventType, - sender: &UserId, - state_key: Option<&str>, - content: &serde_json::value::RawValue, - ) -> Result>> { - let shortstatehash = - if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? { - current_shortstatehash - } else { - return Ok(HashMap::new()); - }; - - let auth_events = state_res::auth_types_for_event(kind, sender, state_key, content) - .expect("content is a valid JSON object"); - - let mut sauthevents = auth_events - .into_iter() - .filter_map(|(event_type, state_key)| { - self.get_shortstatekey(&event_type.to_string().into(), &state_key) - .ok() - .flatten() - .map(|s| (s, (event_type, state_key))) - }) - .collect::>(); - - let full_state = self - .load_shortstatehash_info(shortstatehash)? - .pop() - .expect("there is always one layer") - .1; - - Ok(full_state - .into_iter() - .filter_map(|compressed| self.parse_compressed_state_event(compressed).ok()) - .filter_map(|(shortstatekey, event_id)| { - sauthevents.remove(&shortstatekey).map(|k| (k, event_id)) - }) - .filter_map(|(k, event_id)| self.get_pdu(&event_id).ok().flatten().map(|pdu| (k, pdu))) - .collect()) - } - - /// Generate a new StateHash. - /// - /// A unique hash made from hashing all PDU ids of the state joined with 0xff. - fn calculate_hash(&self, bytes_list: &[&[u8]]) -> StateHashId { - // We only hash the pdu's event ids, not the whole pdu - let bytes = bytes_list.join(&0xff); - let hash = digest::digest(&digest::SHA256, &bytes); - hash.as_ref().into() - } - - #[tracing::instrument(skip(self))] - pub fn iter_ids(&self) -> impl Iterator>> + '_ { - self.roomid_shortroomid.iter().map(|(bytes, _)| { - RoomId::parse( - utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Room ID in publicroomids is invalid unicode.") - })?, - ) - .map_err(|_| Error::bad_database("Room ID in roomid_shortroomid is invalid.")) - }) - } - - pub fn is_disabled(&self, room_id: &RoomId) -> Result { - Ok(self.disabledroomids.get(room_id.as_bytes())?.is_some()) - } +pub mod alias; +pub mod auth_chain; +pub mod directory; +pub mod edus; +pub mod event_handler; +pub mod lazy_loading; +pub mod metadata; +pub mod outlier; +pub mod pdu_metadata; +pub mod search; +pub mod short; +pub mod state; +pub mod state_accessor; +pub mod state_cache; +pub mod state_compressor; +pub mod timeline; +pub mod user; +pub struct Service { + pub alias: alias::Service, + pub auth_chain: auth_chain::Service, + pub directory: directory::Service, + pub edus: edus::Service, + pub event_handler: event_handler::Service, + pub lazy_loading: lazy_loading::Service, + pub metadata: metadata::Service, + pub outlier: outlier::Service, + pub pdu_metadata: pdu_metadata::Service, + pub search: search::Service, + pub short: short::Service, + pub state: state::Service, + pub state_accessor: state_accessor::Service, + pub state_cache: state_cache::Service, + pub state_compressor: state_compressor::Service, + pub timeline: timeline::Service, + pub user: user::Service, } diff --git a/src/service/rooms/outlier/data.rs b/src/service/rooms/outlier/data.rs index 6b534b95..d579515e 100644 --- a/src/service/rooms/outlier/data.rs +++ b/src/service/rooms/outlier/data.rs @@ -1,3 +1,7 @@ +use ruma::{EventId, signatures::CanonicalJsonObject}; + +use crate::PduEvent; + pub trait Data { fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result>; fn get_outlier_pdu(&self, event_id: &EventId) -> Result>; diff --git a/src/service/rooms/outlier/mod.rs b/src/service/rooms/outlier/mod.rs index c82cb628..ee8b940f 100644 --- a/src/service/rooms/outlier/mod.rs +++ b/src/service/rooms/outlier/mod.rs @@ -1,7 +1,8 @@ mod data; pub use data::Data; +use ruma::{EventId, signatures::CanonicalJsonObject}; -use crate::service::*; +use crate::{service::*, PduEvent}; pub struct Service { db: D, diff --git a/src/service/rooms/pdu_metadata/data.rs b/src/service/rooms/pdu_metadata/data.rs index 67787958..531823fe 100644 --- a/src/service/rooms/pdu_metadata/data.rs +++ b/src/service/rooms/pdu_metadata/data.rs @@ -1,3 +1,7 @@ +use std::sync::Arc; + +use ruma::{EventId, RoomId}; + pub trait Data { fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) -> Result<()>; fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result; diff --git a/src/service/rooms/pdu_metadata/mod.rs b/src/service/rooms/pdu_metadata/mod.rs index 6d6df223..3442b830 100644 --- a/src/service/rooms/pdu_metadata/mod.rs +++ b/src/service/rooms/pdu_metadata/mod.rs @@ -1,5 +1,8 @@ mod data; +use std::sync::Arc; + pub use data::Data; +use ruma::{RoomId, EventId}; use crate::service::*; diff --git a/src/service/rooms/search/data.rs b/src/service/rooms/search/data.rs index 1601e0de..16287eba 100644 --- a/src/service/rooms/search/data.rs +++ b/src/service/rooms/search/data.rs @@ -1,7 +1,9 @@ -pub trait Data { - pub fn index_pdu<'a>(&self, room_id: &RoomId, pdu_id: u64, message_body: String) -> Result<()>; +use ruma::RoomId; - pub fn search_pdus<'a>( +pub trait Data { + fn index_pdu<'a>(&self, room_id: &RoomId, pdu_id: u64, message_body: String) -> Result<()>; + + fn search_pdus<'a>( &'a self, room_id: &RoomId, search_string: &str, diff --git a/src/service/rooms/search/mod.rs b/src/service/rooms/search/mod.rs index 5478273c..9087deff 100644 --- a/src/service/rooms/search/mod.rs +++ b/src/service/rooms/search/mod.rs @@ -1,7 +1,6 @@ mod data; pub use data::Data; - -use crate::service::*; +use ruma::RoomId; pub struct Service { db: D, diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index a8e87b91..afde14e2 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -1,7 +1,10 @@ mod data; -pub use data::Data; +use std::sync::Arc; -use crate::service::*; +pub use data::Data; +use ruma::{EventId, events::StateEventType}; + +use crate::{service::*, Error, utils}; pub struct Service { db: D, @@ -188,7 +191,6 @@ impl Service<_> { fn get_or_create_shortstatehash( &self, state_hash: &StateHashId, - globals: &super::globals::Globals, ) -> Result<(u64, bool)> { Ok(match self.statehash_shortstatehash.get(state_hash)? { Some(shortstatehash) => ( diff --git a/src/service/rooms/state/data.rs b/src/service/rooms/state/data.rs index 8aa76380..ac8fac21 100644 --- a/src/service/rooms/state/data.rs +++ b/src/service/rooms/state/data.rs @@ -1,14 +1,20 @@ +use std::sync::Arc; +use std::{sync::MutexGuard, collections::HashSet}; +use std::fmt::Debug; + +use ruma::{EventId, RoomId}; + pub trait Data { /// Returns the last state hash key added to the db for the given room. fn get_room_shortstatehash(room_id: &RoomId); /// Update the current state of the room. - fn set_room_state(room_id: &RoomId, new_shortstatehash: u64 + fn set_room_state(room_id: &RoomId, new_shortstatehash: u64, _mutex_lock: &MutexGuard<'_, StateLock>, // Take mutex guard to make sure users get the room state mutex ); /// Associates a state with an event. - fn set_event_state(shorteventid: u64, shortstatehash: u64) -> Result<()> { + fn set_event_state(shorteventid: u64, shortstatehash: u64) -> Result<()>; /// Returns all events we would send as the prev_events of the next event. fn get_forward_extremities(room_id: &RoomId) -> Result>>; @@ -18,7 +24,7 @@ pub trait Data { room_id: &RoomId, event_ids: impl IntoIterator + Debug, _mutex_lock: &MutexGuard<'_, StateLock>, // Take mutex guard to make sure users get the room state mutex - ) -> Result<()> { + ) -> Result<()>; } pub struct StateLock; diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index b513ab53..6c33d521 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -1,7 +1,12 @@ mod data; -pub use data::Data; +use std::collections::HashSet; -use crate::service::*; +pub use data::Data; +use ruma::{RoomId, events::{room::{member::MembershipState, create::RoomCreateEventContent}, AnyStrippedStateEvent, StateEventType}, UserId, EventId, serde::Raw, RoomVersionId}; +use serde::Deserialize; +use tracing::warn; + +use crate::{service::*, SERVICE, PduEvent, Error, utils::calculate_hash}; pub struct Service { db: D, @@ -9,22 +14,20 @@ pub struct Service { impl Service<_> { /// Set the room to the given statehash and update caches. - #[tracing::instrument(skip(self, new_state_ids_compressed, db))] pub fn force_state( &self, room_id: &RoomId, shortstatehash: u64, statediffnew: HashSet, statediffremoved: HashSet, - db: &Database, ) -> Result<()> { for event_id in statediffnew.into_iter().filter_map(|new| { - state_compressor::parse_compressed_state_event(new) + SERVICE.rooms.state_compressor.parse_compressed_state_event(new) .ok() .map(|(_, id)| id) }) { - let pdu = match timeline::get_pdu_json(&event_id)? { + let pdu = match SERVICE.rooms.timeline.get_pdu_json(&event_id)? { Some(pdu) => pdu, None => continue, }; @@ -60,12 +63,12 @@ impl Service<_> { Err(_) => continue, }; - room::state_cache::update_membership(room_id, &user_id, membership, &pdu.sender, None, db, false)?; + SERVICE.room.state_cache.update_membership(room_id, &user_id, membership, &pdu.sender, None, false)?; } - room::state_cache::update_joined_count(room_id, db)?; + SERVICE.room.state_cache.update_joined_count(room_id)?; - db.set_room_state(room_id, new_shortstatehash); + self.db.set_room_state(room_id, shortstatehash); Ok(()) } @@ -74,19 +77,18 @@ impl Service<_> { /// /// This adds all current state events (not including the incoming event) /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. - #[tracing::instrument(skip(self, state_ids_compressed, globals))] + #[tracing::instrument(skip(self, state_ids_compressed))] pub fn set_event_state( &self, event_id: &EventId, room_id: &RoomId, state_ids_compressed: HashSet, - globals: &super::globals::Globals, ) -> Result<()> { - let shorteventid = short::get_or_create_shorteventid(event_id, globals)?; + let shorteventid = SERVICE.short.get_or_create_shorteventid(event_id)?; - let previous_shortstatehash = db.get_room_shortstatehash(room_id)?; + let previous_shortstatehash = self.db.get_room_shortstatehash(room_id)?; - let state_hash = super::calculate_hash( + let state_hash = calculate_hash( &state_ids_compressed .iter() .map(|s| &s[..]) @@ -94,11 +96,11 @@ impl Service<_> { ); let (shortstatehash, already_existed) = - short::get_or_create_shortstatehash(&state_hash, globals)?; + SERVICE.short.get_or_create_shortstatehash(&state_hash)?; if !already_existed { let states_parents = previous_shortstatehash - .map_or_else(|| Ok(Vec::new()), |p| room::state_compressor.load_shortstatehash_info(p))?; + .map_or_else(|| Ok(Vec::new()), |p| SERVICE.room.state_compressor.load_shortstatehash_info(p))?; let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { @@ -117,7 +119,7 @@ impl Service<_> { } else { (state_ids_compressed, HashSet::new()) }; - state_compressor::save_state_from_diff( + SERVICE.room.state_compressor.save_state_from_diff( shortstatehash, statediffnew, statediffremoved, @@ -126,7 +128,7 @@ impl Service<_> { )?; } - db.set_event_state(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?; + self.db.set_event_state(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?; Ok(()) } @@ -135,13 +137,12 @@ impl Service<_> { /// /// This adds all current state events (not including the incoming event) /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. - #[tracing::instrument(skip(self, new_pdu, globals))] + #[tracing::instrument(skip(self, new_pdu))] pub fn append_to_state( &self, new_pdu: &PduEvent, - globals: &super::globals::Globals, ) -> Result { - let shorteventid = self.get_or_create_shorteventid(&new_pdu.event_id, globals)?; + let shorteventid = self.get_or_create_shorteventid(&new_pdu.event_id)?; let previous_shortstatehash = self.get_room_shortstatehash(&new_pdu.room_id)?; @@ -157,10 +158,9 @@ impl Service<_> { let shortstatekey = self.get_or_create_shortstatekey( &new_pdu.kind.to_string().into(), state_key, - globals, )?; - let new = self.compress_state_event(shortstatekey, &new_pdu.event_id, globals)?; + let new = self.compress_state_event(shortstatekey, &new_pdu.event_id)?; let replaces = states_parents .last() @@ -176,7 +176,7 @@ impl Service<_> { } // TODO: statehash with deterministic inputs - let shortstatehash = globals.next_count()?; + let shortstatehash = SERVICE.globals.next_count()?; let mut statediffnew = HashSet::new(); statediffnew.insert(new); @@ -254,7 +254,23 @@ impl Service<_> { Ok(()) } - pub fn db(&self) -> D { - &self.db + /// Returns the room's version. + #[tracing::instrument(skip(self))] + pub fn get_room_version(&self, room_id: &RoomId) -> Result { + let create_event = self.room_state_get(room_id, &StateEventType::RoomCreate, "")?; + + let create_event_content: Option = create_event + .as_ref() + .map(|create_event| { + serde_json::from_str(create_event.content.get()).map_err(|e| { + warn!("Invalid create event: {}", e); + Error::bad_database("Invalid create event in db.") + }) + }) + .transpose()?; + let room_version = create_event_content + .map(|create_event| create_event.room_version) + .ok_or_else(|| Error::BadDatabase("Invalid room version"))?; + Ok(room_version) } } diff --git a/src/service/rooms/state_accessor/data.rs b/src/service/rooms/state_accessor/data.rs index a2b76e46..bf2972f9 100644 --- a/src/service/rooms/state_accessor/data.rs +++ b/src/service/rooms/state_accessor/data.rs @@ -1,3 +1,9 @@ +use std::{sync::Arc, collections::HashMap}; + +use ruma::{EventId, events::StateEventType, RoomId}; + +use crate::PduEvent; + pub trait Data { /// Builds a StateMap by iterating over all keys that start /// with state_hash, this gives the full state for the given state_hash. diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index 28a49a98..92e5c8e1 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -1,7 +1,10 @@ mod data; -pub use data::Data; +use std::{sync::Arc, collections::{HashMap, BTreeMap}}; -use crate::service::*; +pub use data::Data; +use ruma::{events::StateEventType, RoomId, EventId}; + +use crate::{service::*, PduEvent}; pub struct Service { db: D, @@ -42,7 +45,7 @@ impl Service<_> { event_type: &StateEventType, state_key: &str, ) -> Result>> { - self.db.pdu_state_get(event_id) + self.db.pdu_state_get(shortstatehash, event_type, state_key) } /// Returns the state hash for this pdu. diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs index 166d4f6b..f6519196 100644 --- a/src/service/rooms/state_cache/data.rs +++ b/src/service/rooms/state_cache/data.rs @@ -1,3 +1,5 @@ +use ruma::{UserId, RoomId}; + pub trait Data { fn mark_as_once_joined(user_id: &UserId, room_id: &RoomId) -> Result<()>; } diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index 778679de..d29501a6 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -1,7 +1,11 @@ mod data; -pub use data::Data; +use std::{collections::HashSet, sync::Arc}; -use crate::service::*; +pub use data::Data; +use regex::Regex; +use ruma::{RoomId, UserId, events::{room::{member::MembershipState, create::RoomCreateEventContent}, AnyStrippedStateEvent, StateEventType, tag::TagEvent, RoomAccountDataEventType, GlobalAccountDataEventType, direct::DirectEvent, ignored_user_list::IgnoredUserListEvent, AnySyncStateEvent}, serde::Raw, ServerName}; + +use crate::{service::*, SERVICE, utils, Error}; pub struct Service { db: D, @@ -9,7 +13,7 @@ pub struct Service { impl Service<_> { /// Update current membership data. - #[tracing::instrument(skip(self, last_state, db))] + #[tracing::instrument(skip(self, last_state))] pub fn update_membership( &self, room_id: &RoomId, @@ -17,12 +21,11 @@ impl Service<_> { membership: MembershipState, sender: &UserId, last_state: Option>>, - db: &Database, update_joined_count: bool, ) -> Result<()> { // Keep track what remote users exist by adding them as "deactivated" users - if user_id.server_name() != db.globals.server_name() { - db.users.create(user_id, None)?; + if user_id.server_name() != SERVICE.globals.server_name() { + SERVICE.users.create(user_id, None)?; // TODO: displayname, avatar url } @@ -82,7 +85,7 @@ impl Service<_> { user_id, RoomAccountDataEventType::Tag, )? { - db.account_data + SERVICE.account_data .update( Some(room_id), user_id, @@ -94,7 +97,7 @@ impl Service<_> { }; // Copy direct chat flag - if let Some(mut direct_event) = db.account_data.get::( + if let Some(mut direct_event) = SERVICE.account_data.get::( None, user_id, GlobalAccountDataEventType::Direct.to_string().into(), @@ -109,12 +112,11 @@ impl Service<_> { } if room_ids_updated { - db.account_data.update( + SERVICE.account_data.update( None, user_id, GlobalAccountDataEventType::Direct.to_string().into(), &direct_event, - &db.globals, )?; } }; @@ -130,7 +132,7 @@ impl Service<_> { } MembershipState::Invite => { // We want to know if the sender is ignored by the receiver - let is_ignored = db + let is_ignored = SERVICE .account_data .get::( None, // Ignored users are in global account data @@ -186,7 +188,7 @@ impl Service<_> { } #[tracing::instrument(skip(self, room_id, db))] - pub fn update_joined_count(&self, room_id: &RoomId, db: &Database) -> Result<()> { + pub fn update_joined_count(&self, room_id: &RoomId) -> Result<()> { let mut joinedcount = 0_u64; let mut invitedcount = 0_u64; let mut joined_servers = HashSet::new(); @@ -226,11 +228,10 @@ impl Service<_> { Ok(()) } - #[tracing::instrument(skip(self, room_id, db))] + #[tracing::instrument(skip(self, room_id))] pub fn get_our_real_users( &self, room_id: &RoomId, - db: &Database, ) -> Result>>> { let maybe = self .our_real_users_cache @@ -241,7 +242,7 @@ impl Service<_> { if let Some(users) = maybe { Ok(users) } else { - self.update_joined_count(room_id, db)?; + self.update_joined_count(room_id)?; Ok(Arc::clone( self.our_real_users_cache .read() @@ -252,12 +253,11 @@ impl Service<_> { } } - #[tracing::instrument(skip(self, room_id, appservice, db))] + #[tracing::instrument(skip(self, room_id, appservice))] pub fn appservice_in_room( &self, room_id: &RoomId, appservice: &(String, serde_yaml::Value), - db: &Database, ) -> Result { let maybe = self .appservice_in_room_cache @@ -285,7 +285,7 @@ impl Service<_> { .get("sender_localpart") .and_then(|string| string.as_str()) .and_then(|string| { - UserId::parse_with_server_name(string, db.globals.server_name()).ok() + UserId::parse_with_server_name(string, SERVICE.globals.server_name()).ok() }); let in_room = bridge_user_id diff --git a/src/service/rooms/state_compressor/data.rs b/src/service/rooms/state_compressor/data.rs index 8b855cd2..74a28e7b 100644 --- a/src/service/rooms/state_compressor/data.rs +++ b/src/service/rooms/state_compressor/data.rs @@ -1,4 +1,6 @@ -struct StateDiff { +use crate::service::rooms::CompressedStateEvent; + +pub struct StateDiff { parent: Option, added: Vec, removed: Vec, diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index d6d88e25..3aea4fe6 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -1,7 +1,12 @@ -mod data; -pub use data::Data; +pub mod data; +use std::{mem::size_of, sync::Arc, collections::HashSet}; -use crate::service::*; +pub use data::Data; +use ruma::{EventId, RoomId}; + +use crate::{service::*, utils}; + +use self::data::StateDiff; pub struct Service { db: D, @@ -30,9 +35,9 @@ impl Service<_> { return Ok(r.clone()); } - self.db.get_statediff(shortstatehash)?; + let StateDiff { parent, added, removed } = self.db.get_statediff(shortstatehash)?; - if parent != 0_u64 { + if let Some(parent) = parent { let mut response = self.load_shortstatehash_info(parent)?; let mut state = response.last().unwrap().1.clone(); state.extend(added.iter().copied()); @@ -155,7 +160,7 @@ impl Service<_> { if parent_states.is_empty() { // There is no parent layer, create a new state - self.db.save_statediff(shortstatehash, StateDiff { parent: 0, new: statediffnew, removed: statediffremoved })?; + self.db.save_statediff(shortstatehash, StateDiff { parent: None, added: statediffnew, removed: statediffremoved })?; return Ok(()); }; @@ -197,7 +202,7 @@ impl Service<_> { )?; } else { // Diff small enough, we add diff as layer on top of parent - self.db.save_statediff(shortstatehash, StateDiff { parent: parent.0, new: statediffnew, removed: statediffremoved })?; + self.db.save_statediff(shortstatehash, StateDiff { parent: Some(parent.0), added: statediffnew, removed: statediffremoved })?; } Ok(()) diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index 4e5c3796..bf6d8c5e 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -1,3 +1,9 @@ +use std::sync::Arc; + +use ruma::{signatures::CanonicalJsonObject, EventId, UserId, RoomId}; + +use crate::PduEvent; + pub trait Data { fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result; @@ -5,34 +11,37 @@ pub trait Data { fn get_pdu_count(&self, event_id: &EventId) -> Result>; /// Returns the json of a pdu. - pub fn get_pdu_json(&self, event_id: &EventId) -> Result>; + fn get_pdu_json(&self, event_id: &EventId) -> Result>; /// Returns the json of a pdu. - pub fn get_non_outlier_pdu_json( + fn get_non_outlier_pdu_json( + &self, + event_id: &EventId, + ) -> Result>; /// Returns the pdu's id. - pub fn get_pdu_id(&self, event_id: &EventId) -> Result>>; + fn get_pdu_id(&self, event_id: &EventId) -> Result>>; /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. - pub fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result>; + fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result>; /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. - pub fn get_pdu(&self, event_id: &EventId) -> Result>>; + fn get_pdu(&self, event_id: &EventId) -> Result>>; /// Returns the pdu. /// /// This does __NOT__ check the outliers `Tree`. - pub fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result>; + fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result>; /// Returns the pdu as a `BTreeMap`. - pub fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result>; + fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result>; /// Returns the `count` of this pdu's id. - pub fn pdu_count(&self, pdu_id: &[u8]) -> Result; + fn pdu_count(&self, pdu_id: &[u8]) -> Result; /// Removes a pdu and creates a new one with the same id. fn replace_pdu(&self, pdu_id: &[u8], pdu: &PduEvent) -> Result<()>; @@ -40,7 +49,7 @@ pub trait Data { /// Returns an iterator over all events in a room that happened after the event with id `since` /// in chronological order. #[tracing::instrument(skip(self))] - pub fn pdus_since<'a>( + fn pdus_since<'a>( &'a self, user_id: &UserId, room_id: &RoomId, @@ -50,14 +59,14 @@ pub trait Data { /// Returns an iterator over all events and their tokens in a room that happened before the /// event with id `until` in reverse-chronological order. #[tracing::instrument(skip(self))] - pub fn pdus_until<'a>( + fn pdus_until<'a>( &'a self, user_id: &UserId, room_id: &RoomId, until: u64, ) -> Result, PduEvent)>> + 'a>; - pub fn pdus_after<'a>( + fn pdus_after<'a>( &'a self, user_id: &UserId, room_id: &RoomId, diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index c6393c68..7b60fe5d 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -1,7 +1,17 @@ mod data; -pub use data::Data; +use std::{sync::MutexGuard, iter, collections::HashSet}; +use std::fmt::Debug; -use crate::service::*; +pub use data::Data; +use regex::Regex; +use ruma::signatures::CanonicalJsonValue; +use ruma::{EventId, signatures::CanonicalJsonObject, push::{Action, Tweak}, events::{push_rules::PushRulesEvent, GlobalAccountDataEventType, RoomEventType, room::{member::MembershipState, create::RoomCreateEventContent}, StateEventType}, UserId, RoomAliasId, RoomId, uint, state_res, api::client::error::ErrorKind, serde::to_canonical_value, ServerName}; +use serde::Deserialize; +use serde_json::value::to_raw_value; +use tracing::{warn, error}; + +use crate::SERVICE; +use crate::{service::{*, pdu::{PduBuilder, EventHash}}, Error, PduEvent, utils}; pub struct Service { db: D, @@ -126,13 +136,12 @@ impl Service<_> { /// in `append_pdu`. /// /// Returns pdu id - #[tracing::instrument(skip(self, pdu, pdu_json, leaves, db))] + #[tracing::instrument(skip(self, pdu, pdu_json, leaves))] pub fn append_pdu<'a>( &self, pdu: &PduEvent, mut pdu_json: CanonicalJsonObject, leaves: impl IntoIterator + Debug, - db: &Database, ) -> Result> { let shortroomid = self.get_shortroomid(&pdu.room_id)?.expect("room exists"); @@ -249,7 +258,6 @@ impl Service<_> { &power_levels, &sync_pdu, &pdu.room_id, - db, )? { match action { Action::DontNotify => notify = false, @@ -446,9 +454,8 @@ impl Service<_> { pdu_builder: PduBuilder, sender: &UserId, room_id: &RoomId, - db: &Database, _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex - ) -> (PduEvent, CanonicalJsonObj) { + ) -> (PduEvent, CanonicalJsonObject) { let PduBuilder { event_type, content, @@ -457,14 +464,14 @@ impl Service<_> { redacts, } = pdu_builder; - let prev_events: Vec<_> = db + let prev_events: Vec<_> = SERVICE .rooms .get_pdu_leaves(room_id)? .into_iter() .take(20) .collect(); - let create_event = db + let create_event = SERVICE .rooms .room_state_get(room_id, &StateEventType::RoomCreate, "")?; @@ -481,7 +488,7 @@ impl Service<_> { // If there was no create event yet, assume we are creating a room with the default // version right now let room_version_id = create_event_content - .map_or(db.globals.default_room_version(), |create_event| { + .map_or(SERVICE.globals.default_room_version(), |create_event| { create_event.room_version }); let room_version = @@ -575,8 +582,8 @@ impl Service<_> { ); match ruma::signatures::hash_and_sign_event( - db.globals.server_name().as_str(), - db.globals.keypair(), + SERVICE.globals.server_name().as_str(), + SERVICE.globals.keypair(), &mut pdu_json, &room_version_id, ) { @@ -614,22 +621,21 @@ impl Service<_> { /// Creates a new persisted data unit and adds it to a room. This function takes a /// roomid_mutex_state, meaning that only this function is able to mutate the room state. - #[tracing::instrument(skip(self, db, _mutex_lock))] + #[tracing::instrument(skip(self, _mutex_lock))] pub fn build_and_append_pdu( &self, pdu_builder: PduBuilder, sender: &UserId, room_id: &RoomId, - db: &Database, _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex ) -> Result> { - let (pdu, pdu_json) = create_hash_and_sign_event()?; + let (pdu, pdu_json) = self.create_hash_and_sign_event()?; // We append to state before appending the pdu, so we don't have a moment in time with the // pdu without it's state. This is okay because append_pdu can't fail. - let statehashid = self.append_to_state(&pdu, &db.globals)?; + let statehashid = self.append_to_state(&pdu)?; let pdu_id = self.append_pdu( &pdu, @@ -637,7 +643,6 @@ impl Service<_> { // Since this PDU references all pdu_leaves we can update the leaves // of the room iter::once(&*pdu.event_id), - db, )?; // We set the room state after inserting the pdu, so that we never have a moment in time @@ -659,9 +664,9 @@ impl Service<_> { } // Remove our server from the server list since it will be added to it by room_servers() and/or the if statement above - servers.remove(db.globals.server_name()); + servers.remove(SERVICE.globals.server_name()); - db.sending.send_pdu(servers.into_iter(), &pdu_id)?; + SERVICE.sending.send_pdu(servers.into_iter(), &pdu_id)?; Ok(pdu.event_id) } @@ -670,7 +675,6 @@ impl Service<_> { /// server that sent the event. #[tracing::instrument(skip_all)] fn append_incoming_pdu<'a>( - db: &Database, pdu: &PduEvent, pdu_json: CanonicalJsonObject, new_room_leaves: impl IntoIterator + Clone + Debug, @@ -680,21 +684,20 @@ impl Service<_> { ) -> Result>> { // We append to state before appending the pdu, so we don't have a moment in time with the // pdu without it's state. This is okay because append_pdu can't fail. - db.rooms.set_event_state( + SERVICE.rooms.set_event_state( &pdu.event_id, &pdu.room_id, state_ids_compressed, - &db.globals, )?; if soft_fail { - db.rooms + SERVICE.rooms .mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; - db.rooms.replace_pdu_leaves(&pdu.room_id, new_room_leaves)?; + SERVICE.rooms.replace_pdu_leaves(&pdu.room_id, new_room_leaves)?; return Ok(None); } - let pdu_id = db.rooms.append_pdu(pdu, pdu_json, new_room_leaves, db)?; + let pdu_id = SERVICE.rooms.append_pdu(pdu, pdu_json, new_room_leaves)?; Ok(Some(pdu_id)) } @@ -756,4 +759,4 @@ impl Service<_> { // If event does not exist, just noop Ok(()) } - +} diff --git a/src/service/rooms/user/mod.rs b/src/service/rooms/user/mod.rs index 45fb3551..664f8a0a 100644 --- a/src/service/rooms/user/mod.rs +++ b/src/service/rooms/user/mod.rs @@ -1,5 +1,6 @@ mod data; pub use data::Data; +use ruma::{RoomId, UserId}; use crate::service::*; diff --git a/src/service/transaction_ids/data.rs b/src/service/transaction_ids/data.rs index f1ff5f88..c1b47154 100644 --- a/src/service/transaction_ids/data.rs +++ b/src/service/transaction_ids/data.rs @@ -1,5 +1,5 @@ pub trait Data { - pub fn add_txnid( + fn add_txnid( &self, user_id: &UserId, device_id: Option<&DeviceId>, @@ -7,7 +7,7 @@ pub trait Data { data: &[u8], ) -> Result<()>; - pub fn existing_txnid( + fn existing_txnid( &self, user_id: &UserId, device_id: Option<&DeviceId>, diff --git a/src/service/transaction_ids/mod.rs b/src/service/transaction_ids/mod.rs index d944847e..9b76e13b 100644 --- a/src/service/transaction_ids/mod.rs +++ b/src/service/transaction_ids/mod.rs @@ -1,5 +1,6 @@ mod data; pub use data::Data; +use ruma::{UserId, DeviceId, TransactionId}; use crate::service::*; diff --git a/src/service/uiaa/data.rs b/src/service/uiaa/data.rs index 40e69bda..cc943bff 100644 --- a/src/service/uiaa/data.rs +++ b/src/service/uiaa/data.rs @@ -1,3 +1,5 @@ +use ruma::{api::client::uiaa::UiaaInfo, DeviceId, UserId, signatures::CanonicalJsonValue}; + pub trait Data { fn set_uiaa_request( &self, diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs index 593ea5f2..5e1df8f3 100644 --- a/src/service/uiaa/mod.rs +++ b/src/service/uiaa/mod.rs @@ -1,7 +1,9 @@ mod data; pub use data::Data; +use ruma::{api::client::{uiaa::{UiaaInfo, IncomingAuthData, IncomingPassword, AuthType}, error::ErrorKind}, DeviceId, UserId, signatures::CanonicalJsonValue}; +use tracing::error; -use crate::service::*; +use crate::{service::*, utils, Error, SERVICE}; pub struct Service { db: D, @@ -36,8 +38,6 @@ impl Service<_> { device_id: &DeviceId, auth: &IncomingAuthData, uiaainfo: &UiaaInfo, - users: &super::users::Users, - globals: &super::globals::Globals, ) -> Result<(bool, UiaaInfo)> { let mut uiaainfo = auth .session() @@ -66,13 +66,13 @@ impl Service<_> { }; let user_id = - UserId::parse_with_server_name(username.clone(), globals.server_name()) + UserId::parse_with_server_name(username.clone(), SERVICE.globals.server_name()) .map_err(|_| { Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid.") })?; // Check if password is correct - if let Some(hash) = users.password_hash(&user_id)? { + if let Some(hash) = SERVICE.users.password_hash(&user_id)? { let hash_matches = argon2::verify_encoded(&hash, password.as_bytes()).unwrap_or(false); diff --git a/src/service/users/data.rs b/src/service/users/data.rs index d99d0328..327e0c69 100644 --- a/src/service/users/data.rs +++ b/src/service/users/data.rs @@ -1,34 +1,27 @@ -pub trait Data { +use std::collections::BTreeMap; + +use ruma::{UserId, DeviceId, DeviceKeyAlgorithm, DeviceKeyId, serde::Raw, encryption::{OneTimeKey, DeviceKeys, CrossSigningKey}, UInt, events::AnyToDeviceEvent, api::client::{device::Device, filter::IncomingFilterDefinition}, MxcUri}; + +trait Data { /// Check if a user has an account on this homeserver. - pub fn exists(&self, user_id: &UserId) -> Result; + fn exists(&self, user_id: &UserId) -> Result; /// Check if account is deactivated - pub fn is_deactivated(&self, user_id: &UserId) -> Result; - - /// Check if a user is an admin - pub fn is_admin( - &self, - user_id: &UserId, - rooms: &super::rooms::Rooms, - globals: &super::globals::Globals, - ) -> Result; - - /// Create a new user account on this homeserver. - pub fn create(&self, user_id: &UserId, password: Option<&str>) -> Result<()>; + fn is_deactivated(&self, user_id: &UserId) -> Result; /// Returns the number of users registered on this server. - pub fn count(&self) -> Result; + fn count(&self) -> Result; /// Find out which user an access token belongs to. - pub fn find_from_token(&self, token: &str) -> Result, String)>>; + fn find_from_token(&self, token: &str) -> Result, String)>>; /// Returns an iterator over all users on this homeserver. - pub fn iter(&self) -> impl Iterator>> + '_; + fn iter(&self) -> impl Iterator>> + '_; /// Returns a list of local users as list of usernames. /// /// A user account is considered `local` if the length of it's password is greater then zero. - pub fn list_local_users(&self) -> Result>; + fn list_local_users(&self) -> Result>; /// Will only return with Some(username) if the password was not empty and the /// username could be successfully parsed. @@ -37,31 +30,31 @@ pub trait Data { fn get_username_with_valid_password(&self, username: &[u8], password: &[u8]) -> Option; /// Returns the password hash for the given user. - pub fn password_hash(&self, user_id: &UserId) -> Result>; + fn password_hash(&self, user_id: &UserId) -> Result>; /// Hash and set the user's password to the Argon2 hash - pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()>; + fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()>; /// Returns the displayname of a user on this homeserver. - pub fn displayname(&self, user_id: &UserId) -> Result>; + fn displayname(&self, user_id: &UserId) -> Result>; /// Sets a new displayname or removes it if displayname is None. You still need to nofify all rooms of this change. - pub fn set_displayname(&self, user_id: &UserId, displayname: Option) -> Result<()>; + fn set_displayname(&self, user_id: &UserId, displayname: Option) -> Result<()>; /// Get the avatar_url of a user. - pub fn avatar_url(&self, user_id: &UserId) -> Result>>; + fn avatar_url(&self, user_id: &UserId) -> Result>>; /// Sets a new avatar_url or removes it if avatar_url is None. - pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option>) -> Result<()>; + fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option>) -> Result<()>; /// Get the blurhash of a user. - pub fn blurhash(&self, user_id: &UserId) -> Result>; + fn blurhash(&self, user_id: &UserId) -> Result>; /// Sets a new avatar_url or removes it if avatar_url is None. - pub fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()>; + fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()>; /// Adds a new device to a user. - pub fn create_device( + fn create_device( &self, user_id: &UserId, device_id: &DeviceId, @@ -70,129 +63,118 @@ pub trait Data { ) -> Result<()>; /// Removes a device from a user. - pub fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()>; + fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()>; /// Returns an iterator over all device ids of this user. - pub fn all_device_ids<'a>( + fn all_device_ids<'a>( &'a self, user_id: &UserId, ) -> impl Iterator>> + 'a; /// Replaces the access token of one device. - pub fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()>; + fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()>; - pub fn add_one_time_key( + fn add_one_time_key( &self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &DeviceKeyId, one_time_key_value: &Raw, - globals: &super::globals::Globals, ) -> Result<()>; - pub fn last_one_time_keys_update(&self, user_id: &UserId) -> Result; + fn last_one_time_keys_update(&self, user_id: &UserId) -> Result; - pub fn take_one_time_key( + fn take_one_time_key( &self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm, - globals: &super::globals::Globals, ) -> Result, Raw)>>; - pub fn count_one_time_keys( + fn count_one_time_keys( &self, user_id: &UserId, device_id: &DeviceId, ) -> Result>; - pub fn add_device_keys( + fn add_device_keys( &self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw, - rooms: &super::rooms::Rooms, - globals: &super::globals::Globals, ) -> Result<()>; - pub fn add_cross_signing_keys( + fn add_cross_signing_keys( &self, user_id: &UserId, master_key: &Raw, self_signing_key: &Option>, user_signing_key: &Option>, - rooms: &super::rooms::Rooms, - globals: &super::globals::Globals, ) -> Result<()>; - pub fn sign_key( + fn sign_key( &self, target_id: &UserId, key_id: &str, signature: (String, String), sender_id: &UserId, - rooms: &super::rooms::Rooms, - globals: &super::globals::Globals, ) -> Result<()>; - pub fn keys_changed<'a>( + fn keys_changed<'a>( &'a self, user_or_room_id: &str, from: u64, to: Option, ) -> impl Iterator>> + 'a; - pub fn mark_device_key_update( + fn mark_device_key_update( &self, user_id: &UserId, - rooms: &super::rooms::Rooms, - globals: &super::globals::Globals, ) -> Result<()>; - pub fn get_device_keys( + fn get_device_keys( &self, user_id: &UserId, device_id: &DeviceId, ) -> Result>>; - pub fn get_master_key bool>( + fn get_master_key bool>( &self, user_id: &UserId, allowed_signatures: F, ) -> Result>>; - pub fn get_self_signing_key bool>( + fn get_self_signing_key bool>( &self, user_id: &UserId, allowed_signatures: F, ) -> Result>>; - pub fn get_user_signing_key(&self, user_id: &UserId) -> Result>>; + fn get_user_signing_key(&self, user_id: &UserId) -> Result>>; - pub fn add_to_device_event( + fn add_to_device_event( &self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str, content: serde_json::Value, - globals: &super::globals::Globals, ) -> Result<()>; - pub fn get_to_device_events( + fn get_to_device_events( &self, user_id: &UserId, device_id: &DeviceId, ) -> Result>>; - pub fn remove_to_device_events( + fn remove_to_device_events( &self, user_id: &UserId, device_id: &DeviceId, until: u64, ) -> Result<()>; - pub fn update_device_metadata( + fn update_device_metadata( &self, user_id: &UserId, device_id: &DeviceId, @@ -200,27 +182,27 @@ pub trait Data { ) -> Result<()>; /// Get device metadata. - pub fn get_device_metadata( + fn get_device_metadata( &self, user_id: &UserId, device_id: &DeviceId, ) -> Result>; - pub fn get_devicelist_version(&self, user_id: &UserId) -> Result>; + fn get_devicelist_version(&self, user_id: &UserId) -> Result>; - pub fn all_devices_metadata<'a>( + fn all_devices_metadata<'a>( &'a self, user_id: &UserId, ) -> impl Iterator> + 'a; /// Creates a new sync filter. Returns the filter id. - pub fn create_filter( + fn create_filter( &self, user_id: &UserId, filter: &IncomingFilterDefinition, ) -> Result; - pub fn get_filter( + fn get_filter( &self, user_id: &UserId, filter_id: &str, diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index 93d6ea52..bfa4b8e5 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -1,7 +1,10 @@ mod data; -pub use data::Data; +use std::{collections::BTreeMap, mem}; -use crate::service::*; +pub use data::Data; +use ruma::{UserId, MxcUri, DeviceId, DeviceKeyId, serde::Raw, encryption::{OneTimeKey, CrossSigningKey, DeviceKeys}, DeviceKeyAlgorithm, UInt, events::AnyToDeviceEvent, api::client::{device::Device, filter::IncomingFilterDefinition}}; + +use crate::{service::*, Error}; pub struct Service { db: D, @@ -19,18 +22,24 @@ impl Service<_> { } /// Check if a user is an admin - pub fn is_admin( + fn is_admin( &self, user_id: &UserId, ) -> Result { - self.db.is_admin(user_id) + let admin_room_alias_id = RoomAliasId::parse(format!("#admins:{}", globals.server_name())) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid alias."))?; + let admin_room_id = rooms.id_from_alias(&admin_room_alias_id)?.unwrap(); + + rooms.is_joined(user_id, &admin_room_id) } /// Create a new user account on this homeserver. - pub fn create(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { - self.db.set_password(user_id, password) + fn create(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { + self.db.set_password(user_id, password)?; + Ok(()) } + /// Returns the number of users registered on this server. pub fn count(&self) -> Result { self.db.count() @@ -136,7 +145,6 @@ impl Service<_> { device_id: &DeviceId, one_time_key_key: &DeviceKeyId, one_time_key_value: &Raw, - globals: &super::globals::Globals, ) -> Result<()> { self.db.add_one_time_key(user_id, device_id, one_time_key_key, one_time_key_value) } @@ -220,7 +228,7 @@ impl Service<_> { user_id: &UserId, allowed_signatures: F, ) -> Result>> { - self.db.get_master_key(user_id, allow_signatures) + self.db.get_master_key(user_id, allowed_signatures) } pub fn get_self_signing_key bool>( @@ -327,7 +335,7 @@ impl Service<_> { } /// Ensure that a user only sees signatures from themselves and the target user -fn clean_signatures bool>( +pub fn clean_signatures bool>( cross_signing_key: &mut serde_json::Value, user_id: &UserId, allowed_signatures: F, diff --git a/src/utils/utils.rs b/src/utils/mod.rs similarity index 99% rename from src/utils/utils.rs rename to src/utils/mod.rs index 1ad0aa3f..734da2a8 100644 --- a/src/utils/utils.rs +++ b/src/utils/mod.rs @@ -1,3 +1,5 @@ +pub mod error; + use argon2::{Config, Variant}; use cmp::Ordering; use rand::prelude::*;