refactor resolver tuples into structs

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-07-04 11:44:39 +00:00
parent eeda96d94a
commit eaf1cf38a5
5 changed files with 93 additions and 47 deletions

View file

@ -2,7 +2,7 @@ mod client;
mod data; mod data;
pub(super) mod emerg_access; pub(super) mod emerg_access;
pub(super) mod migrations; pub(super) mod migrations;
mod resolver; pub(crate) mod resolver;
pub(super) mod updates; pub(super) mod updates;
use std::{ use std::{

View file

@ -12,21 +12,21 @@ use hickory_resolver::TokioAsyncResolver;
use reqwest::dns::{Addrs, Name, Resolve, Resolving}; use reqwest::dns::{Addrs, Name, Resolve, Resolving};
use ruma::OwnedServerName; use ruma::OwnedServerName;
use crate::sending::FedDest; use crate::sending::{CachedDest, CachedOverride};
type WellKnownMap = HashMap<OwnedServerName, (FedDest, String)>; type WellKnownMap = HashMap<OwnedServerName, CachedDest>;
type TlsNameMap = HashMap<String, (Vec<IpAddr>, u16)>; type TlsNameMap = HashMap<String, CachedOverride>;
pub struct Resolver { pub struct Resolver {
pub destinations: Arc<RwLock<WellKnownMap>>, // actual_destination, host pub destinations: Arc<RwLock<WellKnownMap>>, // actual_destination, host
pub overrides: Arc<RwLock<TlsNameMap>>, pub overrides: Arc<RwLock<TlsNameMap>>,
pub resolver: Arc<TokioAsyncResolver>, pub(crate) resolver: Arc<TokioAsyncResolver>,
pub hooked: Arc<Hooked>, pub(crate) hooked: Arc<Hooked>,
} }
pub struct Hooked { pub(crate) struct Hooked {
pub overrides: Arc<RwLock<TlsNameMap>>, overrides: Arc<RwLock<TlsNameMap>>,
pub resolver: Arc<TokioAsyncResolver>, resolver: Arc<TokioAsyncResolver>,
} }
impl Resolver { impl Resolver {
@ -117,15 +117,15 @@ impl Resolve for Resolver {
impl Resolve for Hooked { impl Resolve for Hooked {
fn resolve(&self, name: Name) -> Resolving { fn resolve(&self, name: Name) -> Resolving {
let addr_port = self let cached = self
.overrides .overrides
.read() .read()
.expect("locked for reading") .expect("locked for reading")
.get(name.as_str()) .get(name.as_str())
.cloned(); .cloned();
if let Some((addr, port)) = addr_port { if let Some(cached) = cached {
cached_to_reqwest(&addr, port) cached_to_reqwest(&cached.ips, cached.port)
} else { } else {
resolve_to_reqwest(self.resolver.clone(), name) resolve_to_reqwest(self.resolver.clone(), name)
} }

View file

@ -9,7 +9,7 @@ use std::{fmt::Debug, sync::Arc};
use async_trait::async_trait; use async_trait::async_trait;
use conduit::{Error, Result}; use conduit::{Error, Result};
use data::Data; use data::Data;
pub use resolve::{resolve_actual_dest, FedDest}; pub use resolve::{resolve_actual_dest, CachedDest, CachedOverride, FedDest};
use ruma::{ use ruma::{
api::{appservice::Registration, OutgoingRequest}, api::{appservice::Registration, OutgoingRequest},
OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId,

View file

@ -6,7 +6,7 @@ use std::{
use hickory_resolver::{error::ResolveError, lookup::SrvLookup}; use hickory_resolver::{error::ResolveError, lookup::SrvLookup};
use ipaddress::IPAddress; use ipaddress::IPAddress;
use ruma::ServerName; use ruma::{OwnedServerName, ServerName};
use tracing::{debug, error, trace}; use tracing::{debug, error, trace};
use crate::{debug_error, debug_info, debug_warn, services, Error, Result}; use crate::{debug_error, debug_info, debug_warn, services, Error, Result};
@ -35,6 +35,7 @@ pub enum FedDest {
Named(String, String), Named(String, String),
} }
#[derive(Clone, Debug)]
pub(crate) struct ActualDest { pub(crate) struct ActualDest {
pub(crate) dest: FedDest, pub(crate) dest: FedDest,
pub(crate) host: String, pub(crate) host: String,
@ -42,19 +43,31 @@ pub(crate) struct ActualDest {
pub(crate) cached: bool, pub(crate) cached: bool,
} }
#[derive(Clone, Debug)]
pub struct CachedDest {
pub dest: FedDest,
pub host: String,
}
#[derive(Clone, Debug)]
pub struct CachedOverride {
pub ips: Vec<IpAddr>,
pub port: u16,
}
#[tracing::instrument(skip_all, name = "resolve")] #[tracing::instrument(skip_all, name = "resolve")]
pub(crate) async fn get_actual_dest(server_name: &ServerName) -> Result<ActualDest> { pub(crate) async fn get_actual_dest(server_name: &ServerName) -> Result<ActualDest> {
let cached; let cached;
let cached_result = services() let cached_result = services()
.globals .globals
.resolver .resolver
.destinations .get_cached_destination(server_name);
.read()
.expect("locked for reading")
.get(server_name)
.cloned();
let (dest, host) = if let Some(result) = cached_result { let CachedDest {
dest,
host,
..
} = if let Some(result) = cached_result {
cached = true; cached = true;
result result
} else { } else {
@ -77,7 +90,7 @@ pub(crate) async fn get_actual_dest(server_name: &ServerName) -> Result<ActualDe
/// Numbers in comments below refer to bullet points in linked section of /// Numbers in comments below refer to bullet points in linked section of
/// specification /// specification
#[tracing::instrument(skip_all, name = "actual")] #[tracing::instrument(skip_all, name = "actual")]
pub async fn resolve_actual_dest(dest: &ServerName, cache: bool) -> Result<(FedDest, String)> { pub async fn resolve_actual_dest(dest: &ServerName, cache: bool) -> Result<CachedDest> {
trace!("Finding actual destination for {dest}"); trace!("Finding actual destination for {dest}");
let mut host = dest.as_str().to_owned(); let mut host = dest.as_str().to_owned();
let actual_dest = match get_ip_with_port(dest.as_str()) { let actual_dest = match get_ip_with_port(dest.as_str()) {
@ -109,7 +122,10 @@ pub async fn resolve_actual_dest(dest: &ServerName, cache: bool) -> Result<(FedD
}; };
debug!("Actual destination: {actual_dest:?} hostname: {host:?}"); debug!("Actual destination: {actual_dest:?} hostname: {host:?}");
Ok((actual_dest, host.into_uri_string())) Ok(CachedDest {
dest: actual_dest,
host: host.into_uri_string(),
})
} }
fn actual_dest_1(host_port: FedDest) -> Result<FedDest> { fn actual_dest_1(host_port: FedDest) -> Result<FedDest> {
@ -193,14 +209,7 @@ async fn actual_dest_5(dest: &ServerName, cache: bool) -> Result<FedDest> {
#[tracing::instrument(skip_all, name = "well-known")] #[tracing::instrument(skip_all, name = "well-known")]
async fn request_well_known(dest: &str) -> Result<Option<String>> { async fn request_well_known(dest: &str) -> Result<Option<String>> {
trace!("Requesting well known for {dest}"); trace!("Requesting well known for {dest}");
if !services() if !services().globals.resolver.has_cached_override(dest) {
.globals
.resolver
.overrides
.read()
.unwrap()
.contains_key(dest)
{
query_and_cache_override(dest, dest, 8448).await?; query_and_cache_override(dest, dest, 8448).await?;
} }
@ -269,15 +278,16 @@ async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u1
Err(e) => handle_resolve_error(&e), Err(e) => handle_resolve_error(&e),
Ok(override_ip) => { Ok(override_ip) => {
if hostname != overname { if hostname != overname {
debug_info!("{:?} overriden by {:?}", overname, hostname); debug_info!("{overname:?} overriden by {hostname:?}");
} }
services()
.globals services().globals.resolver.set_cached_override(
.resolver overname.to_owned(),
.overrides CachedOverride {
.write() ips: override_ip.iter().collect(),
.unwrap() port,
.insert(overname.to_owned(), (override_ip.iter().collect(), port)); },
);
Ok(()) Ok(())
}, },
@ -392,6 +402,39 @@ fn add_port_to_hostname(dest_str: &str) -> FedDest {
FedDest::Named(host.to_owned(), port.to_owned()) FedDest::Named(host.to_owned(), port.to_owned())
} }
impl crate::globals::resolver::Resolver {
pub(crate) fn set_cached_destination(&self, name: OwnedServerName, dest: CachedDest) -> Option<CachedDest> {
trace!(?name, ?dest, "set cached destination");
self.destinations
.write()
.expect("locked for writing")
.insert(name, dest)
}
pub(crate) fn get_cached_destination(&self, name: &ServerName) -> Option<CachedDest> {
self.destinations
.read()
.expect("locked for reading")
.get(name)
.cloned()
}
pub(crate) fn set_cached_override(&self, name: String, over: CachedOverride) -> Option<CachedOverride> {
trace!(?name, ?over, "set cached override");
self.overrides
.write()
.expect("locked for writing")
.insert(name, over)
}
pub(crate) fn has_cached_override(&self, name: &str) -> bool {
self.overrides
.read()
.expect("locked for reading")
.contains_key(name)
}
}
impl FedDest { impl FedDest {
fn into_https_string(self) -> String { fn into_https_string(self) -> String {
match self { match self {

View file

@ -8,11 +8,14 @@ use ruma::{
client::error::Error as RumaError, EndpointError, IncomingResponse, MatrixVersion, OutgoingRequest, client::error::Error as RumaError, EndpointError, IncomingResponse, MatrixVersion, OutgoingRequest,
SendAccessToken, SendAccessToken,
}, },
OwnedServerName, ServerName, ServerName,
}; };
use tracing::{debug, trace}; use tracing::{debug, trace};
use super::{resolve, resolve::ActualDest}; use super::{
resolve,
resolve::{ActualDest, CachedDest},
};
use crate::{debug_error, debug_warn, services, Error, Result}; use crate::{debug_error, debug_warn, services, Error, Result};
#[tracing::instrument(skip_all, name = "send")] #[tracing::instrument(skip_all, name = "send")]
@ -103,13 +106,13 @@ where
let response = T::IncomingResponse::try_from_http_response(http_response); let response = T::IncomingResponse::try_from_http_response(http_response);
if response.is_ok() && !actual.cached { if response.is_ok() && !actual.cached {
services() services().globals.resolver.set_cached_destination(
.globals dest.to_owned(),
.resolver CachedDest {
.destinations dest: actual.dest.clone(),
.write() host: actual.host.clone(),
.expect("locked for writing") },
.insert(OwnedServerName::from(dest), (actual.dest.clone(), actual.host.clone())); );
} }
match response { match response {