refactor resolver tuples into structs
Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
parent
eeda96d94a
commit
eaf1cf38a5
5 changed files with 93 additions and 47 deletions
|
@ -2,7 +2,7 @@ mod client;
|
||||||
mod data;
|
mod data;
|
||||||
pub(super) mod emerg_access;
|
pub(super) mod emerg_access;
|
||||||
pub(super) mod migrations;
|
pub(super) mod migrations;
|
||||||
mod resolver;
|
pub(crate) mod resolver;
|
||||||
pub(super) mod updates;
|
pub(super) mod updates;
|
||||||
|
|
||||||
use std::{
|
use std::{
|
||||||
|
|
|
@ -12,21 +12,21 @@ use hickory_resolver::TokioAsyncResolver;
|
||||||
use reqwest::dns::{Addrs, Name, Resolve, Resolving};
|
use reqwest::dns::{Addrs, Name, Resolve, Resolving};
|
||||||
use ruma::OwnedServerName;
|
use ruma::OwnedServerName;
|
||||||
|
|
||||||
use crate::sending::FedDest;
|
use crate::sending::{CachedDest, CachedOverride};
|
||||||
|
|
||||||
type WellKnownMap = HashMap<OwnedServerName, (FedDest, String)>;
|
type WellKnownMap = HashMap<OwnedServerName, CachedDest>;
|
||||||
type TlsNameMap = HashMap<String, (Vec<IpAddr>, u16)>;
|
type TlsNameMap = HashMap<String, CachedOverride>;
|
||||||
|
|
||||||
pub struct Resolver {
|
pub struct Resolver {
|
||||||
pub destinations: Arc<RwLock<WellKnownMap>>, // actual_destination, host
|
pub destinations: Arc<RwLock<WellKnownMap>>, // actual_destination, host
|
||||||
pub overrides: Arc<RwLock<TlsNameMap>>,
|
pub overrides: Arc<RwLock<TlsNameMap>>,
|
||||||
pub resolver: Arc<TokioAsyncResolver>,
|
pub(crate) resolver: Arc<TokioAsyncResolver>,
|
||||||
pub hooked: Arc<Hooked>,
|
pub(crate) hooked: Arc<Hooked>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct Hooked {
|
pub(crate) struct Hooked {
|
||||||
pub overrides: Arc<RwLock<TlsNameMap>>,
|
overrides: Arc<RwLock<TlsNameMap>>,
|
||||||
pub resolver: Arc<TokioAsyncResolver>,
|
resolver: Arc<TokioAsyncResolver>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Resolver {
|
impl Resolver {
|
||||||
|
@ -117,15 +117,15 @@ impl Resolve for Resolver {
|
||||||
|
|
||||||
impl Resolve for Hooked {
|
impl Resolve for Hooked {
|
||||||
fn resolve(&self, name: Name) -> Resolving {
|
fn resolve(&self, name: Name) -> Resolving {
|
||||||
let addr_port = self
|
let cached = self
|
||||||
.overrides
|
.overrides
|
||||||
.read()
|
.read()
|
||||||
.expect("locked for reading")
|
.expect("locked for reading")
|
||||||
.get(name.as_str())
|
.get(name.as_str())
|
||||||
.cloned();
|
.cloned();
|
||||||
|
|
||||||
if let Some((addr, port)) = addr_port {
|
if let Some(cached) = cached {
|
||||||
cached_to_reqwest(&addr, port)
|
cached_to_reqwest(&cached.ips, cached.port)
|
||||||
} else {
|
} else {
|
||||||
resolve_to_reqwest(self.resolver.clone(), name)
|
resolve_to_reqwest(self.resolver.clone(), name)
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,7 +9,7 @@ use std::{fmt::Debug, sync::Arc};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use conduit::{Error, Result};
|
use conduit::{Error, Result};
|
||||||
use data::Data;
|
use data::Data;
|
||||||
pub use resolve::{resolve_actual_dest, FedDest};
|
pub use resolve::{resolve_actual_dest, CachedDest, CachedOverride, FedDest};
|
||||||
use ruma::{
|
use ruma::{
|
||||||
api::{appservice::Registration, OutgoingRequest},
|
api::{appservice::Registration, OutgoingRequest},
|
||||||
OwnedServerName, OwnedUserId, RoomId, ServerName, UserId,
|
OwnedServerName, OwnedUserId, RoomId, ServerName, UserId,
|
||||||
|
|
|
@ -6,7 +6,7 @@ use std::{
|
||||||
|
|
||||||
use hickory_resolver::{error::ResolveError, lookup::SrvLookup};
|
use hickory_resolver::{error::ResolveError, lookup::SrvLookup};
|
||||||
use ipaddress::IPAddress;
|
use ipaddress::IPAddress;
|
||||||
use ruma::ServerName;
|
use ruma::{OwnedServerName, ServerName};
|
||||||
use tracing::{debug, error, trace};
|
use tracing::{debug, error, trace};
|
||||||
|
|
||||||
use crate::{debug_error, debug_info, debug_warn, services, Error, Result};
|
use crate::{debug_error, debug_info, debug_warn, services, Error, Result};
|
||||||
|
@ -35,6 +35,7 @@ pub enum FedDest {
|
||||||
Named(String, String),
|
Named(String, String),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
pub(crate) struct ActualDest {
|
pub(crate) struct ActualDest {
|
||||||
pub(crate) dest: FedDest,
|
pub(crate) dest: FedDest,
|
||||||
pub(crate) host: String,
|
pub(crate) host: String,
|
||||||
|
@ -42,19 +43,31 @@ pub(crate) struct ActualDest {
|
||||||
pub(crate) cached: bool,
|
pub(crate) cached: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct CachedDest {
|
||||||
|
pub dest: FedDest,
|
||||||
|
pub host: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct CachedOverride {
|
||||||
|
pub ips: Vec<IpAddr>,
|
||||||
|
pub port: u16,
|
||||||
|
}
|
||||||
|
|
||||||
#[tracing::instrument(skip_all, name = "resolve")]
|
#[tracing::instrument(skip_all, name = "resolve")]
|
||||||
pub(crate) async fn get_actual_dest(server_name: &ServerName) -> Result<ActualDest> {
|
pub(crate) async fn get_actual_dest(server_name: &ServerName) -> Result<ActualDest> {
|
||||||
let cached;
|
let cached;
|
||||||
let cached_result = services()
|
let cached_result = services()
|
||||||
.globals
|
.globals
|
||||||
.resolver
|
.resolver
|
||||||
.destinations
|
.get_cached_destination(server_name);
|
||||||
.read()
|
|
||||||
.expect("locked for reading")
|
|
||||||
.get(server_name)
|
|
||||||
.cloned();
|
|
||||||
|
|
||||||
let (dest, host) = if let Some(result) = cached_result {
|
let CachedDest {
|
||||||
|
dest,
|
||||||
|
host,
|
||||||
|
..
|
||||||
|
} = if let Some(result) = cached_result {
|
||||||
cached = true;
|
cached = true;
|
||||||
result
|
result
|
||||||
} else {
|
} else {
|
||||||
|
@ -77,7 +90,7 @@ pub(crate) async fn get_actual_dest(server_name: &ServerName) -> Result<ActualDe
|
||||||
/// Numbers in comments below refer to bullet points in linked section of
|
/// Numbers in comments below refer to bullet points in linked section of
|
||||||
/// specification
|
/// specification
|
||||||
#[tracing::instrument(skip_all, name = "actual")]
|
#[tracing::instrument(skip_all, name = "actual")]
|
||||||
pub async fn resolve_actual_dest(dest: &ServerName, cache: bool) -> Result<(FedDest, String)> {
|
pub async fn resolve_actual_dest(dest: &ServerName, cache: bool) -> Result<CachedDest> {
|
||||||
trace!("Finding actual destination for {dest}");
|
trace!("Finding actual destination for {dest}");
|
||||||
let mut host = dest.as_str().to_owned();
|
let mut host = dest.as_str().to_owned();
|
||||||
let actual_dest = match get_ip_with_port(dest.as_str()) {
|
let actual_dest = match get_ip_with_port(dest.as_str()) {
|
||||||
|
@ -109,7 +122,10 @@ pub async fn resolve_actual_dest(dest: &ServerName, cache: bool) -> Result<(FedD
|
||||||
};
|
};
|
||||||
|
|
||||||
debug!("Actual destination: {actual_dest:?} hostname: {host:?}");
|
debug!("Actual destination: {actual_dest:?} hostname: {host:?}");
|
||||||
Ok((actual_dest, host.into_uri_string()))
|
Ok(CachedDest {
|
||||||
|
dest: actual_dest,
|
||||||
|
host: host.into_uri_string(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn actual_dest_1(host_port: FedDest) -> Result<FedDest> {
|
fn actual_dest_1(host_port: FedDest) -> Result<FedDest> {
|
||||||
|
@ -193,14 +209,7 @@ async fn actual_dest_5(dest: &ServerName, cache: bool) -> Result<FedDest> {
|
||||||
#[tracing::instrument(skip_all, name = "well-known")]
|
#[tracing::instrument(skip_all, name = "well-known")]
|
||||||
async fn request_well_known(dest: &str) -> Result<Option<String>> {
|
async fn request_well_known(dest: &str) -> Result<Option<String>> {
|
||||||
trace!("Requesting well known for {dest}");
|
trace!("Requesting well known for {dest}");
|
||||||
if !services()
|
if !services().globals.resolver.has_cached_override(dest) {
|
||||||
.globals
|
|
||||||
.resolver
|
|
||||||
.overrides
|
|
||||||
.read()
|
|
||||||
.unwrap()
|
|
||||||
.contains_key(dest)
|
|
||||||
{
|
|
||||||
query_and_cache_override(dest, dest, 8448).await?;
|
query_and_cache_override(dest, dest, 8448).await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -269,15 +278,16 @@ async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u1
|
||||||
Err(e) => handle_resolve_error(&e),
|
Err(e) => handle_resolve_error(&e),
|
||||||
Ok(override_ip) => {
|
Ok(override_ip) => {
|
||||||
if hostname != overname {
|
if hostname != overname {
|
||||||
debug_info!("{:?} overriden by {:?}", overname, hostname);
|
debug_info!("{overname:?} overriden by {hostname:?}");
|
||||||
}
|
}
|
||||||
services()
|
|
||||||
.globals
|
services().globals.resolver.set_cached_override(
|
||||||
.resolver
|
overname.to_owned(),
|
||||||
.overrides
|
CachedOverride {
|
||||||
.write()
|
ips: override_ip.iter().collect(),
|
||||||
.unwrap()
|
port,
|
||||||
.insert(overname.to_owned(), (override_ip.iter().collect(), port));
|
},
|
||||||
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
},
|
},
|
||||||
|
@ -392,6 +402,39 @@ fn add_port_to_hostname(dest_str: &str) -> FedDest {
|
||||||
FedDest::Named(host.to_owned(), port.to_owned())
|
FedDest::Named(host.to_owned(), port.to_owned())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl crate::globals::resolver::Resolver {
|
||||||
|
pub(crate) fn set_cached_destination(&self, name: OwnedServerName, dest: CachedDest) -> Option<CachedDest> {
|
||||||
|
trace!(?name, ?dest, "set cached destination");
|
||||||
|
self.destinations
|
||||||
|
.write()
|
||||||
|
.expect("locked for writing")
|
||||||
|
.insert(name, dest)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn get_cached_destination(&self, name: &ServerName) -> Option<CachedDest> {
|
||||||
|
self.destinations
|
||||||
|
.read()
|
||||||
|
.expect("locked for reading")
|
||||||
|
.get(name)
|
||||||
|
.cloned()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn set_cached_override(&self, name: String, over: CachedOverride) -> Option<CachedOverride> {
|
||||||
|
trace!(?name, ?over, "set cached override");
|
||||||
|
self.overrides
|
||||||
|
.write()
|
||||||
|
.expect("locked for writing")
|
||||||
|
.insert(name, over)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn has_cached_override(&self, name: &str) -> bool {
|
||||||
|
self.overrides
|
||||||
|
.read()
|
||||||
|
.expect("locked for reading")
|
||||||
|
.contains_key(name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl FedDest {
|
impl FedDest {
|
||||||
fn into_https_string(self) -> String {
|
fn into_https_string(self) -> String {
|
||||||
match self {
|
match self {
|
||||||
|
|
|
@ -8,11 +8,14 @@ use ruma::{
|
||||||
client::error::Error as RumaError, EndpointError, IncomingResponse, MatrixVersion, OutgoingRequest,
|
client::error::Error as RumaError, EndpointError, IncomingResponse, MatrixVersion, OutgoingRequest,
|
||||||
SendAccessToken,
|
SendAccessToken,
|
||||||
},
|
},
|
||||||
OwnedServerName, ServerName,
|
ServerName,
|
||||||
};
|
};
|
||||||
use tracing::{debug, trace};
|
use tracing::{debug, trace};
|
||||||
|
|
||||||
use super::{resolve, resolve::ActualDest};
|
use super::{
|
||||||
|
resolve,
|
||||||
|
resolve::{ActualDest, CachedDest},
|
||||||
|
};
|
||||||
use crate::{debug_error, debug_warn, services, Error, Result};
|
use crate::{debug_error, debug_warn, services, Error, Result};
|
||||||
|
|
||||||
#[tracing::instrument(skip_all, name = "send")]
|
#[tracing::instrument(skip_all, name = "send")]
|
||||||
|
@ -103,13 +106,13 @@ where
|
||||||
|
|
||||||
let response = T::IncomingResponse::try_from_http_response(http_response);
|
let response = T::IncomingResponse::try_from_http_response(http_response);
|
||||||
if response.is_ok() && !actual.cached {
|
if response.is_ok() && !actual.cached {
|
||||||
services()
|
services().globals.resolver.set_cached_destination(
|
||||||
.globals
|
dest.to_owned(),
|
||||||
.resolver
|
CachedDest {
|
||||||
.destinations
|
dest: actual.dest.clone(),
|
||||||
.write()
|
host: actual.host.clone(),
|
||||||
.expect("locked for writing")
|
},
|
||||||
.insert(OwnedServerName::from(dest), (actual.dest.clone(), actual.host.clone()));
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
match response {
|
match response {
|
||||||
|
|
Loading…
Add table
Reference in a new issue