refactor resolver tuples into structs

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-07-04 11:44:39 +00:00
parent eeda96d94a
commit eaf1cf38a5
5 changed files with 93 additions and 47 deletions

View file

@ -2,7 +2,7 @@ mod client;
mod data;
pub(super) mod emerg_access;
pub(super) mod migrations;
mod resolver;
pub(crate) mod resolver;
pub(super) mod updates;
use std::{

View file

@ -12,21 +12,21 @@ use hickory_resolver::TokioAsyncResolver;
use reqwest::dns::{Addrs, Name, Resolve, Resolving};
use ruma::OwnedServerName;
use crate::sending::FedDest;
use crate::sending::{CachedDest, CachedOverride};
type WellKnownMap = HashMap<OwnedServerName, (FedDest, String)>;
type TlsNameMap = HashMap<String, (Vec<IpAddr>, u16)>;
type WellKnownMap = HashMap<OwnedServerName, CachedDest>;
type TlsNameMap = HashMap<String, CachedOverride>;
pub struct Resolver {
pub destinations: Arc<RwLock<WellKnownMap>>, // actual_destination, host
pub overrides: Arc<RwLock<TlsNameMap>>,
pub resolver: Arc<TokioAsyncResolver>,
pub hooked: Arc<Hooked>,
pub(crate) resolver: Arc<TokioAsyncResolver>,
pub(crate) hooked: Arc<Hooked>,
}
pub struct Hooked {
pub overrides: Arc<RwLock<TlsNameMap>>,
pub resolver: Arc<TokioAsyncResolver>,
pub(crate) struct Hooked {
overrides: Arc<RwLock<TlsNameMap>>,
resolver: Arc<TokioAsyncResolver>,
}
impl Resolver {
@ -117,15 +117,15 @@ impl Resolve for Resolver {
impl Resolve for Hooked {
fn resolve(&self, name: Name) -> Resolving {
let addr_port = self
let cached = self
.overrides
.read()
.expect("locked for reading")
.get(name.as_str())
.cloned();
if let Some((addr, port)) = addr_port {
cached_to_reqwest(&addr, port)
if let Some(cached) = cached {
cached_to_reqwest(&cached.ips, cached.port)
} else {
resolve_to_reqwest(self.resolver.clone(), name)
}

View file

@ -9,7 +9,7 @@ use std::{fmt::Debug, sync::Arc};
use async_trait::async_trait;
use conduit::{Error, Result};
use data::Data;
pub use resolve::{resolve_actual_dest, FedDest};
pub use resolve::{resolve_actual_dest, CachedDest, CachedOverride, FedDest};
use ruma::{
api::{appservice::Registration, OutgoingRequest},
OwnedServerName, OwnedUserId, RoomId, ServerName, UserId,

View file

@ -6,7 +6,7 @@ use std::{
use hickory_resolver::{error::ResolveError, lookup::SrvLookup};
use ipaddress::IPAddress;
use ruma::ServerName;
use ruma::{OwnedServerName, ServerName};
use tracing::{debug, error, trace};
use crate::{debug_error, debug_info, debug_warn, services, Error, Result};
@ -35,6 +35,7 @@ pub enum FedDest {
Named(String, String),
}
#[derive(Clone, Debug)]
pub(crate) struct ActualDest {
pub(crate) dest: FedDest,
pub(crate) host: String,
@ -42,19 +43,31 @@ pub(crate) struct ActualDest {
pub(crate) cached: bool,
}
#[derive(Clone, Debug)]
pub struct CachedDest {
pub dest: FedDest,
pub host: String,
}
#[derive(Clone, Debug)]
pub struct CachedOverride {
pub ips: Vec<IpAddr>,
pub port: u16,
}
#[tracing::instrument(skip_all, name = "resolve")]
pub(crate) async fn get_actual_dest(server_name: &ServerName) -> Result<ActualDest> {
let cached;
let cached_result = services()
.globals
.resolver
.destinations
.read()
.expect("locked for reading")
.get(server_name)
.cloned();
.get_cached_destination(server_name);
let (dest, host) = if let Some(result) = cached_result {
let CachedDest {
dest,
host,
..
} = if let Some(result) = cached_result {
cached = true;
result
} else {
@ -77,7 +90,7 @@ pub(crate) async fn get_actual_dest(server_name: &ServerName) -> Result<ActualDe
/// Numbers in comments below refer to bullet points in linked section of
/// specification
#[tracing::instrument(skip_all, name = "actual")]
pub async fn resolve_actual_dest(dest: &ServerName, cache: bool) -> Result<(FedDest, String)> {
pub async fn resolve_actual_dest(dest: &ServerName, cache: bool) -> Result<CachedDest> {
trace!("Finding actual destination for {dest}");
let mut host = dest.as_str().to_owned();
let actual_dest = match get_ip_with_port(dest.as_str()) {
@ -109,7 +122,10 @@ pub async fn resolve_actual_dest(dest: &ServerName, cache: bool) -> Result<(FedD
};
debug!("Actual destination: {actual_dest:?} hostname: {host:?}");
Ok((actual_dest, host.into_uri_string()))
Ok(CachedDest {
dest: actual_dest,
host: host.into_uri_string(),
})
}
fn actual_dest_1(host_port: FedDest) -> Result<FedDest> {
@ -193,14 +209,7 @@ async fn actual_dest_5(dest: &ServerName, cache: bool) -> Result<FedDest> {
#[tracing::instrument(skip_all, name = "well-known")]
async fn request_well_known(dest: &str) -> Result<Option<String>> {
trace!("Requesting well known for {dest}");
if !services()
.globals
.resolver
.overrides
.read()
.unwrap()
.contains_key(dest)
{
if !services().globals.resolver.has_cached_override(dest) {
query_and_cache_override(dest, dest, 8448).await?;
}
@ -269,15 +278,16 @@ async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u1
Err(e) => handle_resolve_error(&e),
Ok(override_ip) => {
if hostname != overname {
debug_info!("{:?} overriden by {:?}", overname, hostname);
debug_info!("{overname:?} overriden by {hostname:?}");
}
services()
.globals
.resolver
.overrides
.write()
.unwrap()
.insert(overname.to_owned(), (override_ip.iter().collect(), port));
services().globals.resolver.set_cached_override(
overname.to_owned(),
CachedOverride {
ips: override_ip.iter().collect(),
port,
},
);
Ok(())
},
@ -392,6 +402,39 @@ fn add_port_to_hostname(dest_str: &str) -> FedDest {
FedDest::Named(host.to_owned(), port.to_owned())
}
impl crate::globals::resolver::Resolver {
pub(crate) fn set_cached_destination(&self, name: OwnedServerName, dest: CachedDest) -> Option<CachedDest> {
trace!(?name, ?dest, "set cached destination");
self.destinations
.write()
.expect("locked for writing")
.insert(name, dest)
}
pub(crate) fn get_cached_destination(&self, name: &ServerName) -> Option<CachedDest> {
self.destinations
.read()
.expect("locked for reading")
.get(name)
.cloned()
}
pub(crate) fn set_cached_override(&self, name: String, over: CachedOverride) -> Option<CachedOverride> {
trace!(?name, ?over, "set cached override");
self.overrides
.write()
.expect("locked for writing")
.insert(name, over)
}
pub(crate) fn has_cached_override(&self, name: &str) -> bool {
self.overrides
.read()
.expect("locked for reading")
.contains_key(name)
}
}
impl FedDest {
fn into_https_string(self) -> String {
match self {

View file

@ -8,11 +8,14 @@ use ruma::{
client::error::Error as RumaError, EndpointError, IncomingResponse, MatrixVersion, OutgoingRequest,
SendAccessToken,
},
OwnedServerName, ServerName,
ServerName,
};
use tracing::{debug, trace};
use super::{resolve, resolve::ActualDest};
use super::{
resolve,
resolve::{ActualDest, CachedDest},
};
use crate::{debug_error, debug_warn, services, Error, Result};
#[tracing::instrument(skip_all, name = "send")]
@ -103,13 +106,13 @@ where
let response = T::IncomingResponse::try_from_http_response(http_response);
if response.is_ok() && !actual.cached {
services()
.globals
.resolver
.destinations
.write()
.expect("locked for writing")
.insert(OwnedServerName::from(dest), (actual.dest.clone(), actual.host.clone()));
services().globals.resolver.set_cached_destination(
dest.to_owned(),
CachedDest {
dest: actual.dest.clone(),
host: actual.host.clone(),
},
);
}
match response {