From eaf1cf38a52caa6e82d5beb9539fa35d8b818a3e Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Thu, 4 Jul 2024 11:44:39 +0000 Subject: [PATCH] refactor resolver tuples into structs Signed-off-by: Jason Volk --- src/service/globals/mod.rs | 2 +- src/service/globals/resolver.rs | 22 ++++---- src/service/sending/mod.rs | 2 +- src/service/sending/resolve.rs | 93 ++++++++++++++++++++++++--------- src/service/sending/send.rs | 21 ++++---- 5 files changed, 93 insertions(+), 47 deletions(-) diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 12843b9d..3c0e66c0 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -2,7 +2,7 @@ mod client; mod data; pub(super) mod emerg_access; pub(super) mod migrations; -mod resolver; +pub(crate) mod resolver; pub(super) mod updates; use std::{ diff --git a/src/service/globals/resolver.rs b/src/service/globals/resolver.rs index e5199162..0cfaa1b9 100644 --- a/src/service/globals/resolver.rs +++ b/src/service/globals/resolver.rs @@ -12,21 +12,21 @@ use hickory_resolver::TokioAsyncResolver; use reqwest::dns::{Addrs, Name, Resolve, Resolving}; use ruma::OwnedServerName; -use crate::sending::FedDest; +use crate::sending::{CachedDest, CachedOverride}; -type WellKnownMap = HashMap; -type TlsNameMap = HashMap, u16)>; +type WellKnownMap = HashMap; +type TlsNameMap = HashMap; pub struct Resolver { pub destinations: Arc>, // actual_destination, host pub overrides: Arc>, - pub resolver: Arc, - pub hooked: Arc, + pub(crate) resolver: Arc, + pub(crate) hooked: Arc, } -pub struct Hooked { - pub overrides: Arc>, - pub resolver: Arc, +pub(crate) struct Hooked { + overrides: Arc>, + resolver: Arc, } impl Resolver { @@ -117,15 +117,15 @@ impl Resolve for Resolver { impl Resolve for Hooked { fn resolve(&self, name: Name) -> Resolving { - let addr_port = self + let cached = self .overrides .read() .expect("locked for reading") .get(name.as_str()) .cloned(); - if let Some((addr, port)) = addr_port { - cached_to_reqwest(&addr, port) + if let Some(cached) = cached { + cached_to_reqwest(&cached.ips, cached.port) } else { resolve_to_reqwest(self.resolver.clone(), name) } diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index d3f3cbdd..0dcebcec 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -9,7 +9,7 @@ use std::{fmt::Debug, sync::Arc}; use async_trait::async_trait; use conduit::{Error, Result}; use data::Data; -pub use resolve::{resolve_actual_dest, FedDest}; +pub use resolve::{resolve_actual_dest, CachedDest, CachedOverride, FedDest}; use ruma::{ api::{appservice::Registration, OutgoingRequest}, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, diff --git a/src/service/sending/resolve.rs b/src/service/sending/resolve.rs index 5f257217..01943cc0 100644 --- a/src/service/sending/resolve.rs +++ b/src/service/sending/resolve.rs @@ -6,7 +6,7 @@ use std::{ use hickory_resolver::{error::ResolveError, lookup::SrvLookup}; use ipaddress::IPAddress; -use ruma::ServerName; +use ruma::{OwnedServerName, ServerName}; use tracing::{debug, error, trace}; use crate::{debug_error, debug_info, debug_warn, services, Error, Result}; @@ -35,6 +35,7 @@ pub enum FedDest { Named(String, String), } +#[derive(Clone, Debug)] pub(crate) struct ActualDest { pub(crate) dest: FedDest, pub(crate) host: String, @@ -42,19 +43,31 @@ pub(crate) struct ActualDest { pub(crate) cached: bool, } +#[derive(Clone, Debug)] +pub struct CachedDest { + pub dest: FedDest, + pub host: String, +} + +#[derive(Clone, Debug)] +pub struct CachedOverride { + pub ips: Vec, + pub port: u16, +} + #[tracing::instrument(skip_all, name = "resolve")] pub(crate) async fn get_actual_dest(server_name: &ServerName) -> Result { let cached; let cached_result = services() .globals .resolver - .destinations - .read() - .expect("locked for reading") - .get(server_name) - .cloned(); + .get_cached_destination(server_name); - let (dest, host) = if let Some(result) = cached_result { + let CachedDest { + dest, + host, + .. + } = if let Some(result) = cached_result { cached = true; result } else { @@ -77,7 +90,7 @@ pub(crate) async fn get_actual_dest(server_name: &ServerName) -> Result Result<(FedDest, String)> { +pub async fn resolve_actual_dest(dest: &ServerName, cache: bool) -> Result { trace!("Finding actual destination for {dest}"); let mut host = dest.as_str().to_owned(); let actual_dest = match get_ip_with_port(dest.as_str()) { @@ -109,7 +122,10 @@ pub async fn resolve_actual_dest(dest: &ServerName, cache: bool) -> Result<(FedD }; debug!("Actual destination: {actual_dest:?} hostname: {host:?}"); - Ok((actual_dest, host.into_uri_string())) + Ok(CachedDest { + dest: actual_dest, + host: host.into_uri_string(), + }) } fn actual_dest_1(host_port: FedDest) -> Result { @@ -193,14 +209,7 @@ async fn actual_dest_5(dest: &ServerName, cache: bool) -> Result { #[tracing::instrument(skip_all, name = "well-known")] async fn request_well_known(dest: &str) -> Result> { trace!("Requesting well known for {dest}"); - if !services() - .globals - .resolver - .overrides - .read() - .unwrap() - .contains_key(dest) - { + if !services().globals.resolver.has_cached_override(dest) { query_and_cache_override(dest, dest, 8448).await?; } @@ -269,15 +278,16 @@ async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u1 Err(e) => handle_resolve_error(&e), Ok(override_ip) => { if hostname != overname { - debug_info!("{:?} overriden by {:?}", overname, hostname); + debug_info!("{overname:?} overriden by {hostname:?}"); } - services() - .globals - .resolver - .overrides - .write() - .unwrap() - .insert(overname.to_owned(), (override_ip.iter().collect(), port)); + + services().globals.resolver.set_cached_override( + overname.to_owned(), + CachedOverride { + ips: override_ip.iter().collect(), + port, + }, + ); Ok(()) }, @@ -392,6 +402,39 @@ fn add_port_to_hostname(dest_str: &str) -> FedDest { FedDest::Named(host.to_owned(), port.to_owned()) } +impl crate::globals::resolver::Resolver { + pub(crate) fn set_cached_destination(&self, name: OwnedServerName, dest: CachedDest) -> Option { + trace!(?name, ?dest, "set cached destination"); + self.destinations + .write() + .expect("locked for writing") + .insert(name, dest) + } + + pub(crate) fn get_cached_destination(&self, name: &ServerName) -> Option { + self.destinations + .read() + .expect("locked for reading") + .get(name) + .cloned() + } + + pub(crate) fn set_cached_override(&self, name: String, over: CachedOverride) -> Option { + trace!(?name, ?over, "set cached override"); + self.overrides + .write() + .expect("locked for writing") + .insert(name, over) + } + + pub(crate) fn has_cached_override(&self, name: &str) -> bool { + self.overrides + .read() + .expect("locked for reading") + .contains_key(name) + } +} + impl FedDest { fn into_https_string(self) -> String { match self { diff --git a/src/service/sending/send.rs b/src/service/sending/send.rs index 57ff9127..1b977a73 100644 --- a/src/service/sending/send.rs +++ b/src/service/sending/send.rs @@ -8,11 +8,14 @@ use ruma::{ client::error::Error as RumaError, EndpointError, IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken, }, - OwnedServerName, ServerName, + ServerName, }; use tracing::{debug, trace}; -use super::{resolve, resolve::ActualDest}; +use super::{ + resolve, + resolve::{ActualDest, CachedDest}, +}; use crate::{debug_error, debug_warn, services, Error, Result}; #[tracing::instrument(skip_all, name = "send")] @@ -103,13 +106,13 @@ where let response = T::IncomingResponse::try_from_http_response(http_response); if response.is_ok() && !actual.cached { - services() - .globals - .resolver - .destinations - .write() - .expect("locked for writing") - .insert(OwnedServerName::from(dest), (actual.dest.clone(), actual.host.clone())); + services().globals.resolver.set_cached_destination( + dest.to_owned(), + CachedDest { + dest: actual.dest.clone(), + host: actual.host.clone(), + }, + ); } match response {