refactor appservice type stuff
Signed-off-by: strawberry <strawberry@puppygock.gay>
This commit is contained in:
parent
7c9c5b1d78
commit
60f2471f59
11 changed files with 125 additions and 133 deletions
|
@ -58,7 +58,6 @@
|
|||
- Basic validation/checks on user-specified room aliases and custom room ID creations
|
||||
- Warn on unknown config options specified
|
||||
- Add support for preventing certain room alias names and usernames using regex (via upstream MR) and extended to custom room IDs
|
||||
- Revamp appservice registration to ruma's `Registration` type which fixes various appservice registration issues, including fixing crashing upon no URL specified (via upstream MR)
|
||||
- URL preview support (via upstream MR) with various improvements
|
||||
- Increased graceful shutdown timeout from a low 60 seconds to 180 seconds to avoid killing connections and let the remaining ones finish processing, and ask systemd for more time to shutdown if needed to prevent systemd's default [`TimeoutStopSec=`](https://www.freedesktop.org/software/systemd/man/latest/systemd.service.html#TimeoutStopSec=) of 90 seconds from killing conduwuit
|
||||
- Bumped default max_concurrent_requests to 500
|
||||
|
|
|
@ -14,81 +14,76 @@ pub(crate) async fn send_request<T>(registration: Registration, request: T) -> O
|
|||
where
|
||||
T: OutgoingRequest + Debug,
|
||||
{
|
||||
if let Some(destination) = registration.url {
|
||||
let hs_token = registration.hs_token.as_str();
|
||||
let destination = registration.url?;
|
||||
|
||||
let mut http_request = request
|
||||
.try_into_http_request::<BytesMut>(
|
||||
&destination,
|
||||
SendAccessToken::IfRequired(hs_token),
|
||||
&[MatrixVersion::V1_0],
|
||||
)
|
||||
.map_err(|e| {
|
||||
warn!("Failed to find destination {}: {}", destination, e);
|
||||
Error::BadServerResponse("Invalid destination")
|
||||
})
|
||||
.unwrap()
|
||||
.map(BytesMut::freeze);
|
||||
let hs_token = registration.hs_token.as_str();
|
||||
|
||||
let mut parts = http_request.uri().clone().into_parts();
|
||||
let old_path_and_query = parts.path_and_query.unwrap().as_str().to_owned();
|
||||
let symbol = if old_path_and_query.contains('?') {
|
||||
"&"
|
||||
} else {
|
||||
"?"
|
||||
};
|
||||
let mut http_request = request
|
||||
.try_into_http_request::<BytesMut>(&destination, SendAccessToken::IfRequired(hs_token), &[MatrixVersion::V1_0])
|
||||
.map_err(|e| {
|
||||
warn!("Failed to find destination {}: {}", destination, e);
|
||||
Error::BadServerResponse("Invalid destination")
|
||||
})
|
||||
.unwrap()
|
||||
.map(BytesMut::freeze);
|
||||
|
||||
parts.path_and_query = Some((old_path_and_query + symbol + "access_token=" + hs_token).parse().unwrap());
|
||||
*http_request.uri_mut() = parts.try_into().expect("our manipulation is always valid");
|
||||
|
||||
let mut reqwest_request =
|
||||
reqwest::Request::try_from(http_request).expect("all http requests are valid reqwest requests");
|
||||
|
||||
*reqwest_request.timeout_mut() = Some(Duration::from_secs(120));
|
||||
|
||||
let url = reqwest_request.url().clone();
|
||||
let mut response = match services().globals.client.appservice.execute(reqwest_request).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Could not send request to appservice {} at {}: {}",
|
||||
registration.id, destination, e
|
||||
);
|
||||
return Some(Err(e.into()));
|
||||
},
|
||||
};
|
||||
|
||||
// reqwest::Response -> http::Response conversion
|
||||
let status = response.status();
|
||||
let mut http_response_builder = http::Response::builder().status(status).version(response.version());
|
||||
mem::swap(
|
||||
response.headers_mut(),
|
||||
http_response_builder.headers_mut().expect("http::response::Builder is usable"),
|
||||
);
|
||||
|
||||
let body = response.bytes().await.unwrap_or_else(|e| {
|
||||
warn!("server error: {}", e);
|
||||
Vec::new().into()
|
||||
}); // TODO: handle timeout
|
||||
|
||||
if !status.is_success() {
|
||||
warn!(
|
||||
"Appservice returned bad response {} {}\n{}\n{:?}",
|
||||
destination,
|
||||
status,
|
||||
url,
|
||||
utils::string_from_bytes(&body)
|
||||
);
|
||||
}
|
||||
|
||||
let response = T::IncomingResponse::try_from_http_response(
|
||||
http_response_builder.body(body).expect("reqwest body is valid http body"),
|
||||
);
|
||||
Some(response.map_err(|_| {
|
||||
warn!("Appservice returned invalid response bytes {}\n{}", destination, url);
|
||||
Error::BadServerResponse("Server returned bad response.")
|
||||
}))
|
||||
let mut parts = http_request.uri().clone().into_parts();
|
||||
let old_path_and_query = parts.path_and_query.unwrap().as_str().to_owned();
|
||||
let symbol = if old_path_and_query.contains('?') {
|
||||
"&"
|
||||
} else {
|
||||
None
|
||||
"?"
|
||||
};
|
||||
|
||||
parts.path_and_query = Some((old_path_and_query + symbol + "access_token=" + hs_token).parse().unwrap());
|
||||
*http_request.uri_mut() = parts.try_into().expect("our manipulation is always valid");
|
||||
|
||||
let mut reqwest_request =
|
||||
reqwest::Request::try_from(http_request).expect("all http requests are valid reqwest requests");
|
||||
|
||||
*reqwest_request.timeout_mut() = Some(Duration::from_secs(120));
|
||||
|
||||
let url = reqwest_request.url().clone();
|
||||
let mut response = match services().globals.client.appservice.execute(reqwest_request).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Could not send request to appservice {} at {}: {}",
|
||||
registration.id, destination, e
|
||||
);
|
||||
return Some(Err(e.into()));
|
||||
},
|
||||
};
|
||||
|
||||
// reqwest::Response -> http::Response conversion
|
||||
let status = response.status();
|
||||
let mut http_response_builder = http::Response::builder().status(status).version(response.version());
|
||||
mem::swap(
|
||||
response.headers_mut(),
|
||||
http_response_builder.headers_mut().expect("http::response::Builder is usable"),
|
||||
);
|
||||
|
||||
let body = response.bytes().await.unwrap_or_else(|e| {
|
||||
warn!("server error: {}", e);
|
||||
Vec::new().into()
|
||||
}); // TODO: handle timeout
|
||||
|
||||
if !status.is_success() {
|
||||
warn!(
|
||||
"Appservice returned bad response {} {}\n{}\n{:?}",
|
||||
destination,
|
||||
status,
|
||||
url,
|
||||
utils::string_from_bytes(&body)
|
||||
);
|
||||
}
|
||||
|
||||
let response = T::IncomingResponse::try_from_http_response(
|
||||
http_response_builder.body(body).expect("reqwest body is valid http body"),
|
||||
);
|
||||
|
||||
Some(response.map_err(|_| {
|
||||
warn!("Appservice returned invalid response bytes {}\n{}", destination, url);
|
||||
Error::BadServerResponse("Server returned bad response.")
|
||||
}))
|
||||
}
|
||||
|
|
|
@ -115,7 +115,7 @@ pub(crate) async fn get_alias_helper(room_alias: OwnedRoomAliasId) -> Result<get
|
|||
match services().rooms.alias.resolve_local_alias(&room_alias)? {
|
||||
Some(r) => room_id = Some(r),
|
||||
None => {
|
||||
for appservice in services().appservice.registration_info.read().await.values() {
|
||||
for appservice in services().appservice.read().await.values() {
|
||||
if appservice.aliases.is_match(room_alias.as_str())
|
||||
&& if let Some(opt_result) = services()
|
||||
.sending
|
||||
|
|
|
@ -76,11 +76,13 @@ where
|
|||
|
||||
let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok();
|
||||
|
||||
let appservices = services().appservice.all().unwrap();
|
||||
let appservice_registration =
|
||||
appservices.iter().find(|(_id, registration)| Some(registration.as_token.as_str()) == token);
|
||||
let appservice_registration = if let Some(token) = token {
|
||||
services().appservice.find_from_token(token).await
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let (sender_user, sender_device, sender_servername, from_appservice) = if let Some((_id, registration)) =
|
||||
let (sender_user, sender_device, sender_servername, from_appservice) = if let Some(info) =
|
||||
appservice_registration
|
||||
{
|
||||
match metadata.authentication {
|
||||
|
@ -88,7 +90,7 @@ where
|
|||
let user_id = query_params.user_id.map_or_else(
|
||||
|| {
|
||||
UserId::parse_with_server_name(
|
||||
registration.sender_localpart.as_str(),
|
||||
info.registration.sender_localpart.as_str(),
|
||||
services().globals.server_name(),
|
||||
)
|
||||
.unwrap()
|
||||
|
@ -109,7 +111,7 @@ where
|
|||
let user_id = query_params.user_id.map_or_else(
|
||||
|| {
|
||||
UserId::parse_with_server_name(
|
||||
registration.sender_localpart.as_str(),
|
||||
info.registration.sender_localpart.as_str(),
|
||||
services().globals.server_name(),
|
||||
)
|
||||
.unwrap()
|
||||
|
|
|
@ -7,7 +7,6 @@ impl service::appservice::Data for KeyValueDatabase {
|
|||
fn register_appservice(&self, yaml: Registration) -> Result<String> {
|
||||
let id = yaml.id.as_str();
|
||||
self.id_appserviceregistrations.insert(id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes())?;
|
||||
self.cached_registrations.write().unwrap().insert(id.to_owned(), yaml.clone());
|
||||
|
||||
Ok(id.to_owned())
|
||||
}
|
||||
|
@ -19,24 +18,17 @@ impl service::appservice::Data for KeyValueDatabase {
|
|||
/// * `service_name` - the name you send to register the service previously
|
||||
fn unregister_appservice(&self, service_name: &str) -> Result<()> {
|
||||
self.id_appserviceregistrations.remove(service_name.as_bytes())?;
|
||||
self.cached_registrations.write().unwrap().remove(service_name);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_registration(&self, id: &str) -> Result<Option<Registration>> {
|
||||
self.cached_registrations.read().unwrap().get(id).map_or_else(
|
||||
|| {
|
||||
self.id_appserviceregistrations
|
||||
.get(id.as_bytes())?
|
||||
.map(|bytes| {
|
||||
serde_yaml::from_slice(&bytes).map_err(|_| {
|
||||
Error::bad_database("Invalid registration bytes in id_appserviceregistrations.")
|
||||
})
|
||||
})
|
||||
.transpose()
|
||||
},
|
||||
|r| Ok(Some(r.clone())),
|
||||
)
|
||||
self.id_appserviceregistrations
|
||||
.get(id.as_bytes())?
|
||||
.map(|bytes| {
|
||||
serde_yaml::from_slice(&bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid registration bytes in id_appserviceregistrations."))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
fn iter_ids<'a>(&'a self) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>> {
|
||||
|
|
|
@ -17,7 +17,6 @@ use itertools::Itertools;
|
|||
use lru_cache::LruCache;
|
||||
use rand::thread_rng;
|
||||
use ruma::{
|
||||
api::appservice::Registration,
|
||||
events::{
|
||||
push_rules::{PushRulesEvent, PushRulesEventContent},
|
||||
room::message::RoomMessageEventContent,
|
||||
|
@ -179,7 +178,6 @@ pub struct KeyValueDatabase {
|
|||
//pub pusher: pusher::PushData,
|
||||
pub(super) senderkey_pusher: Arc<dyn KvTree>,
|
||||
|
||||
pub(super) cached_registrations: Arc<RwLock<HashMap<String, Registration>>>,
|
||||
pub(super) pdu_cache: Mutex<LruCache<OwnedEventId, Arc<PduEvent>>>,
|
||||
pub(super) shorteventid_cache: Mutex<LruCache<u64, Arc<EventId>>>,
|
||||
pub(super) auth_chain_cache: Mutex<LruCache<Vec<u64>, Arc<HashSet<u64>>>>,
|
||||
|
@ -379,7 +377,6 @@ impl KeyValueDatabase {
|
|||
global: builder.open_tree("global")?,
|
||||
server_signingkeys: builder.open_tree("server_signingkeys")?,
|
||||
|
||||
cached_registrations: Arc::new(RwLock::new(HashMap::new())),
|
||||
pdu_cache: Mutex::new(LruCache::new(
|
||||
config.pdu_cache_capacity.try_into().expect("pdu cache capacity fits into usize"),
|
||||
)),
|
||||
|
@ -992,14 +989,6 @@ impl KeyValueDatabase {
|
|||
);
|
||||
}
|
||||
|
||||
// Inserting registrations into cache
|
||||
for appservice in services().appservice.all()? {
|
||||
services().appservice.registration_info.write().await.insert(
|
||||
appservice.0,
|
||||
appservice.1.try_into().expect("Should be validated on registration"),
|
||||
);
|
||||
}
|
||||
|
||||
services().admin.start_handler();
|
||||
|
||||
// Set emergency access for the conduit user
|
||||
|
|
|
@ -623,8 +623,8 @@ impl Service {
|
|||
},
|
||||
AppserviceCommand::Show {
|
||||
appservice_identifier,
|
||||
} => match services().appservice.get_registration(&appservice_identifier) {
|
||||
Ok(Some(config)) => {
|
||||
} => match services().appservice.get_registration(&appservice_identifier).await {
|
||||
Some(config) => {
|
||||
let config_str =
|
||||
serde_yaml::to_string(&config).expect("config should've been validated on register");
|
||||
let output = format!("Config for {}:\n\n```yaml\n{}\n```", appservice_identifier, config_str,);
|
||||
|
@ -635,21 +635,12 @@ impl Service {
|
|||
);
|
||||
RoomMessageEventContent::text_html(output, output_html)
|
||||
},
|
||||
Ok(None) => RoomMessageEventContent::text_plain("Appservice does not exist."),
|
||||
Err(_) => RoomMessageEventContent::text_plain("Failed to get appservice."),
|
||||
None => RoomMessageEventContent::text_plain("Appservice does not exist."),
|
||||
},
|
||||
AppserviceCommand::List => {
|
||||
if let Ok(appservices) = services().appservice.iter_ids().map(Iterator::collect::<Vec<_>>) {
|
||||
let count = appservices.len();
|
||||
let output = format!(
|
||||
"Appservices ({}): {}",
|
||||
count,
|
||||
appservices.into_iter().filter_map(std::result::Result::ok).collect::<Vec<_>>().join(", ")
|
||||
);
|
||||
RoomMessageEventContent::text_plain(output)
|
||||
} else {
|
||||
RoomMessageEventContent::text_plain("Failed to get appservices.")
|
||||
}
|
||||
let appservices = services().appservice.iter_ids().await;
|
||||
let output = format!("Appservices ({}): {}", appservices.len(), appservices.join(", "));
|
||||
RoomMessageEventContent::text_plain(output)
|
||||
},
|
||||
},
|
||||
AdminCommand::Media(command) => {
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
mod data;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
pub(crate) use data::Data;
|
||||
use futures_util::Future;
|
||||
use regex::RegexSet;
|
||||
use ruma::api::appservice::{Namespace, Registration};
|
||||
use tokio::sync::RwLock;
|
||||
|
@ -10,6 +11,7 @@ use tokio::sync::RwLock;
|
|||
use crate::{services, Result};
|
||||
|
||||
/// Compiled regular expressions for a namespace
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct NamespaceRegex {
|
||||
pub exclusive: Option<RegexSet>,
|
||||
pub non_exclusive: Option<RegexSet>,
|
||||
|
@ -71,7 +73,8 @@ impl TryFrom<Vec<Namespace>> for NamespaceRegex {
|
|||
}
|
||||
}
|
||||
|
||||
/// Compiled regular expressions for an appservice
|
||||
/// Appservice registration combined with its compiled regular expressions.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct RegistrationInfo {
|
||||
pub registration: Registration,
|
||||
pub users: NamespaceRegex,
|
||||
|
@ -94,10 +97,26 @@ impl TryFrom<Registration> for RegistrationInfo {
|
|||
|
||||
pub struct Service {
|
||||
pub db: &'static dyn Data,
|
||||
pub registration_info: RwLock<HashMap<String, RegistrationInfo>>,
|
||||
registration_info: RwLock<BTreeMap<String, RegistrationInfo>>,
|
||||
}
|
||||
|
||||
impl Service {
|
||||
pub fn build(db: &'static dyn Data) -> Result<Self> {
|
||||
let mut registration_info = BTreeMap::new();
|
||||
// Inserting registrations into cache
|
||||
for appservice in db.all()? {
|
||||
registration_info.insert(
|
||||
appservice.0,
|
||||
appservice.1.try_into().expect("Should be validated on registration"),
|
||||
);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
db,
|
||||
registration_info: RwLock::new(registration_info),
|
||||
})
|
||||
}
|
||||
|
||||
/// Registers an appservice and returns the ID to the caller
|
||||
pub async fn register_appservice(&self, yaml: Registration) -> Result<String> {
|
||||
services().appservice.registration_info.write().await.insert(yaml.id.clone(), yaml.clone().try_into()?);
|
||||
|
@ -116,9 +135,17 @@ impl Service {
|
|||
self.db.unregister_appservice(service_name)
|
||||
}
|
||||
|
||||
pub fn get_registration(&self, id: &str) -> Result<Option<Registration>> { self.db.get_registration(id) }
|
||||
pub async fn get_registration(&self, id: &str) -> Option<Registration> {
|
||||
self.registration_info.read().await.get(id).cloned().map(|info| info.registration)
|
||||
}
|
||||
|
||||
pub fn iter_ids(&self) -> Result<impl Iterator<Item = Result<String>> + '_> { self.db.iter_ids() }
|
||||
pub async fn iter_ids(&self) -> Vec<String> { self.registration_info.read().await.keys().cloned().collect() }
|
||||
|
||||
pub fn all(&self) -> Result<Vec<(String, Registration)>> { self.db.all() }
|
||||
pub async fn find_from_token(&self, token: &str) -> Option<RegistrationInfo> {
|
||||
self.read().await.values().find(|info| info.registration.as_token == token).cloned()
|
||||
}
|
||||
|
||||
pub fn read(&self) -> impl Future<Output = tokio::sync::RwLockReadGuard<'_, BTreeMap<String, RegistrationInfo>>> {
|
||||
self.registration_info.read()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -55,10 +55,7 @@ impl Services<'_> {
|
|||
db: &'static D, config: Config,
|
||||
) -> Result<Self> {
|
||||
Ok(Self {
|
||||
appservice: appservice::Service {
|
||||
db,
|
||||
registration_info: RwLock::new(HashMap::new()),
|
||||
},
|
||||
appservice: appservice::Service::build(db)?,
|
||||
pusher: pusher::Service {
|
||||
db,
|
||||
},
|
||||
|
|
|
@ -510,7 +510,7 @@ impl Service {
|
|||
}
|
||||
}
|
||||
|
||||
for appservice in services().appservice.registration_info.read().await.values() {
|
||||
for appservice in services().appservice.read().await.values() {
|
||||
if services().rooms.state_cache.appservice_in_room(&pdu.room_id, appservice)? {
|
||||
services().sending.send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?;
|
||||
continue;
|
||||
|
|
|
@ -502,7 +502,7 @@ impl Service {
|
|||
let permit = services().sending.maximum_requests.acquire().await;
|
||||
|
||||
let response = match appservice_server::send_request(
|
||||
services().appservice.get_registration(id).map_err(|e| (kind.clone(), e))?.ok_or_else(|| {
|
||||
services().appservice.get_registration(id).await.ok_or_else(|| {
|
||||
(
|
||||
kind.clone(),
|
||||
Error::bad_database("[Appservice] Could not load registration from db."),
|
||||
|
|
Loading…
Add table
Reference in a new issue