diff --git a/src/admin/query/resolver.rs b/src/admin/query/resolver.rs index 2a2554b5..06cd8ba9 100644 --- a/src/admin/query/resolver.rs +++ b/src/admin/query/resolver.rs @@ -19,7 +19,7 @@ pub(super) async fn resolver(subcommand: Resolver) -> Result) -> Result { - use service::sending::CachedDest; + use service::resolver::CachedDest; let mut out = String::new(); writeln!(out, "| Server Name | Destination | Hostname | Expires |")?; @@ -36,12 +36,7 @@ async fn destinations_cache(server_name: Option) -> Result) -> Result) -> Result { - use service::sending::CachedOverride; + use service::resolver::CachedOverride; let mut out = String::new(); writeln!(out, "| Server Name | IP | Port | Expires |")?; @@ -70,12 +65,7 @@ async fn overrides_cache(server_name: Option) -> Result) -> Self { + pub fn new(config: &Config, resolver: &Arc) -> Self { Self { default: Self::base(config) .unwrap() diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 75cae822..a2fba4d9 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -2,7 +2,6 @@ mod client; mod data; mod emerg_access; pub(super) mod migrations; -pub(crate) mod resolver; use std::{ collections::{BTreeMap, HashMap}, @@ -25,7 +24,7 @@ use ruma::{ use tokio::sync::Mutex; use url::Url; -use crate::services; +use crate::{resolver, service, services}; pub struct Service { pub db: Data, @@ -34,7 +33,6 @@ pub struct Service { pub cidr_range_denylist: Vec, keypair: Arc, jwt_decoding_key: Option, - pub resolver: Arc, pub client: client::Client, pub stable_room_versions: Vec, pub unstable_room_versions: Vec, @@ -68,8 +66,6 @@ impl crate::Service for Service { .as_ref() .map(|secret| jsonwebtoken::DecodingKey::from_secret(secret.as_bytes())); - let resolver = Arc::new(resolver::Resolver::new(config)); - // Supported and stable room versions let stable_room_versions = vec![ RoomVersionId::V6, @@ -89,12 +85,14 @@ impl crate::Service for Service { cidr_range_denylist.push(cidr); } + let resolver = service::get::(args.service, "resolver") + .expect("resolver must be built prior to globals"); + let mut s = Self { db, config: config.clone(), cidr_range_denylist, keypair: Arc::new(keypair), - resolver: resolver.clone(), client: client::Client::new(config, &resolver), jwt_decoding_key, stable_room_versions, @@ -126,8 +124,6 @@ impl crate::Service for Service { } fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { - self.resolver.memory_usage(out)?; - let bad_event_ratelimiter = self .bad_event_ratelimiter .read() @@ -146,8 +142,6 @@ impl crate::Service for Service { } fn clear_cache(&self) { - self.resolver.clear_cache(); - self.bad_event_ratelimiter .write() .expect("locked for writing") @@ -159,7 +153,7 @@ impl crate::Service for Service { .clear(); } - fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } + fn name(&self) -> &str { service::make_name(std::module_path!()) } } impl Service { diff --git a/src/service/globals/resolver.rs b/src/service/globals/resolver.rs deleted file mode 100644 index 3002decf..00000000 --- a/src/service/globals/resolver.rs +++ /dev/null @@ -1,156 +0,0 @@ -use std::{ - collections::HashMap, - fmt::Write, - future, iter, - net::{IpAddr, SocketAddr}, - sync::{Arc, RwLock}, - time::Duration, -}; - -use conduit::{error, Config, Result}; -use hickory_resolver::TokioAsyncResolver; -use reqwest::dns::{Addrs, Name, Resolve, Resolving}; -use ruma::OwnedServerName; - -use crate::sending::{CachedDest, CachedOverride}; - -type WellKnownMap = HashMap; -type TlsNameMap = HashMap; - -pub struct Resolver { - pub destinations: Arc>, // actual_destination, host - pub overrides: Arc>, - pub(crate) resolver: Arc, - pub(crate) hooked: Arc, -} - -pub(crate) struct Hooked { - overrides: Arc>, - resolver: Arc, -} - -impl Resolver { - #[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)] - pub(super) fn new(config: &Config) -> Self { - let (sys_conf, mut opts) = hickory_resolver::system_conf::read_system_conf() - .inspect_err(|e| error!("Failed to set up hickory dns resolver with system config: {e}")) - .expect("DNS system config must be valid"); - - let mut conf = hickory_resolver::config::ResolverConfig::new(); - - if let Some(domain) = sys_conf.domain() { - conf.set_domain(domain.clone()); - } - - for sys_conf in sys_conf.search() { - conf.add_search(sys_conf.clone()); - } - - for sys_conf in sys_conf.name_servers() { - let mut ns = sys_conf.clone(); - - if config.query_over_tcp_only { - ns.protocol = hickory_resolver::config::Protocol::Tcp; - } - - ns.trust_negative_responses = !config.query_all_nameservers; - - conf.add_name_server(ns); - } - - opts.cache_size = config.dns_cache_entries as usize; - opts.negative_min_ttl = Some(Duration::from_secs(config.dns_min_ttl_nxdomain)); - opts.negative_max_ttl = Some(Duration::from_secs(60 * 60 * 24 * 30)); - opts.positive_min_ttl = Some(Duration::from_secs(config.dns_min_ttl)); - opts.positive_max_ttl = Some(Duration::from_secs(60 * 60 * 24 * 7)); - opts.timeout = Duration::from_secs(config.dns_timeout); - opts.attempts = config.dns_attempts as usize; - opts.try_tcp_on_error = config.dns_tcp_fallback; - opts.num_concurrent_reqs = 1; - opts.shuffle_dns_servers = true; - opts.rotate = true; - opts.ip_strategy = match config.ip_lookup_strategy { - 1 => hickory_resolver::config::LookupIpStrategy::Ipv4Only, - 2 => hickory_resolver::config::LookupIpStrategy::Ipv6Only, - 3 => hickory_resolver::config::LookupIpStrategy::Ipv4AndIpv6, - 4 => hickory_resolver::config::LookupIpStrategy::Ipv6thenIpv4, - _ => hickory_resolver::config::LookupIpStrategy::Ipv4thenIpv6, - }; - opts.authentic_data = false; - - let resolver = Arc::new(TokioAsyncResolver::tokio(conf, opts)); - let overrides = Arc::new(RwLock::new(TlsNameMap::new())); - Self { - destinations: Arc::new(RwLock::new(WellKnownMap::new())), - overrides: overrides.clone(), - resolver: resolver.clone(), - hooked: Arc::new(Hooked { - overrides, - resolver, - }), - } - } - - pub(super) fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { - let resolver_overrides_cache = self.overrides.read().expect("locked for reading").len(); - writeln!(out, "resolver_overrides_cache: {resolver_overrides_cache}")?; - - let resolver_destinations_cache = self.destinations.read().expect("locked for reading").len(); - writeln!(out, "resolver_destinations_cache: {resolver_destinations_cache}")?; - - Ok(()) - } - - pub(super) fn clear_cache(&self) { - self.overrides.write().expect("write locked").clear(); - self.destinations.write().expect("write locked").clear(); - self.resolver.clear_cache(); - } -} - -impl Resolve for Resolver { - fn resolve(&self, name: Name) -> Resolving { resolve_to_reqwest(self.resolver.clone(), name) } -} - -impl Resolve for Hooked { - fn resolve(&self, name: Name) -> Resolving { - let cached = self - .overrides - .read() - .expect("locked for reading") - .get(name.as_str()) - .filter(|cached| cached.valid()) - .cloned(); - - if let Some(cached) = cached { - cached_to_reqwest(&cached.ips, cached.port) - } else { - resolve_to_reqwest(self.resolver.clone(), name) - } - } -} - -fn cached_to_reqwest(override_name: &[IpAddr], port: u16) -> Resolving { - override_name - .first() - .map(|first_name| -> Resolving { - let saddr = SocketAddr::new(*first_name, port); - let result: Box + Send> = Box::new(iter::once(saddr)); - Box::pin(future::ready(Ok(result))) - }) - .expect("must provide at least one override name") -} - -fn resolve_to_reqwest(resolver: Arc, name: Name) -> Resolving { - Box::pin(async move { - let results = resolver - .lookup_ip(name.as_str()) - .await? - .into_iter() - .map(|ip| SocketAddr::new(ip, 0)); - - let results: Addrs = Box::new(results); - - Ok(results) - }) -} diff --git a/src/service/mod.rs b/src/service/mod.rs index 81e0be3b..870d865b 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -12,6 +12,7 @@ pub mod key_backups; pub mod media; pub mod presence; pub mod pusher; +pub mod resolver; pub mod rooms; pub mod sending; pub mod transaction_ids; diff --git a/src/service/resolver/mod.rs b/src/service/resolver/mod.rs new file mode 100644 index 00000000..62fd1625 --- /dev/null +++ b/src/service/resolver/mod.rs @@ -0,0 +1,349 @@ +use std::{ + collections::HashMap, + fmt, + fmt::Write, + future, iter, + net::{IpAddr, SocketAddr}, + sync::{Arc, RwLock}, + time::{Duration, SystemTime}, +}; + +use conduit::{err, trace, Result}; +use hickory_resolver::TokioAsyncResolver; +use reqwest::dns::{Addrs, Name, Resolve, Resolving}; +use ruma::{OwnedServerName, ServerName}; + +use crate::utils::rand; + +pub struct Service { + pub destinations: Arc>, // actual_destination, host + pub overrides: Arc>, + pub(crate) resolver: Arc, + pub(crate) hooked: Arc, +} + +pub(crate) struct Hooked { + overrides: Arc>, + resolver: Arc, +} + +#[derive(Clone, Debug)] +pub struct CachedDest { + pub dest: FedDest, + pub host: String, + pub expire: SystemTime, +} + +#[derive(Clone, Debug)] +pub struct CachedOverride { + pub ips: Vec, + pub port: u16, + pub expire: SystemTime, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum FedDest { + Literal(SocketAddr), + Named(String, String), +} + +type WellKnownMap = HashMap; +type TlsNameMap = HashMap; + +impl crate::Service for Service { + #[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)] + fn build(args: crate::Args<'_>) -> Result> { + let config = &args.server.config; + let (sys_conf, mut opts) = hickory_resolver::system_conf::read_system_conf() + .map_err(|e| err!(error!("Failed to configure DNS resolver from system: {e}")))?; + + let mut conf = hickory_resolver::config::ResolverConfig::new(); + + if let Some(domain) = sys_conf.domain() { + conf.set_domain(domain.clone()); + } + + for sys_conf in sys_conf.search() { + conf.add_search(sys_conf.clone()); + } + + for sys_conf in sys_conf.name_servers() { + let mut ns = sys_conf.clone(); + + if config.query_over_tcp_only { + ns.protocol = hickory_resolver::config::Protocol::Tcp; + } + + ns.trust_negative_responses = !config.query_all_nameservers; + + conf.add_name_server(ns); + } + + opts.cache_size = config.dns_cache_entries as usize; + opts.negative_min_ttl = Some(Duration::from_secs(config.dns_min_ttl_nxdomain)); + opts.negative_max_ttl = Some(Duration::from_secs(60 * 60 * 24 * 30)); + opts.positive_min_ttl = Some(Duration::from_secs(config.dns_min_ttl)); + opts.positive_max_ttl = Some(Duration::from_secs(60 * 60 * 24 * 7)); + opts.timeout = Duration::from_secs(config.dns_timeout); + opts.attempts = config.dns_attempts as usize; + opts.try_tcp_on_error = config.dns_tcp_fallback; + opts.num_concurrent_reqs = 1; + opts.shuffle_dns_servers = true; + opts.rotate = true; + opts.ip_strategy = match config.ip_lookup_strategy { + 1 => hickory_resolver::config::LookupIpStrategy::Ipv4Only, + 2 => hickory_resolver::config::LookupIpStrategy::Ipv6Only, + 3 => hickory_resolver::config::LookupIpStrategy::Ipv4AndIpv6, + 4 => hickory_resolver::config::LookupIpStrategy::Ipv6thenIpv4, + _ => hickory_resolver::config::LookupIpStrategy::Ipv4thenIpv6, + }; + opts.authentic_data = false; + + let resolver = Arc::new(TokioAsyncResolver::tokio(conf, opts)); + let overrides = Arc::new(RwLock::new(TlsNameMap::new())); + Ok(Arc::new(Self { + destinations: Arc::new(RwLock::new(WellKnownMap::new())), + overrides: overrides.clone(), + resolver: resolver.clone(), + hooked: Arc::new(Hooked { + overrides, + resolver, + }), + })) + } + + fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { + let resolver_overrides_cache = self.overrides.read().expect("locked for reading").len(); + writeln!(out, "resolver_overrides_cache: {resolver_overrides_cache}")?; + + let resolver_destinations_cache = self.destinations.read().expect("locked for reading").len(); + writeln!(out, "resolver_destinations_cache: {resolver_destinations_cache}")?; + + Ok(()) + } + + fn clear_cache(&self) { + self.overrides.write().expect("write locked").clear(); + self.destinations.write().expect("write locked").clear(); + self.resolver.clear_cache(); + } + + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { + pub fn set_cached_destination(&self, name: OwnedServerName, dest: CachedDest) -> Option { + trace!(?name, ?dest, "set cached destination"); + self.destinations + .write() + .expect("locked for writing") + .insert(name, dest) + } + + #[must_use] + pub fn get_cached_destination(&self, name: &ServerName) -> Option { + self.destinations + .read() + .expect("locked for reading") + .get(name) + .cloned() + } + + pub fn set_cached_override(&self, name: String, over: CachedOverride) -> Option { + trace!(?name, ?over, "set cached override"); + self.overrides + .write() + .expect("locked for writing") + .insert(name, over) + } + + #[must_use] + pub fn get_cached_override(&self, name: &str) -> Option { + self.overrides + .read() + .expect("locked for reading") + .get(name) + .cloned() + } + + #[must_use] + pub fn has_cached_override(&self, name: &str) -> bool { + self.overrides + .read() + .expect("locked for reading") + .contains_key(name) + } +} + +impl Resolve for Service { + fn resolve(&self, name: Name) -> Resolving { resolve_to_reqwest(self.resolver.clone(), name) } +} + +impl Resolve for Hooked { + fn resolve(&self, name: Name) -> Resolving { + let cached = self + .overrides + .read() + .expect("locked for reading") + .get(name.as_str()) + .cloned(); + + if let Some(cached) = cached { + cached_to_reqwest(&cached.ips, cached.port) + } else { + resolve_to_reqwest(self.resolver.clone(), name) + } + } +} + +fn cached_to_reqwest(override_name: &[IpAddr], port: u16) -> Resolving { + override_name + .first() + .map(|first_name| -> Resolving { + let saddr = SocketAddr::new(*first_name, port); + let result: Box + Send> = Box::new(iter::once(saddr)); + Box::pin(future::ready(Ok(result))) + }) + .expect("must provide at least one override name") +} + +fn resolve_to_reqwest(resolver: Arc, name: Name) -> Resolving { + Box::pin(async move { + let results = resolver + .lookup_ip(name.as_str()) + .await? + .into_iter() + .map(|ip| SocketAddr::new(ip, 0)); + + let results: Addrs = Box::new(results); + + Ok(results) + }) +} + +impl CachedDest { + #[inline] + #[must_use] + pub fn valid(&self) -> bool { true } + + //pub fn valid(&self) -> bool { self.expire > SystemTime::now() } + + #[must_use] + pub(crate) fn default_expire() -> SystemTime { rand::timepoint_secs(60 * 60 * 18..60 * 60 * 36) } +} + +impl CachedOverride { + #[inline] + #[must_use] + pub fn valid(&self) -> bool { true } + + //pub fn valid(&self) -> bool { self.expire > SystemTime::now() } + + #[must_use] + pub(crate) fn default_expire() -> SystemTime { rand::timepoint_secs(60 * 60 * 6..60 * 60 * 12) } +} + +pub(crate) fn get_ip_with_port(dest_str: &str) -> Option { + if let Ok(dest) = dest_str.parse::() { + Some(FedDest::Literal(dest)) + } else if let Ok(ip_addr) = dest_str.parse::() { + Some(FedDest::Literal(SocketAddr::new(ip_addr, 8448))) + } else { + None + } +} + +pub(crate) fn add_port_to_hostname(dest_str: &str) -> FedDest { + let (host, port) = match dest_str.find(':') { + None => (dest_str, ":8448"), + Some(pos) => dest_str.split_at(pos), + }; + + FedDest::Named(host.to_owned(), port.to_owned()) +} + +impl FedDest { + pub(crate) fn into_https_string(self) -> String { + match self { + Self::Literal(addr) => format!("https://{addr}"), + Self::Named(host, port) => format!("https://{host}{port}"), + } + } + + pub(crate) fn into_uri_string(self) -> String { + match self { + Self::Literal(addr) => addr.to_string(), + Self::Named(host, port) => format!("{host}{port}"), + } + } + + pub(crate) fn hostname(&self) -> String { + match &self { + Self::Literal(addr) => addr.ip().to_string(), + Self::Named(host, _) => host.clone(), + } + } + + #[inline] + #[allow(clippy::string_slice)] + pub(crate) fn port(&self) -> Option { + match &self { + Self::Literal(addr) => Some(addr.port()), + Self::Named(_, port) => port[1..].parse().ok(), + } + } +} + +impl fmt::Display for FedDest { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Named(host, port) => write!(f, "{host}{port}"), + Self::Literal(addr) => write!(f, "{addr}"), + } + } +} + +#[cfg(test)] +mod tests { + use super::{add_port_to_hostname, get_ip_with_port, FedDest}; + + #[test] + fn ips_get_default_ports() { + assert_eq!( + get_ip_with_port("1.1.1.1"), + Some(FedDest::Literal("1.1.1.1:8448".parse().unwrap())) + ); + assert_eq!( + get_ip_with_port("dead:beef::"), + Some(FedDest::Literal("[dead:beef::]:8448".parse().unwrap())) + ); + } + + #[test] + fn ips_keep_custom_ports() { + assert_eq!( + get_ip_with_port("1.1.1.1:1234"), + Some(FedDest::Literal("1.1.1.1:1234".parse().unwrap())) + ); + assert_eq!( + get_ip_with_port("[dead::beef]:8933"), + Some(FedDest::Literal("[dead::beef]:8933".parse().unwrap())) + ); + } + + #[test] + fn hostnames_get_default_ports() { + assert_eq!( + add_port_to_hostname("example.com"), + FedDest::Named(String::from("example.com"), String::from(":8448")) + ); + } + + #[test] + fn hostnames_keep_custom_ports() { + assert_eq!( + add_port_to_hostname("example.com:1337"), + FedDest::Named(String::from("example.com"), String::from(":1337")) + ); + } +} diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index 88b8b189..fd9acd3d 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -7,7 +7,7 @@ mod sender; use std::fmt::Debug; use conduit::{err, Result}; -pub use resolve::{resolve_actual_dest, CachedDest, CachedOverride, FedDest}; +pub use resolve::resolve_actual_dest; use ruma::{ api::{appservice::Registration, OutgoingRequest}, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, diff --git a/src/service/sending/resolve.rs b/src/service/sending/resolve.rs index 48dddf35..4c1e4758 100644 --- a/src/service/sending/resolve.rs +++ b/src/service/sending/resolve.rs @@ -1,40 +1,17 @@ use std::{ - fmt, fmt::Debug, net::{IpAddr, SocketAddr}, - time::SystemTime, }; -use conduit::{debug, debug_error, debug_info, debug_warn, trace, utils::rand, Err, Error, Result}; +use conduit::{debug, debug_error, debug_info, debug_warn, trace, Err, Error, Result}; use hickory_resolver::{error::ResolveError, lookup::SrvLookup}; use ipaddress::IPAddress; -use ruma::{OwnedServerName, ServerName}; +use ruma::ServerName; -use crate::services; - -/// Wraps either an literal IP address plus port, or a hostname plus complement -/// (colon-plus-port if it was specified). -/// -/// Note: A `FedDest::Named` might contain an IP address in string form if there -/// was no port specified to construct a `SocketAddr` with. -/// -/// # Examples: -/// ```rust -/// # use conduit_service::sending::FedDest; -/// # fn main() -> Result<(), std::net::AddrParseError> { -/// FedDest::Literal("198.51.100.3:8448".parse()?); -/// FedDest::Literal("[2001:db8::4:5]:443".parse()?); -/// FedDest::Named("matrix.example.org".to_owned(), String::new()); -/// FedDest::Named("matrix.example.org".to_owned(), ":8448".to_owned()); -/// FedDest::Named("198.51.100.5".to_owned(), String::new()); -/// # Ok(()) -/// # } -/// ``` -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum FedDest { - Literal(SocketAddr), - Named(String, String), -} +use crate::{ + resolver::{add_port_to_hostname, get_ip_with_port, CachedDest, CachedOverride, FedDest}, + services, +}; #[derive(Clone, Debug)] pub(crate) struct ActualDest { @@ -44,27 +21,10 @@ pub(crate) struct ActualDest { pub(crate) cached: bool, } -#[derive(Clone, Debug)] -pub struct CachedDest { - pub dest: FedDest, - pub host: String, - pub expire: SystemTime, -} - -#[derive(Clone, Debug)] -pub struct CachedOverride { - pub ips: Vec, - pub port: u16, - pub expire: SystemTime, -} - #[tracing::instrument(skip_all, name = "resolve")] pub(crate) async fn get_actual_dest(server_name: &ServerName) -> Result { let cached; - let cached_result = services() - .globals - .resolver - .get_cached_destination(server_name); + let cached_result = services().resolver.get_cached_destination(server_name); let CachedDest { dest, @@ -213,7 +173,7 @@ async fn actual_dest_5(dest: &ServerName, cache: bool) -> Result { #[tracing::instrument(skip_all, name = "well-known")] async fn request_well_known(dest: &str) -> Result> { trace!("Requesting well known for {dest}"); - if !services().globals.resolver.has_cached_override(dest) { + if !services().resolver.has_cached_override(dest) { query_and_cache_override(dest, dest, 8448).await?; } @@ -273,7 +233,6 @@ async fn conditional_query_and_cache_override(overname: &str, hostname: &str, po #[tracing::instrument(skip_all, name = "ip")] async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u16) -> Result<()> { match services() - .globals .resolver .resolver .lookup_ip(hostname.to_owned()) @@ -285,7 +244,7 @@ async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u1 debug_info!("{overname:?} overriden by {hostname:?}"); } - services().globals.resolver.set_cached_override( + services().resolver.set_cached_override( overname.to_owned(), CachedOverride { ips: override_ip.iter().collect(), @@ -314,7 +273,6 @@ async fn query_srv_record(hostname: &'_ str) -> Result> { debug!("querying SRV for {:?}", hostname); let hostname = hostname.trim_end_matches('.'); services() - .globals .resolver .resolver .srv_lookup(hostname.to_owned()) @@ -384,166 +342,3 @@ pub(crate) fn validate_ip(ip: &IPAddress) -> Result<()> { Ok(()) } - -fn get_ip_with_port(dest_str: &str) -> Option { - if let Ok(dest) = dest_str.parse::() { - Some(FedDest::Literal(dest)) - } else if let Ok(ip_addr) = dest_str.parse::() { - Some(FedDest::Literal(SocketAddr::new(ip_addr, 8448))) - } else { - None - } -} - -fn add_port_to_hostname(dest_str: &str) -> FedDest { - let (host, port) = match dest_str.find(':') { - None => (dest_str, ":8448"), - Some(pos) => dest_str.split_at(pos), - }; - - FedDest::Named(host.to_owned(), port.to_owned()) -} - -impl crate::globals::resolver::Resolver { - pub(crate) fn set_cached_destination(&self, name: OwnedServerName, dest: CachedDest) -> Option { - trace!(?name, ?dest, "set cached destination"); - self.destinations - .write() - .expect("locked for writing") - .insert(name, dest) - } - - pub(crate) fn get_cached_destination(&self, name: &ServerName) -> Option { - self.destinations - .read() - .expect("locked for reading") - .get(name) - .filter(|cached| cached.valid()) - .cloned() - } - - pub(crate) fn set_cached_override(&self, name: String, over: CachedOverride) -> Option { - trace!(?name, ?over, "set cached override"); - self.overrides - .write() - .expect("locked for writing") - .insert(name, over) - } - - pub(crate) fn has_cached_override(&self, name: &str) -> bool { - self.overrides - .read() - .expect("locked for reading") - .get(name) - .filter(|cached| cached.valid()) - .is_some() - } -} - -impl CachedDest { - #[inline] - #[must_use] - pub fn valid(&self) -> bool { true } - - //pub fn valid(&self) -> bool { self.expire > SystemTime::now() } - - #[must_use] - pub(crate) fn default_expire() -> SystemTime { rand::timepoint_secs(60 * 60 * 18..60 * 60 * 36) } -} - -impl CachedOverride { - #[inline] - #[must_use] - pub fn valid(&self) -> bool { true } - - //pub fn valid(&self) -> bool { self.expire > SystemTime::now() } - - #[must_use] - pub(crate) fn default_expire() -> SystemTime { rand::timepoint_secs(60 * 60 * 6..60 * 60 * 12) } -} - -impl FedDest { - fn into_https_string(self) -> String { - match self { - Self::Literal(addr) => format!("https://{addr}"), - Self::Named(host, port) => format!("https://{host}{port}"), - } - } - - fn into_uri_string(self) -> String { - match self { - Self::Literal(addr) => addr.to_string(), - Self::Named(host, port) => format!("{host}{port}"), - } - } - - fn hostname(&self) -> String { - match &self { - Self::Literal(addr) => addr.ip().to_string(), - Self::Named(host, _) => host.clone(), - } - } - - #[inline] - #[allow(clippy::string_slice)] - fn port(&self) -> Option { - match &self { - Self::Literal(addr) => Some(addr.port()), - Self::Named(_, port) => port[1..].parse().ok(), - } - } -} - -impl fmt::Display for FedDest { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Named(host, port) => write!(f, "{host}{port}"), - Self::Literal(addr) => write!(f, "{addr}"), - } - } -} - -#[cfg(test)] -mod tests { - use super::{add_port_to_hostname, get_ip_with_port, FedDest}; - - #[test] - fn ips_get_default_ports() { - assert_eq!( - get_ip_with_port("1.1.1.1"), - Some(FedDest::Literal("1.1.1.1:8448".parse().unwrap())) - ); - assert_eq!( - get_ip_with_port("dead:beef::"), - Some(FedDest::Literal("[dead:beef::]:8448".parse().unwrap())) - ); - } - - #[test] - fn ips_keep_custom_ports() { - assert_eq!( - get_ip_with_port("1.1.1.1:1234"), - Some(FedDest::Literal("1.1.1.1:1234".parse().unwrap())) - ); - assert_eq!( - get_ip_with_port("[dead::beef]:8933"), - Some(FedDest::Literal("[dead::beef]:8933".parse().unwrap())) - ); - } - - #[test] - fn hostnames_get_default_ports() { - assert_eq!( - add_port_to_hostname("example.com"), - FedDest::Named(String::from("example.com"), String::from(":8448")) - ); - } - - #[test] - fn hostnames_keep_custom_ports() { - assert_eq!( - add_port_to_hostname("example.com:1337"), - FedDest::Named(String::from("example.com"), String::from(":1337")) - ); - } -} diff --git a/src/service/sending/send.rs b/src/service/sending/send.rs index 18a98828..dc4541d6 100644 --- a/src/service/sending/send.rs +++ b/src/service/sending/send.rs @@ -15,11 +15,8 @@ use ruma::{ }; use tracing::{debug, trace}; -use super::{ - resolve, - resolve::{ActualDest, CachedDest}, -}; -use crate::{debug_error, debug_warn, services, Error, Result}; +use super::{resolve, resolve::ActualDest}; +use crate::{debug_error, debug_warn, resolver::CachedDest, services, Error, Result}; #[tracing::instrument(skip_all, name = "send")] pub async fn send(client: &Client, dest: &ServerName, req: T) -> Result @@ -109,7 +106,7 @@ where let response = T::IncomingResponse::try_from_http_response(http_response); if response.is_ok() && !actual.cached { - services().globals.resolver.set_cached_destination( + services().resolver.set_cached_destination( dest.to_owned(), CachedDest { dest: actual.dest.clone(), diff --git a/src/service/service.rs b/src/service/service.rs index ab41753a..8b49d455 100644 --- a/src/service/service.rs +++ b/src/service/service.rs @@ -37,7 +37,7 @@ pub(crate) trait Service: Any + Send + Sync { pub(crate) struct Args<'a> { pub(crate) server: &'a Arc, pub(crate) db: &'a Arc, - pub(crate) _service: &'a Map, + pub(crate) service: &'a Map, } pub(crate) type Map = BTreeMap; diff --git a/src/service/services.rs b/src/service/services.rs index 0aac10dc..fb13f24a 100644 --- a/src/service/services.rs +++ b/src/service/services.rs @@ -7,12 +7,14 @@ use tokio::sync::Mutex; use crate::{ account_data, admin, appservice, globals, key_backups, manager::Manager, - media, presence, pusher, rooms, sending, service, + media, presence, pusher, resolver, rooms, sending, service, service::{Args, Map, Service}, transaction_ids, uiaa, updates, users, }; pub struct Services { + pub resolver: Arc, + pub globals: Arc, pub rooms: rooms::Service, pub appservice: Arc, pub pusher: Arc, @@ -26,7 +28,6 @@ pub struct Services { pub media: Arc, pub sending: Arc, pub updates: Arc, - pub globals: Arc, manager: Mutex>>, pub(crate) service: Map, @@ -42,7 +43,7 @@ impl Services { let built = <$tyname>::build(Args { server: &server, db: &db, - _service: &service, + service: &service, })?; service.insert(built.name().to_owned(), (built.clone(), built.clone())); built @@ -50,6 +51,8 @@ impl Services { } Ok(Self { + resolver: build!(resolver::Service), + globals: build!(globals::Service), rooms: rooms::Service { alias: build!(rooms::alias::Service), auth_chain: build!(rooms::auth_chain::Service), @@ -84,7 +87,6 @@ impl Services { media: build!(media::Service), sending: build!(sending::Service), updates: build!(updates::Service), - globals: build!(globals::Service), manager: Mutex::new(None), service, server,