convert Resolver into a Service.
Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
parent
2fd6f6b0ff
commit
f465d77ad3
11 changed files with 381 additions and 409 deletions
|
@ -19,7 +19,7 @@ pub(super) async fn resolver(subcommand: Resolver) -> Result<RoomMessageEventCon
|
|||
}
|
||||
|
||||
async fn destinations_cache(server_name: Option<OwnedServerName>) -> Result<RoomMessageEventContent> {
|
||||
use service::sending::CachedDest;
|
||||
use service::resolver::CachedDest;
|
||||
|
||||
let mut out = String::new();
|
||||
writeln!(out, "| Server Name | Destination | Hostname | Expires |")?;
|
||||
|
@ -36,12 +36,7 @@ async fn destinations_cache(server_name: Option<OwnedServerName>) -> Result<Room
|
|||
writeln!(out, "| {name} | {dest} | {host} | {expire} |").expect("wrote line");
|
||||
};
|
||||
|
||||
let map = services()
|
||||
.globals
|
||||
.resolver
|
||||
.destinations
|
||||
.read()
|
||||
.expect("locked");
|
||||
let map = services().resolver.destinations.read().expect("locked");
|
||||
|
||||
if let Some(server_name) = server_name.as_ref() {
|
||||
map.get_key_value(server_name).map(row);
|
||||
|
@ -53,7 +48,7 @@ async fn destinations_cache(server_name: Option<OwnedServerName>) -> Result<Room
|
|||
}
|
||||
|
||||
async fn overrides_cache(server_name: Option<String>) -> Result<RoomMessageEventContent> {
|
||||
use service::sending::CachedOverride;
|
||||
use service::resolver::CachedOverride;
|
||||
|
||||
let mut out = String::new();
|
||||
writeln!(out, "| Server Name | IP | Port | Expires |")?;
|
||||
|
@ -70,12 +65,7 @@ async fn overrides_cache(server_name: Option<String>) -> Result<RoomMessageEvent
|
|||
writeln!(out, "| {name} | {ips:?} | {port} | {expire} |").expect("wrote line");
|
||||
};
|
||||
|
||||
let map = services()
|
||||
.globals
|
||||
.resolver
|
||||
.overrides
|
||||
.read()
|
||||
.expect("locked");
|
||||
let map = services().resolver.overrides.read().expect("locked");
|
||||
|
||||
if let Some(server_name) = server_name.as_ref() {
|
||||
map.get_key_value(server_name).map(row);
|
||||
|
|
|
@ -2,7 +2,7 @@ use std::{sync::Arc, time::Duration};
|
|||
|
||||
use reqwest::redirect;
|
||||
|
||||
use crate::{globals::resolver, Config, Result};
|
||||
use crate::{resolver, Config, Result};
|
||||
|
||||
pub struct Client {
|
||||
pub default: reqwest::Client,
|
||||
|
@ -15,7 +15,7 @@ pub struct Client {
|
|||
}
|
||||
|
||||
impl Client {
|
||||
pub fn new(config: &Config, resolver: &Arc<resolver::Resolver>) -> Self {
|
||||
pub fn new(config: &Config, resolver: &Arc<resolver::Service>) -> Self {
|
||||
Self {
|
||||
default: Self::base(config)
|
||||
.unwrap()
|
||||
|
|
|
@ -2,7 +2,6 @@ mod client;
|
|||
mod data;
|
||||
mod emerg_access;
|
||||
pub(super) mod migrations;
|
||||
pub(crate) mod resolver;
|
||||
|
||||
use std::{
|
||||
collections::{BTreeMap, HashMap},
|
||||
|
@ -25,7 +24,7 @@ use ruma::{
|
|||
use tokio::sync::Mutex;
|
||||
use url::Url;
|
||||
|
||||
use crate::services;
|
||||
use crate::{resolver, service, services};
|
||||
|
||||
pub struct Service {
|
||||
pub db: Data,
|
||||
|
@ -34,7 +33,6 @@ pub struct Service {
|
|||
pub cidr_range_denylist: Vec<IPAddress>,
|
||||
keypair: Arc<ruma::signatures::Ed25519KeyPair>,
|
||||
jwt_decoding_key: Option<jsonwebtoken::DecodingKey>,
|
||||
pub resolver: Arc<resolver::Resolver>,
|
||||
pub client: client::Client,
|
||||
pub stable_room_versions: Vec<RoomVersionId>,
|
||||
pub unstable_room_versions: Vec<RoomVersionId>,
|
||||
|
@ -68,8 +66,6 @@ impl crate::Service for Service {
|
|||
.as_ref()
|
||||
.map(|secret| jsonwebtoken::DecodingKey::from_secret(secret.as_bytes()));
|
||||
|
||||
let resolver = Arc::new(resolver::Resolver::new(config));
|
||||
|
||||
// Supported and stable room versions
|
||||
let stable_room_versions = vec![
|
||||
RoomVersionId::V6,
|
||||
|
@ -89,12 +85,14 @@ impl crate::Service for Service {
|
|||
cidr_range_denylist.push(cidr);
|
||||
}
|
||||
|
||||
let resolver = service::get::<resolver::Service>(args.service, "resolver")
|
||||
.expect("resolver must be built prior to globals");
|
||||
|
||||
let mut s = Self {
|
||||
db,
|
||||
config: config.clone(),
|
||||
cidr_range_denylist,
|
||||
keypair: Arc::new(keypair),
|
||||
resolver: resolver.clone(),
|
||||
client: client::Client::new(config, &resolver),
|
||||
jwt_decoding_key,
|
||||
stable_room_versions,
|
||||
|
@ -126,8 +124,6 @@ impl crate::Service for Service {
|
|||
}
|
||||
|
||||
fn memory_usage(&self, out: &mut dyn Write) -> Result<()> {
|
||||
self.resolver.memory_usage(out)?;
|
||||
|
||||
let bad_event_ratelimiter = self
|
||||
.bad_event_ratelimiter
|
||||
.read()
|
||||
|
@ -146,8 +142,6 @@ impl crate::Service for Service {
|
|||
}
|
||||
|
||||
fn clear_cache(&self) {
|
||||
self.resolver.clear_cache();
|
||||
|
||||
self.bad_event_ratelimiter
|
||||
.write()
|
||||
.expect("locked for writing")
|
||||
|
@ -159,7 +153,7 @@ impl crate::Service for Service {
|
|||
.clear();
|
||||
}
|
||||
|
||||
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
|
||||
fn name(&self) -> &str { service::make_name(std::module_path!()) }
|
||||
}
|
||||
|
||||
impl Service {
|
||||
|
|
|
@ -1,156 +0,0 @@
|
|||
use std::{
|
||||
collections::HashMap,
|
||||
fmt::Write,
|
||||
future, iter,
|
||||
net::{IpAddr, SocketAddr},
|
||||
sync::{Arc, RwLock},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use conduit::{error, Config, Result};
|
||||
use hickory_resolver::TokioAsyncResolver;
|
||||
use reqwest::dns::{Addrs, Name, Resolve, Resolving};
|
||||
use ruma::OwnedServerName;
|
||||
|
||||
use crate::sending::{CachedDest, CachedOverride};
|
||||
|
||||
type WellKnownMap = HashMap<OwnedServerName, CachedDest>;
|
||||
type TlsNameMap = HashMap<String, CachedOverride>;
|
||||
|
||||
pub struct Resolver {
|
||||
pub destinations: Arc<RwLock<WellKnownMap>>, // actual_destination, host
|
||||
pub overrides: Arc<RwLock<TlsNameMap>>,
|
||||
pub(crate) resolver: Arc<TokioAsyncResolver>,
|
||||
pub(crate) hooked: Arc<Hooked>,
|
||||
}
|
||||
|
||||
pub(crate) struct Hooked {
|
||||
overrides: Arc<RwLock<TlsNameMap>>,
|
||||
resolver: Arc<TokioAsyncResolver>,
|
||||
}
|
||||
|
||||
impl Resolver {
|
||||
#[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)]
|
||||
pub(super) fn new(config: &Config) -> Self {
|
||||
let (sys_conf, mut opts) = hickory_resolver::system_conf::read_system_conf()
|
||||
.inspect_err(|e| error!("Failed to set up hickory dns resolver with system config: {e}"))
|
||||
.expect("DNS system config must be valid");
|
||||
|
||||
let mut conf = hickory_resolver::config::ResolverConfig::new();
|
||||
|
||||
if let Some(domain) = sys_conf.domain() {
|
||||
conf.set_domain(domain.clone());
|
||||
}
|
||||
|
||||
for sys_conf in sys_conf.search() {
|
||||
conf.add_search(sys_conf.clone());
|
||||
}
|
||||
|
||||
for sys_conf in sys_conf.name_servers() {
|
||||
let mut ns = sys_conf.clone();
|
||||
|
||||
if config.query_over_tcp_only {
|
||||
ns.protocol = hickory_resolver::config::Protocol::Tcp;
|
||||
}
|
||||
|
||||
ns.trust_negative_responses = !config.query_all_nameservers;
|
||||
|
||||
conf.add_name_server(ns);
|
||||
}
|
||||
|
||||
opts.cache_size = config.dns_cache_entries as usize;
|
||||
opts.negative_min_ttl = Some(Duration::from_secs(config.dns_min_ttl_nxdomain));
|
||||
opts.negative_max_ttl = Some(Duration::from_secs(60 * 60 * 24 * 30));
|
||||
opts.positive_min_ttl = Some(Duration::from_secs(config.dns_min_ttl));
|
||||
opts.positive_max_ttl = Some(Duration::from_secs(60 * 60 * 24 * 7));
|
||||
opts.timeout = Duration::from_secs(config.dns_timeout);
|
||||
opts.attempts = config.dns_attempts as usize;
|
||||
opts.try_tcp_on_error = config.dns_tcp_fallback;
|
||||
opts.num_concurrent_reqs = 1;
|
||||
opts.shuffle_dns_servers = true;
|
||||
opts.rotate = true;
|
||||
opts.ip_strategy = match config.ip_lookup_strategy {
|
||||
1 => hickory_resolver::config::LookupIpStrategy::Ipv4Only,
|
||||
2 => hickory_resolver::config::LookupIpStrategy::Ipv6Only,
|
||||
3 => hickory_resolver::config::LookupIpStrategy::Ipv4AndIpv6,
|
||||
4 => hickory_resolver::config::LookupIpStrategy::Ipv6thenIpv4,
|
||||
_ => hickory_resolver::config::LookupIpStrategy::Ipv4thenIpv6,
|
||||
};
|
||||
opts.authentic_data = false;
|
||||
|
||||
let resolver = Arc::new(TokioAsyncResolver::tokio(conf, opts));
|
||||
let overrides = Arc::new(RwLock::new(TlsNameMap::new()));
|
||||
Self {
|
||||
destinations: Arc::new(RwLock::new(WellKnownMap::new())),
|
||||
overrides: overrides.clone(),
|
||||
resolver: resolver.clone(),
|
||||
hooked: Arc::new(Hooked {
|
||||
overrides,
|
||||
resolver,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn memory_usage(&self, out: &mut dyn Write) -> Result<()> {
|
||||
let resolver_overrides_cache = self.overrides.read().expect("locked for reading").len();
|
||||
writeln!(out, "resolver_overrides_cache: {resolver_overrides_cache}")?;
|
||||
|
||||
let resolver_destinations_cache = self.destinations.read().expect("locked for reading").len();
|
||||
writeln!(out, "resolver_destinations_cache: {resolver_destinations_cache}")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) fn clear_cache(&self) {
|
||||
self.overrides.write().expect("write locked").clear();
|
||||
self.destinations.write().expect("write locked").clear();
|
||||
self.resolver.clear_cache();
|
||||
}
|
||||
}
|
||||
|
||||
impl Resolve for Resolver {
|
||||
fn resolve(&self, name: Name) -> Resolving { resolve_to_reqwest(self.resolver.clone(), name) }
|
||||
}
|
||||
|
||||
impl Resolve for Hooked {
|
||||
fn resolve(&self, name: Name) -> Resolving {
|
||||
let cached = self
|
||||
.overrides
|
||||
.read()
|
||||
.expect("locked for reading")
|
||||
.get(name.as_str())
|
||||
.filter(|cached| cached.valid())
|
||||
.cloned();
|
||||
|
||||
if let Some(cached) = cached {
|
||||
cached_to_reqwest(&cached.ips, cached.port)
|
||||
} else {
|
||||
resolve_to_reqwest(self.resolver.clone(), name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn cached_to_reqwest(override_name: &[IpAddr], port: u16) -> Resolving {
|
||||
override_name
|
||||
.first()
|
||||
.map(|first_name| -> Resolving {
|
||||
let saddr = SocketAddr::new(*first_name, port);
|
||||
let result: Box<dyn Iterator<Item = SocketAddr> + Send> = Box::new(iter::once(saddr));
|
||||
Box::pin(future::ready(Ok(result)))
|
||||
})
|
||||
.expect("must provide at least one override name")
|
||||
}
|
||||
|
||||
fn resolve_to_reqwest(resolver: Arc<TokioAsyncResolver>, name: Name) -> Resolving {
|
||||
Box::pin(async move {
|
||||
let results = resolver
|
||||
.lookup_ip(name.as_str())
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|ip| SocketAddr::new(ip, 0));
|
||||
|
||||
let results: Addrs = Box::new(results);
|
||||
|
||||
Ok(results)
|
||||
})
|
||||
}
|
|
@ -12,6 +12,7 @@ pub mod key_backups;
|
|||
pub mod media;
|
||||
pub mod presence;
|
||||
pub mod pusher;
|
||||
pub mod resolver;
|
||||
pub mod rooms;
|
||||
pub mod sending;
|
||||
pub mod transaction_ids;
|
||||
|
|
349
src/service/resolver/mod.rs
Normal file
349
src/service/resolver/mod.rs
Normal file
|
@ -0,0 +1,349 @@
|
|||
use std::{
|
||||
collections::HashMap,
|
||||
fmt,
|
||||
fmt::Write,
|
||||
future, iter,
|
||||
net::{IpAddr, SocketAddr},
|
||||
sync::{Arc, RwLock},
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
|
||||
use conduit::{err, trace, Result};
|
||||
use hickory_resolver::TokioAsyncResolver;
|
||||
use reqwest::dns::{Addrs, Name, Resolve, Resolving};
|
||||
use ruma::{OwnedServerName, ServerName};
|
||||
|
||||
use crate::utils::rand;
|
||||
|
||||
pub struct Service {
|
||||
pub destinations: Arc<RwLock<WellKnownMap>>, // actual_destination, host
|
||||
pub overrides: Arc<RwLock<TlsNameMap>>,
|
||||
pub(crate) resolver: Arc<TokioAsyncResolver>,
|
||||
pub(crate) hooked: Arc<Hooked>,
|
||||
}
|
||||
|
||||
pub(crate) struct Hooked {
|
||||
overrides: Arc<RwLock<TlsNameMap>>,
|
||||
resolver: Arc<TokioAsyncResolver>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CachedDest {
|
||||
pub dest: FedDest,
|
||||
pub host: String,
|
||||
pub expire: SystemTime,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CachedOverride {
|
||||
pub ips: Vec<IpAddr>,
|
||||
pub port: u16,
|
||||
pub expire: SystemTime,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum FedDest {
|
||||
Literal(SocketAddr),
|
||||
Named(String, String),
|
||||
}
|
||||
|
||||
type WellKnownMap = HashMap<OwnedServerName, CachedDest>;
|
||||
type TlsNameMap = HashMap<String, CachedOverride>;
|
||||
|
||||
impl crate::Service for Service {
|
||||
#[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)]
|
||||
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
|
||||
let config = &args.server.config;
|
||||
let (sys_conf, mut opts) = hickory_resolver::system_conf::read_system_conf()
|
||||
.map_err(|e| err!(error!("Failed to configure DNS resolver from system: {e}")))?;
|
||||
|
||||
let mut conf = hickory_resolver::config::ResolverConfig::new();
|
||||
|
||||
if let Some(domain) = sys_conf.domain() {
|
||||
conf.set_domain(domain.clone());
|
||||
}
|
||||
|
||||
for sys_conf in sys_conf.search() {
|
||||
conf.add_search(sys_conf.clone());
|
||||
}
|
||||
|
||||
for sys_conf in sys_conf.name_servers() {
|
||||
let mut ns = sys_conf.clone();
|
||||
|
||||
if config.query_over_tcp_only {
|
||||
ns.protocol = hickory_resolver::config::Protocol::Tcp;
|
||||
}
|
||||
|
||||
ns.trust_negative_responses = !config.query_all_nameservers;
|
||||
|
||||
conf.add_name_server(ns);
|
||||
}
|
||||
|
||||
opts.cache_size = config.dns_cache_entries as usize;
|
||||
opts.negative_min_ttl = Some(Duration::from_secs(config.dns_min_ttl_nxdomain));
|
||||
opts.negative_max_ttl = Some(Duration::from_secs(60 * 60 * 24 * 30));
|
||||
opts.positive_min_ttl = Some(Duration::from_secs(config.dns_min_ttl));
|
||||
opts.positive_max_ttl = Some(Duration::from_secs(60 * 60 * 24 * 7));
|
||||
opts.timeout = Duration::from_secs(config.dns_timeout);
|
||||
opts.attempts = config.dns_attempts as usize;
|
||||
opts.try_tcp_on_error = config.dns_tcp_fallback;
|
||||
opts.num_concurrent_reqs = 1;
|
||||
opts.shuffle_dns_servers = true;
|
||||
opts.rotate = true;
|
||||
opts.ip_strategy = match config.ip_lookup_strategy {
|
||||
1 => hickory_resolver::config::LookupIpStrategy::Ipv4Only,
|
||||
2 => hickory_resolver::config::LookupIpStrategy::Ipv6Only,
|
||||
3 => hickory_resolver::config::LookupIpStrategy::Ipv4AndIpv6,
|
||||
4 => hickory_resolver::config::LookupIpStrategy::Ipv6thenIpv4,
|
||||
_ => hickory_resolver::config::LookupIpStrategy::Ipv4thenIpv6,
|
||||
};
|
||||
opts.authentic_data = false;
|
||||
|
||||
let resolver = Arc::new(TokioAsyncResolver::tokio(conf, opts));
|
||||
let overrides = Arc::new(RwLock::new(TlsNameMap::new()));
|
||||
Ok(Arc::new(Self {
|
||||
destinations: Arc::new(RwLock::new(WellKnownMap::new())),
|
||||
overrides: overrides.clone(),
|
||||
resolver: resolver.clone(),
|
||||
hooked: Arc::new(Hooked {
|
||||
overrides,
|
||||
resolver,
|
||||
}),
|
||||
}))
|
||||
}
|
||||
|
||||
fn memory_usage(&self, out: &mut dyn Write) -> Result<()> {
|
||||
let resolver_overrides_cache = self.overrides.read().expect("locked for reading").len();
|
||||
writeln!(out, "resolver_overrides_cache: {resolver_overrides_cache}")?;
|
||||
|
||||
let resolver_destinations_cache = self.destinations.read().expect("locked for reading").len();
|
||||
writeln!(out, "resolver_destinations_cache: {resolver_destinations_cache}")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn clear_cache(&self) {
|
||||
self.overrides.write().expect("write locked").clear();
|
||||
self.destinations.write().expect("write locked").clear();
|
||||
self.resolver.clear_cache();
|
||||
}
|
||||
|
||||
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
|
||||
}
|
||||
|
||||
impl Service {
|
||||
pub fn set_cached_destination(&self, name: OwnedServerName, dest: CachedDest) -> Option<CachedDest> {
|
||||
trace!(?name, ?dest, "set cached destination");
|
||||
self.destinations
|
||||
.write()
|
||||
.expect("locked for writing")
|
||||
.insert(name, dest)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn get_cached_destination(&self, name: &ServerName) -> Option<CachedDest> {
|
||||
self.destinations
|
||||
.read()
|
||||
.expect("locked for reading")
|
||||
.get(name)
|
||||
.cloned()
|
||||
}
|
||||
|
||||
pub fn set_cached_override(&self, name: String, over: CachedOverride) -> Option<CachedOverride> {
|
||||
trace!(?name, ?over, "set cached override");
|
||||
self.overrides
|
||||
.write()
|
||||
.expect("locked for writing")
|
||||
.insert(name, over)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn get_cached_override(&self, name: &str) -> Option<CachedOverride> {
|
||||
self.overrides
|
||||
.read()
|
||||
.expect("locked for reading")
|
||||
.get(name)
|
||||
.cloned()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn has_cached_override(&self, name: &str) -> bool {
|
||||
self.overrides
|
||||
.read()
|
||||
.expect("locked for reading")
|
||||
.contains_key(name)
|
||||
}
|
||||
}
|
||||
|
||||
impl Resolve for Service {
|
||||
fn resolve(&self, name: Name) -> Resolving { resolve_to_reqwest(self.resolver.clone(), name) }
|
||||
}
|
||||
|
||||
impl Resolve for Hooked {
|
||||
fn resolve(&self, name: Name) -> Resolving {
|
||||
let cached = self
|
||||
.overrides
|
||||
.read()
|
||||
.expect("locked for reading")
|
||||
.get(name.as_str())
|
||||
.cloned();
|
||||
|
||||
if let Some(cached) = cached {
|
||||
cached_to_reqwest(&cached.ips, cached.port)
|
||||
} else {
|
||||
resolve_to_reqwest(self.resolver.clone(), name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn cached_to_reqwest(override_name: &[IpAddr], port: u16) -> Resolving {
|
||||
override_name
|
||||
.first()
|
||||
.map(|first_name| -> Resolving {
|
||||
let saddr = SocketAddr::new(*first_name, port);
|
||||
let result: Box<dyn Iterator<Item = SocketAddr> + Send> = Box::new(iter::once(saddr));
|
||||
Box::pin(future::ready(Ok(result)))
|
||||
})
|
||||
.expect("must provide at least one override name")
|
||||
}
|
||||
|
||||
fn resolve_to_reqwest(resolver: Arc<TokioAsyncResolver>, name: Name) -> Resolving {
|
||||
Box::pin(async move {
|
||||
let results = resolver
|
||||
.lookup_ip(name.as_str())
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|ip| SocketAddr::new(ip, 0));
|
||||
|
||||
let results: Addrs = Box::new(results);
|
||||
|
||||
Ok(results)
|
||||
})
|
||||
}
|
||||
|
||||
impl CachedDest {
|
||||
#[inline]
|
||||
#[must_use]
|
||||
pub fn valid(&self) -> bool { true }
|
||||
|
||||
//pub fn valid(&self) -> bool { self.expire > SystemTime::now() }
|
||||
|
||||
#[must_use]
|
||||
pub(crate) fn default_expire() -> SystemTime { rand::timepoint_secs(60 * 60 * 18..60 * 60 * 36) }
|
||||
}
|
||||
|
||||
impl CachedOverride {
|
||||
#[inline]
|
||||
#[must_use]
|
||||
pub fn valid(&self) -> bool { true }
|
||||
|
||||
//pub fn valid(&self) -> bool { self.expire > SystemTime::now() }
|
||||
|
||||
#[must_use]
|
||||
pub(crate) fn default_expire() -> SystemTime { rand::timepoint_secs(60 * 60 * 6..60 * 60 * 12) }
|
||||
}
|
||||
|
||||
pub(crate) fn get_ip_with_port(dest_str: &str) -> Option<FedDest> {
|
||||
if let Ok(dest) = dest_str.parse::<SocketAddr>() {
|
||||
Some(FedDest::Literal(dest))
|
||||
} else if let Ok(ip_addr) = dest_str.parse::<IpAddr>() {
|
||||
Some(FedDest::Literal(SocketAddr::new(ip_addr, 8448)))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn add_port_to_hostname(dest_str: &str) -> FedDest {
|
||||
let (host, port) = match dest_str.find(':') {
|
||||
None => (dest_str, ":8448"),
|
||||
Some(pos) => dest_str.split_at(pos),
|
||||
};
|
||||
|
||||
FedDest::Named(host.to_owned(), port.to_owned())
|
||||
}
|
||||
|
||||
impl FedDest {
|
||||
pub(crate) fn into_https_string(self) -> String {
|
||||
match self {
|
||||
Self::Literal(addr) => format!("https://{addr}"),
|
||||
Self::Named(host, port) => format!("https://{host}{port}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn into_uri_string(self) -> String {
|
||||
match self {
|
||||
Self::Literal(addr) => addr.to_string(),
|
||||
Self::Named(host, port) => format!("{host}{port}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn hostname(&self) -> String {
|
||||
match &self {
|
||||
Self::Literal(addr) => addr.ip().to_string(),
|
||||
Self::Named(host, _) => host.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
#[allow(clippy::string_slice)]
|
||||
pub(crate) fn port(&self) -> Option<u16> {
|
||||
match &self {
|
||||
Self::Literal(addr) => Some(addr.port()),
|
||||
Self::Named(_, port) => port[1..].parse().ok(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for FedDest {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Named(host, port) => write!(f, "{host}{port}"),
|
||||
Self::Literal(addr) => write!(f, "{addr}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{add_port_to_hostname, get_ip_with_port, FedDest};
|
||||
|
||||
#[test]
|
||||
fn ips_get_default_ports() {
|
||||
assert_eq!(
|
||||
get_ip_with_port("1.1.1.1"),
|
||||
Some(FedDest::Literal("1.1.1.1:8448".parse().unwrap()))
|
||||
);
|
||||
assert_eq!(
|
||||
get_ip_with_port("dead:beef::"),
|
||||
Some(FedDest::Literal("[dead:beef::]:8448".parse().unwrap()))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ips_keep_custom_ports() {
|
||||
assert_eq!(
|
||||
get_ip_with_port("1.1.1.1:1234"),
|
||||
Some(FedDest::Literal("1.1.1.1:1234".parse().unwrap()))
|
||||
);
|
||||
assert_eq!(
|
||||
get_ip_with_port("[dead::beef]:8933"),
|
||||
Some(FedDest::Literal("[dead::beef]:8933".parse().unwrap()))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hostnames_get_default_ports() {
|
||||
assert_eq!(
|
||||
add_port_to_hostname("example.com"),
|
||||
FedDest::Named(String::from("example.com"), String::from(":8448"))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hostnames_keep_custom_ports() {
|
||||
assert_eq!(
|
||||
add_port_to_hostname("example.com:1337"),
|
||||
FedDest::Named(String::from("example.com"), String::from(":1337"))
|
||||
);
|
||||
}
|
||||
}
|
|
@ -7,7 +7,7 @@ mod sender;
|
|||
use std::fmt::Debug;
|
||||
|
||||
use conduit::{err, Result};
|
||||
pub use resolve::{resolve_actual_dest, CachedDest, CachedOverride, FedDest};
|
||||
pub use resolve::resolve_actual_dest;
|
||||
use ruma::{
|
||||
api::{appservice::Registration, OutgoingRequest},
|
||||
OwnedServerName, OwnedUserId, RoomId, ServerName, UserId,
|
||||
|
|
|
@ -1,40 +1,17 @@
|
|||
use std::{
|
||||
fmt,
|
||||
fmt::Debug,
|
||||
net::{IpAddr, SocketAddr},
|
||||
time::SystemTime,
|
||||
};
|
||||
|
||||
use conduit::{debug, debug_error, debug_info, debug_warn, trace, utils::rand, Err, Error, Result};
|
||||
use conduit::{debug, debug_error, debug_info, debug_warn, trace, Err, Error, Result};
|
||||
use hickory_resolver::{error::ResolveError, lookup::SrvLookup};
|
||||
use ipaddress::IPAddress;
|
||||
use ruma::{OwnedServerName, ServerName};
|
||||
use ruma::ServerName;
|
||||
|
||||
use crate::services;
|
||||
|
||||
/// Wraps either an literal IP address plus port, or a hostname plus complement
|
||||
/// (colon-plus-port if it was specified).
|
||||
///
|
||||
/// Note: A `FedDest::Named` might contain an IP address in string form if there
|
||||
/// was no port specified to construct a `SocketAddr` with.
|
||||
///
|
||||
/// # Examples:
|
||||
/// ```rust
|
||||
/// # use conduit_service::sending::FedDest;
|
||||
/// # fn main() -> Result<(), std::net::AddrParseError> {
|
||||
/// FedDest::Literal("198.51.100.3:8448".parse()?);
|
||||
/// FedDest::Literal("[2001:db8::4:5]:443".parse()?);
|
||||
/// FedDest::Named("matrix.example.org".to_owned(), String::new());
|
||||
/// FedDest::Named("matrix.example.org".to_owned(), ":8448".to_owned());
|
||||
/// FedDest::Named("198.51.100.5".to_owned(), String::new());
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||
pub enum FedDest {
|
||||
Literal(SocketAddr),
|
||||
Named(String, String),
|
||||
}
|
||||
use crate::{
|
||||
resolver::{add_port_to_hostname, get_ip_with_port, CachedDest, CachedOverride, FedDest},
|
||||
services,
|
||||
};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct ActualDest {
|
||||
|
@ -44,27 +21,10 @@ pub(crate) struct ActualDest {
|
|||
pub(crate) cached: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CachedDest {
|
||||
pub dest: FedDest,
|
||||
pub host: String,
|
||||
pub expire: SystemTime,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CachedOverride {
|
||||
pub ips: Vec<IpAddr>,
|
||||
pub port: u16,
|
||||
pub expire: SystemTime,
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, name = "resolve")]
|
||||
pub(crate) async fn get_actual_dest(server_name: &ServerName) -> Result<ActualDest> {
|
||||
let cached;
|
||||
let cached_result = services()
|
||||
.globals
|
||||
.resolver
|
||||
.get_cached_destination(server_name);
|
||||
let cached_result = services().resolver.get_cached_destination(server_name);
|
||||
|
||||
let CachedDest {
|
||||
dest,
|
||||
|
@ -213,7 +173,7 @@ async fn actual_dest_5(dest: &ServerName, cache: bool) -> Result<FedDest> {
|
|||
#[tracing::instrument(skip_all, name = "well-known")]
|
||||
async fn request_well_known(dest: &str) -> Result<Option<String>> {
|
||||
trace!("Requesting well known for {dest}");
|
||||
if !services().globals.resolver.has_cached_override(dest) {
|
||||
if !services().resolver.has_cached_override(dest) {
|
||||
query_and_cache_override(dest, dest, 8448).await?;
|
||||
}
|
||||
|
||||
|
@ -273,7 +233,6 @@ async fn conditional_query_and_cache_override(overname: &str, hostname: &str, po
|
|||
#[tracing::instrument(skip_all, name = "ip")]
|
||||
async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u16) -> Result<()> {
|
||||
match services()
|
||||
.globals
|
||||
.resolver
|
||||
.resolver
|
||||
.lookup_ip(hostname.to_owned())
|
||||
|
@ -285,7 +244,7 @@ async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u1
|
|||
debug_info!("{overname:?} overriden by {hostname:?}");
|
||||
}
|
||||
|
||||
services().globals.resolver.set_cached_override(
|
||||
services().resolver.set_cached_override(
|
||||
overname.to_owned(),
|
||||
CachedOverride {
|
||||
ips: override_ip.iter().collect(),
|
||||
|
@ -314,7 +273,6 @@ async fn query_srv_record(hostname: &'_ str) -> Result<Option<FedDest>> {
|
|||
debug!("querying SRV for {:?}", hostname);
|
||||
let hostname = hostname.trim_end_matches('.');
|
||||
services()
|
||||
.globals
|
||||
.resolver
|
||||
.resolver
|
||||
.srv_lookup(hostname.to_owned())
|
||||
|
@ -384,166 +342,3 @@ pub(crate) fn validate_ip(ip: &IPAddress) -> Result<()> {
|
|||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_ip_with_port(dest_str: &str) -> Option<FedDest> {
|
||||
if let Ok(dest) = dest_str.parse::<SocketAddr>() {
|
||||
Some(FedDest::Literal(dest))
|
||||
} else if let Ok(ip_addr) = dest_str.parse::<IpAddr>() {
|
||||
Some(FedDest::Literal(SocketAddr::new(ip_addr, 8448)))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn add_port_to_hostname(dest_str: &str) -> FedDest {
|
||||
let (host, port) = match dest_str.find(':') {
|
||||
None => (dest_str, ":8448"),
|
||||
Some(pos) => dest_str.split_at(pos),
|
||||
};
|
||||
|
||||
FedDest::Named(host.to_owned(), port.to_owned())
|
||||
}
|
||||
|
||||
impl crate::globals::resolver::Resolver {
|
||||
pub(crate) fn set_cached_destination(&self, name: OwnedServerName, dest: CachedDest) -> Option<CachedDest> {
|
||||
trace!(?name, ?dest, "set cached destination");
|
||||
self.destinations
|
||||
.write()
|
||||
.expect("locked for writing")
|
||||
.insert(name, dest)
|
||||
}
|
||||
|
||||
pub(crate) fn get_cached_destination(&self, name: &ServerName) -> Option<CachedDest> {
|
||||
self.destinations
|
||||
.read()
|
||||
.expect("locked for reading")
|
||||
.get(name)
|
||||
.filter(|cached| cached.valid())
|
||||
.cloned()
|
||||
}
|
||||
|
||||
pub(crate) fn set_cached_override(&self, name: String, over: CachedOverride) -> Option<CachedOverride> {
|
||||
trace!(?name, ?over, "set cached override");
|
||||
self.overrides
|
||||
.write()
|
||||
.expect("locked for writing")
|
||||
.insert(name, over)
|
||||
}
|
||||
|
||||
pub(crate) fn has_cached_override(&self, name: &str) -> bool {
|
||||
self.overrides
|
||||
.read()
|
||||
.expect("locked for reading")
|
||||
.get(name)
|
||||
.filter(|cached| cached.valid())
|
||||
.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
impl CachedDest {
|
||||
#[inline]
|
||||
#[must_use]
|
||||
pub fn valid(&self) -> bool { true }
|
||||
|
||||
//pub fn valid(&self) -> bool { self.expire > SystemTime::now() }
|
||||
|
||||
#[must_use]
|
||||
pub(crate) fn default_expire() -> SystemTime { rand::timepoint_secs(60 * 60 * 18..60 * 60 * 36) }
|
||||
}
|
||||
|
||||
impl CachedOverride {
|
||||
#[inline]
|
||||
#[must_use]
|
||||
pub fn valid(&self) -> bool { true }
|
||||
|
||||
//pub fn valid(&self) -> bool { self.expire > SystemTime::now() }
|
||||
|
||||
#[must_use]
|
||||
pub(crate) fn default_expire() -> SystemTime { rand::timepoint_secs(60 * 60 * 6..60 * 60 * 12) }
|
||||
}
|
||||
|
||||
impl FedDest {
|
||||
fn into_https_string(self) -> String {
|
||||
match self {
|
||||
Self::Literal(addr) => format!("https://{addr}"),
|
||||
Self::Named(host, port) => format!("https://{host}{port}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn into_uri_string(self) -> String {
|
||||
match self {
|
||||
Self::Literal(addr) => addr.to_string(),
|
||||
Self::Named(host, port) => format!("{host}{port}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn hostname(&self) -> String {
|
||||
match &self {
|
||||
Self::Literal(addr) => addr.ip().to_string(),
|
||||
Self::Named(host, _) => host.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
#[allow(clippy::string_slice)]
|
||||
fn port(&self) -> Option<u16> {
|
||||
match &self {
|
||||
Self::Literal(addr) => Some(addr.port()),
|
||||
Self::Named(_, port) => port[1..].parse().ok(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for FedDest {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Named(host, port) => write!(f, "{host}{port}"),
|
||||
Self::Literal(addr) => write!(f, "{addr}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{add_port_to_hostname, get_ip_with_port, FedDest};
|
||||
|
||||
#[test]
|
||||
fn ips_get_default_ports() {
|
||||
assert_eq!(
|
||||
get_ip_with_port("1.1.1.1"),
|
||||
Some(FedDest::Literal("1.1.1.1:8448".parse().unwrap()))
|
||||
);
|
||||
assert_eq!(
|
||||
get_ip_with_port("dead:beef::"),
|
||||
Some(FedDest::Literal("[dead:beef::]:8448".parse().unwrap()))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ips_keep_custom_ports() {
|
||||
assert_eq!(
|
||||
get_ip_with_port("1.1.1.1:1234"),
|
||||
Some(FedDest::Literal("1.1.1.1:1234".parse().unwrap()))
|
||||
);
|
||||
assert_eq!(
|
||||
get_ip_with_port("[dead::beef]:8933"),
|
||||
Some(FedDest::Literal("[dead::beef]:8933".parse().unwrap()))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hostnames_get_default_ports() {
|
||||
assert_eq!(
|
||||
add_port_to_hostname("example.com"),
|
||||
FedDest::Named(String::from("example.com"), String::from(":8448"))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hostnames_keep_custom_ports() {
|
||||
assert_eq!(
|
||||
add_port_to_hostname("example.com:1337"),
|
||||
FedDest::Named(String::from("example.com"), String::from(":1337"))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,11 +15,8 @@ use ruma::{
|
|||
};
|
||||
use tracing::{debug, trace};
|
||||
|
||||
use super::{
|
||||
resolve,
|
||||
resolve::{ActualDest, CachedDest},
|
||||
};
|
||||
use crate::{debug_error, debug_warn, services, Error, Result};
|
||||
use super::{resolve, resolve::ActualDest};
|
||||
use crate::{debug_error, debug_warn, resolver::CachedDest, services, Error, Result};
|
||||
|
||||
#[tracing::instrument(skip_all, name = "send")]
|
||||
pub async fn send<T>(client: &Client, dest: &ServerName, req: T) -> Result<T::IncomingResponse>
|
||||
|
@ -109,7 +106,7 @@ where
|
|||
|
||||
let response = T::IncomingResponse::try_from_http_response(http_response);
|
||||
if response.is_ok() && !actual.cached {
|
||||
services().globals.resolver.set_cached_destination(
|
||||
services().resolver.set_cached_destination(
|
||||
dest.to_owned(),
|
||||
CachedDest {
|
||||
dest: actual.dest.clone(),
|
||||
|
|
|
@ -37,7 +37,7 @@ pub(crate) trait Service: Any + Send + Sync {
|
|||
pub(crate) struct Args<'a> {
|
||||
pub(crate) server: &'a Arc<Server>,
|
||||
pub(crate) db: &'a Arc<Database>,
|
||||
pub(crate) _service: &'a Map,
|
||||
pub(crate) service: &'a Map,
|
||||
}
|
||||
|
||||
pub(crate) type Map = BTreeMap<String, MapVal>;
|
||||
|
|
|
@ -7,12 +7,14 @@ use tokio::sync::Mutex;
|
|||
use crate::{
|
||||
account_data, admin, appservice, globals, key_backups,
|
||||
manager::Manager,
|
||||
media, presence, pusher, rooms, sending, service,
|
||||
media, presence, pusher, resolver, rooms, sending, service,
|
||||
service::{Args, Map, Service},
|
||||
transaction_ids, uiaa, updates, users,
|
||||
};
|
||||
|
||||
pub struct Services {
|
||||
pub resolver: Arc<resolver::Service>,
|
||||
pub globals: Arc<globals::Service>,
|
||||
pub rooms: rooms::Service,
|
||||
pub appservice: Arc<appservice::Service>,
|
||||
pub pusher: Arc<pusher::Service>,
|
||||
|
@ -26,7 +28,6 @@ pub struct Services {
|
|||
pub media: Arc<media::Service>,
|
||||
pub sending: Arc<sending::Service>,
|
||||
pub updates: Arc<updates::Service>,
|
||||
pub globals: Arc<globals::Service>,
|
||||
|
||||
manager: Mutex<Option<Arc<Manager>>>,
|
||||
pub(crate) service: Map,
|
||||
|
@ -42,7 +43,7 @@ impl Services {
|
|||
let built = <$tyname>::build(Args {
|
||||
server: &server,
|
||||
db: &db,
|
||||
_service: &service,
|
||||
service: &service,
|
||||
})?;
|
||||
service.insert(built.name().to_owned(), (built.clone(), built.clone()));
|
||||
built
|
||||
|
@ -50,6 +51,8 @@ impl Services {
|
|||
}
|
||||
|
||||
Ok(Self {
|
||||
resolver: build!(resolver::Service),
|
||||
globals: build!(globals::Service),
|
||||
rooms: rooms::Service {
|
||||
alias: build!(rooms::alias::Service),
|
||||
auth_chain: build!(rooms::auth_chain::Service),
|
||||
|
@ -84,7 +87,6 @@ impl Services {
|
|||
media: build!(media::Service),
|
||||
sending: build!(sending::Service),
|
||||
updates: build!(updates::Service),
|
||||
globals: build!(globals::Service),
|
||||
manager: Mutex::new(None),
|
||||
service,
|
||||
server,
|
||||
|
|
Loading…
Add table
Reference in a new issue