refactor: minor appservice code cleanup

This commit is contained in:
Timo Kösters 2024-03-22 08:52:39 +01:00
parent fa930182ae
commit 0bb28f60cf
No known key found for this signature in database
GPG key ID: 0B25E636FBA7E4CB
8 changed files with 136 additions and 163 deletions

View file

@ -17,95 +17,96 @@ pub(crate) async fn send_request<T: OutgoingRequest>(
where where
T: Debug, T: Debug,
{ {
if let Some(destination) = registration.url { let Some(destination) = registration.url else {
let hs_token = registration.hs_token.as_str(); return None;
};
let mut http_request = request let hs_token = registration.hs_token.as_str();
.try_into_http_request::<BytesMut>(
&destination,
SendAccessToken::IfRequired(hs_token),
&[MatrixVersion::V1_0],
)
.unwrap()
.map(|body| body.freeze());
let mut parts = http_request.uri().clone().into_parts(); let mut http_request = request
let old_path_and_query = parts.path_and_query.unwrap().as_str().to_owned(); .try_into_http_request::<BytesMut>(
let symbol = if old_path_and_query.contains('?') { &destination,
"&" SendAccessToken::IfRequired(hs_token),
} else { &[MatrixVersion::V1_0],
"?" )
}; .unwrap()
.map(|body| body.freeze());
parts.path_and_query = Some( let mut parts = http_request.uri().clone().into_parts();
(old_path_and_query + symbol + "access_token=" + hs_token) let old_path_and_query = parts.path_and_query.unwrap().as_str().to_owned();
.parse() let symbol = if old_path_and_query.contains('?') {
.unwrap(), "&"
);
*http_request.uri_mut() = parts.try_into().expect("our manipulation is always valid");
let mut reqwest_request = reqwest::Request::try_from(http_request)
.expect("all http requests are valid reqwest requests");
*reqwest_request.timeout_mut() = Some(Duration::from_secs(30));
let url = reqwest_request.url().clone();
let mut response = match services()
.globals
.default_client()
.execute(reqwest_request)
.await
{
Ok(r) => r,
Err(e) => {
warn!(
"Could not send request to appservice {:?} at {}: {}",
registration.id, destination, e
);
return Some(Err(e.into()));
}
};
// reqwest::Response -> http::Response conversion
let status = response.status();
let mut http_response_builder = http::Response::builder()
.status(status)
.version(response.version());
mem::swap(
response.headers_mut(),
http_response_builder
.headers_mut()
.expect("http::response::Builder is usable"),
);
let body = response.bytes().await.unwrap_or_else(|e| {
warn!("server error: {}", e);
Vec::new().into()
}); // TODO: handle timeout
if status != 200 {
warn!(
"Appservice returned bad response {} {}\n{}\n{:?}",
destination,
status,
url,
utils::string_from_bytes(&body)
);
}
let response = T::IncomingResponse::try_from_http_response(
http_response_builder
.body(body)
.expect("reqwest body is valid http body"),
);
Some(response.map_err(|_| {
warn!(
"Appservice returned invalid response bytes {}\n{}",
destination, url
);
Error::BadServerResponse("Server returned bad response.")
}))
} else { } else {
None "?"
};
parts.path_and_query = Some(
(old_path_and_query + symbol + "access_token=" + hs_token)
.parse()
.unwrap(),
);
*http_request.uri_mut() = parts.try_into().expect("our manipulation is always valid");
let mut reqwest_request = reqwest::Request::try_from(http_request)
.expect("all http requests are valid reqwest requests");
*reqwest_request.timeout_mut() = Some(Duration::from_secs(30));
let url = reqwest_request.url().clone();
let mut response = match services()
.globals
.default_client()
.execute(reqwest_request)
.await
{
Ok(r) => r,
Err(e) => {
warn!(
"Could not send request to appservice {:?} at {}: {}",
registration.id, destination, e
);
return Some(Err(e.into()));
}
};
// reqwest::Response -> http::Response conversion
let status = response.status();
let mut http_response_builder = http::Response::builder()
.status(status)
.version(response.version());
mem::swap(
response.headers_mut(),
http_response_builder
.headers_mut()
.expect("http::response::Builder is usable"),
);
let body = response.bytes().await.unwrap_or_else(|e| {
warn!("server error: {}", e);
Vec::new().into()
}); // TODO: handle timeout
if status != 200 {
warn!(
"Appservice returned bad response {} {}\n{}\n{:?}",
destination,
status,
url,
utils::string_from_bytes(&body)
);
} }
let response = T::IncomingResponse::try_from_http_response(
http_response_builder
.body(body)
.expect("reqwest body is valid http body"),
);
Some(response.map_err(|_| {
warn!(
"Appservice returned invalid response bytes {}\n{}",
destination, url
);
Error::BadServerResponse("Server returned bad response.")
}))
} }

View file

@ -100,13 +100,7 @@ pub(crate) async fn get_alias_helper(
match services().rooms.alias.resolve_local_alias(&room_alias)? { match services().rooms.alias.resolve_local_alias(&room_alias)? {
Some(r) => room_id = Some(r), Some(r) => room_id = Some(r),
None => { None => {
for appservice in services() for appservice in services().appservice.all().await {
.appservice
.registration_info
.read()
.await
.values()
{
if appservice.aliases.is_match(room_alias.as_str()) if appservice.aliases.is_match(room_alias.as_str())
&& if let Some(opt_result) = services() && if let Some(opt_result) = services()
.sending .sending

View file

@ -80,19 +80,19 @@ where
let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok(); let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok();
let appservices = services().appservice.all().unwrap(); let appservices = services().appservice.all().await;
let appservice_registration = appservices let appservice_registration = appservices
.iter() .iter()
.find(|(_id, registration)| Some(registration.as_token.as_str()) == token); .find(|info| Some(info.registration.as_token.as_str()) == token);
let (sender_user, sender_device, sender_servername, from_appservice) = let (sender_user, sender_device, sender_servername, from_appservice) =
if let Some((_id, registration)) = appservice_registration { if let Some(info) = appservice_registration {
match metadata.authentication { match metadata.authentication {
AuthScheme::AccessToken => { AuthScheme::AccessToken => {
let user_id = query_params.user_id.map_or_else( let user_id = query_params.user_id.map_or_else(
|| { || {
UserId::parse_with_server_name( UserId::parse_with_server_name(
registration.sender_localpart.as_str(), info.registration.sender_localpart.as_str(),
services().globals.server_name(), services().globals.server_name(),
) )
.unwrap() .unwrap()

View file

@ -10,10 +10,6 @@ impl service::appservice::Data for KeyValueDatabase {
id.as_bytes(), id.as_bytes(),
serde_yaml::to_string(&yaml).unwrap().as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes(),
)?; )?;
self.cached_registrations
.write()
.unwrap()
.insert(id.to_owned(), yaml.to_owned());
Ok(id.to_owned()) Ok(id.to_owned())
} }
@ -26,33 +22,18 @@ impl service::appservice::Data for KeyValueDatabase {
fn unregister_appservice(&self, service_name: &str) -> Result<()> { fn unregister_appservice(&self, service_name: &str) -> Result<()> {
self.id_appserviceregistrations self.id_appserviceregistrations
.remove(service_name.as_bytes())?; .remove(service_name.as_bytes())?;
self.cached_registrations
.write()
.unwrap()
.remove(service_name);
Ok(()) Ok(())
} }
fn get_registration(&self, id: &str) -> Result<Option<Registration>> { fn get_registration(&self, id: &str) -> Result<Option<Registration>> {
self.cached_registrations self.id_appserviceregistrations
.read() .get(id.as_bytes())?
.unwrap() .map(|bytes| {
.get(id) serde_yaml::from_slice(&bytes).map_err(|_| {
.map_or_else( Error::bad_database("Invalid registration bytes in id_appserviceregistrations.")
|| { })
self.id_appserviceregistrations })
.get(id.as_bytes())? .transpose()
.map(|bytes| {
serde_yaml::from_slice(&bytes).map_err(|_| {
Error::bad_database(
"Invalid registration bytes in id_appserviceregistrations.",
)
})
})
.transpose()
},
|r| Ok(Some(r.clone())),
)
} }
fn iter_ids<'a>(&'a self) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>> { fn iter_ids<'a>(&'a self) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>> {

View file

@ -10,7 +10,6 @@ use directories::ProjectDirs;
use lru_cache::LruCache; use lru_cache::LruCache;
use ruma::{ use ruma::{
api::appservice::Registration,
events::{ events::{
push_rules::{PushRulesEvent, PushRulesEventContent}, push_rules::{PushRulesEvent, PushRulesEventContent},
room::message::RoomMessageEventContent, room::message::RoomMessageEventContent,
@ -164,7 +163,6 @@ pub struct KeyValueDatabase {
//pub pusher: pusher::PushData, //pub pusher: pusher::PushData,
pub(super) senderkey_pusher: Arc<dyn KvTree>, pub(super) senderkey_pusher: Arc<dyn KvTree>,
pub(super) cached_registrations: Arc<RwLock<HashMap<String, Registration>>>,
pub(super) pdu_cache: Mutex<LruCache<OwnedEventId, Arc<PduEvent>>>, pub(super) pdu_cache: Mutex<LruCache<OwnedEventId, Arc<PduEvent>>>,
pub(super) shorteventid_cache: Mutex<LruCache<u64, Arc<EventId>>>, pub(super) shorteventid_cache: Mutex<LruCache<u64, Arc<EventId>>>,
pub(super) auth_chain_cache: Mutex<LruCache<Vec<u64>, Arc<HashSet<u64>>>>, pub(super) auth_chain_cache: Mutex<LruCache<Vec<u64>, Arc<HashSet<u64>>>>,
@ -374,7 +372,6 @@ impl KeyValueDatabase {
global: builder.open_tree("global")?, global: builder.open_tree("global")?,
server_signingkeys: builder.open_tree("server_signingkeys")?, server_signingkeys: builder.open_tree("server_signingkeys")?,
cached_registrations: Arc::new(RwLock::new(HashMap::new())),
pdu_cache: Mutex::new(LruCache::new( pdu_cache: Mutex::new(LruCache::new(
config config
.pdu_cache_capacity .pdu_cache_capacity
@ -969,22 +966,6 @@ impl KeyValueDatabase {
); );
} }
// Inserting registrations into cache
for appservice in services().appservice.all()? {
services()
.appservice
.registration_info
.write()
.await
.insert(
appservice.0,
appservice
.1
.try_into()
.expect("Should be validated on registration"),
);
}
// This data is probably outdated // This data is probably outdated
db.presenceid_presence.clear()?; db.presenceid_presence.clear()?;

View file

@ -10,7 +10,8 @@ use tokio::sync::RwLock;
use crate::{services, Result}; use crate::{services, Result};
/// Compiled regular expressions for a namespace /// Compiled regular expressions for a namespace.
#[derive(Clone, Debug)]
pub struct NamespaceRegex { pub struct NamespaceRegex {
pub exclusive: Option<RegexSet>, pub exclusive: Option<RegexSet>,
pub non_exclusive: Option<RegexSet>, pub non_exclusive: Option<RegexSet>,
@ -72,7 +73,8 @@ impl TryFrom<Vec<Namespace>> for NamespaceRegex {
type Error = regex::Error; type Error = regex::Error;
} }
/// Compiled regular expressions for an appservice /// Appservice registration combined with its compiled regular expressions.
#[derive(Clone, Debug)]
pub struct RegistrationInfo { pub struct RegistrationInfo {
pub registration: Registration, pub registration: Registration,
pub users: NamespaceRegex, pub users: NamespaceRegex,
@ -95,11 +97,29 @@ impl TryFrom<Registration> for RegistrationInfo {
pub struct Service { pub struct Service {
pub db: &'static dyn Data, pub db: &'static dyn Data,
pub registration_info: RwLock<HashMap<String, RegistrationInfo>>, registration_info: RwLock<HashMap<String, RegistrationInfo>>,
} }
impl Service { impl Service {
/// Registers an appservice and returns the ID to the caller pub fn build(db: &'static dyn Data) -> Result<Self> {
let mut registration_info = HashMap::new();
// Inserting registrations into cache
for appservice in db.all()? {
registration_info.insert(
appservice.0,
appservice
.1
.try_into()
.expect("Should be validated on registration"),
);
}
Ok(Self {
db,
registration_info: RwLock::new(registration_info),
})
}
/// Registers an appservice and returns the ID to the caller.
pub async fn register_appservice(&self, yaml: Registration) -> Result<String> { pub async fn register_appservice(&self, yaml: Registration) -> Result<String> {
services() services()
.appservice .appservice
@ -111,7 +131,7 @@ impl Service {
self.db.register_appservice(yaml) self.db.register_appservice(yaml)
} }
/// Remove an appservice registration /// Removes an appservice registration.
/// ///
/// # Arguments /// # Arguments
/// ///
@ -135,7 +155,12 @@ impl Service {
self.db.iter_ids() self.db.iter_ids()
} }
pub fn all(&self) -> Result<Vec<(String, Registration)>> { pub async fn all(&self) -> Vec<RegistrationInfo> {
self.db.all() self.registration_info
.read()
.await
.values()
.cloned()
.collect()
} }
} }

View file

@ -4,7 +4,7 @@ use std::{
}; };
use lru_cache::LruCache; use lru_cache::LruCache;
use tokio::sync::{Mutex, RwLock}; use tokio::sync::Mutex;
use crate::{Config, Result}; use crate::{Config, Result};
@ -56,10 +56,7 @@ impl Services {
config: Config, config: Config,
) -> Result<Self> { ) -> Result<Self> {
Ok(Self { Ok(Self {
appservice: appservice::Service { appservice: appservice::Service::build(db)?,
db,
registration_info: RwLock::new(HashMap::new()),
},
pusher: pusher::Service { db }, pusher: pusher::Service { db },
rooms: rooms::Service { rooms: rooms::Service {
alias: rooms::alias::Service { db }, alias: rooms::alias::Service { db },

View file

@ -524,17 +524,11 @@ impl Service {
} }
} }
for appservice in services() for appservice in services().appservice.all().await {
.appservice
.registration_info
.read()
.await
.values()
{
if services() if services()
.rooms .rooms
.state_cache .state_cache
.appservice_in_room(&pdu.room_id, appservice)? .appservice_in_room(&pdu.room_id, &appservice)?
{ {
services() services()
.sending .sending