remove resolver wrapper; use std mutex

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-07-03 01:24:16 +00:00
parent be2d1c722b
commit 7658387a74
5 changed files with 47 additions and 20 deletions

View file

@ -196,8 +196,6 @@ impl Service {
pub fn dns_resolver(&self) -> &TokioAsyncResolver { &self.resolver.resolver } pub fn dns_resolver(&self) -> &TokioAsyncResolver { &self.resolver.resolver }
pub fn actual_destinations(&self) -> &Arc<RwLock<resolver::WellKnownMap>> { &self.resolver.destinations }
pub fn jwt_decoding_key(&self) -> Option<&jsonwebtoken::DecodingKey> { self.jwt_decoding_key.as_ref() } pub fn jwt_decoding_key(&self) -> Option<&jsonwebtoken::DecodingKey> { self.jwt_decoding_key.as_ref() }
pub fn turn_password(&self) -> &String { &self.config.turn_password } pub fn turn_password(&self) -> &String { &self.config.turn_password }

View file

@ -2,7 +2,7 @@ use std::{
collections::HashMap, collections::HashMap,
future, iter, future, iter,
net::{IpAddr, SocketAddr}, net::{IpAddr, SocketAddr},
sync::{Arc, RwLock as StdRwLock}, sync::{Arc, RwLock},
time::Duration, time::Duration,
}; };
@ -10,22 +10,21 @@ use conduit::{error, Config, Error};
use hickory_resolver::TokioAsyncResolver; use hickory_resolver::TokioAsyncResolver;
use reqwest::dns::{Addrs, Name, Resolve, Resolving}; use reqwest::dns::{Addrs, Name, Resolve, Resolving};
use ruma::OwnedServerName; use ruma::OwnedServerName;
use tokio::sync::RwLock;
use crate::sending::FedDest; use crate::sending::FedDest;
pub(crate) type WellKnownMap = HashMap<OwnedServerName, (FedDest, String)>; type WellKnownMap = HashMap<OwnedServerName, (FedDest, String)>;
type TlsNameMap = HashMap<String, (Vec<IpAddr>, u16)>; type TlsNameMap = HashMap<String, (Vec<IpAddr>, u16)>;
pub struct Resolver { pub struct Resolver {
pub destinations: Arc<RwLock<WellKnownMap>>, // actual_destination, host pub destinations: Arc<RwLock<WellKnownMap>>, // actual_destination, host
pub overrides: Arc<StdRwLock<TlsNameMap>>, pub overrides: Arc<RwLock<TlsNameMap>>,
pub resolver: Arc<TokioAsyncResolver>, pub resolver: Arc<TokioAsyncResolver>,
pub hooked: Arc<Hooked>, pub hooked: Arc<Hooked>,
} }
pub struct Hooked { pub struct Hooked {
pub overrides: Arc<StdRwLock<TlsNameMap>>, pub overrides: Arc<RwLock<TlsNameMap>>,
pub resolver: Arc<TokioAsyncResolver>, pub resolver: Arc<TokioAsyncResolver>,
} }
@ -34,10 +33,10 @@ impl Resolver {
pub fn new(config: &Config) -> Self { pub fn new(config: &Config) -> Self {
let (sys_conf, mut opts) = hickory_resolver::system_conf::read_system_conf() let (sys_conf, mut opts) = hickory_resolver::system_conf::read_system_conf()
.map_err(|e| { .map_err(|e| {
error!("Failed to set up hickory dns resolver with system config: {}", e); error!("Failed to set up hickory dns resolver with system config: {e}");
Error::bad_config("Failed to set up hickory dns resolver with system config.") Error::bad_config("Failed to set up hickory dns resolver with system config.")
}) })
.unwrap(); .expect("DNS system config must be valid");
let mut conf = hickory_resolver::config::ResolverConfig::new(); let mut conf = hickory_resolver::config::ResolverConfig::new();
@ -82,7 +81,7 @@ impl Resolver {
opts.authentic_data = false; opts.authentic_data = false;
let resolver = Arc::new(TokioAsyncResolver::tokio(conf, opts)); let resolver = Arc::new(TokioAsyncResolver::tokio(conf, opts));
let overrides = Arc::new(StdRwLock::new(TlsNameMap::new())); let overrides = Arc::new(RwLock::new(TlsNameMap::new()));
Self { Self {
destinations: Arc::new(RwLock::new(WellKnownMap::new())), destinations: Arc::new(RwLock::new(WellKnownMap::new())),
overrides: overrides.clone(), overrides: overrides.clone(),
@ -101,7 +100,13 @@ impl Resolve for Resolver {
impl Resolve for Hooked { impl Resolve for Hooked {
fn resolve(&self, name: Name) -> Resolving { fn resolve(&self, name: Name) -> Resolving {
let addr_port = self.overrides.read().unwrap().get(name.as_str()).cloned(); let addr_port = self
.overrides
.read()
.expect("locked for reading")
.get(name.as_str())
.cloned();
if let Some((addr, port)) = addr_port { if let Some((addr, port)) = addr_port {
cached_to_reqwest(&addr, port) cached_to_reqwest(&addr, port)
} else { } else {
@ -118,7 +123,7 @@ fn cached_to_reqwest(override_name: &[IpAddr], port: u16) -> Resolving {
let result: Box<dyn Iterator<Item = SocketAddr> + Send> = Box::new(iter::once(saddr)); let result: Box<dyn Iterator<Item = SocketAddr> + Send> = Box::new(iter::once(saddr));
Box::pin(future::ready(Ok(result))) Box::pin(future::ready(Ok(result)))
}) })
.unwrap() .expect("must provide at least one override name")
} }
fn resolve_to_reqwest(resolver: Arc<TokioAsyncResolver>, name: Name) -> Resolving { fn resolve_to_reqwest(resolver: Arc<TokioAsyncResolver>, name: Name) -> Resolving {

View file

@ -47,9 +47,10 @@ pub(crate) async fn get_actual_dest(server_name: &ServerName) -> Result<ActualDe
let cached; let cached;
let cached_result = services() let cached_result = services()
.globals .globals
.actual_destinations() .resolver
.destinations
.read() .read()
.await .expect("locked for reading")
.get(server_name) .get(server_name)
.cloned(); .cloned();

View file

@ -105,9 +105,10 @@ where
if response.is_ok() && !actual.cached { if response.is_ok() && !actual.cached {
services() services()
.globals .globals
.actual_destinations() .resolver
.destinations
.write() .write()
.await .expect("locked for writing")
.insert(OwnedServerName::from(dest), (actual.dest.clone(), actual.host.clone())); .insert(OwnedServerName::from(dest), (actual.dest.clone(), actual.host.clone()));
} }

View file

@ -106,8 +106,20 @@ impl Services {
.lock() .lock()
.await .await
.len(); .len();
let resolver_overrides_cache = self.globals.resolver.overrides.read().unwrap().len(); let resolver_overrides_cache = self
let resolver_destinations_cache = self.globals.resolver.destinations.read().await.len(); .globals
.resolver
.overrides
.read()
.expect("locked for reading")
.len();
let resolver_destinations_cache = self
.globals
.resolver
.destinations
.read()
.expect("locked for reading")
.len();
let bad_event_ratelimiter = self.globals.bad_event_ratelimiter.read().await.len(); let bad_event_ratelimiter = self.globals.bad_event_ratelimiter.read().await.len();
let bad_query_ratelimiter = self.globals.bad_query_ratelimiter.read().await.len(); let bad_query_ratelimiter = self.globals.bad_query_ratelimiter.read().await.len();
let bad_signature_ratelimiter = self.globals.bad_signature_ratelimiter.read().await.len(); let bad_signature_ratelimiter = self.globals.bad_signature_ratelimiter.read().await.len();
@ -179,8 +191,18 @@ bad_signature_ratelimiter: {bad_signature_ratelimiter}
.clear(); .clear();
} }
if amount > 6 { if amount > 6 {
self.globals.resolver.overrides.write().unwrap().clear(); self.globals
self.globals.resolver.destinations.write().await.clear(); .resolver
.overrides
.write()
.expect("locked for writing")
.clear();
self.globals
.resolver
.destinations
.write()
.expect("locked for writing")
.clear();
} }
if amount > 7 { if amount > 7 {
self.globals.resolver.resolver.clear_cache(); self.globals.resolver.resolver.clear_cache();