From 60f2471f59a246e71b4785ff2c5d0fb348206e94 Mon Sep 17 00:00:00 2001 From: Matthias Ahouansou Date: Fri, 22 Mar 2024 19:21:51 -0400 Subject: [PATCH] refactor appservice type stuff Signed-off-by: strawberry --- docs/differences.md | 1 - src/api/appservice_server.rs | 139 +++++++++++++-------------- src/api/client_server/alias.rs | 2 +- src/api/ruma_wrapper/axum.rs | 14 +-- src/database/key_value/appservice.rs | 22 ++--- src/database/mod.rs | 11 --- src/service/admin/mod.rs | 21 ++-- src/service/appservice/mod.rs | 39 ++++++-- src/service/mod.rs | 5 +- src/service/rooms/timeline/mod.rs | 2 +- src/service/sending/mod.rs | 2 +- 11 files changed, 125 insertions(+), 133 deletions(-) diff --git a/docs/differences.md b/docs/differences.md index 306c542e..fb239b98 100644 --- a/docs/differences.md +++ b/docs/differences.md @@ -58,7 +58,6 @@ - Basic validation/checks on user-specified room aliases and custom room ID creations - Warn on unknown config options specified - Add support for preventing certain room alias names and usernames using regex (via upstream MR) and extended to custom room IDs -- Revamp appservice registration to ruma's `Registration` type which fixes various appservice registration issues, including fixing crashing upon no URL specified (via upstream MR) - URL preview support (via upstream MR) with various improvements - Increased graceful shutdown timeout from a low 60 seconds to 180 seconds to avoid killing connections and let the remaining ones finish processing, and ask systemd for more time to shutdown if needed to prevent systemd's default [`TimeoutStopSec=`](https://www.freedesktop.org/software/systemd/man/latest/systemd.service.html#TimeoutStopSec=) of 90 seconds from killing conduwuit - Bumped default max_concurrent_requests to 500 diff --git a/src/api/appservice_server.rs b/src/api/appservice_server.rs index 217acf32..6b3a4152 100644 --- a/src/api/appservice_server.rs +++ b/src/api/appservice_server.rs @@ -14,81 +14,76 @@ pub(crate) async fn send_request(registration: Registration, request: T) -> O where T: OutgoingRequest + Debug, { - if let Some(destination) = registration.url { - let hs_token = registration.hs_token.as_str(); + let destination = registration.url?; - let mut http_request = request - .try_into_http_request::( - &destination, - SendAccessToken::IfRequired(hs_token), - &[MatrixVersion::V1_0], - ) - .map_err(|e| { - warn!("Failed to find destination {}: {}", destination, e); - Error::BadServerResponse("Invalid destination") - }) - .unwrap() - .map(BytesMut::freeze); + let hs_token = registration.hs_token.as_str(); - let mut parts = http_request.uri().clone().into_parts(); - let old_path_and_query = parts.path_and_query.unwrap().as_str().to_owned(); - let symbol = if old_path_and_query.contains('?') { - "&" - } else { - "?" - }; + let mut http_request = request + .try_into_http_request::(&destination, SendAccessToken::IfRequired(hs_token), &[MatrixVersion::V1_0]) + .map_err(|e| { + warn!("Failed to find destination {}: {}", destination, e); + Error::BadServerResponse("Invalid destination") + }) + .unwrap() + .map(BytesMut::freeze); - parts.path_and_query = Some((old_path_and_query + symbol + "access_token=" + hs_token).parse().unwrap()); - *http_request.uri_mut() = parts.try_into().expect("our manipulation is always valid"); - - let mut reqwest_request = - reqwest::Request::try_from(http_request).expect("all http requests are valid reqwest requests"); - - *reqwest_request.timeout_mut() = Some(Duration::from_secs(120)); - - let url = reqwest_request.url().clone(); - let mut response = match services().globals.client.appservice.execute(reqwest_request).await { - Ok(r) => r, - Err(e) => { - warn!( - "Could not send request to appservice {} at {}: {}", - registration.id, destination, e - ); - return Some(Err(e.into())); - }, - }; - - // reqwest::Response -> http::Response conversion - let status = response.status(); - let mut http_response_builder = http::Response::builder().status(status).version(response.version()); - mem::swap( - response.headers_mut(), - http_response_builder.headers_mut().expect("http::response::Builder is usable"), - ); - - let body = response.bytes().await.unwrap_or_else(|e| { - warn!("server error: {}", e); - Vec::new().into() - }); // TODO: handle timeout - - if !status.is_success() { - warn!( - "Appservice returned bad response {} {}\n{}\n{:?}", - destination, - status, - url, - utils::string_from_bytes(&body) - ); - } - - let response = T::IncomingResponse::try_from_http_response( - http_response_builder.body(body).expect("reqwest body is valid http body"), - ); - Some(response.map_err(|_| { - warn!("Appservice returned invalid response bytes {}\n{}", destination, url); - Error::BadServerResponse("Server returned bad response.") - })) + let mut parts = http_request.uri().clone().into_parts(); + let old_path_and_query = parts.path_and_query.unwrap().as_str().to_owned(); + let symbol = if old_path_and_query.contains('?') { + "&" } else { - None + "?" + }; + + parts.path_and_query = Some((old_path_and_query + symbol + "access_token=" + hs_token).parse().unwrap()); + *http_request.uri_mut() = parts.try_into().expect("our manipulation is always valid"); + + let mut reqwest_request = + reqwest::Request::try_from(http_request).expect("all http requests are valid reqwest requests"); + + *reqwest_request.timeout_mut() = Some(Duration::from_secs(120)); + + let url = reqwest_request.url().clone(); + let mut response = match services().globals.client.appservice.execute(reqwest_request).await { + Ok(r) => r, + Err(e) => { + warn!( + "Could not send request to appservice {} at {}: {}", + registration.id, destination, e + ); + return Some(Err(e.into())); + }, + }; + + // reqwest::Response -> http::Response conversion + let status = response.status(); + let mut http_response_builder = http::Response::builder().status(status).version(response.version()); + mem::swap( + response.headers_mut(), + http_response_builder.headers_mut().expect("http::response::Builder is usable"), + ); + + let body = response.bytes().await.unwrap_or_else(|e| { + warn!("server error: {}", e); + Vec::new().into() + }); // TODO: handle timeout + + if !status.is_success() { + warn!( + "Appservice returned bad response {} {}\n{}\n{:?}", + destination, + status, + url, + utils::string_from_bytes(&body) + ); } + + let response = T::IncomingResponse::try_from_http_response( + http_response_builder.body(body).expect("reqwest body is valid http body"), + ); + + Some(response.map_err(|_| { + warn!("Appservice returned invalid response bytes {}\n{}", destination, url); + Error::BadServerResponse("Server returned bad response.") + })) } diff --git a/src/api/client_server/alias.rs b/src/api/client_server/alias.rs index dc37c3a7..477a5e73 100644 --- a/src/api/client_server/alias.rs +++ b/src/api/client_server/alias.rs @@ -115,7 +115,7 @@ pub(crate) async fn get_alias_helper(room_alias: OwnedRoomAliasId) -> Result room_id = Some(r), None => { - for appservice in services().appservice.registration_info.read().await.values() { + for appservice in services().appservice.read().await.values() { if appservice.aliases.is_match(room_alias.as_str()) && if let Some(opt_result) = services() .sending diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs index 9040ae0d..0a571105 100644 --- a/src/api/ruma_wrapper/axum.rs +++ b/src/api/ruma_wrapper/axum.rs @@ -76,11 +76,13 @@ where let mut json_body = serde_json::from_slice::(&body).ok(); - let appservices = services().appservice.all().unwrap(); - let appservice_registration = - appservices.iter().find(|(_id, registration)| Some(registration.as_token.as_str()) == token); + let appservice_registration = if let Some(token) = token { + services().appservice.find_from_token(token).await + } else { + None + }; - let (sender_user, sender_device, sender_servername, from_appservice) = if let Some((_id, registration)) = + let (sender_user, sender_device, sender_servername, from_appservice) = if let Some(info) = appservice_registration { match metadata.authentication { @@ -88,7 +90,7 @@ where let user_id = query_params.user_id.map_or_else( || { UserId::parse_with_server_name( - registration.sender_localpart.as_str(), + info.registration.sender_localpart.as_str(), services().globals.server_name(), ) .unwrap() @@ -109,7 +111,7 @@ where let user_id = query_params.user_id.map_or_else( || { UserId::parse_with_server_name( - registration.sender_localpart.as_str(), + info.registration.sender_localpart.as_str(), services().globals.server_name(), ) .unwrap() diff --git a/src/database/key_value/appservice.rs b/src/database/key_value/appservice.rs index 598fb9cd..0379504f 100644 --- a/src/database/key_value/appservice.rs +++ b/src/database/key_value/appservice.rs @@ -7,7 +7,6 @@ impl service::appservice::Data for KeyValueDatabase { fn register_appservice(&self, yaml: Registration) -> Result { let id = yaml.id.as_str(); self.id_appserviceregistrations.insert(id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes())?; - self.cached_registrations.write().unwrap().insert(id.to_owned(), yaml.clone()); Ok(id.to_owned()) } @@ -19,24 +18,17 @@ impl service::appservice::Data for KeyValueDatabase { /// * `service_name` - the name you send to register the service previously fn unregister_appservice(&self, service_name: &str) -> Result<()> { self.id_appserviceregistrations.remove(service_name.as_bytes())?; - self.cached_registrations.write().unwrap().remove(service_name); Ok(()) } fn get_registration(&self, id: &str) -> Result> { - self.cached_registrations.read().unwrap().get(id).map_or_else( - || { - self.id_appserviceregistrations - .get(id.as_bytes())? - .map(|bytes| { - serde_yaml::from_slice(&bytes).map_err(|_| { - Error::bad_database("Invalid registration bytes in id_appserviceregistrations.") - }) - }) - .transpose() - }, - |r| Ok(Some(r.clone())), - ) + self.id_appserviceregistrations + .get(id.as_bytes())? + .map(|bytes| { + serde_yaml::from_slice(&bytes) + .map_err(|_| Error::bad_database("Invalid registration bytes in id_appserviceregistrations.")) + }) + .transpose() } fn iter_ids<'a>(&'a self) -> Result> + 'a>> { diff --git a/src/database/mod.rs b/src/database/mod.rs index 0822fb80..88641935 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -17,7 +17,6 @@ use itertools::Itertools; use lru_cache::LruCache; use rand::thread_rng; use ruma::{ - api::appservice::Registration, events::{ push_rules::{PushRulesEvent, PushRulesEventContent}, room::message::RoomMessageEventContent, @@ -179,7 +178,6 @@ pub struct KeyValueDatabase { //pub pusher: pusher::PushData, pub(super) senderkey_pusher: Arc, - pub(super) cached_registrations: Arc>>, pub(super) pdu_cache: Mutex>>, pub(super) shorteventid_cache: Mutex>>, pub(super) auth_chain_cache: Mutex, Arc>>>, @@ -379,7 +377,6 @@ impl KeyValueDatabase { global: builder.open_tree("global")?, server_signingkeys: builder.open_tree("server_signingkeys")?, - cached_registrations: Arc::new(RwLock::new(HashMap::new())), pdu_cache: Mutex::new(LruCache::new( config.pdu_cache_capacity.try_into().expect("pdu cache capacity fits into usize"), )), @@ -992,14 +989,6 @@ impl KeyValueDatabase { ); } - // Inserting registrations into cache - for appservice in services().appservice.all()? { - services().appservice.registration_info.write().await.insert( - appservice.0, - appservice.1.try_into().expect("Should be validated on registration"), - ); - } - services().admin.start_handler(); // Set emergency access for the conduit user diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index 1b30a33c..f760c0a2 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -623,8 +623,8 @@ impl Service { }, AppserviceCommand::Show { appservice_identifier, - } => match services().appservice.get_registration(&appservice_identifier) { - Ok(Some(config)) => { + } => match services().appservice.get_registration(&appservice_identifier).await { + Some(config) => { let config_str = serde_yaml::to_string(&config).expect("config should've been validated on register"); let output = format!("Config for {}:\n\n```yaml\n{}\n```", appservice_identifier, config_str,); @@ -635,21 +635,12 @@ impl Service { ); RoomMessageEventContent::text_html(output, output_html) }, - Ok(None) => RoomMessageEventContent::text_plain("Appservice does not exist."), - Err(_) => RoomMessageEventContent::text_plain("Failed to get appservice."), + None => RoomMessageEventContent::text_plain("Appservice does not exist."), }, AppserviceCommand::List => { - if let Ok(appservices) = services().appservice.iter_ids().map(Iterator::collect::>) { - let count = appservices.len(); - let output = format!( - "Appservices ({}): {}", - count, - appservices.into_iter().filter_map(std::result::Result::ok).collect::>().join(", ") - ); - RoomMessageEventContent::text_plain(output) - } else { - RoomMessageEventContent::text_plain("Failed to get appservices.") - } + let appservices = services().appservice.iter_ids().await; + let output = format!("Appservices ({}): {}", appservices.len(), appservices.join(", ")); + RoomMessageEventContent::text_plain(output) }, }, AdminCommand::Media(command) => { diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs index 5884e4a8..d0a37961 100644 --- a/src/service/appservice/mod.rs +++ b/src/service/appservice/mod.rs @@ -1,8 +1,9 @@ mod data; -use std::collections::HashMap; +use std::collections::BTreeMap; pub(crate) use data::Data; +use futures_util::Future; use regex::RegexSet; use ruma::api::appservice::{Namespace, Registration}; use tokio::sync::RwLock; @@ -10,6 +11,7 @@ use tokio::sync::RwLock; use crate::{services, Result}; /// Compiled regular expressions for a namespace +#[derive(Clone, Debug)] pub struct NamespaceRegex { pub exclusive: Option, pub non_exclusive: Option, @@ -71,7 +73,8 @@ impl TryFrom> for NamespaceRegex { } } -/// Compiled regular expressions for an appservice +/// Appservice registration combined with its compiled regular expressions. +#[derive(Clone, Debug)] pub struct RegistrationInfo { pub registration: Registration, pub users: NamespaceRegex, @@ -94,10 +97,26 @@ impl TryFrom for RegistrationInfo { pub struct Service { pub db: &'static dyn Data, - pub registration_info: RwLock>, + registration_info: RwLock>, } impl Service { + pub fn build(db: &'static dyn Data) -> Result { + let mut registration_info = BTreeMap::new(); + // Inserting registrations into cache + for appservice in db.all()? { + registration_info.insert( + appservice.0, + appservice.1.try_into().expect("Should be validated on registration"), + ); + } + + Ok(Self { + db, + registration_info: RwLock::new(registration_info), + }) + } + /// Registers an appservice and returns the ID to the caller pub async fn register_appservice(&self, yaml: Registration) -> Result { services().appservice.registration_info.write().await.insert(yaml.id.clone(), yaml.clone().try_into()?); @@ -116,9 +135,17 @@ impl Service { self.db.unregister_appservice(service_name) } - pub fn get_registration(&self, id: &str) -> Result> { self.db.get_registration(id) } + pub async fn get_registration(&self, id: &str) -> Option { + self.registration_info.read().await.get(id).cloned().map(|info| info.registration) + } - pub fn iter_ids(&self) -> Result> + '_> { self.db.iter_ids() } + pub async fn iter_ids(&self) -> Vec { self.registration_info.read().await.keys().cloned().collect() } - pub fn all(&self) -> Result> { self.db.all() } + pub async fn find_from_token(&self, token: &str) -> Option { + self.read().await.values().find(|info| info.registration.as_token == token).cloned() + } + + pub fn read(&self) -> impl Future>> { + self.registration_info.read() + } } diff --git a/src/service/mod.rs b/src/service/mod.rs index 1667dc81..8a6afe50 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -55,10 +55,7 @@ impl Services<'_> { db: &'static D, config: Config, ) -> Result { Ok(Self { - appservice: appservice::Service { - db, - registration_info: RwLock::new(HashMap::new()), - }, + appservice: appservice::Service::build(db)?, pusher: pusher::Service { db, }, diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index e5342e09..b191076c 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -510,7 +510,7 @@ impl Service { } } - for appservice in services().appservice.registration_info.read().await.values() { + for appservice in services().appservice.read().await.values() { if services().rooms.state_cache.appservice_in_room(&pdu.room_id, appservice)? { services().sending.send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; continue; diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index fc5ce586..1c9d4564 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -502,7 +502,7 @@ impl Service { let permit = services().sending.maximum_requests.acquire().await; let response = match appservice_server::send_request( - services().appservice.get_registration(id).map_err(|e| (kind.clone(), e))?.ok_or_else(|| { + services().appservice.get_registration(id).await.ok_or_else(|| { ( kind.clone(), Error::bad_database("[Appservice] Could not load registration from db."),