convert Resolver into a Service.

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-07-16 22:00:54 +00:00
parent 2fd6f6b0ff
commit f465d77ad3
11 changed files with 381 additions and 409 deletions

View file

@ -19,7 +19,7 @@ pub(super) async fn resolver(subcommand: Resolver) -> Result<RoomMessageEventCon
}
async fn destinations_cache(server_name: Option<OwnedServerName>) -> Result<RoomMessageEventContent> {
use service::sending::CachedDest;
use service::resolver::CachedDest;
let mut out = String::new();
writeln!(out, "| Server Name | Destination | Hostname | Expires |")?;
@ -36,12 +36,7 @@ async fn destinations_cache(server_name: Option<OwnedServerName>) -> Result<Room
writeln!(out, "| {name} | {dest} | {host} | {expire} |").expect("wrote line");
};
let map = services()
.globals
.resolver
.destinations
.read()
.expect("locked");
let map = services().resolver.destinations.read().expect("locked");
if let Some(server_name) = server_name.as_ref() {
map.get_key_value(server_name).map(row);
@ -53,7 +48,7 @@ async fn destinations_cache(server_name: Option<OwnedServerName>) -> Result<Room
}
async fn overrides_cache(server_name: Option<String>) -> Result<RoomMessageEventContent> {
use service::sending::CachedOverride;
use service::resolver::CachedOverride;
let mut out = String::new();
writeln!(out, "| Server Name | IP | Port | Expires |")?;
@ -70,12 +65,7 @@ async fn overrides_cache(server_name: Option<String>) -> Result<RoomMessageEvent
writeln!(out, "| {name} | {ips:?} | {port} | {expire} |").expect("wrote line");
};
let map = services()
.globals
.resolver
.overrides
.read()
.expect("locked");
let map = services().resolver.overrides.read().expect("locked");
if let Some(server_name) = server_name.as_ref() {
map.get_key_value(server_name).map(row);

View file

@ -2,7 +2,7 @@ use std::{sync::Arc, time::Duration};
use reqwest::redirect;
use crate::{globals::resolver, Config, Result};
use crate::{resolver, Config, Result};
pub struct Client {
pub default: reqwest::Client,
@ -15,7 +15,7 @@ pub struct Client {
}
impl Client {
pub fn new(config: &Config, resolver: &Arc<resolver::Resolver>) -> Self {
pub fn new(config: &Config, resolver: &Arc<resolver::Service>) -> Self {
Self {
default: Self::base(config)
.unwrap()

View file

@ -2,7 +2,6 @@ mod client;
mod data;
mod emerg_access;
pub(super) mod migrations;
pub(crate) mod resolver;
use std::{
collections::{BTreeMap, HashMap},
@ -25,7 +24,7 @@ use ruma::{
use tokio::sync::Mutex;
use url::Url;
use crate::services;
use crate::{resolver, service, services};
pub struct Service {
pub db: Data,
@ -34,7 +33,6 @@ pub struct Service {
pub cidr_range_denylist: Vec<IPAddress>,
keypair: Arc<ruma::signatures::Ed25519KeyPair>,
jwt_decoding_key: Option<jsonwebtoken::DecodingKey>,
pub resolver: Arc<resolver::Resolver>,
pub client: client::Client,
pub stable_room_versions: Vec<RoomVersionId>,
pub unstable_room_versions: Vec<RoomVersionId>,
@ -68,8 +66,6 @@ impl crate::Service for Service {
.as_ref()
.map(|secret| jsonwebtoken::DecodingKey::from_secret(secret.as_bytes()));
let resolver = Arc::new(resolver::Resolver::new(config));
// Supported and stable room versions
let stable_room_versions = vec![
RoomVersionId::V6,
@ -89,12 +85,14 @@ impl crate::Service for Service {
cidr_range_denylist.push(cidr);
}
let resolver = service::get::<resolver::Service>(args.service, "resolver")
.expect("resolver must be built prior to globals");
let mut s = Self {
db,
config: config.clone(),
cidr_range_denylist,
keypair: Arc::new(keypair),
resolver: resolver.clone(),
client: client::Client::new(config, &resolver),
jwt_decoding_key,
stable_room_versions,
@ -126,8 +124,6 @@ impl crate::Service for Service {
}
fn memory_usage(&self, out: &mut dyn Write) -> Result<()> {
self.resolver.memory_usage(out)?;
let bad_event_ratelimiter = self
.bad_event_ratelimiter
.read()
@ -146,8 +142,6 @@ impl crate::Service for Service {
}
fn clear_cache(&self) {
self.resolver.clear_cache();
self.bad_event_ratelimiter
.write()
.expect("locked for writing")
@ -159,7 +153,7 @@ impl crate::Service for Service {
.clear();
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
fn name(&self) -> &str { service::make_name(std::module_path!()) }
}
impl Service {

View file

@ -1,156 +0,0 @@
use std::{
collections::HashMap,
fmt::Write,
future, iter,
net::{IpAddr, SocketAddr},
sync::{Arc, RwLock},
time::Duration,
};
use conduit::{error, Config, Result};
use hickory_resolver::TokioAsyncResolver;
use reqwest::dns::{Addrs, Name, Resolve, Resolving};
use ruma::OwnedServerName;
use crate::sending::{CachedDest, CachedOverride};
type WellKnownMap = HashMap<OwnedServerName, CachedDest>;
type TlsNameMap = HashMap<String, CachedOverride>;
pub struct Resolver {
pub destinations: Arc<RwLock<WellKnownMap>>, // actual_destination, host
pub overrides: Arc<RwLock<TlsNameMap>>,
pub(crate) resolver: Arc<TokioAsyncResolver>,
pub(crate) hooked: Arc<Hooked>,
}
pub(crate) struct Hooked {
overrides: Arc<RwLock<TlsNameMap>>,
resolver: Arc<TokioAsyncResolver>,
}
impl Resolver {
#[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)]
pub(super) fn new(config: &Config) -> Self {
let (sys_conf, mut opts) = hickory_resolver::system_conf::read_system_conf()
.inspect_err(|e| error!("Failed to set up hickory dns resolver with system config: {e}"))
.expect("DNS system config must be valid");
let mut conf = hickory_resolver::config::ResolverConfig::new();
if let Some(domain) = sys_conf.domain() {
conf.set_domain(domain.clone());
}
for sys_conf in sys_conf.search() {
conf.add_search(sys_conf.clone());
}
for sys_conf in sys_conf.name_servers() {
let mut ns = sys_conf.clone();
if config.query_over_tcp_only {
ns.protocol = hickory_resolver::config::Protocol::Tcp;
}
ns.trust_negative_responses = !config.query_all_nameservers;
conf.add_name_server(ns);
}
opts.cache_size = config.dns_cache_entries as usize;
opts.negative_min_ttl = Some(Duration::from_secs(config.dns_min_ttl_nxdomain));
opts.negative_max_ttl = Some(Duration::from_secs(60 * 60 * 24 * 30));
opts.positive_min_ttl = Some(Duration::from_secs(config.dns_min_ttl));
opts.positive_max_ttl = Some(Duration::from_secs(60 * 60 * 24 * 7));
opts.timeout = Duration::from_secs(config.dns_timeout);
opts.attempts = config.dns_attempts as usize;
opts.try_tcp_on_error = config.dns_tcp_fallback;
opts.num_concurrent_reqs = 1;
opts.shuffle_dns_servers = true;
opts.rotate = true;
opts.ip_strategy = match config.ip_lookup_strategy {
1 => hickory_resolver::config::LookupIpStrategy::Ipv4Only,
2 => hickory_resolver::config::LookupIpStrategy::Ipv6Only,
3 => hickory_resolver::config::LookupIpStrategy::Ipv4AndIpv6,
4 => hickory_resolver::config::LookupIpStrategy::Ipv6thenIpv4,
_ => hickory_resolver::config::LookupIpStrategy::Ipv4thenIpv6,
};
opts.authentic_data = false;
let resolver = Arc::new(TokioAsyncResolver::tokio(conf, opts));
let overrides = Arc::new(RwLock::new(TlsNameMap::new()));
Self {
destinations: Arc::new(RwLock::new(WellKnownMap::new())),
overrides: overrides.clone(),
resolver: resolver.clone(),
hooked: Arc::new(Hooked {
overrides,
resolver,
}),
}
}
pub(super) fn memory_usage(&self, out: &mut dyn Write) -> Result<()> {
let resolver_overrides_cache = self.overrides.read().expect("locked for reading").len();
writeln!(out, "resolver_overrides_cache: {resolver_overrides_cache}")?;
let resolver_destinations_cache = self.destinations.read().expect("locked for reading").len();
writeln!(out, "resolver_destinations_cache: {resolver_destinations_cache}")?;
Ok(())
}
pub(super) fn clear_cache(&self) {
self.overrides.write().expect("write locked").clear();
self.destinations.write().expect("write locked").clear();
self.resolver.clear_cache();
}
}
impl Resolve for Resolver {
fn resolve(&self, name: Name) -> Resolving { resolve_to_reqwest(self.resolver.clone(), name) }
}
impl Resolve for Hooked {
fn resolve(&self, name: Name) -> Resolving {
let cached = self
.overrides
.read()
.expect("locked for reading")
.get(name.as_str())
.filter(|cached| cached.valid())
.cloned();
if let Some(cached) = cached {
cached_to_reqwest(&cached.ips, cached.port)
} else {
resolve_to_reqwest(self.resolver.clone(), name)
}
}
}
fn cached_to_reqwest(override_name: &[IpAddr], port: u16) -> Resolving {
override_name
.first()
.map(|first_name| -> Resolving {
let saddr = SocketAddr::new(*first_name, port);
let result: Box<dyn Iterator<Item = SocketAddr> + Send> = Box::new(iter::once(saddr));
Box::pin(future::ready(Ok(result)))
})
.expect("must provide at least one override name")
}
fn resolve_to_reqwest(resolver: Arc<TokioAsyncResolver>, name: Name) -> Resolving {
Box::pin(async move {
let results = resolver
.lookup_ip(name.as_str())
.await?
.into_iter()
.map(|ip| SocketAddr::new(ip, 0));
let results: Addrs = Box::new(results);
Ok(results)
})
}

View file

@ -12,6 +12,7 @@ pub mod key_backups;
pub mod media;
pub mod presence;
pub mod pusher;
pub mod resolver;
pub mod rooms;
pub mod sending;
pub mod transaction_ids;

349
src/service/resolver/mod.rs Normal file
View file

@ -0,0 +1,349 @@
use std::{
collections::HashMap,
fmt,
fmt::Write,
future, iter,
net::{IpAddr, SocketAddr},
sync::{Arc, RwLock},
time::{Duration, SystemTime},
};
use conduit::{err, trace, Result};
use hickory_resolver::TokioAsyncResolver;
use reqwest::dns::{Addrs, Name, Resolve, Resolving};
use ruma::{OwnedServerName, ServerName};
use crate::utils::rand;
pub struct Service {
pub destinations: Arc<RwLock<WellKnownMap>>, // actual_destination, host
pub overrides: Arc<RwLock<TlsNameMap>>,
pub(crate) resolver: Arc<TokioAsyncResolver>,
pub(crate) hooked: Arc<Hooked>,
}
pub(crate) struct Hooked {
overrides: Arc<RwLock<TlsNameMap>>,
resolver: Arc<TokioAsyncResolver>,
}
#[derive(Clone, Debug)]
pub struct CachedDest {
pub dest: FedDest,
pub host: String,
pub expire: SystemTime,
}
#[derive(Clone, Debug)]
pub struct CachedOverride {
pub ips: Vec<IpAddr>,
pub port: u16,
pub expire: SystemTime,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum FedDest {
Literal(SocketAddr),
Named(String, String),
}
type WellKnownMap = HashMap<OwnedServerName, CachedDest>;
type TlsNameMap = HashMap<String, CachedOverride>;
impl crate::Service for Service {
#[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)]
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
let config = &args.server.config;
let (sys_conf, mut opts) = hickory_resolver::system_conf::read_system_conf()
.map_err(|e| err!(error!("Failed to configure DNS resolver from system: {e}")))?;
let mut conf = hickory_resolver::config::ResolverConfig::new();
if let Some(domain) = sys_conf.domain() {
conf.set_domain(domain.clone());
}
for sys_conf in sys_conf.search() {
conf.add_search(sys_conf.clone());
}
for sys_conf in sys_conf.name_servers() {
let mut ns = sys_conf.clone();
if config.query_over_tcp_only {
ns.protocol = hickory_resolver::config::Protocol::Tcp;
}
ns.trust_negative_responses = !config.query_all_nameservers;
conf.add_name_server(ns);
}
opts.cache_size = config.dns_cache_entries as usize;
opts.negative_min_ttl = Some(Duration::from_secs(config.dns_min_ttl_nxdomain));
opts.negative_max_ttl = Some(Duration::from_secs(60 * 60 * 24 * 30));
opts.positive_min_ttl = Some(Duration::from_secs(config.dns_min_ttl));
opts.positive_max_ttl = Some(Duration::from_secs(60 * 60 * 24 * 7));
opts.timeout = Duration::from_secs(config.dns_timeout);
opts.attempts = config.dns_attempts as usize;
opts.try_tcp_on_error = config.dns_tcp_fallback;
opts.num_concurrent_reqs = 1;
opts.shuffle_dns_servers = true;
opts.rotate = true;
opts.ip_strategy = match config.ip_lookup_strategy {
1 => hickory_resolver::config::LookupIpStrategy::Ipv4Only,
2 => hickory_resolver::config::LookupIpStrategy::Ipv6Only,
3 => hickory_resolver::config::LookupIpStrategy::Ipv4AndIpv6,
4 => hickory_resolver::config::LookupIpStrategy::Ipv6thenIpv4,
_ => hickory_resolver::config::LookupIpStrategy::Ipv4thenIpv6,
};
opts.authentic_data = false;
let resolver = Arc::new(TokioAsyncResolver::tokio(conf, opts));
let overrides = Arc::new(RwLock::new(TlsNameMap::new()));
Ok(Arc::new(Self {
destinations: Arc::new(RwLock::new(WellKnownMap::new())),
overrides: overrides.clone(),
resolver: resolver.clone(),
hooked: Arc::new(Hooked {
overrides,
resolver,
}),
}))
}
fn memory_usage(&self, out: &mut dyn Write) -> Result<()> {
let resolver_overrides_cache = self.overrides.read().expect("locked for reading").len();
writeln!(out, "resolver_overrides_cache: {resolver_overrides_cache}")?;
let resolver_destinations_cache = self.destinations.read().expect("locked for reading").len();
writeln!(out, "resolver_destinations_cache: {resolver_destinations_cache}")?;
Ok(())
}
fn clear_cache(&self) {
self.overrides.write().expect("write locked").clear();
self.destinations.write().expect("write locked").clear();
self.resolver.clear_cache();
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
impl Service {
pub fn set_cached_destination(&self, name: OwnedServerName, dest: CachedDest) -> Option<CachedDest> {
trace!(?name, ?dest, "set cached destination");
self.destinations
.write()
.expect("locked for writing")
.insert(name, dest)
}
#[must_use]
pub fn get_cached_destination(&self, name: &ServerName) -> Option<CachedDest> {
self.destinations
.read()
.expect("locked for reading")
.get(name)
.cloned()
}
pub fn set_cached_override(&self, name: String, over: CachedOverride) -> Option<CachedOverride> {
trace!(?name, ?over, "set cached override");
self.overrides
.write()
.expect("locked for writing")
.insert(name, over)
}
#[must_use]
pub fn get_cached_override(&self, name: &str) -> Option<CachedOverride> {
self.overrides
.read()
.expect("locked for reading")
.get(name)
.cloned()
}
#[must_use]
pub fn has_cached_override(&self, name: &str) -> bool {
self.overrides
.read()
.expect("locked for reading")
.contains_key(name)
}
}
impl Resolve for Service {
fn resolve(&self, name: Name) -> Resolving { resolve_to_reqwest(self.resolver.clone(), name) }
}
impl Resolve for Hooked {
fn resolve(&self, name: Name) -> Resolving {
let cached = self
.overrides
.read()
.expect("locked for reading")
.get(name.as_str())
.cloned();
if let Some(cached) = cached {
cached_to_reqwest(&cached.ips, cached.port)
} else {
resolve_to_reqwest(self.resolver.clone(), name)
}
}
}
fn cached_to_reqwest(override_name: &[IpAddr], port: u16) -> Resolving {
override_name
.first()
.map(|first_name| -> Resolving {
let saddr = SocketAddr::new(*first_name, port);
let result: Box<dyn Iterator<Item = SocketAddr> + Send> = Box::new(iter::once(saddr));
Box::pin(future::ready(Ok(result)))
})
.expect("must provide at least one override name")
}
fn resolve_to_reqwest(resolver: Arc<TokioAsyncResolver>, name: Name) -> Resolving {
Box::pin(async move {
let results = resolver
.lookup_ip(name.as_str())
.await?
.into_iter()
.map(|ip| SocketAddr::new(ip, 0));
let results: Addrs = Box::new(results);
Ok(results)
})
}
impl CachedDest {
#[inline]
#[must_use]
pub fn valid(&self) -> bool { true }
//pub fn valid(&self) -> bool { self.expire > SystemTime::now() }
#[must_use]
pub(crate) fn default_expire() -> SystemTime { rand::timepoint_secs(60 * 60 * 18..60 * 60 * 36) }
}
impl CachedOverride {
#[inline]
#[must_use]
pub fn valid(&self) -> bool { true }
//pub fn valid(&self) -> bool { self.expire > SystemTime::now() }
#[must_use]
pub(crate) fn default_expire() -> SystemTime { rand::timepoint_secs(60 * 60 * 6..60 * 60 * 12) }
}
pub(crate) fn get_ip_with_port(dest_str: &str) -> Option<FedDest> {
if let Ok(dest) = dest_str.parse::<SocketAddr>() {
Some(FedDest::Literal(dest))
} else if let Ok(ip_addr) = dest_str.parse::<IpAddr>() {
Some(FedDest::Literal(SocketAddr::new(ip_addr, 8448)))
} else {
None
}
}
pub(crate) fn add_port_to_hostname(dest_str: &str) -> FedDest {
let (host, port) = match dest_str.find(':') {
None => (dest_str, ":8448"),
Some(pos) => dest_str.split_at(pos),
};
FedDest::Named(host.to_owned(), port.to_owned())
}
impl FedDest {
pub(crate) fn into_https_string(self) -> String {
match self {
Self::Literal(addr) => format!("https://{addr}"),
Self::Named(host, port) => format!("https://{host}{port}"),
}
}
pub(crate) fn into_uri_string(self) -> String {
match self {
Self::Literal(addr) => addr.to_string(),
Self::Named(host, port) => format!("{host}{port}"),
}
}
pub(crate) fn hostname(&self) -> String {
match &self {
Self::Literal(addr) => addr.ip().to_string(),
Self::Named(host, _) => host.clone(),
}
}
#[inline]
#[allow(clippy::string_slice)]
pub(crate) fn port(&self) -> Option<u16> {
match &self {
Self::Literal(addr) => Some(addr.port()),
Self::Named(_, port) => port[1..].parse().ok(),
}
}
}
impl fmt::Display for FedDest {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Named(host, port) => write!(f, "{host}{port}"),
Self::Literal(addr) => write!(f, "{addr}"),
}
}
}
#[cfg(test)]
mod tests {
use super::{add_port_to_hostname, get_ip_with_port, FedDest};
#[test]
fn ips_get_default_ports() {
assert_eq!(
get_ip_with_port("1.1.1.1"),
Some(FedDest::Literal("1.1.1.1:8448".parse().unwrap()))
);
assert_eq!(
get_ip_with_port("dead:beef::"),
Some(FedDest::Literal("[dead:beef::]:8448".parse().unwrap()))
);
}
#[test]
fn ips_keep_custom_ports() {
assert_eq!(
get_ip_with_port("1.1.1.1:1234"),
Some(FedDest::Literal("1.1.1.1:1234".parse().unwrap()))
);
assert_eq!(
get_ip_with_port("[dead::beef]:8933"),
Some(FedDest::Literal("[dead::beef]:8933".parse().unwrap()))
);
}
#[test]
fn hostnames_get_default_ports() {
assert_eq!(
add_port_to_hostname("example.com"),
FedDest::Named(String::from("example.com"), String::from(":8448"))
);
}
#[test]
fn hostnames_keep_custom_ports() {
assert_eq!(
add_port_to_hostname("example.com:1337"),
FedDest::Named(String::from("example.com"), String::from(":1337"))
);
}
}

View file

@ -7,7 +7,7 @@ mod sender;
use std::fmt::Debug;
use conduit::{err, Result};
pub use resolve::{resolve_actual_dest, CachedDest, CachedOverride, FedDest};
pub use resolve::resolve_actual_dest;
use ruma::{
api::{appservice::Registration, OutgoingRequest},
OwnedServerName, OwnedUserId, RoomId, ServerName, UserId,

View file

@ -1,40 +1,17 @@
use std::{
fmt,
fmt::Debug,
net::{IpAddr, SocketAddr},
time::SystemTime,
};
use conduit::{debug, debug_error, debug_info, debug_warn, trace, utils::rand, Err, Error, Result};
use conduit::{debug, debug_error, debug_info, debug_warn, trace, Err, Error, Result};
use hickory_resolver::{error::ResolveError, lookup::SrvLookup};
use ipaddress::IPAddress;
use ruma::{OwnedServerName, ServerName};
use ruma::ServerName;
use crate::services;
/// Wraps either an literal IP address plus port, or a hostname plus complement
/// (colon-plus-port if it was specified).
///
/// Note: A `FedDest::Named` might contain an IP address in string form if there
/// was no port specified to construct a `SocketAddr` with.
///
/// # Examples:
/// ```rust
/// # use conduit_service::sending::FedDest;
/// # fn main() -> Result<(), std::net::AddrParseError> {
/// FedDest::Literal("198.51.100.3:8448".parse()?);
/// FedDest::Literal("[2001:db8::4:5]:443".parse()?);
/// FedDest::Named("matrix.example.org".to_owned(), String::new());
/// FedDest::Named("matrix.example.org".to_owned(), ":8448".to_owned());
/// FedDest::Named("198.51.100.5".to_owned(), String::new());
/// # Ok(())
/// # }
/// ```
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum FedDest {
Literal(SocketAddr),
Named(String, String),
}
use crate::{
resolver::{add_port_to_hostname, get_ip_with_port, CachedDest, CachedOverride, FedDest},
services,
};
#[derive(Clone, Debug)]
pub(crate) struct ActualDest {
@ -44,27 +21,10 @@ pub(crate) struct ActualDest {
pub(crate) cached: bool,
}
#[derive(Clone, Debug)]
pub struct CachedDest {
pub dest: FedDest,
pub host: String,
pub expire: SystemTime,
}
#[derive(Clone, Debug)]
pub struct CachedOverride {
pub ips: Vec<IpAddr>,
pub port: u16,
pub expire: SystemTime,
}
#[tracing::instrument(skip_all, name = "resolve")]
pub(crate) async fn get_actual_dest(server_name: &ServerName) -> Result<ActualDest> {
let cached;
let cached_result = services()
.globals
.resolver
.get_cached_destination(server_name);
let cached_result = services().resolver.get_cached_destination(server_name);
let CachedDest {
dest,
@ -213,7 +173,7 @@ async fn actual_dest_5(dest: &ServerName, cache: bool) -> Result<FedDest> {
#[tracing::instrument(skip_all, name = "well-known")]
async fn request_well_known(dest: &str) -> Result<Option<String>> {
trace!("Requesting well known for {dest}");
if !services().globals.resolver.has_cached_override(dest) {
if !services().resolver.has_cached_override(dest) {
query_and_cache_override(dest, dest, 8448).await?;
}
@ -273,7 +233,6 @@ async fn conditional_query_and_cache_override(overname: &str, hostname: &str, po
#[tracing::instrument(skip_all, name = "ip")]
async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u16) -> Result<()> {
match services()
.globals
.resolver
.resolver
.lookup_ip(hostname.to_owned())
@ -285,7 +244,7 @@ async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u1
debug_info!("{overname:?} overriden by {hostname:?}");
}
services().globals.resolver.set_cached_override(
services().resolver.set_cached_override(
overname.to_owned(),
CachedOverride {
ips: override_ip.iter().collect(),
@ -314,7 +273,6 @@ async fn query_srv_record(hostname: &'_ str) -> Result<Option<FedDest>> {
debug!("querying SRV for {:?}", hostname);
let hostname = hostname.trim_end_matches('.');
services()
.globals
.resolver
.resolver
.srv_lookup(hostname.to_owned())
@ -384,166 +342,3 @@ pub(crate) fn validate_ip(ip: &IPAddress) -> Result<()> {
Ok(())
}
fn get_ip_with_port(dest_str: &str) -> Option<FedDest> {
if let Ok(dest) = dest_str.parse::<SocketAddr>() {
Some(FedDest::Literal(dest))
} else if let Ok(ip_addr) = dest_str.parse::<IpAddr>() {
Some(FedDest::Literal(SocketAddr::new(ip_addr, 8448)))
} else {
None
}
}
fn add_port_to_hostname(dest_str: &str) -> FedDest {
let (host, port) = match dest_str.find(':') {
None => (dest_str, ":8448"),
Some(pos) => dest_str.split_at(pos),
};
FedDest::Named(host.to_owned(), port.to_owned())
}
impl crate::globals::resolver::Resolver {
pub(crate) fn set_cached_destination(&self, name: OwnedServerName, dest: CachedDest) -> Option<CachedDest> {
trace!(?name, ?dest, "set cached destination");
self.destinations
.write()
.expect("locked for writing")
.insert(name, dest)
}
pub(crate) fn get_cached_destination(&self, name: &ServerName) -> Option<CachedDest> {
self.destinations
.read()
.expect("locked for reading")
.get(name)
.filter(|cached| cached.valid())
.cloned()
}
pub(crate) fn set_cached_override(&self, name: String, over: CachedOverride) -> Option<CachedOverride> {
trace!(?name, ?over, "set cached override");
self.overrides
.write()
.expect("locked for writing")
.insert(name, over)
}
pub(crate) fn has_cached_override(&self, name: &str) -> bool {
self.overrides
.read()
.expect("locked for reading")
.get(name)
.filter(|cached| cached.valid())
.is_some()
}
}
impl CachedDest {
#[inline]
#[must_use]
pub fn valid(&self) -> bool { true }
//pub fn valid(&self) -> bool { self.expire > SystemTime::now() }
#[must_use]
pub(crate) fn default_expire() -> SystemTime { rand::timepoint_secs(60 * 60 * 18..60 * 60 * 36) }
}
impl CachedOverride {
#[inline]
#[must_use]
pub fn valid(&self) -> bool { true }
//pub fn valid(&self) -> bool { self.expire > SystemTime::now() }
#[must_use]
pub(crate) fn default_expire() -> SystemTime { rand::timepoint_secs(60 * 60 * 6..60 * 60 * 12) }
}
impl FedDest {
fn into_https_string(self) -> String {
match self {
Self::Literal(addr) => format!("https://{addr}"),
Self::Named(host, port) => format!("https://{host}{port}"),
}
}
fn into_uri_string(self) -> String {
match self {
Self::Literal(addr) => addr.to_string(),
Self::Named(host, port) => format!("{host}{port}"),
}
}
fn hostname(&self) -> String {
match &self {
Self::Literal(addr) => addr.ip().to_string(),
Self::Named(host, _) => host.clone(),
}
}
#[inline]
#[allow(clippy::string_slice)]
fn port(&self) -> Option<u16> {
match &self {
Self::Literal(addr) => Some(addr.port()),
Self::Named(_, port) => port[1..].parse().ok(),
}
}
}
impl fmt::Display for FedDest {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Named(host, port) => write!(f, "{host}{port}"),
Self::Literal(addr) => write!(f, "{addr}"),
}
}
}
#[cfg(test)]
mod tests {
use super::{add_port_to_hostname, get_ip_with_port, FedDest};
#[test]
fn ips_get_default_ports() {
assert_eq!(
get_ip_with_port("1.1.1.1"),
Some(FedDest::Literal("1.1.1.1:8448".parse().unwrap()))
);
assert_eq!(
get_ip_with_port("dead:beef::"),
Some(FedDest::Literal("[dead:beef::]:8448".parse().unwrap()))
);
}
#[test]
fn ips_keep_custom_ports() {
assert_eq!(
get_ip_with_port("1.1.1.1:1234"),
Some(FedDest::Literal("1.1.1.1:1234".parse().unwrap()))
);
assert_eq!(
get_ip_with_port("[dead::beef]:8933"),
Some(FedDest::Literal("[dead::beef]:8933".parse().unwrap()))
);
}
#[test]
fn hostnames_get_default_ports() {
assert_eq!(
add_port_to_hostname("example.com"),
FedDest::Named(String::from("example.com"), String::from(":8448"))
);
}
#[test]
fn hostnames_keep_custom_ports() {
assert_eq!(
add_port_to_hostname("example.com:1337"),
FedDest::Named(String::from("example.com"), String::from(":1337"))
);
}
}

View file

@ -15,11 +15,8 @@ use ruma::{
};
use tracing::{debug, trace};
use super::{
resolve,
resolve::{ActualDest, CachedDest},
};
use crate::{debug_error, debug_warn, services, Error, Result};
use super::{resolve, resolve::ActualDest};
use crate::{debug_error, debug_warn, resolver::CachedDest, services, Error, Result};
#[tracing::instrument(skip_all, name = "send")]
pub async fn send<T>(client: &Client, dest: &ServerName, req: T) -> Result<T::IncomingResponse>
@ -109,7 +106,7 @@ where
let response = T::IncomingResponse::try_from_http_response(http_response);
if response.is_ok() && !actual.cached {
services().globals.resolver.set_cached_destination(
services().resolver.set_cached_destination(
dest.to_owned(),
CachedDest {
dest: actual.dest.clone(),

View file

@ -37,7 +37,7 @@ pub(crate) trait Service: Any + Send + Sync {
pub(crate) struct Args<'a> {
pub(crate) server: &'a Arc<Server>,
pub(crate) db: &'a Arc<Database>,
pub(crate) _service: &'a Map,
pub(crate) service: &'a Map,
}
pub(crate) type Map = BTreeMap<String, MapVal>;

View file

@ -7,12 +7,14 @@ use tokio::sync::Mutex;
use crate::{
account_data, admin, appservice, globals, key_backups,
manager::Manager,
media, presence, pusher, rooms, sending, service,
media, presence, pusher, resolver, rooms, sending, service,
service::{Args, Map, Service},
transaction_ids, uiaa, updates, users,
};
pub struct Services {
pub resolver: Arc<resolver::Service>,
pub globals: Arc<globals::Service>,
pub rooms: rooms::Service,
pub appservice: Arc<appservice::Service>,
pub pusher: Arc<pusher::Service>,
@ -26,7 +28,6 @@ pub struct Services {
pub media: Arc<media::Service>,
pub sending: Arc<sending::Service>,
pub updates: Arc<updates::Service>,
pub globals: Arc<globals::Service>,
manager: Mutex<Option<Arc<Manager>>>,
pub(crate) service: Map,
@ -42,7 +43,7 @@ impl Services {
let built = <$tyname>::build(Args {
server: &server,
db: &db,
_service: &service,
service: &service,
})?;
service.insert(built.name().to_owned(), (built.clone(), built.clone()));
built
@ -50,6 +51,8 @@ impl Services {
}
Ok(Self {
resolver: build!(resolver::Service),
globals: build!(globals::Service),
rooms: rooms::Service {
alias: build!(rooms::alias::Service),
auth_chain: build!(rooms::auth_chain::Service),
@ -84,7 +87,6 @@ impl Services {
media: build!(media::Service),
sending: build!(sending::Service),
updates: build!(updates::Service),
globals: build!(globals::Service),
manager: Mutex::new(None),
service,
server,