From 14e6afc45e320fc420273cd5de9ceb8a690a1806 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Sun, 4 Jul 2021 01:18:06 +0200 Subject: [PATCH] remove eldrich being and install good being --- .gitignore | 1 + src/client_server/account.rs | 10 +- src/client_server/alias.rs | 8 +- src/client_server/backup.rs | 30 ++--- src/client_server/config.rs | 10 +- src/client_server/context.rs | 4 +- src/client_server/device.rs | 12 +- src/client_server/directory.rs | 10 +- src/client_server/keys.rs | 14 +- src/client_server/media.rs | 14 +- src/client_server/membership.rs | 23 ++-- src/client_server/message.rs | 6 +- src/client_server/presence.rs | 6 +- src/client_server/profile.rs | 12 +- src/client_server/push.rs | 22 +-- src/client_server/read_marker.rs | 6 +- src/client_server/redact.rs | 4 +- src/client_server/room.rs | 11 +- src/client_server/search.rs | 4 +- src/client_server/session.rs | 8 +- src/client_server/state.rs | 12 +- src/client_server/sync.rs | 16 ++- src/client_server/tag.rs | 8 +- src/client_server/to_device.rs | 4 +- src/client_server/typing.rs | 4 +- src/client_server/user_directory.rs | 4 +- src/database.rs | 201 +++++++++++++++++----------- src/database/abstraction/sqlite.rs | 44 +----- src/database/admin.rs | 54 ++++---- src/database/sending.rs | 43 +++--- src/lib.rs | 3 +- src/main.rs | 33 ++++- src/ruma_wrapper.rs | 4 +- src/server_server.rs | 41 +++--- 34 files changed, 371 insertions(+), 315 deletions(-) diff --git a/.gitignore b/.gitignore index e2f4e882..1f5f395f 100644 --- a/.gitignore +++ b/.gitignore @@ -59,6 +59,7 @@ $RECYCLE.BIN/ # Conduit Rocket.toml conduit.toml +conduit.db # Etc. **/*.rs.bk diff --git a/src/client_server/account.rs b/src/client_server/account.rs index f495e287..dad1b2a9 100644 --- a/src/client_server/account.rs +++ b/src/client_server/account.rs @@ -1,7 +1,7 @@ use std::{collections::BTreeMap, convert::TryInto, sync::Arc}; use super::{State, DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH}; -use crate::{pdu::PduBuilder, utils, ConduitResult, Database, Error, Ruma}; +use crate::{database::ReadGuard, pdu::PduBuilder, utils, ConduitResult, Database, Error, Ruma}; use log::info; use ruma::{ api::client::{ @@ -42,7 +42,7 @@ const GUEST_NAME_LENGTH: usize = 10; )] #[tracing::instrument(skip(db, body))] pub async fn get_register_available_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { // Validate user id @@ -85,7 +85,7 @@ pub async fn get_register_available_route( )] #[tracing::instrument(skip(db, body))] pub async fn register_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_registration() && !body.from_appservice { @@ -496,7 +496,7 @@ pub async fn register_route( )] #[tracing::instrument(skip(db, body))] pub async fn change_password_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -588,7 +588,7 @@ pub async fn whoami_route(body: Ruma) -> ConduitResult>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/alias.rs b/src/client_server/alias.rs index a54bd36d..e9e3a23e 100644 --- a/src/client_server/alias.rs +++ b/src/client_server/alias.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use super::State; -use crate::{ConduitResult, Database, Error, Ruma}; +use crate::{database::ReadGuard, ConduitResult, Database, Error, Ruma}; use regex::Regex; use ruma::{ api::{ @@ -24,7 +24,7 @@ use rocket::{delete, get, put}; )] #[tracing::instrument(skip(db, body))] pub async fn create_alias_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { if db.rooms.id_from_alias(&body.room_alias)?.is_some() { @@ -45,7 +45,7 @@ pub async fn create_alias_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_alias_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { db.rooms.set_alias(&body.room_alias, None, &db.globals)?; @@ -61,7 +61,7 @@ pub async fn delete_alias_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_alias_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { get_alias_helper(&db, &body.room_alias).await diff --git a/src/client_server/backup.rs b/src/client_server/backup.rs index fcca676f..a50e71e8 100644 --- a/src/client_server/backup.rs +++ b/src/client_server/backup.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use super::State; -use crate::{ConduitResult, Database, Error, Ruma}; +use crate::{database::ReadGuard, ConduitResult, Database, Error, Ruma}; use ruma::api::client::{ error::ErrorKind, r0::backup::{ @@ -21,7 +21,7 @@ use rocket::{delete, get, post, put}; )] #[tracing::instrument(skip(db, body))] pub async fn create_backup_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -40,7 +40,7 @@ pub async fn create_backup_route( )] #[tracing::instrument(skip(db, body))] pub async fn update_backup_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -58,7 +58,7 @@ pub async fn update_backup_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_latest_backup_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -86,7 +86,7 @@ pub async fn get_latest_backup_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_backup_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -113,7 +113,7 @@ pub async fn get_backup_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_backup_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -132,7 +132,7 @@ pub async fn delete_backup_route( )] #[tracing::instrument(skip(db, body))] pub async fn add_backup_keys_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -166,7 +166,7 @@ pub async fn add_backup_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn add_backup_key_sessions_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -198,7 +198,7 @@ pub async fn add_backup_key_sessions_route( )] #[tracing::instrument(skip(db, body))] pub async fn add_backup_key_session_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -227,7 +227,7 @@ pub async fn add_backup_key_session_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_backup_keys_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -243,7 +243,7 @@ pub async fn get_backup_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_backup_key_sessions_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -261,7 +261,7 @@ pub async fn get_backup_key_sessions_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_backup_key_session_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -283,7 +283,7 @@ pub async fn get_backup_key_session_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_backup_keys_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -306,7 +306,7 @@ pub async fn delete_backup_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_backup_key_sessions_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -329,7 +329,7 @@ pub async fn delete_backup_key_sessions_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_backup_key_session_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/config.rs b/src/client_server/config.rs index 829bf94a..ed2628a9 100644 --- a/src/client_server/config.rs +++ b/src/client_server/config.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use super::State; -use crate::{ConduitResult, Database, Error, Ruma}; +use crate::{database::ReadGuard, ConduitResult, Database, Error, Ruma}; use ruma::{ api::client::{ error::ErrorKind, @@ -25,7 +25,7 @@ use rocket::{get, put}; )] #[tracing::instrument(skip(db, body))] pub async fn set_global_account_data_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -60,7 +60,7 @@ pub async fn set_global_account_data_route( )] #[tracing::instrument(skip(db, body))] pub async fn set_room_account_data_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -92,7 +92,7 @@ pub async fn set_room_account_data_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_global_account_data_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -119,7 +119,7 @@ pub async fn get_global_account_data_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_room_account_data_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/context.rs b/src/client_server/context.rs index b86fd0bf..7a3c083a 100644 --- a/src/client_server/context.rs +++ b/src/client_server/context.rs @@ -1,5 +1,5 @@ use super::State; -use crate::{ConduitResult, Database, Error, Ruma}; +use crate::{database::ReadGuard, ConduitResult, Database, Error, Ruma}; use ruma::api::client::{error::ErrorKind, r0::context::get_context}; use std::{convert::TryFrom, sync::Arc}; @@ -12,7 +12,7 @@ use rocket::get; )] #[tracing::instrument(skip(db, body))] pub async fn get_context_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/device.rs b/src/client_server/device.rs index 2441524d..361af68e 100644 --- a/src/client_server/device.rs +++ b/src/client_server/device.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use super::State; -use crate::{utils, ConduitResult, Database, Error, Ruma}; +use crate::{database::ReadGuard, utils, ConduitResult, Database, Error, Ruma}; use ruma::api::client::{ error::ErrorKind, r0::{ @@ -20,7 +20,7 @@ use rocket::{delete, get, post, put}; )] #[tracing::instrument(skip(db, body))] pub async fn get_devices_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -40,7 +40,7 @@ pub async fn get_devices_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_device_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -59,7 +59,7 @@ pub async fn get_device_route( )] #[tracing::instrument(skip(db, body))] pub async fn update_device_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -85,7 +85,7 @@ pub async fn update_device_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_device_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -139,7 +139,7 @@ pub async fn delete_device_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_devices_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/directory.rs b/src/client_server/directory.rs index 1b6b1d7b..3c96ec14 100644 --- a/src/client_server/directory.rs +++ b/src/client_server/directory.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use super::State; -use crate::{ConduitResult, Database, Error, Result, Ruma}; +use crate::{database::ReadGuard, ConduitResult, Database, Error, Result, Ruma}; use log::info; use ruma::{ api::{ @@ -35,7 +35,7 @@ use rocket::{get, post, put}; )] #[tracing::instrument(skip(db, body))] pub async fn get_public_rooms_filtered_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { get_public_rooms_filtered_helper( @@ -55,7 +55,7 @@ pub async fn get_public_rooms_filtered_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_public_rooms_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let response = get_public_rooms_filtered_helper( @@ -84,7 +84,7 @@ pub async fn get_public_rooms_route( )] #[tracing::instrument(skip(db, body))] pub async fn set_room_visibility_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -114,7 +114,7 @@ pub async fn set_room_visibility_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_room_visibility_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { Ok(get_room_visibility::Response { diff --git a/src/client_server/keys.rs b/src/client_server/keys.rs index f80a3294..310fd620 100644 --- a/src/client_server/keys.rs +++ b/src/client_server/keys.rs @@ -1,5 +1,5 @@ use super::{State, SESSION_ID_LENGTH}; -use crate::{utils, ConduitResult, Database, Error, Result, Ruma}; +use crate::{database::ReadGuard, utils, ConduitResult, Database, Error, Result, Ruma}; use ruma::{ api::client::{ error::ErrorKind, @@ -28,7 +28,7 @@ use rocket::{get, post}; )] #[tracing::instrument(skip(db, body))] pub async fn upload_keys_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -77,7 +77,7 @@ pub async fn upload_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_keys_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -98,7 +98,7 @@ pub async fn get_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn claim_keys_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma, ) -> ConduitResult { let response = claim_keys_helper(&body.one_time_keys, &db)?; @@ -114,7 +114,7 @@ pub async fn claim_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn upload_signing_keys_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -177,7 +177,7 @@ pub async fn upload_signing_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn upload_signatures_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -238,7 +238,7 @@ pub async fn upload_signatures_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_key_changes_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/media.rs b/src/client_server/media.rs index 0b1fbd7b..cd7f714c 100644 --- a/src/client_server/media.rs +++ b/src/client_server/media.rs @@ -1,5 +1,7 @@ use super::State; -use crate::{database::media::FileMeta, utils, ConduitResult, Database, Error, Ruma}; +use crate::{ + database::media::FileMeta, database::ReadGuard, utils, ConduitResult, Database, Error, Ruma, +}; use ruma::api::client::{ error::ErrorKind, r0::media::{create_content, get_content, get_content_thumbnail, get_media_config}, @@ -13,9 +15,7 @@ const MXC_LENGTH: usize = 32; #[cfg_attr(feature = "conduit_bin", get("/_matrix/media/r0/config"))] #[tracing::instrument(skip(db))] -pub async fn get_media_config_route( - db: State<'_, Arc>, -) -> ConduitResult { +pub async fn get_media_config_route(db: ReadGuard) -> ConduitResult { Ok(get_media_config::Response { upload_size: db.globals.max_request_size().into(), } @@ -28,7 +28,7 @@ pub async fn get_media_config_route( )] #[tracing::instrument(skip(db, body))] pub async fn create_content_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let mxc = format!( @@ -66,7 +66,7 @@ pub async fn create_content_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_content_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); @@ -119,7 +119,7 @@ pub async fn get_content_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_content_thumbnail_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); diff --git a/src/client_server/membership.rs b/src/client_server/membership.rs index 5c57b68a..b3f0a0e3 100644 --- a/src/client_server/membership.rs +++ b/src/client_server/membership.rs @@ -1,6 +1,7 @@ use super::State; use crate::{ client_server, + database::ReadGuard, pdu::{PduBuilder, PduEvent}, server_server, utils, ConduitResult, Database, Error, Result, Ruma, }; @@ -44,7 +45,7 @@ use rocket::{get, post}; )] #[tracing::instrument(skip(db, body))] pub async fn join_room_by_id_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -81,7 +82,7 @@ pub async fn join_room_by_id_route( )] #[tracing::instrument(skip(db, body))] pub async fn join_room_by_id_or_alias_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -135,7 +136,7 @@ pub async fn join_room_by_id_or_alias_route( )] #[tracing::instrument(skip(db, body))] pub async fn leave_room_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -153,7 +154,7 @@ pub async fn leave_room_route( )] #[tracing::instrument(skip(db, body))] pub async fn invite_user_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -173,7 +174,7 @@ pub async fn invite_user_route( )] #[tracing::instrument(skip(db, body))] pub async fn kick_user_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -223,7 +224,7 @@ pub async fn kick_user_route( )] #[tracing::instrument(skip(db, body))] pub async fn ban_user_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -281,7 +282,7 @@ pub async fn ban_user_route( )] #[tracing::instrument(skip(db, body))] pub async fn unban_user_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -330,7 +331,7 @@ pub async fn unban_user_route( )] #[tracing::instrument(skip(db, body))] pub async fn forget_room_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -348,7 +349,7 @@ pub async fn forget_room_route( )] #[tracing::instrument(skip(db, body))] pub async fn joined_rooms_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -369,7 +370,7 @@ pub async fn joined_rooms_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_member_events_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -399,7 +400,7 @@ pub async fn get_member_events_route( )] #[tracing::instrument(skip(db, body))] pub async fn joined_members_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/message.rs b/src/client_server/message.rs index 0d19f347..9764d53b 100644 --- a/src/client_server/message.rs +++ b/src/client_server/message.rs @@ -1,5 +1,5 @@ use super::State; -use crate::{pdu::PduBuilder, utils, ConduitResult, Database, Error, Ruma}; +use crate::{database::ReadGuard, pdu::PduBuilder, utils, ConduitResult, Database, Error, Ruma}; use ruma::{ api::client::{ error::ErrorKind, @@ -23,7 +23,7 @@ use rocket::{get, put}; )] #[tracing::instrument(skip(db, body))] pub async fn send_message_event_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -86,7 +86,7 @@ pub async fn send_message_event_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_message_events_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/presence.rs b/src/client_server/presence.rs index ce80dfd7..69cde563 100644 --- a/src/client_server/presence.rs +++ b/src/client_server/presence.rs @@ -1,5 +1,5 @@ use super::State; -use crate::{utils, ConduitResult, Database, Ruma}; +use crate::{database::ReadGuard, utils, ConduitResult, Database, Ruma}; use ruma::api::client::r0::presence::{get_presence, set_presence}; use std::{convert::TryInto, sync::Arc, time::Duration}; @@ -12,7 +12,7 @@ use rocket::{get, put}; )] #[tracing::instrument(skip(db, body))] pub async fn set_presence_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -53,7 +53,7 @@ pub async fn set_presence_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_presence_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/profile.rs b/src/client_server/profile.rs index 4e9a37b6..b7e79985 100644 --- a/src/client_server/profile.rs +++ b/src/client_server/profile.rs @@ -1,5 +1,5 @@ use super::State; -use crate::{pdu::PduBuilder, utils, ConduitResult, Database, Error, Ruma}; +use crate::{database::ReadGuard, pdu::PduBuilder, utils, ConduitResult, Database, Error, Ruma}; use ruma::{ api::client::{ error::ErrorKind, @@ -21,7 +21,7 @@ use std::{convert::TryInto, sync::Arc}; )] #[tracing::instrument(skip(db, body))] pub async fn set_displayname_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -108,7 +108,7 @@ pub async fn set_displayname_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_displayname_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { Ok(get_display_name::Response { @@ -123,7 +123,7 @@ pub async fn get_displayname_route( )] #[tracing::instrument(skip(db, body))] pub async fn set_avatar_url_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -210,7 +210,7 @@ pub async fn set_avatar_url_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_avatar_url_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { Ok(get_avatar_url::Response { @@ -225,7 +225,7 @@ pub async fn get_avatar_url_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_profile_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { if !db.users.exists(&body.user_id)? { diff --git a/src/client_server/push.rs b/src/client_server/push.rs index d6f62126..8d4564cb 100644 --- a/src/client_server/push.rs +++ b/src/client_server/push.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use super::State; -use crate::{ConduitResult, Database, Error, Ruma}; +use crate::{database::ReadGuard, ConduitResult, Database, Error, Ruma}; use ruma::{ api::client::{ error::ErrorKind, @@ -24,7 +24,7 @@ use rocket::{delete, get, post, put}; )] #[tracing::instrument(skip(db, body))] pub async fn get_pushrules_all_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -49,7 +49,7 @@ pub async fn get_pushrules_all_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_pushrule_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -103,7 +103,7 @@ pub async fn get_pushrule_route( )] #[tracing::instrument(skip(db, req))] pub async fn set_pushrule_route( - db: State<'_, Arc>, + db: ReadGuard, req: Ruma>, ) -> ConduitResult { let sender_user = req.sender_user.as_ref().expect("user is authenticated"); @@ -206,7 +206,7 @@ pub async fn set_pushrule_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_pushrule_actions_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -265,7 +265,7 @@ pub async fn get_pushrule_actions_route( )] #[tracing::instrument(skip(db, body))] pub async fn set_pushrule_actions_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -339,7 +339,7 @@ pub async fn set_pushrule_actions_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_pushrule_enabled_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -400,7 +400,7 @@ pub async fn get_pushrule_enabled_route( )] #[tracing::instrument(skip(db, body))] pub async fn set_pushrule_enabled_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -479,7 +479,7 @@ pub async fn set_pushrule_enabled_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_pushrule_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -548,7 +548,7 @@ pub async fn delete_pushrule_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_pushers_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -565,7 +565,7 @@ pub async fn get_pushers_route( )] #[tracing::instrument(skip(db, body))] pub async fn set_pushers_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/read_marker.rs b/src/client_server/read_marker.rs index 837170ff..7ab367f3 100644 --- a/src/client_server/read_marker.rs +++ b/src/client_server/read_marker.rs @@ -1,5 +1,5 @@ use super::State; -use crate::{ConduitResult, Database, Error, Ruma}; +use crate::{database::ReadGuard, ConduitResult, Database, Error, Ruma}; use ruma::{ api::client::{ error::ErrorKind, @@ -20,7 +20,7 @@ use std::{collections::BTreeMap, sync::Arc}; )] #[tracing::instrument(skip(db, body))] pub async fn set_read_marker_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -87,7 +87,7 @@ pub async fn set_read_marker_route( )] #[tracing::instrument(skip(db, body))] pub async fn create_receipt_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/redact.rs b/src/client_server/redact.rs index e1930823..98e01e75 100644 --- a/src/client_server/redact.rs +++ b/src/client_server/redact.rs @@ -1,5 +1,5 @@ use super::State; -use crate::{pdu::PduBuilder, ConduitResult, Database, Ruma}; +use crate::{database::ReadGuard, pdu::PduBuilder, ConduitResult, Database, Ruma}; use ruma::{ api::client::r0::redact::redact_event, events::{room::redaction, EventType}, @@ -15,7 +15,7 @@ use rocket::put; )] #[tracing::instrument(skip(db, body))] pub async fn redact_event_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/room.rs b/src/client_server/room.rs index b33b5500..a0b7eb80 100644 --- a/src/client_server/room.rs +++ b/src/client_server/room.rs @@ -1,5 +1,8 @@ use super::State; -use crate::{client_server::invite_helper, pdu::PduBuilder, ConduitResult, Database, Error, Ruma}; +use crate::{ + client_server::invite_helper, database::ReadGuard, pdu::PduBuilder, ConduitResult, Database, + Error, Ruma, +}; use log::info; use ruma::{ api::client::{ @@ -24,7 +27,7 @@ use rocket::{get, post}; )] #[tracing::instrument(skip(db, body))] pub async fn create_room_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -294,7 +297,7 @@ pub async fn create_room_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_room_event_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -322,7 +325,7 @@ pub async fn get_room_event_route( )] #[tracing::instrument(skip(db, body))] pub async fn upgrade_room_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, _room_id: String, ) -> ConduitResult { diff --git a/src/client_server/search.rs b/src/client_server/search.rs index 5fc64d09..25b04588 100644 --- a/src/client_server/search.rs +++ b/src/client_server/search.rs @@ -1,5 +1,5 @@ use super::State; -use crate::{ConduitResult, Database, Error, Ruma}; +use crate::{database::ReadGuard, ConduitResult, Database, Error, Ruma}; use ruma::api::client::{error::ErrorKind, r0::search::search_events}; use std::sync::Arc; @@ -14,7 +14,7 @@ use std::collections::BTreeMap; )] #[tracing::instrument(skip(db, body))] pub async fn search_events_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/session.rs b/src/client_server/session.rs index dd504f18..ff018d2d 100644 --- a/src/client_server/session.rs +++ b/src/client_server/session.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use super::{State, DEVICE_ID_LENGTH, TOKEN_LENGTH}; -use crate::{utils, ConduitResult, Database, Error, Ruma}; +use crate::{database::ReadGuard, utils, ConduitResult, Database, Error, Ruma}; use log::info; use ruma::{ api::client::{ @@ -52,7 +52,7 @@ pub async fn get_login_types_route() -> ConduitResult )] #[tracing::instrument(skip(db, body))] pub async fn login_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { // Validate login method @@ -169,7 +169,7 @@ pub async fn login_route( )] #[tracing::instrument(skip(db, body))] pub async fn logout_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -197,7 +197,7 @@ pub async fn logout_route( )] #[tracing::instrument(skip(db, body))] pub async fn logout_all_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/state.rs b/src/client_server/state.rs index be52834a..1798536b 100644 --- a/src/client_server/state.rs +++ b/src/client_server/state.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use super::State; -use crate::{pdu::PduBuilder, ConduitResult, Database, Error, Result, Ruma}; +use crate::{database::ReadGuard, pdu::PduBuilder, ConduitResult, Database, Error, Result, Ruma}; use ruma::{ api::client::{ error::ErrorKind, @@ -27,7 +27,7 @@ use rocket::{get, put}; )] #[tracing::instrument(skip(db, body))] pub async fn send_state_event_for_key_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -53,7 +53,7 @@ pub async fn send_state_event_for_key_route( )] #[tracing::instrument(skip(db, body))] pub async fn send_state_event_for_empty_key_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -79,7 +79,7 @@ pub async fn send_state_event_for_empty_key_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_state_events_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -126,7 +126,7 @@ pub async fn get_state_events_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_state_events_for_key_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -177,7 +177,7 @@ pub async fn get_state_events_for_key_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_state_events_for_empty_key_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/sync.rs b/src/client_server/sync.rs index 69511fa1..6df8af84 100644 --- a/src/client_server/sync.rs +++ b/src/client_server/sync.rs @@ -1,5 +1,5 @@ use super::State; -use crate::{ConduitResult, Database, Error, Result, Ruma, RumaResponse}; +use crate::{database::ReadGuard, ConduitResult, Database, Error, Result, Ruma, RumaResponse}; use log::error; use ruma::{ api::client::r0::{sync::sync_events, uiaa::UiaaResponse}, @@ -35,13 +35,15 @@ use rocket::{get, tokio}; )] #[tracing::instrument(skip(db, body))] pub async fn sync_events_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> std::result::Result, RumaResponse> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); - let mut rx = match db + let arc_db = Arc::new(db); + + let mut rx = match arc_db .globals .sync_receivers .write() @@ -52,7 +54,7 @@ pub async fn sync_events_route( let (tx, rx) = tokio::sync::watch::channel(None); tokio::spawn(sync_helper_wrapper( - Arc::clone(&db), + Arc::clone(&arc_db), sender_user.clone(), sender_device.clone(), body.since.clone(), @@ -68,7 +70,7 @@ pub async fn sync_events_route( let (tx, rx) = tokio::sync::watch::channel(None); tokio::spawn(sync_helper_wrapper( - Arc::clone(&db), + Arc::clone(&arc_db), sender_user.clone(), sender_device.clone(), body.since.clone(), @@ -104,7 +106,7 @@ pub async fn sync_events_route( } pub async fn sync_helper_wrapper( - db: Arc, + db: Arc, sender_user: UserId, sender_device: Box, since: Option, @@ -146,7 +148,7 @@ pub async fn sync_helper_wrapper( } async fn sync_helper( - db: Arc, + db: Arc, sender_user: UserId, sender_device: Box, since: Option, diff --git a/src/client_server/tag.rs b/src/client_server/tag.rs index 2382fe0a..cc0d487a 100644 --- a/src/client_server/tag.rs +++ b/src/client_server/tag.rs @@ -1,5 +1,5 @@ use super::State; -use crate::{ConduitResult, Database, Ruma}; +use crate::{database::ReadGuard, ConduitResult, Database, Ruma}; use ruma::{ api::client::r0::tag::{create_tag, delete_tag, get_tags}, events::EventType, @@ -15,7 +15,7 @@ use rocket::{delete, get, put}; )] #[tracing::instrument(skip(db, body))] pub async fn update_tag_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -52,7 +52,7 @@ pub async fn update_tag_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_tag_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -86,7 +86,7 @@ pub async fn delete_tag_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_tags_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/to_device.rs b/src/client_server/to_device.rs index ada0c9a1..2814a9d8 100644 --- a/src/client_server/to_device.rs +++ b/src/client_server/to_device.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use super::State; -use crate::{ConduitResult, Database, Error, Ruma}; +use crate::{database::ReadGuard, ConduitResult, Database, Error, Ruma}; use ruma::{ api::client::{error::ErrorKind, r0::to_device::send_event_to_device}, to_device::DeviceIdOrAllDevices, @@ -16,7 +16,7 @@ use rocket::put; )] #[tracing::instrument(skip(db, body))] pub async fn send_event_to_device_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/typing.rs b/src/client_server/typing.rs index a0a5d430..f39ef374 100644 --- a/src/client_server/typing.rs +++ b/src/client_server/typing.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use super::State; -use crate::{utils, ConduitResult, Database, Ruma}; +use crate::{database::ReadGuard, utils, ConduitResult, Database, Ruma}; use create_typing_event::Typing; use ruma::api::client::r0::typing::create_typing_event; @@ -14,7 +14,7 @@ use rocket::put; )] #[tracing::instrument(skip(db, body))] pub fn create_typing_event_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/user_directory.rs b/src/client_server/user_directory.rs index d7c16d7c..ce382b00 100644 --- a/src/client_server/user_directory.rs +++ b/src/client_server/user_directory.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use super::State; -use crate::{ConduitResult, Database, Ruma}; +use crate::{database::ReadGuard, ConduitResult, Database, Ruma}; use ruma::api::client::r0::user_directory::search_users; #[cfg(feature = "conduit_bin")] @@ -13,7 +13,7 @@ use rocket::post; )] #[tracing::instrument(skip(db, body))] pub async fn search_users_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let limit = u64::from(body.limit) as usize; diff --git a/src/database.rs b/src/database.rs index de4f4412..fda8a955 100644 --- a/src/database.rs +++ b/src/database.rs @@ -19,16 +19,22 @@ use abstraction::DatabaseEngine; use directories::ProjectDirs; use log::error; use lru_cache::LruCache; -use rocket::futures::{channel::mpsc, stream::FuturesUnordered, StreamExt}; +use rocket::{ + futures::{channel::mpsc, stream::FuturesUnordered, StreamExt}, + outcome::IntoOutcome, + request::{FromRequest, Request}, + try_outcome, State, +}; use ruma::{DeviceId, ServerName, UserId}; use serde::Deserialize; use std::{ collections::HashMap, fs::{self, remove_dir_all}, io::Write, + ops::Deref, sync::{Arc, RwLock}, }; -use tokio::sync::Semaphore; +use tokio::sync::{OwnedRwLockReadGuard, RwLock as TokioRwLock, RwLockReadGuard, Semaphore}; use self::proxy::ProxyConfig; @@ -122,7 +128,7 @@ impl Database { } /// Load an existing database or create a new one. - pub async fn load_or_create(config: Config) -> Result> { + pub async fn load_or_create(config: Config) -> Result>> { let builder = Engine::open(&config)?; if config.max_request_size < 1024 { @@ -132,7 +138,7 @@ impl Database { let (admin_sender, admin_receiver) = mpsc::unbounded(); let (sending_sender, sending_receiver) = mpsc::unbounded(); - let db = Arc::new(Self { + let db = Arc::new(TokioRwLock::from(Self { _db: builder.clone(), users: users::Users { userid_password: builder.open_tree("userid_password")?, @@ -238,98 +244,105 @@ impl Database { builder.open_tree("server_signingkeys")?, config, )?, - }); + })); - // MIGRATIONS - // TODO: database versions of new dbs should probably not be 0 - if db.globals.database_version()? < 1 { - for (roomserverid, _) in db.rooms.roomserverids.iter() { - let mut parts = roomserverid.split(|&b| b == 0xff); - let room_id = parts.next().expect("split always returns one element"); - let servername = match parts.next() { - Some(s) => s, - None => { - error!("Migration: Invalid roomserverid in db."); + { + let db = db.read().await; + // MIGRATIONS + // TODO: database versions of new dbs should probably not be 0 + if db.globals.database_version()? < 1 { + for (roomserverid, _) in db.rooms.roomserverids.iter() { + let mut parts = roomserverid.split(|&b| b == 0xff); + let room_id = parts.next().expect("split always returns one element"); + let servername = match parts.next() { + Some(s) => s, + None => { + error!("Migration: Invalid roomserverid in db."); + continue; + } + }; + let mut serverroomid = servername.to_vec(); + serverroomid.push(0xff); + serverroomid.extend_from_slice(room_id); + + db.rooms.serverroomids.insert(&serverroomid, &[])?; + } + + db.globals.bump_database_version(1)?; + + println!("Migration: 0 -> 1 finished"); + } + + if db.globals.database_version()? < 2 { + // We accidentally inserted hashed versions of "" into the db instead of just "" + for (userid, password) in db.users.userid_password.iter() { + let password = utils::string_from_bytes(&password); + + let empty_hashed_password = password.map_or(false, |password| { + argon2::verify_encoded(&password, b"").unwrap_or(false) + }); + + if empty_hashed_password { + db.users.userid_password.insert(&userid, b"")?; + } + } + + db.globals.bump_database_version(2)?; + + println!("Migration: 1 -> 2 finished"); + } + + if db.globals.database_version()? < 3 { + // Move media to filesystem + for (key, content) in db.media.mediaid_file.iter() { + if content.len() == 0 { continue; } - }; - let mut serverroomid = servername.to_vec(); - serverroomid.push(0xff); - serverroomid.extend_from_slice(room_id); - db.rooms.serverroomids.insert(&serverroomid, &[])?; - } - - db.globals.bump_database_version(1)?; - - println!("Migration: 0 -> 1 finished"); - } - - if db.globals.database_version()? < 2 { - // We accidentally inserted hashed versions of "" into the db instead of just "" - for (userid, password) in db.users.userid_password.iter() { - let password = utils::string_from_bytes(&password); - - let empty_hashed_password = password.map_or(false, |password| { - argon2::verify_encoded(&password, b"").unwrap_or(false) - }); - - if empty_hashed_password { - db.users.userid_password.insert(&userid, b"")?; - } - } - - db.globals.bump_database_version(2)?; - - println!("Migration: 1 -> 2 finished"); - } - - if db.globals.database_version()? < 3 { - // Move media to filesystem - for (key, content) in db.media.mediaid_file.iter() { - if content.len() == 0 { - continue; + let path = db.globals.get_media_file(&key); + let mut file = fs::File::create(path)?; + file.write_all(&content)?; + db.media.mediaid_file.insert(&key, &[])?; } - let path = db.globals.get_media_file(&key); - let mut file = fs::File::create(path)?; - file.write_all(&content)?; - db.media.mediaid_file.insert(&key, &[])?; + db.globals.bump_database_version(3)?; + + println!("Migration: 2 -> 3 finished"); } - db.globals.bump_database_version(3)?; - - println!("Migration: 2 -> 3 finished"); - } - - if db.globals.database_version()? < 4 { - // Add federated users to db as deactivated - for our_user in db.users.iter() { - let our_user = our_user?; - if db.users.is_deactivated(&our_user)? { - continue; - } - for room in db.rooms.rooms_joined(&our_user) { - for user in db.rooms.room_members(&room?) { - let user = user?; - if user.server_name() != db.globals.server_name() { - println!("Migration: Creating user {}", user); - db.users.create(&user, None)?; + if db.globals.database_version()? < 4 { + // Add federated users to db as deactivated + for our_user in db.users.iter() { + let our_user = our_user?; + if db.users.is_deactivated(&our_user)? { + continue; + } + for room in db.rooms.rooms_joined(&our_user) { + for user in db.rooms.room_members(&room?) { + let user = user?; + if user.server_name() != db.globals.server_name() { + println!("Migration: Creating user {}", user); + db.users.create(&user, None)?; + } } } } + + db.globals.bump_database_version(4)?; + + println!("Migration: 3 -> 4 finished"); } - - db.globals.bump_database_version(4)?; - - println!("Migration: 3 -> 4 finished"); } - // This data is probably outdated - db.rooms.edus.presenceid_presence.clear()?; + let guard = db.read().await; - db.admin.start_handler(Arc::clone(&db), admin_receiver); - db.sending.start_handler(Arc::clone(&db), sending_receiver); + // This data is probably outdated + guard.rooms.edus.presenceid_presence.clear()?; + + guard.admin.start_handler(Arc::clone(&db), admin_receiver); + guard.sending.start_handler(Arc::clone(&db), sending_receiver); + + drop(guard); Ok(db) } @@ -431,4 +444,30 @@ impl Database { res } + + pub fn flush_wal(&self) -> Result<()> { + self._db.flush_wal() + } +} + +pub struct ReadGuard(OwnedRwLockReadGuard); + +impl Deref for ReadGuard { + type Target = OwnedRwLockReadGuard; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[cfg(feature = "conduit_bin")] +#[rocket::async_trait] +impl<'r> FromRequest<'r> for ReadGuard { + type Error = (); + + async fn from_request(req: &'r Request<'_>) -> rocket::request::Outcome { + let db = try_outcome!(req.guard::>>>().await); + + Ok(ReadGuard(Arc::clone(&db).read_owned().await)).or_forward(()) + } } diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index be5ce6c2..164d9858 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -3,7 +3,7 @@ use std::{ ops::Deref, path::{Path, PathBuf}, pin::Pin, - sync::{Arc, Weak}, + sync::Arc, thread, time::{Duration, Instant}, }; @@ -36,7 +36,6 @@ use tokio::sync::oneshot::Sender; struct Pool { writer: Mutex, readers: Vec>, - reader_rwlock: RwLock<()>, path: PathBuf, } @@ -71,7 +70,6 @@ impl Pool { Ok(Self { writer, readers, - reader_rwlock: RwLock::new(()), path: path.as_ref().to_path_buf(), }) } @@ -95,16 +93,12 @@ impl Pool { } fn read_lock(&self) -> HoldingConn<'_> { - let _guard = self.reader_rwlock.read(); - for r in &self.readers { if let Some(reader) = r.try_lock() { return HoldingConn::FromGuard(reader); } } - drop(_guard); - log::warn!("all readers locked, creating spillover reader..."); let spilled = Self::prepare_conn(&self.path).unwrap(); @@ -115,7 +109,6 @@ impl Pool { pub struct SqliteEngine { pool: Pool, - iterator_lock: RwLock<()>, } impl DatabaseEngine for SqliteEngine { @@ -128,36 +121,7 @@ impl DatabaseEngine for SqliteEngine { pool.write_lock() .execute("CREATE TABLE IF NOT EXISTS _noop (\"key\" INT)", params![])?; - let arc = Arc::new(SqliteEngine { - pool, - iterator_lock: RwLock::new(()), - }); - - let weak: Weak = Arc::downgrade(&arc); - - thread::spawn(move || { - let r = crossbeam::channel::tick(Duration::from_secs(60)); - - let weak = weak; - - loop { - let _ = r.recv(); - - if let Some(arc) = Weak::upgrade(&weak) { - log::warn!("wal-trunc: locking..."); - let iterator_guard = arc.iterator_lock.write(); - let read_guard = arc.pool.reader_rwlock.write(); - log::warn!("wal-trunc: locked, flushing..."); - let start = Instant::now(); - arc.flush_wal().unwrap(); - log::warn!("wal-trunc: locked, flushed in {:?}", start.elapsed()); - drop(read_guard); - drop(iterator_guard); - } else { - break; - } - } - }); + let arc = Arc::new(SqliteEngine { pool }); Ok(arc) } @@ -190,7 +154,7 @@ impl DatabaseEngine for SqliteEngine { } impl SqliteEngine { - fn flush_wal(self: &Arc) -> Result<()> { + pub fn flush_wal(self: &Arc) -> Result<()> { self.pool .write_lock() .execute_batch( @@ -244,9 +208,7 @@ impl SqliteTable { let engine = self.engine.clone(); thread::spawn(move || { - let guard = engine.iterator_lock.read(); let _ = f(&engine.pool.read_lock(), s); - drop(guard); }); Box::new(r.into_iter()) diff --git a/src/database/admin.rs b/src/database/admin.rs index 7826cfea..f79b789a 100644 --- a/src/database/admin.rs +++ b/src/database/admin.rs @@ -10,6 +10,7 @@ use ruma::{ events::{room::message, EventType}, UserId, }; +use tokio::sync::RwLock; pub enum AdminCommand { RegisterAppservice(serde_yaml::Value), @@ -25,20 +26,22 @@ pub struct Admin { impl Admin { pub fn start_handler( &self, - db: Arc, + db: Arc>, mut receiver: mpsc::UnboundedReceiver, ) { tokio::spawn(async move { // TODO: Use futures when we have long admin commands //let mut futures = FuturesUnordered::new(); - let conduit_user = UserId::try_from(format!("@conduit:{}", db.globals.server_name())) + let guard = db.read().await; + + let conduit_user = UserId::try_from(format!("@conduit:{}", guard.globals.server_name())) .expect("@conduit:server_name is valid"); - let conduit_room = db + let conduit_room = guard .rooms .id_from_alias( - &format!("#admins:{}", db.globals.server_name()) + &format!("#admins:{}", guard.globals.server_name()) .try_into() .expect("#admins:server_name is a valid room alias"), ) @@ -48,24 +51,10 @@ impl Admin { warn!("Conduit instance does not have an #admins room. Logging to that room will not work. Restart Conduit after creating a user to fix this."); } + drop(guard); + let send_message = |message: message::MessageEventContent| { - if let Some(conduit_room) = &conduit_room { - db.rooms - .build_and_append_pdu( - PduBuilder { - event_type: EventType::RoomMessage, - content: serde_json::to_value(message) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: None, - redacts: None, - }, - &conduit_user, - &conduit_room, - &db, - ) - .unwrap(); - } + }; loop { @@ -73,10 +62,10 @@ impl Admin { Some(event) = receiver.next() => { match event { AdminCommand::RegisterAppservice(yaml) => { - db.appservice.register_appservice(yaml).unwrap(); // TODO handle error + db.read().await.appservice.register_appservice(yaml).unwrap(); // TODO handle error } AdminCommand::ListAppservices => { - if let Ok(appservices) = db.appservice.iter_ids().map(|ids| ids.collect::>()) { + if let Ok(appservices) = db.read().await.appservice.iter_ids().map(|ids| ids.collect::>()) { let count = appservices.len(); let output = format!( "Appservices ({}): {}", @@ -89,7 +78,24 @@ impl Admin { } } AdminCommand::SendMessage(message) => { - send_message(message); + if let Some(conduit_room) = &conduit_room { + let guard = db.read().await; + guard.rooms + .build_and_append_pdu( + PduBuilder { + event_type: EventType::RoomMessage, + content: serde_json::to_value(message) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: None, + redacts: None, + }, + &conduit_user, + &conduit_room, + &guard, + ) + .unwrap(); + } } } } diff --git a/src/database/sending.rs b/src/database/sending.rs index c2e13976..102fb15f 100644 --- a/src/database/sending.rs +++ b/src/database/sending.rs @@ -30,7 +30,7 @@ use ruma::{ receipt::ReceiptType, MilliSecondsSinceUnixEpoch, ServerName, UInt, UserId, }; -use tokio::{select, sync::Semaphore}; +use tokio::{select, sync::{Semaphore, RwLock}}; use super::abstraction::Tree; @@ -90,7 +90,7 @@ enum TransactionStatus { } impl Sending { - pub fn start_handler(&self, db: Arc, mut receiver: mpsc::UnboundedReceiver>) { + pub fn start_handler(&self, db: Arc>, mut receiver: mpsc::UnboundedReceiver>) { tokio::spawn(async move { let mut futures = FuturesUnordered::new(); @@ -98,8 +98,11 @@ impl Sending { // Retry requests we could not finish yet let mut initial_transactions = HashMap::>::new(); + + let guard = db.read().await; + for (key, outgoing_kind, event) in - db.sending + guard.sending .servercurrentevents .iter() .filter_map(|(key, _)| { @@ -117,17 +120,19 @@ impl Sending { "Dropping some current events: {:?} {:?} {:?}", key, outgoing_kind, event ); - db.sending.servercurrentevents.remove(&key).unwrap(); + guard.sending.servercurrentevents.remove(&key).unwrap(); continue; } entry.push(event); } + drop(guard); + for (outgoing_kind, events) in initial_transactions { current_transaction_status .insert(outgoing_kind.get_prefix(), TransactionStatus::Running); - futures.push(Self::handle_events(outgoing_kind.clone(), events, &db)); + futures.push(Self::handle_events(outgoing_kind.clone(), events, Arc::clone(&db))); } loop { @@ -135,15 +140,17 @@ impl Sending { Some(response) = futures.next() => { match response { Ok(outgoing_kind) => { + let guard = db.read().await; + let prefix = outgoing_kind.get_prefix(); - for (key, _) in db.sending.servercurrentevents + for (key, _) in guard.sending.servercurrentevents .scan_prefix(prefix.clone()) { - db.sending.servercurrentevents.remove(&key).unwrap(); + guard.sending.servercurrentevents.remove(&key).unwrap(); } // Find events that have been added since starting the last request - let new_events = db.sending.servernamepduids + let new_events = guard.sending.servernamepduids .scan_prefix(prefix.clone()) .map(|(k, _)| { SendingEventType::Pdu(k[prefix.len()..].to_vec()) @@ -161,17 +168,19 @@ impl Sending { SendingEventType::Pdu(b) | SendingEventType::Edu(b) => { current_key.extend_from_slice(&b); - db.sending.servercurrentevents.insert(¤t_key, &[]).unwrap(); - db.sending.servernamepduids.remove(¤t_key).unwrap(); + guard.sending.servercurrentevents.insert(¤t_key, &[]).unwrap(); + guard.sending.servernamepduids.remove(¤t_key).unwrap(); } } } + drop(guard); + futures.push( Self::handle_events( outgoing_kind.clone(), new_events, - &db, + Arc::clone(&db), ) ); } else { @@ -192,13 +201,15 @@ impl Sending { }, Some(key) = receiver.next() => { if let Ok((outgoing_kind, event)) = Self::parse_servercurrentevent(&key) { + let guard = db.read().await; + if let Ok(Some(events)) = Self::select_events( &outgoing_kind, vec![(event, key)], &mut current_transaction_status, - &db + &guard ) { - futures.push(Self::handle_events(outgoing_kind, events, &db)); + futures.push(Self::handle_events(outgoing_kind, events, Arc::clone(&db))); } } } @@ -403,8 +414,10 @@ impl Sending { async fn handle_events( kind: OutgoingKind, events: Vec, - db: &Database, + db: Arc>, ) -> std::result::Result { + let db = db.read().await; + match &kind { OutgoingKind::Appservice(server) => { let mut pdu_jsons = Vec::new(); @@ -543,7 +556,7 @@ impl Sending { &pusher, rules_for_user, &pdu, - db, + &db, ) .await .map(|_response| kind.clone()) diff --git a/src/lib.rs b/src/lib.rs index 50ca6ea6..c1f0c4b0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,7 +15,8 @@ pub use error::{Error, Result}; pub use pdu::PduEvent; pub use rocket::Config; pub use ruma_wrapper::{ConduitResult, Ruma, RumaResponse}; -use std::ops::Deref; +use std::{ops::Deref, sync::Arc}; +use tokio::sync::RwLock; pub struct State<'r, T: Send + Sync + 'static>(pub &'r T); diff --git a/src/main.rs b/src/main.rs index 99d45607..826f182e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,7 +12,7 @@ mod pdu; mod ruma_wrapper; mod utils; -use std::sync::Arc; +use std::{sync::{Arc, Weak}, time::{Duration, Instant}}; use database::Config; pub use database::Database; @@ -30,10 +30,14 @@ use rocket::{ }, routes, Request, }; +use tokio::{ + sync::RwLock, + time::{interval, Interval}, +}; use tracing::span; use tracing_subscriber::{prelude::*, Registry}; -fn setup_rocket(config: Figment, data: Arc) -> rocket::Rocket { +fn setup_rocket(config: Figment, data: Arc>) -> rocket::Rocket { rocket::custom(config) .manage(data) .mount( @@ -201,6 +205,31 @@ async fn main() { .await .expect("config is valid"); + { + let weak: Weak> = Arc::downgrade(&db); + + tokio::spawn(async { + let weak = weak; + + let mut i = interval(Duration::from_secs(10)); + + loop { + i.tick().await; + + if let Some(arc) = Weak::upgrade(&weak) { + log::warn!("wal-trunc: locking..."); + let guard = arc.write().await; + log::warn!("wal-trunc: locked, flushing..."); + let start = Instant::now(); + guard.flush_wal(); + log::warn!("wal-trunc: locked, flushed in {:?}", start.elapsed()); + } else { + break; + } + } + }); + } + if config.allow_jaeger { let (tracer, _uninstall) = opentelemetry_jaeger::new_pipeline() .with_service_name("conduit") diff --git a/src/ruma_wrapper.rs b/src/ruma_wrapper.rs index 8c22f79b..647d25e6 100644 --- a/src/ruma_wrapper.rs +++ b/src/ruma_wrapper.rs @@ -1,4 +1,4 @@ -use crate::Error; +use crate::{database::ReadGuard, Error}; use ruma::{ api::{client::r0::uiaa::UiaaResponse, OutgoingResponse}, identifiers::{DeviceId, UserId}, @@ -49,7 +49,7 @@ where async fn from_data(request: &'a Request<'_>, data: Data) -> data::Outcome { let metadata = T::Incoming::METADATA; let db = request - .guard::>>() + .guard::() .await .expect("database was loaded"); diff --git a/src/server_server.rs b/src/server_server.rs index 2bcfd2b3..b3e8f1b6 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -1,5 +1,6 @@ use crate::{ client_server::{self, claim_keys_helper, get_keys_helper}, + database::ReadGuard, utils, ConduitResult, Database, Error, PduEvent, Result, Ruma, }; use get_profile_information::v1::ProfileField; @@ -431,9 +432,7 @@ pub async fn request_well_known( #[cfg_attr(feature = "conduit_bin", get("/_matrix/federation/v1/version"))] #[tracing::instrument(skip(db))] -pub fn get_server_version_route( - db: State<'_, Arc>, -) -> ConduitResult { +pub fn get_server_version_route(db: ReadGuard) -> ConduitResult { if !db.globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } @@ -450,7 +449,7 @@ pub fn get_server_version_route( // Response type for this endpoint is Json because we need to calculate a signature for the response #[cfg_attr(feature = "conduit_bin", get("/_matrix/key/v2/server"))] #[tracing::instrument(skip(db))] -pub fn get_server_keys_route(db: State<'_, Arc>) -> Json { +pub fn get_server_keys_route(db: ReadGuard) -> Json { if !db.globals.allow_federation() { // TODO: Use proper types return Json("Federation is disabled.".to_owned()); @@ -497,7 +496,7 @@ pub fn get_server_keys_route(db: State<'_, Arc>) -> Json { #[cfg_attr(feature = "conduit_bin", get("/_matrix/key/v2/server/<_>"))] #[tracing::instrument(skip(db))] -pub fn get_server_keys_deprecated_route(db: State<'_, Arc>) -> Json { +pub fn get_server_keys_deprecated_route(db: ReadGuard) -> Json { get_server_keys_route(db) } @@ -507,7 +506,7 @@ pub fn get_server_keys_deprecated_route(db: State<'_, Arc>) -> Json>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -551,7 +550,7 @@ pub async fn get_public_rooms_filtered_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_public_rooms_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -595,7 +594,7 @@ pub async fn get_public_rooms_route( )] #[tracing::instrument(skip(db, body))] pub async fn send_transaction_message_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -1674,7 +1673,7 @@ pub(crate) fn append_incoming_pdu( )] #[tracing::instrument(skip(db, body))] pub fn get_event_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -1699,7 +1698,7 @@ pub fn get_event_route( )] #[tracing::instrument(skip(db, body))] pub fn get_missing_events_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -1748,7 +1747,7 @@ pub fn get_missing_events_route( )] #[tracing::instrument(skip(db, body))] pub fn get_event_authorization_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -1792,7 +1791,7 @@ pub fn get_event_authorization_route( )] #[tracing::instrument(skip(db, body))] pub fn get_room_state_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -1855,7 +1854,7 @@ pub fn get_room_state_route( )] #[tracing::instrument(skip(db, body))] pub fn get_room_state_ids_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -1907,7 +1906,7 @@ pub fn get_room_state_ids_route( )] #[tracing::instrument(skip(db, body))] pub fn create_join_event_template_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -2076,7 +2075,7 @@ pub fn create_join_event_template_route( )] #[tracing::instrument(skip(db, body))] pub async fn create_join_event_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -2184,7 +2183,7 @@ pub async fn create_join_event_route( )] #[tracing::instrument(skip(db, body))] pub async fn create_invite_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -2289,7 +2288,7 @@ pub async fn create_invite_route( )] #[tracing::instrument(skip(db, body))] pub fn get_devices_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -2329,7 +2328,7 @@ pub fn get_devices_route( )] #[tracing::instrument(skip(db, body))] pub fn get_room_information_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -2357,7 +2356,7 @@ pub fn get_room_information_route( )] #[tracing::instrument(skip(db, body))] pub fn get_profile_information_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -2391,7 +2390,7 @@ pub fn get_profile_information_route( )] #[tracing::instrument(skip(db, body))] pub fn get_keys_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -2419,7 +2418,7 @@ pub fn get_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn claim_keys_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma, ) -> ConduitResult { if !db.globals.allow_federation() {