diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 979f671f..7e1ccb5e 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -196,8 +196,6 @@ impl Service { pub fn dns_resolver(&self) -> &TokioAsyncResolver { &self.resolver.resolver } - pub fn actual_destinations(&self) -> &Arc> { &self.resolver.destinations } - pub fn jwt_decoding_key(&self) -> Option<&jsonwebtoken::DecodingKey> { self.jwt_decoding_key.as_ref() } pub fn turn_password(&self) -> &String { &self.config.turn_password } diff --git a/src/service/globals/resolver.rs b/src/service/globals/resolver.rs index 7cf88cc0..86fa6700 100644 --- a/src/service/globals/resolver.rs +++ b/src/service/globals/resolver.rs @@ -2,7 +2,7 @@ use std::{ collections::HashMap, future, iter, net::{IpAddr, SocketAddr}, - sync::{Arc, RwLock as StdRwLock}, + sync::{Arc, RwLock}, time::Duration, }; @@ -10,22 +10,21 @@ use conduit::{error, Config, Error}; use hickory_resolver::TokioAsyncResolver; use reqwest::dns::{Addrs, Name, Resolve, Resolving}; use ruma::OwnedServerName; -use tokio::sync::RwLock; use crate::sending::FedDest; -pub(crate) type WellKnownMap = HashMap; +type WellKnownMap = HashMap; type TlsNameMap = HashMap, u16)>; pub struct Resolver { pub destinations: Arc>, // actual_destination, host - pub overrides: Arc>, + pub overrides: Arc>, pub resolver: Arc, pub hooked: Arc, } pub struct Hooked { - pub overrides: Arc>, + pub overrides: Arc>, pub resolver: Arc, } @@ -34,10 +33,10 @@ impl Resolver { pub fn new(config: &Config) -> Self { let (sys_conf, mut opts) = hickory_resolver::system_conf::read_system_conf() .map_err(|e| { - error!("Failed to set up hickory dns resolver with system config: {}", e); + error!("Failed to set up hickory dns resolver with system config: {e}"); Error::bad_config("Failed to set up hickory dns resolver with system config.") }) - .unwrap(); + .expect("DNS system config must be valid"); let mut conf = hickory_resolver::config::ResolverConfig::new(); @@ -82,7 +81,7 @@ impl Resolver { opts.authentic_data = false; let resolver = Arc::new(TokioAsyncResolver::tokio(conf, opts)); - let overrides = Arc::new(StdRwLock::new(TlsNameMap::new())); + let overrides = Arc::new(RwLock::new(TlsNameMap::new())); Self { destinations: Arc::new(RwLock::new(WellKnownMap::new())), overrides: overrides.clone(), @@ -101,7 +100,13 @@ impl Resolve for Resolver { impl Resolve for Hooked { fn resolve(&self, name: Name) -> Resolving { - let addr_port = self.overrides.read().unwrap().get(name.as_str()).cloned(); + let addr_port = self + .overrides + .read() + .expect("locked for reading") + .get(name.as_str()) + .cloned(); + if let Some((addr, port)) = addr_port { cached_to_reqwest(&addr, port) } else { @@ -118,7 +123,7 @@ fn cached_to_reqwest(override_name: &[IpAddr], port: u16) -> Resolving { let result: Box + Send> = Box::new(iter::once(saddr)); Box::pin(future::ready(Ok(result))) }) - .unwrap() + .expect("must provide at least one override name") } fn resolve_to_reqwest(resolver: Arc, name: Name) -> Resolving { diff --git a/src/service/sending/resolve.rs b/src/service/sending/resolve.rs index 23080bed..294ac09f 100644 --- a/src/service/sending/resolve.rs +++ b/src/service/sending/resolve.rs @@ -47,9 +47,10 @@ pub(crate) async fn get_actual_dest(server_name: &ServerName) -> Result 6 { - self.globals.resolver.overrides.write().unwrap().clear(); - self.globals.resolver.destinations.write().await.clear(); + self.globals + .resolver + .overrides + .write() + .expect("locked for writing") + .clear(); + self.globals + .resolver + .destinations + .write() + .expect("locked for writing") + .clear(); } if amount > 7 { self.globals.resolver.resolver.clear_cache();