refactor resolver tuples into structs
Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
parent
eeda96d94a
commit
eaf1cf38a5
5 changed files with 93 additions and 47 deletions
|
@ -2,7 +2,7 @@ mod client;
|
|||
mod data;
|
||||
pub(super) mod emerg_access;
|
||||
pub(super) mod migrations;
|
||||
mod resolver;
|
||||
pub(crate) mod resolver;
|
||||
pub(super) mod updates;
|
||||
|
||||
use std::{
|
||||
|
|
|
@ -12,21 +12,21 @@ use hickory_resolver::TokioAsyncResolver;
|
|||
use reqwest::dns::{Addrs, Name, Resolve, Resolving};
|
||||
use ruma::OwnedServerName;
|
||||
|
||||
use crate::sending::FedDest;
|
||||
use crate::sending::{CachedDest, CachedOverride};
|
||||
|
||||
type WellKnownMap = HashMap<OwnedServerName, (FedDest, String)>;
|
||||
type TlsNameMap = HashMap<String, (Vec<IpAddr>, u16)>;
|
||||
type WellKnownMap = HashMap<OwnedServerName, CachedDest>;
|
||||
type TlsNameMap = HashMap<String, CachedOverride>;
|
||||
|
||||
pub struct Resolver {
|
||||
pub destinations: Arc<RwLock<WellKnownMap>>, // actual_destination, host
|
||||
pub overrides: Arc<RwLock<TlsNameMap>>,
|
||||
pub resolver: Arc<TokioAsyncResolver>,
|
||||
pub hooked: Arc<Hooked>,
|
||||
pub(crate) resolver: Arc<TokioAsyncResolver>,
|
||||
pub(crate) hooked: Arc<Hooked>,
|
||||
}
|
||||
|
||||
pub struct Hooked {
|
||||
pub overrides: Arc<RwLock<TlsNameMap>>,
|
||||
pub resolver: Arc<TokioAsyncResolver>,
|
||||
pub(crate) struct Hooked {
|
||||
overrides: Arc<RwLock<TlsNameMap>>,
|
||||
resolver: Arc<TokioAsyncResolver>,
|
||||
}
|
||||
|
||||
impl Resolver {
|
||||
|
@ -117,15 +117,15 @@ impl Resolve for Resolver {
|
|||
|
||||
impl Resolve for Hooked {
|
||||
fn resolve(&self, name: Name) -> Resolving {
|
||||
let addr_port = self
|
||||
let cached = self
|
||||
.overrides
|
||||
.read()
|
||||
.expect("locked for reading")
|
||||
.get(name.as_str())
|
||||
.cloned();
|
||||
|
||||
if let Some((addr, port)) = addr_port {
|
||||
cached_to_reqwest(&addr, port)
|
||||
if let Some(cached) = cached {
|
||||
cached_to_reqwest(&cached.ips, cached.port)
|
||||
} else {
|
||||
resolve_to_reqwest(self.resolver.clone(), name)
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ use std::{fmt::Debug, sync::Arc};
|
|||
use async_trait::async_trait;
|
||||
use conduit::{Error, Result};
|
||||
use data::Data;
|
||||
pub use resolve::{resolve_actual_dest, FedDest};
|
||||
pub use resolve::{resolve_actual_dest, CachedDest, CachedOverride, FedDest};
|
||||
use ruma::{
|
||||
api::{appservice::Registration, OutgoingRequest},
|
||||
OwnedServerName, OwnedUserId, RoomId, ServerName, UserId,
|
||||
|
|
|
@ -6,7 +6,7 @@ use std::{
|
|||
|
||||
use hickory_resolver::{error::ResolveError, lookup::SrvLookup};
|
||||
use ipaddress::IPAddress;
|
||||
use ruma::ServerName;
|
||||
use ruma::{OwnedServerName, ServerName};
|
||||
use tracing::{debug, error, trace};
|
||||
|
||||
use crate::{debug_error, debug_info, debug_warn, services, Error, Result};
|
||||
|
@ -35,6 +35,7 @@ pub enum FedDest {
|
|||
Named(String, String),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct ActualDest {
|
||||
pub(crate) dest: FedDest,
|
||||
pub(crate) host: String,
|
||||
|
@ -42,19 +43,31 @@ pub(crate) struct ActualDest {
|
|||
pub(crate) cached: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CachedDest {
|
||||
pub dest: FedDest,
|
||||
pub host: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CachedOverride {
|
||||
pub ips: Vec<IpAddr>,
|
||||
pub port: u16,
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, name = "resolve")]
|
||||
pub(crate) async fn get_actual_dest(server_name: &ServerName) -> Result<ActualDest> {
|
||||
let cached;
|
||||
let cached_result = services()
|
||||
.globals
|
||||
.resolver
|
||||
.destinations
|
||||
.read()
|
||||
.expect("locked for reading")
|
||||
.get(server_name)
|
||||
.cloned();
|
||||
.get_cached_destination(server_name);
|
||||
|
||||
let (dest, host) = if let Some(result) = cached_result {
|
||||
let CachedDest {
|
||||
dest,
|
||||
host,
|
||||
..
|
||||
} = if let Some(result) = cached_result {
|
||||
cached = true;
|
||||
result
|
||||
} else {
|
||||
|
@ -77,7 +90,7 @@ pub(crate) async fn get_actual_dest(server_name: &ServerName) -> Result<ActualDe
|
|||
/// Numbers in comments below refer to bullet points in linked section of
|
||||
/// specification
|
||||
#[tracing::instrument(skip_all, name = "actual")]
|
||||
pub async fn resolve_actual_dest(dest: &ServerName, cache: bool) -> Result<(FedDest, String)> {
|
||||
pub async fn resolve_actual_dest(dest: &ServerName, cache: bool) -> Result<CachedDest> {
|
||||
trace!("Finding actual destination for {dest}");
|
||||
let mut host = dest.as_str().to_owned();
|
||||
let actual_dest = match get_ip_with_port(dest.as_str()) {
|
||||
|
@ -109,7 +122,10 @@ pub async fn resolve_actual_dest(dest: &ServerName, cache: bool) -> Result<(FedD
|
|||
};
|
||||
|
||||
debug!("Actual destination: {actual_dest:?} hostname: {host:?}");
|
||||
Ok((actual_dest, host.into_uri_string()))
|
||||
Ok(CachedDest {
|
||||
dest: actual_dest,
|
||||
host: host.into_uri_string(),
|
||||
})
|
||||
}
|
||||
|
||||
fn actual_dest_1(host_port: FedDest) -> Result<FedDest> {
|
||||
|
@ -193,14 +209,7 @@ async fn actual_dest_5(dest: &ServerName, cache: bool) -> Result<FedDest> {
|
|||
#[tracing::instrument(skip_all, name = "well-known")]
|
||||
async fn request_well_known(dest: &str) -> Result<Option<String>> {
|
||||
trace!("Requesting well known for {dest}");
|
||||
if !services()
|
||||
.globals
|
||||
.resolver
|
||||
.overrides
|
||||
.read()
|
||||
.unwrap()
|
||||
.contains_key(dest)
|
||||
{
|
||||
if !services().globals.resolver.has_cached_override(dest) {
|
||||
query_and_cache_override(dest, dest, 8448).await?;
|
||||
}
|
||||
|
||||
|
@ -269,15 +278,16 @@ async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u1
|
|||
Err(e) => handle_resolve_error(&e),
|
||||
Ok(override_ip) => {
|
||||
if hostname != overname {
|
||||
debug_info!("{:?} overriden by {:?}", overname, hostname);
|
||||
debug_info!("{overname:?} overriden by {hostname:?}");
|
||||
}
|
||||
services()
|
||||
.globals
|
||||
.resolver
|
||||
.overrides
|
||||
.write()
|
||||
.unwrap()
|
||||
.insert(overname.to_owned(), (override_ip.iter().collect(), port));
|
||||
|
||||
services().globals.resolver.set_cached_override(
|
||||
overname.to_owned(),
|
||||
CachedOverride {
|
||||
ips: override_ip.iter().collect(),
|
||||
port,
|
||||
},
|
||||
);
|
||||
|
||||
Ok(())
|
||||
},
|
||||
|
@ -392,6 +402,39 @@ fn add_port_to_hostname(dest_str: &str) -> FedDest {
|
|||
FedDest::Named(host.to_owned(), port.to_owned())
|
||||
}
|
||||
|
||||
impl crate::globals::resolver::Resolver {
|
||||
pub(crate) fn set_cached_destination(&self, name: OwnedServerName, dest: CachedDest) -> Option<CachedDest> {
|
||||
trace!(?name, ?dest, "set cached destination");
|
||||
self.destinations
|
||||
.write()
|
||||
.expect("locked for writing")
|
||||
.insert(name, dest)
|
||||
}
|
||||
|
||||
pub(crate) fn get_cached_destination(&self, name: &ServerName) -> Option<CachedDest> {
|
||||
self.destinations
|
||||
.read()
|
||||
.expect("locked for reading")
|
||||
.get(name)
|
||||
.cloned()
|
||||
}
|
||||
|
||||
pub(crate) fn set_cached_override(&self, name: String, over: CachedOverride) -> Option<CachedOverride> {
|
||||
trace!(?name, ?over, "set cached override");
|
||||
self.overrides
|
||||
.write()
|
||||
.expect("locked for writing")
|
||||
.insert(name, over)
|
||||
}
|
||||
|
||||
pub(crate) fn has_cached_override(&self, name: &str) -> bool {
|
||||
self.overrides
|
||||
.read()
|
||||
.expect("locked for reading")
|
||||
.contains_key(name)
|
||||
}
|
||||
}
|
||||
|
||||
impl FedDest {
|
||||
fn into_https_string(self) -> String {
|
||||
match self {
|
||||
|
|
|
@ -8,11 +8,14 @@ use ruma::{
|
|||
client::error::Error as RumaError, EndpointError, IncomingResponse, MatrixVersion, OutgoingRequest,
|
||||
SendAccessToken,
|
||||
},
|
||||
OwnedServerName, ServerName,
|
||||
ServerName,
|
||||
};
|
||||
use tracing::{debug, trace};
|
||||
|
||||
use super::{resolve, resolve::ActualDest};
|
||||
use super::{
|
||||
resolve,
|
||||
resolve::{ActualDest, CachedDest},
|
||||
};
|
||||
use crate::{debug_error, debug_warn, services, Error, Result};
|
||||
|
||||
#[tracing::instrument(skip_all, name = "send")]
|
||||
|
@ -103,13 +106,13 @@ where
|
|||
|
||||
let response = T::IncomingResponse::try_from_http_response(http_response);
|
||||
if response.is_ok() && !actual.cached {
|
||||
services()
|
||||
.globals
|
||||
.resolver
|
||||
.destinations
|
||||
.write()
|
||||
.expect("locked for writing")
|
||||
.insert(OwnedServerName::from(dest), (actual.dest.clone(), actual.host.clone()));
|
||||
services().globals.resolver.set_cached_destination(
|
||||
dest.to_owned(),
|
||||
CachedDest {
|
||||
dest: actual.dest.clone(),
|
||||
host: actual.host.clone(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
match response {
|
||||
|
|
Loading…
Add table
Reference in a new issue