From 784d3074258248ebf9550e6972a3ab5cdbcc3b0c Mon Sep 17 00:00:00 2001 From: Matthias Ahouansou Date: Wed, 7 Feb 2024 17:48:01 -0500 Subject: [PATCH] revamp appservice registration to ruma's Registration type squashed from https://gitlab.com/famedly/conduit/-/merge_requests/583 Signed-off-by: strawberry --- src/api/appservice_server.rs | 173 ++++++++++---------- src/api/client_server/alias.rs | 22 +-- src/api/ruma_wrapper/axum.rs | 15 +- src/database/key_value/appservice.rs | 11 +- src/database/key_value/rooms/state_cache.rs | 32 ++-- src/database/mod.rs | 3 +- src/service/admin/mod.rs | 4 +- src/service/appservice/data.rs | 8 +- src/service/appservice/mod.rs | 7 +- src/service/rooms/state_cache/data.rs | 3 +- src/service/rooms/state_cache/mod.rs | 3 +- src/service/rooms/timeline/mod.rs | 110 ++++++------- src/service/sending/mod.rs | 20 ++- 13 files changed, 203 insertions(+), 208 deletions(-) diff --git a/src/api/appservice_server.rs b/src/api/appservice_server.rs index e296ef26..2bc430d9 100644 --- a/src/api/appservice_server.rs +++ b/src/api/appservice_server.rs @@ -1,107 +1,114 @@ use crate::{services, utils, Error, Result}; use bytes::BytesMut; -use ruma::api::{IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken}; +use ruma::api::{ + appservice::Registration, IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken, +}; use std::{fmt::Debug, mem, time::Duration}; use tracing::warn; -#[tracing::instrument(skip(request))] +/// Sends a request to an appservice +/// +/// Only returns None if there is no url specified in the appservice registration file pub(crate) async fn send_request( - registration: serde_yaml::Value, + registration: Registration, request: T, -) -> Result +) -> Option> where T: Debug, { - let destination = registration.get("url").unwrap().as_str().unwrap(); - let hs_token = registration.get("hs_token").unwrap().as_str().unwrap(); + if let Some(destination) = registration.url { + let hs_token = registration.hs_token.as_str(); - let mut http_request = request - .try_into_http_request::( - destination, - SendAccessToken::IfRequired(hs_token), - &[MatrixVersion::V1_0], - ) - .map_err(|e| { - warn!("Failed to find destination {}: {}", destination, e); - Error::BadServerResponse("Invalid destination") - })? - .map(|body| body.freeze()); + let mut http_request = request + .try_into_http_request::( + &destination, + SendAccessToken::IfRequired(hs_token), + &[MatrixVersion::V1_0], + ) + .map_err(|e| { + warn!("Failed to find destination {}: {}", destination, e); + Error::BadServerResponse("Invalid destination") + }) + .unwrap() + .map(|body| body.freeze()); - let mut parts = http_request.uri().clone().into_parts(); - let old_path_and_query = parts.path_and_query.unwrap().as_str().to_owned(); - let symbol = if old_path_and_query.contains('?') { - "&" - } else { - "?" - }; + let mut parts = http_request.uri().clone().into_parts(); + let old_path_and_query = parts.path_and_query.unwrap().as_str().to_owned(); + let symbol = if old_path_and_query.contains('?') { + "&" + } else { + "?" + }; - parts.path_and_query = Some( - (old_path_and_query + symbol + "access_token=" + hs_token) - .parse() - .unwrap(), - ); - *http_request.uri_mut() = parts.try_into().expect("our manipulation is always valid"); + parts.path_and_query = Some( + (old_path_and_query + symbol + "access_token=" + hs_token) + .parse() + .unwrap(), + ); + *http_request.uri_mut() = parts.try_into().expect("our manipulation is always valid"); - let mut reqwest_request = reqwest::Request::try_from(http_request)?; + let mut reqwest_request = reqwest::Request::try_from(http_request) + .expect("all http requests are valid reqwest requests"); - *reqwest_request.timeout_mut() = Some(Duration::from_secs(120)); + *reqwest_request.timeout_mut() = Some(Duration::from_secs(120)); - let url = reqwest_request.url().clone(); - let mut response = match services() - .globals - .default_client() - .execute(reqwest_request) - .await - { - Ok(r) => r, - Err(e) => { + let url = reqwest_request.url().clone(); + let mut response = match services() + .globals + .default_client() + .execute(reqwest_request) + .await + { + Ok(r) => r, + Err(e) => { + warn!( + "Could not send request to appservice {} at {}: {}", + registration.id, destination, e + ); + return Some(Err(e.into())); + } + }; + + // reqwest::Response -> http::Response conversion + let status = response.status(); + let mut http_response_builder = http::Response::builder() + .status(status) + .version(response.version()); + mem::swap( + response.headers_mut(), + http_response_builder + .headers_mut() + .expect("http::response::Builder is usable"), + ); + + let body = response.bytes().await.unwrap_or_else(|e| { + warn!("server error: {}", e); + Vec::new().into() + }); // TODO: handle timeout + + if !status.is_success() { warn!( - "Could not send request to appservice {:?} at {}: {}", - registration.get("id"), + "Appservice returned bad response {} {}\n{}\n{:?}", destination, - e + status, + url, + utils::string_from_bytes(&body) ); - return Err(e.into()); } - }; - // reqwest::Response -> http::Response conversion - let status = response.status(); - let mut http_response_builder = http::Response::builder() - .status(status) - .version(response.version()); - mem::swap( - response.headers_mut(), - http_response_builder - .headers_mut() - .expect("http::response::Builder is usable"), - ); - - let body = response.bytes().await.unwrap_or_else(|e| { - warn!("server error: {}", e); - Vec::new().into() - }); // TODO: handle timeout - - if !status.is_success() { - warn!( - "Appservice returned bad response {} {}\n{}\n{:?}", - destination, - status, - url, - utils::string_from_bytes(&body) + let response = T::IncomingResponse::try_from_http_response( + http_response_builder + .body(body) + .expect("reqwest body is valid http body"), ); + Some(response.map_err(|_| { + warn!( + "Appservice returned invalid response bytes {}\n{}", + destination, url + ); + Error::BadServerResponse("Server returned bad response.") + })) + } else { + None } - - let response = T::IncomingResponse::try_from_http_response( - http_response_builder - .body(body) - .expect("reqwest body is valid http body"), - ); - response.map_err(|_| { - warn!( - "Appservice returned invalid response bytes {}\n{}", - destination, url - ); - Error::BadServerResponse("Server returned bad response.") - }) } diff --git a/src/api/client_server/alias.rs b/src/api/client_server/alias.rs index b0f5a289..bf903445 100644 --- a/src/api/client_server/alias.rs +++ b/src/api/client_server/alias.rs @@ -157,20 +157,16 @@ pub(crate) async fn get_alias_helper( None => { for (_id, registration) in services().appservice.all()? { let aliases = registration - .get("namespaces") - .and_then(|ns| ns.get("aliases")) - .and_then(|aliases| aliases.as_sequence()) - .map_or_else(Vec::new, |aliases| { - aliases - .iter() - .filter_map(|aliases| Regex::new(aliases.get("regex")?.as_str()?).ok()) - .collect::>() - }); + .namespaces + .aliases + .iter() + .filter_map(|alias| Regex::new(alias.regex.as_str()).ok()) + .collect::>(); if aliases .iter() .any(|aliases| aliases.is_match(room_alias.as_str())) - && services() + && if let Some(opt_result) = services() .sending .send_appservice_request( registration, @@ -179,7 +175,11 @@ pub(crate) async fn get_alias_helper( }, ) .await - .is_ok() + { + opt_result.is_ok() + } else { + false + } { room_id = Some( services() diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs index 39cabaf7..fda13531 100644 --- a/src/api/ruma_wrapper/axum.rs +++ b/src/api/ruma_wrapper/axum.rs @@ -81,12 +81,9 @@ where let mut json_body = serde_json::from_slice::(&body).ok(); let appservices = services().appservice.all().unwrap(); - let appservice_registration = appservices.iter().find(|(_id, registration)| { - registration - .get("as_token") - .and_then(|as_token| as_token.as_str()) - .map_or(false, |as_token| token == Some(as_token)) - }); + let appservice_registration = appservices + .iter() + .find(|(_id, registration)| Some(registration.as_token.as_str()) == token); let (sender_user, sender_device, sender_servername, from_appservice) = if let Some((_id, registration)) = appservice_registration { @@ -95,11 +92,7 @@ where let user_id = query_params.user_id.map_or_else( || { UserId::parse_with_server_name( - registration - .get("sender_localpart") - .unwrap() - .as_str() - .unwrap(), + registration.sender_localpart.as_str(), services().globals.server_name(), ) .unwrap() diff --git a/src/database/key_value/appservice.rs b/src/database/key_value/appservice.rs index 9a821a65..3243183d 100644 --- a/src/database/key_value/appservice.rs +++ b/src/database/key_value/appservice.rs @@ -1,10 +1,11 @@ +use ruma::api::appservice::Registration; + use crate::{database::KeyValueDatabase, service, utils, Error, Result}; impl service::appservice::Data for KeyValueDatabase { /// Registers an appservice and returns the ID to the caller - fn register_appservice(&self, yaml: serde_yaml::Value) -> Result { - // TODO: Rumaify - let id = yaml.get("id").unwrap().as_str().unwrap(); + fn register_appservice(&self, yaml: Registration) -> Result { + let id = yaml.id.as_str(); self.id_appserviceregistrations.insert( id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes(), @@ -32,7 +33,7 @@ impl service::appservice::Data for KeyValueDatabase { Ok(()) } - fn get_registration(&self, id: &str) -> Result> { + fn get_registration(&self, id: &str) -> Result> { self.cached_registrations .read() .unwrap() @@ -64,7 +65,7 @@ impl service::appservice::Data for KeyValueDatabase { ))) } - fn all(&self) -> Result> { + fn all(&self) -> Result> { self.iter_ids()? .filter_map(|id| id.ok()) .map(move |id| { diff --git a/src/database/key_value/rooms/state_cache.rs b/src/database/key_value/rooms/state_cache.rs index 7894b0f3..b23db30c 100644 --- a/src/database/key_value/rooms/state_cache.rs +++ b/src/database/key_value/rooms/state_cache.rs @@ -2,6 +2,7 @@ use std::{collections::HashSet, sync::Arc}; use regex::Regex; use ruma::{ + api::appservice::Registration, events::{AnyStrippedStateEvent, AnySyncStateEvent}, serde::Raw, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, @@ -193,7 +194,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { fn appservice_in_room( &self, room_id: &RoomId, - appservice: &(String, serde_yaml::Value), + appservice: &(String, Registration), ) -> Result { let maybe = self .appservice_in_room_cache @@ -205,24 +206,19 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { if let Some(b) = maybe { Ok(b) - } else if let Some(namespaces) = appservice.1.get("namespaces") { + } else { + let namespaces = &appservice.1.namespaces; let users = namespaces - .get("users") - .and_then(|users| users.as_sequence()) - .map_or_else(Vec::new, |users| { - users - .iter() - .filter_map(|users| Regex::new(users.get("regex")?.as_str()?).ok()) - .collect::>() - }); + .users + .iter() + .filter_map(|users| Regex::new(users.regex.as_str()).ok()) + .collect::>(); - let bridge_user_id = appservice - .1 - .get("sender_localpart") - .and_then(|string| string.as_str()) - .and_then(|string| { - UserId::parse_with_server_name(string, services().globals.server_name()).ok() - }); + let bridge_user_id = UserId::parse_with_server_name( + appservice.1.sender_localpart.as_str(), + services().globals.server_name(), + ) + .ok(); let in_room = bridge_user_id .map_or(false, |id| self.is_joined(&id, room_id).unwrap_or(false)) @@ -240,8 +236,6 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { .insert(appservice.0.clone(), in_room); Ok(in_room) - } else { - Ok(false) } } diff --git a/src/database/mod.rs b/src/database/mod.rs index 739f5b20..32921b13 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -11,6 +11,7 @@ use directories::ProjectDirs; use lru_cache::LruCache; use rand::thread_rng; use ruma::{ + api::appservice::Registration, events::{ push_rules::{PushRulesEvent, PushRulesEventContent}, room::message::RoomMessageEventContent, @@ -163,7 +164,7 @@ pub struct KeyValueDatabase { //pub pusher: pusher::PushData, pub(super) senderkey_pusher: Arc, - pub(super) cached_registrations: Arc>>, + pub(super) cached_registrations: Arc>>, pub(super) pdu_cache: Mutex>>, pub(super) shorteventid_cache: Mutex>>, pub(super) auth_chain_cache: Mutex, Arc>>>, diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index 524099d6..c066c42c 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -10,7 +10,7 @@ use std::fmt::Write; use clap::{Parser, Subcommand}; use regex::Regex; use ruma::{ - api::client::error::ErrorKind, + api::{appservice::Registration, client::error::ErrorKind}, events::{ relation::InReplyTo, room::{ @@ -472,7 +472,7 @@ impl Service { { let appservice_config = body[1..body.len() - 1].join("\n"); let parsed_config = - serde_yaml::from_str::(&appservice_config); + serde_yaml::from_str::(&appservice_config); match parsed_config { Ok(yaml) => match services().appservice.register_appservice(yaml) { Ok(id) => RoomMessageEventContent::text_plain(format!( diff --git a/src/service/appservice/data.rs b/src/service/appservice/data.rs index 744f0f94..ab19a50c 100644 --- a/src/service/appservice/data.rs +++ b/src/service/appservice/data.rs @@ -1,8 +1,10 @@ +use ruma::api::appservice::Registration; + use crate::Result; pub trait Data: Send + Sync { /// Registers an appservice and returns the ID to the caller - fn register_appservice(&self, yaml: serde_yaml::Value) -> Result; + fn register_appservice(&self, yaml: Registration) -> Result; /// Remove an appservice registration /// @@ -11,9 +13,9 @@ pub trait Data: Send + Sync { /// * `service_name` - the name you send to register the service previously fn unregister_appservice(&self, service_name: &str) -> Result<()>; - fn get_registration(&self, id: &str) -> Result>; + fn get_registration(&self, id: &str) -> Result>; fn iter_ids<'a>(&'a self) -> Result> + 'a>>; - fn all(&self) -> Result>; + fn all(&self) -> Result>; } diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs index 7789968c..ed6f37bf 100644 --- a/src/service/appservice/mod.rs +++ b/src/service/appservice/mod.rs @@ -1,6 +1,7 @@ mod data; pub(crate) use data::Data; +use ruma::api::appservice::Registration; use crate::Result; @@ -10,7 +11,7 @@ pub struct Service { impl Service { /// Registers an appservice and returns the ID to the caller - pub fn register_appservice(&self, yaml: serde_yaml::Value) -> Result { + pub fn register_appservice(&self, yaml: Registration) -> Result { self.db.register_appservice(yaml) } @@ -23,7 +24,7 @@ impl Service { self.db.unregister_appservice(service_name) } - pub fn get_registration(&self, id: &str) -> Result> { + pub fn get_registration(&self, id: &str) -> Result> { self.db.get_registration(id) } @@ -31,7 +32,7 @@ impl Service { self.db.iter_ids() } - pub fn all(&self) -> Result> { + pub fn all(&self) -> Result> { self.db.all() } } diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs index ed868529..e36ac03e 100644 --- a/src/service/rooms/state_cache/data.rs +++ b/src/service/rooms/state_cache/data.rs @@ -2,6 +2,7 @@ use std::{collections::HashSet, sync::Arc}; use crate::Result; use ruma::{ + api::appservice::Registration, events::{AnyStrippedStateEvent, AnySyncStateEvent}, serde::Raw, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, @@ -31,7 +32,7 @@ pub trait Data: Send + Sync { fn appservice_in_room( &self, room_id: &RoomId, - appservice: &(String, serde_yaml::Value), + appservice: &(String, Registration), ) -> Result; /// Makes a user forget a room. diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index 546767a9..6498b86e 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -4,6 +4,7 @@ use std::{collections::HashSet, sync::Arc}; pub use data::Data; use ruma::{ + api::appservice::Registration, events::{ direct::DirectEvent, ignored_user_list::IgnoredUserListEvent, @@ -240,7 +241,7 @@ impl Service { pub fn appservice_in_room( &self, room_id: &RoomId, - appservice: &(String, serde_yaml::Value), + appservice: &(String, Registration), ) -> Result { self.db.appservice_in_room(room_id, appservice) } diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index be65815a..96e95d7f 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -621,73 +621,61 @@ impl Service { .as_ref() .and_then(|state_key| UserId::parse(state_key.as_str()).ok()) { - if let Some(appservice_uid) = appservice - .1 - .get("sender_localpart") - .and_then(|string| string.as_str()) - .and_then(|string| { - UserId::parse_with_server_name(string, services().globals.server_name()) - .ok() - }) - { - if state_key_uid == &appservice_uid { - services() - .sending - .send_pdu_appservice(appservice.0, pdu_id.clone())?; - continue; - } + let appservice_uid = appservice.1.sender_localpart.as_str(); + if state_key_uid == appservice_uid { + services() + .sending + .send_pdu_appservice(appservice.0, pdu_id.clone())?; + continue; } } } - if let Some(namespaces) = appservice.1.get("namespaces") { - let users = namespaces - .get("users") - .and_then(|users| users.as_sequence()) - .map_or_else(Vec::new, |users| { - users - .iter() - .filter_map(|users| Regex::new(users.get("regex")?.as_str()?).ok()) - .collect::>() - }); - let aliases = namespaces - .get("aliases") - .and_then(|aliases| aliases.as_sequence()) - .map_or_else(Vec::new, |aliases| { - aliases - .iter() - .filter_map(|aliases| Regex::new(aliases.get("regex")?.as_str()?).ok()) - .collect::>() - }); - let rooms = namespaces - .get("rooms") - .and_then(|rooms| rooms.as_sequence()); + let namespaces = appservice.1.namespaces; - let matching_users = |users: &Regex| { - users.is_match(pdu.sender.as_str()) - || pdu.kind == TimelineEventType::RoomMember - && pdu - .state_key - .as_ref() - .map_or(false, |state_key| users.is_match(state_key)) - }; - let matching_aliases = |aliases: &Regex| { - services() - .rooms - .alias - .local_aliases_for_room(&pdu.room_id) - .filter_map(|r| r.ok()) - .any(|room_alias| aliases.is_match(room_alias.as_str())) - }; + // TODO: create some helper function to change from Strings to Regexes + let users = namespaces + .users + .iter() + .filter_map(|user| Regex::new(user.regex.as_str()).ok()) + .collect::>(); + let aliases = namespaces + .aliases + .iter() + .filter_map(|alias| Regex::new(alias.regex.as_str()).ok()) + .collect::>(); + let rooms = namespaces + .rooms + .iter() + .filter_map(|room| Regex::new(room.regex.as_str()).ok()) + .collect::>(); - if aliases.iter().any(matching_aliases) - || rooms.map_or(false, |rooms| rooms.contains(&pdu.room_id.as_str().into())) - || users.iter().any(matching_users) - { - services() - .sending - .send_pdu_appservice(appservice.0, pdu_id.clone())?; - } + let matching_users = |users: &Regex| { + users.is_match(pdu.sender.as_str()) + || pdu.kind == TimelineEventType::RoomMember + && pdu + .state_key + .as_ref() + .map_or(false, |state_key| users.is_match(state_key)) + }; + let matching_aliases = |aliases: &Regex| { + services() + .rooms + .alias + .local_aliases_for_room(&pdu.room_id) + .filter_map(|r| r.ok()) + .any(|room_alias| aliases.is_match(room_alias.as_str())) + }; + + if aliases.iter().any(matching_aliases) + || rooms + .iter() + .any(|namespace| namespace.is_match(pdu.room_id.as_str())) + || users.iter().any(matching_users) + { + services() + .sending + .send_pdu_appservice(appservice.0, pdu_id.clone())?; } } diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index 785c50fe..7f3ba124 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -23,7 +23,7 @@ use base64::{engine::general_purpose, Engine as _}; use ruma::{ api::{ - appservice, + appservice::{self, Registration}, federation::{ self, transactions::edu::{ @@ -520,7 +520,7 @@ impl Service { let permit = services().sending.maximum_requests.acquire().await; - let response = appservice_server::send_request( + let response = match appservice_server::send_request( services() .appservice .get_registration(id) @@ -547,8 +547,12 @@ impl Service { }, ) .await - .map(|_response| kind.clone()) - .map_err(|e| (kind, e)); + { + None => Ok(kind.clone()), + Some(op_resp) => op_resp + .map(|_response| kind.clone()) + .map_err(|e| (kind.clone(), e)), + }; drop(permit); @@ -764,12 +768,14 @@ impl Service { response } - #[tracing::instrument(skip(self, registration, request))] + /// Sends a request to an appservice + /// + /// Only returns None if there is no url specified in the appservice registration file pub async fn send_appservice_request( &self, - registration: serde_yaml::Value, + registration: Registration, request: T, - ) -> Result + ) -> Option> where T: Debug, {