From 9df5265c00c5083a3567c8be176ca149adf2c833 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Fri, 31 May 2024 00:14:22 +0000 Subject: [PATCH] split sending resolver into unit. Signed-off-by: Jason Volk --- src/admin/debug/debug_commands.rs | 2 +- src/service/sending/mod.rs | 13 +- src/service/sending/resolve.rs | 460 +++++++++++++++++++++++++++++ src/service/sending/send.rs | 465 +----------------------------- 4 files changed, 474 insertions(+), 466 deletions(-) create mode 100644 src/service/sending/resolve.rs diff --git a/src/admin/debug/debug_commands.rs b/src/admin/debug/debug_commands.rs index 74313089..873a23e7 100644 --- a/src/admin/debug/debug_commands.rs +++ b/src/admin/debug/debug_commands.rs @@ -5,7 +5,7 @@ use ruma::{ api::client::error::ErrorKind, events::room::message::RoomMessageEventContent, CanonicalJsonObject, EventId, RoomId, RoomVersionId, ServerName, }; -use service::{rooms::event_handler::parse_incoming_pdu, sending::send::resolve_actual_dest, services, PduEvent}; +use service::{rooms::event_handler::parse_incoming_pdu, sending::resolve::resolve_actual_dest, services, PduEvent}; use tokio::sync::RwLock; use tracing::{debug, info, warn}; use tracing_subscriber::EnvFilter; diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index 654ec523..4ad40f67 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -1,6 +1,13 @@ +mod appservice; +mod data; +pub mod resolve; +mod send; +mod sender; + use std::{fmt::Debug, sync::Arc}; use data::Data; +pub use resolve::FedDest; use ruma::{ api::{appservice::Registration, OutgoingRequest}, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, @@ -10,12 +17,6 @@ use tracing::{error, warn}; use crate::{server_is_ours, services, Config, Error, Result}; -mod appservice; -mod data; -pub mod send; -pub mod sender; -pub use send::FedDest; - pub struct Service { pub db: Arc, diff --git a/src/service/sending/resolve.rs b/src/service/sending/resolve.rs new file mode 100644 index 00000000..d452a555 --- /dev/null +++ b/src/service/sending/resolve.rs @@ -0,0 +1,460 @@ +use std::{ + fmt::Debug, + net::{IpAddr, SocketAddr}, +}; + +use hickory_resolver::{error::ResolveError, lookup::SrvLookup}; +use ipaddress::IPAddress; +use ruma::ServerName; +use tracing::{debug, error, trace}; + +use crate::{debug_error, debug_info, debug_warn, services, Error, Result}; + +/// Wraps either an literal IP address plus port, or a hostname plus complement +/// (colon-plus-port if it was specified). +/// +/// Note: A `FedDest::Named` might contain an IP address in string form if there +/// was no port specified to construct a `SocketAddr` with. +/// +/// # Examples: +/// ```rust +/// # use conduit_service::sending::FedDest; +/// # fn main() -> Result<(), std::net::AddrParseError> { +/// FedDest::Literal("198.51.100.3:8448".parse()?); +/// FedDest::Literal("[2001:db8::4:5]:443".parse()?); +/// FedDest::Named("matrix.example.org".to_owned(), String::new()); +/// FedDest::Named("matrix.example.org".to_owned(), ":8448".to_owned()); +/// FedDest::Named("198.51.100.5".to_owned(), String::new()); +/// # Ok(()) +/// # } +/// ``` +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum FedDest { + Literal(SocketAddr), + Named(String, String), +} + +pub(crate) struct ActualDest { + pub(crate) dest: FedDest, + pub(crate) host: String, + pub(crate) string: String, + pub(crate) cached: bool, +} + +#[tracing::instrument(skip_all, name = "resolve")] +pub(crate) async fn get_actual_dest(server_name: &ServerName) -> Result { + let cached; + let cached_result = services() + .globals + .actual_destinations() + .read() + .await + .get(server_name) + .cloned(); + + let (dest, host) = if let Some(result) = cached_result { + cached = true; + result + } else { + cached = false; + validate_dest(server_name)?; + resolve_actual_dest(server_name, false).await? + }; + + let string = dest.clone().into_https_string(); + Ok(ActualDest { + dest, + host, + string, + cached, + }) +} + +/// Returns: `actual_destination`, host header +/// Implemented according to the specification at +/// Numbers in comments below refer to bullet points in linked section of +/// specification +pub async fn resolve_actual_dest(dest: &ServerName, no_cache_dest: bool) -> Result<(FedDest, String)> { + trace!("Finding actual destination for {dest}"); + let dest_str = dest.as_str().to_owned(); + let mut hostname = dest_str.clone(); + + #[allow(clippy::single_match_else)] + let actual_dest = match get_ip_with_port(&dest_str) { + Some(host_port) => { + debug!("1: IP literal with provided or default port"); + host_port + }, + None => { + if let Some(pos) = dest_str.find(':') { + debug!("2: Hostname with included port"); + let (host, port) = dest_str.split_at(pos); + if !no_cache_dest { + query_and_cache_override(host, host, port.parse::().unwrap_or(8448)).await?; + } + + FedDest::Named(host.to_owned(), port.to_owned()) + } else { + trace!("Requesting well known for {dest}"); + if let Some(delegated_hostname) = request_well_known(dest.as_str()).await? { + debug!("3: A .well-known file is available"); + hostname = add_port_to_hostname(&delegated_hostname).into_uri_string(); + match get_ip_with_port(&delegated_hostname) { + Some(host_and_port) => { + debug!("3.1: IP literal in .well-known file"); + host_and_port + }, + None => { + if let Some(pos) = delegated_hostname.find(':') { + debug!("3.2: Hostname with port in .well-known file"); + let (host, port) = delegated_hostname.split_at(pos); + if !no_cache_dest { + query_and_cache_override(host, host, port.parse::().unwrap_or(8448)).await?; + } + + FedDest::Named(host.to_owned(), port.to_owned()) + } else { + trace!("Delegated hostname has no port in this branch"); + if let Some(hostname_override) = query_srv_record(&delegated_hostname).await? { + debug!("3.3: SRV lookup successful"); + let force_port = hostname_override.port(); + if !no_cache_dest { + query_and_cache_override( + &delegated_hostname, + &hostname_override.hostname(), + force_port.unwrap_or(8448), + ) + .await?; + } + + if let Some(port) = force_port { + FedDest::Named(delegated_hostname, format!(":{port}")) + } else { + add_port_to_hostname(&delegated_hostname) + } + } else { + debug!("3.4: No SRV records, just use the hostname from .well-known"); + if !no_cache_dest { + query_and_cache_override(&delegated_hostname, &delegated_hostname, 8448) + .await?; + } + + add_port_to_hostname(&delegated_hostname) + } + } + }, + } + } else { + trace!("4: No .well-known or an error occured"); + if let Some(hostname_override) = query_srv_record(&dest_str).await? { + debug!("4: No .well-known; SRV record found"); + let force_port = hostname_override.port(); + + if !no_cache_dest { + query_and_cache_override( + &hostname, + &hostname_override.hostname(), + force_port.unwrap_or(8448), + ) + .await?; + } + + if let Some(port) = force_port { + FedDest::Named(hostname.clone(), format!(":{port}")) + } else { + add_port_to_hostname(&hostname) + } + } else { + debug!("4: No .well-known; 5: No SRV record found"); + if !no_cache_dest { + query_and_cache_override(&dest_str, &dest_str, 8448).await?; + } + + add_port_to_hostname(&dest_str) + } + } + } + }, + }; + + // Can't use get_ip_with_port here because we don't want to add a port + // to an IP address if it wasn't specified + let hostname = if let Ok(addr) = hostname.parse::() { + FedDest::Literal(addr) + } else if let Ok(addr) = hostname.parse::() { + FedDest::Named(addr.to_string(), ":8448".to_owned()) + } else if let Some(pos) = hostname.find(':') { + let (host, port) = hostname.split_at(pos); + FedDest::Named(host.to_owned(), port.to_owned()) + } else { + FedDest::Named(hostname, ":8448".to_owned()) + }; + + debug!("Actual destination: {actual_dest:?} hostname: {hostname:?}"); + Ok((actual_dest, hostname.into_uri_string())) +} + +#[tracing::instrument(skip_all, name = "well-known")] +async fn request_well_known(dest: &str) -> Result> { + if !services() + .globals + .resolver + .overrides + .read() + .unwrap() + .contains_key(dest) + { + query_and_cache_override(dest, dest, 8448).await?; + } + + let response = services() + .globals + .client + .well_known + .get(&format!("https://{dest}/.well-known/matrix/server")) + .send() + .await; + + trace!("response: {:?}", response); + if let Err(e) = &response { + debug!("error: {e:?}"); + return Ok(None); + } + + let response = response?; + if !response.status().is_success() { + debug!("response not 2XX"); + return Ok(None); + } + + let text = response.text().await?; + trace!("response text: {:?}", text); + if text.len() >= 12288 { + debug_warn!("response contains junk"); + return Ok(None); + } + + let body: serde_json::Value = serde_json::from_str(&text).unwrap_or_default(); + + let m_server = body + .get("m.server") + .unwrap_or(&serde_json::Value::Null) + .as_str() + .unwrap_or_default(); + + if ruma_identifiers_validation::server_name::validate(m_server).is_err() { + debug_error!("response content missing or invalid"); + return Ok(None); + } + + debug_info!("{:?} found at {:?}", dest, m_server); + Ok(Some(m_server.to_owned())) +} + +#[tracing::instrument(skip_all, name = "ip")] +async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u16) -> Result<()> { + match services() + .globals + .dns_resolver() + .lookup_ip(hostname.to_owned()) + .await + { + Err(e) => handle_resolve_error(&e), + Ok(override_ip) => { + if hostname != overname { + debug_info!("{:?} overriden by {:?}", overname, hostname); + } + services() + .globals + .resolver + .overrides + .write() + .unwrap() + .insert(overname.to_owned(), (override_ip.iter().collect(), port)); + + Ok(()) + }, + } +} + +#[tracing::instrument(skip_all, name = "srv")] +async fn query_srv_record(hostname: &'_ str) -> Result> { + fn handle_successful_srv(srv: &SrvLookup) -> Option { + srv.iter().next().map(|result| { + FedDest::Named( + result.target().to_string().trim_end_matches('.').to_owned(), + format!(":{}", result.port()), + ) + }) + } + + async fn lookup_srv(hostname: &str) -> Result { + debug!("querying SRV for {:?}", hostname); + let hostname = hostname.trim_end_matches('.'); + services() + .globals + .dns_resolver() + .srv_lookup(hostname.to_owned()) + .await + } + + let hostnames = [format!("_matrix-fed._tcp.{hostname}."), format!("_matrix._tcp.{hostname}.")]; + + for hostname in hostnames { + match lookup_srv(&hostname).await { + Ok(result) => return Ok(handle_successful_srv(&result)), + Err(e) => handle_resolve_error(&e)?, + } + } + + Ok(None) +} + +#[allow(clippy::single_match_else)] +fn handle_resolve_error(e: &ResolveError) -> Result<()> { + use hickory_resolver::error::ResolveErrorKind; + + match *e.kind() { + ResolveErrorKind::NoRecordsFound { + .. + } => { + // Raise to debug_warn if we can find out the result wasn't from cache + debug!("{e}"); + Ok(()) + }, + _ => { + error!("DNS {e}"); + Err(Error::Err(e.to_string())) + }, + } +} + +fn validate_dest(dest: &ServerName) -> Result<()> { + if dest == services().globals.server_name() { + return Err(Error::bad_config("Won't send federation request to ourselves")); + } + + if dest.is_ip_literal() || IPAddress::is_valid(dest.host()) { + validate_dest_ip_literal(dest)?; + } + + Ok(()) +} + +fn validate_dest_ip_literal(dest: &ServerName) -> Result<()> { + trace!("Destination is an IP literal, checking against IP range denylist.",); + debug_assert!( + dest.is_ip_literal() || !IPAddress::is_valid(dest.host()), + "Destination is not an IP literal." + ); + let ip = IPAddress::parse(dest.host()).map_err(|e| { + debug_error!("Failed to parse IP literal from string: {}", e); + Error::BadServerResponse("Invalid IP address") + })?; + + validate_ip(&ip)?; + + Ok(()) +} + +pub(crate) fn validate_ip(ip: &IPAddress) -> Result<()> { + if !services().globals.valid_cidr_range(ip) { + return Err(Error::BadServerResponse("Not allowed to send requests to this IP")); + } + + Ok(()) +} + +fn get_ip_with_port(dest_str: &str) -> Option { + if let Ok(dest) = dest_str.parse::() { + Some(FedDest::Literal(dest)) + } else if let Ok(ip_addr) = dest_str.parse::() { + Some(FedDest::Literal(SocketAddr::new(ip_addr, 8448))) + } else { + None + } +} + +fn add_port_to_hostname(dest_str: &str) -> FedDest { + let (host, port) = match dest_str.find(':') { + None => (dest_str, ":8448"), + Some(pos) => dest_str.split_at(pos), + }; + + FedDest::Named(host.to_owned(), port.to_owned()) +} + +impl FedDest { + fn into_https_string(self) -> String { + match self { + Self::Literal(addr) => format!("https://{addr}"), + Self::Named(host, port) => format!("https://{host}{port}"), + } + } + + fn into_uri_string(self) -> String { + match self { + Self::Literal(addr) => addr.to_string(), + Self::Named(host, port) => format!("{host}{port}"), + } + } + + fn hostname(&self) -> String { + match &self { + Self::Literal(addr) => addr.ip().to_string(), + Self::Named(host, _) => host.clone(), + } + } + + fn port(&self) -> Option { + match &self { + Self::Literal(addr) => Some(addr.port()), + Self::Named(_, port) => port[1..].parse().ok(), + } + } +} + +#[cfg(test)] +mod tests { + use super::{add_port_to_hostname, get_ip_with_port, FedDest}; + + #[test] + fn ips_get_default_ports() { + assert_eq!( + get_ip_with_port("1.1.1.1"), + Some(FedDest::Literal("1.1.1.1:8448".parse().unwrap())) + ); + assert_eq!( + get_ip_with_port("dead:beef::"), + Some(FedDest::Literal("[dead:beef::]:8448".parse().unwrap())) + ); + } + + #[test] + fn ips_keep_custom_ports() { + assert_eq!( + get_ip_with_port("1.1.1.1:1234"), + Some(FedDest::Literal("1.1.1.1:1234".parse().unwrap())) + ); + assert_eq!( + get_ip_with_port("[dead::beef]:8933"), + Some(FedDest::Literal("[dead::beef]:8933".parse().unwrap())) + ); + } + + #[test] + fn hostnames_get_default_ports() { + assert_eq!( + add_port_to_hostname("example.com"), + FedDest::Named(String::from("example.com"), String::from(":8448")) + ); + } + + #[test] + fn hostnames_keep_custom_ports() { + assert_eq!( + add_port_to_hostname("example.com:1337"), + FedDest::Named(String::from("example.com"), String::from(":1337")) + ); + } +} diff --git a/src/service/sending/send.rs b/src/service/sending/send.rs index a835e438..e8432d12 100644 --- a/src/service/sending/send.rs +++ b/src/service/sending/send.rs @@ -1,10 +1,5 @@ -use std::{ - fmt::Debug, - mem, - net::{IpAddr, SocketAddr}, -}; +use std::{fmt::Debug, mem}; -use hickory_resolver::{error::ResolveError, lookup::SrvLookup}; use http::{header::AUTHORIZATION, HeaderValue}; use ipaddress::IPAddress; use reqwest::{Client, Method, Request, Response, Url}; @@ -15,40 +10,10 @@ use ruma::{ }, OwnedServerName, ServerName, }; -use tracing::{debug, error, trace}; +use tracing::{debug, trace}; -use crate::{debug_error, debug_info, debug_warn, services, Error, Result}; - -/// Wraps either an literal IP address plus port, or a hostname plus complement -/// (colon-plus-port if it was specified). -/// -/// Note: A `FedDest::Named` might contain an IP address in string form if there -/// was no port specified to construct a `SocketAddr` with. -/// -/// # Examples: -/// ```rust -/// # use conduit_service::sending::FedDest; -/// # fn main() -> Result<(), std::net::AddrParseError> { -/// FedDest::Literal("198.51.100.3:8448".parse()?); -/// FedDest::Literal("[2001:db8::4:5]:443".parse()?); -/// FedDest::Named("matrix.example.org".to_owned(), String::new()); -/// FedDest::Named("matrix.example.org".to_owned(), ":8448".to_owned()); -/// FedDest::Named("198.51.100.5".to_owned(), String::new()); -/// # Ok(()) -/// # } -/// ``` -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum FedDest { - Literal(SocketAddr), - Named(String, String), -} - -struct ActualDest { - dest: FedDest, - host: String, - string: String, - cached: bool, -} +use super::{resolve, resolve::ActualDest}; +use crate::{debug_error, debug_warn, services, Error, Result}; #[tracing::instrument(skip_all, name = "send")] pub async fn send(client: &Client, dest: &ServerName, req: T) -> Result @@ -59,7 +24,7 @@ where return Err(Error::bad_config("Federation is disabled.")); } - let actual = get_actual_dest(dest).await?; + let actual = resolve::get_actual_dest(dest).await?; let request = prepare::(dest, &actual, req).await?; execute::(client, dest, &actual, request).await } @@ -177,294 +142,6 @@ where Err(e.into()) } -#[tracing::instrument(skip_all, name = "resolve")] -async fn get_actual_dest(server_name: &ServerName) -> Result { - let cached; - let cached_result = services() - .globals - .actual_destinations() - .read() - .await - .get(server_name) - .cloned(); - - let (dest, host) = if let Some(result) = cached_result { - cached = true; - result - } else { - cached = false; - validate_dest(server_name)?; - resolve_actual_dest(server_name, false).await? - }; - - let string = dest.clone().into_https_string(); - Ok(ActualDest { - dest, - host, - string, - cached, - }) -} - -/// Returns: `actual_destination`, host header -/// Implemented according to the specification at -/// Numbers in comments below refer to bullet points in linked section of -/// specification -pub async fn resolve_actual_dest(dest: &ServerName, no_cache_dest: bool) -> Result<(FedDest, String)> { - trace!("Finding actual destination for {dest}"); - let dest_str = dest.as_str().to_owned(); - let mut hostname = dest_str.clone(); - - #[allow(clippy::single_match_else)] - let actual_dest = match get_ip_with_port(&dest_str) { - Some(host_port) => { - debug!("1: IP literal with provided or default port"); - host_port - }, - None => { - if let Some(pos) = dest_str.find(':') { - debug!("2: Hostname with included port"); - let (host, port) = dest_str.split_at(pos); - if !no_cache_dest { - query_and_cache_override(host, host, port.parse::().unwrap_or(8448)).await?; - } - - FedDest::Named(host.to_owned(), port.to_owned()) - } else { - trace!("Requesting well known for {dest}"); - if let Some(delegated_hostname) = request_well_known(dest.as_str()).await? { - debug!("3: A .well-known file is available"); - hostname = add_port_to_hostname(&delegated_hostname).into_uri_string(); - match get_ip_with_port(&delegated_hostname) { - Some(host_and_port) => { - debug!("3.1: IP literal in .well-known file"); - host_and_port - }, - None => { - if let Some(pos) = delegated_hostname.find(':') { - debug!("3.2: Hostname with port in .well-known file"); - let (host, port) = delegated_hostname.split_at(pos); - if !no_cache_dest { - query_and_cache_override(host, host, port.parse::().unwrap_or(8448)).await?; - } - - FedDest::Named(host.to_owned(), port.to_owned()) - } else { - trace!("Delegated hostname has no port in this branch"); - if let Some(hostname_override) = query_srv_record(&delegated_hostname).await? { - debug!("3.3: SRV lookup successful"); - let force_port = hostname_override.port(); - if !no_cache_dest { - query_and_cache_override( - &delegated_hostname, - &hostname_override.hostname(), - force_port.unwrap_or(8448), - ) - .await?; - } - - if let Some(port) = force_port { - FedDest::Named(delegated_hostname, format!(":{port}")) - } else { - add_port_to_hostname(&delegated_hostname) - } - } else { - debug!("3.4: No SRV records, just use the hostname from .well-known"); - if !no_cache_dest { - query_and_cache_override(&delegated_hostname, &delegated_hostname, 8448) - .await?; - } - - add_port_to_hostname(&delegated_hostname) - } - } - }, - } - } else { - trace!("4: No .well-known or an error occured"); - if let Some(hostname_override) = query_srv_record(&dest_str).await? { - debug!("4: No .well-known; SRV record found"); - let force_port = hostname_override.port(); - - if !no_cache_dest { - query_and_cache_override( - &hostname, - &hostname_override.hostname(), - force_port.unwrap_or(8448), - ) - .await?; - } - - if let Some(port) = force_port { - FedDest::Named(hostname.clone(), format!(":{port}")) - } else { - add_port_to_hostname(&hostname) - } - } else { - debug!("4: No .well-known; 5: No SRV record found"); - if !no_cache_dest { - query_and_cache_override(&dest_str, &dest_str, 8448).await?; - } - - add_port_to_hostname(&dest_str) - } - } - } - }, - }; - - // Can't use get_ip_with_port here because we don't want to add a port - // to an IP address if it wasn't specified - let hostname = if let Ok(addr) = hostname.parse::() { - FedDest::Literal(addr) - } else if let Ok(addr) = hostname.parse::() { - FedDest::Named(addr.to_string(), ":8448".to_owned()) - } else if let Some(pos) = hostname.find(':') { - let (host, port) = hostname.split_at(pos); - FedDest::Named(host.to_owned(), port.to_owned()) - } else { - FedDest::Named(hostname, ":8448".to_owned()) - }; - - debug!("Actual destination: {actual_dest:?} hostname: {hostname:?}"); - Ok((actual_dest, hostname.into_uri_string())) -} - -#[tracing::instrument(skip_all, name = "well-known")] -async fn request_well_known(dest: &str) -> Result> { - if !services() - .globals - .resolver - .overrides - .read() - .unwrap() - .contains_key(dest) - { - query_and_cache_override(dest, dest, 8448).await?; - } - - let response = services() - .globals - .client - .well_known - .get(&format!("https://{dest}/.well-known/matrix/server")) - .send() - .await; - - trace!("response: {:?}", response); - if let Err(e) = &response { - debug!("error: {e:?}"); - return Ok(None); - } - - let response = response?; - if !response.status().is_success() { - debug!("response not 2XX"); - return Ok(None); - } - - let text = response.text().await?; - trace!("response text: {:?}", text); - if text.len() >= 12288 { - debug_warn!("response contains junk"); - return Ok(None); - } - - let body: serde_json::Value = serde_json::from_str(&text).unwrap_or_default(); - - let m_server = body - .get("m.server") - .unwrap_or(&serde_json::Value::Null) - .as_str() - .unwrap_or_default(); - - if ruma_identifiers_validation::server_name::validate(m_server).is_err() { - debug_error!("response content missing or invalid"); - return Ok(None); - } - - debug_info!("{:?} found at {:?}", dest, m_server); - Ok(Some(m_server.to_owned())) -} - -#[tracing::instrument(skip_all, name = "ip")] -async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u16) -> Result<()> { - match services() - .globals - .dns_resolver() - .lookup_ip(hostname.to_owned()) - .await - { - Err(e) => handle_resolve_error(&e), - Ok(override_ip) => { - if hostname != overname { - debug_info!("{:?} overriden by {:?}", overname, hostname); - } - services() - .globals - .resolver - .overrides - .write() - .unwrap() - .insert(overname.to_owned(), (override_ip.iter().collect(), port)); - - Ok(()) - }, - } -} - -#[tracing::instrument(skip_all, name = "srv")] -async fn query_srv_record(hostname: &'_ str) -> Result> { - fn handle_successful_srv(srv: &SrvLookup) -> Option { - srv.iter().next().map(|result| { - FedDest::Named( - result.target().to_string().trim_end_matches('.').to_owned(), - format!(":{}", result.port()), - ) - }) - } - - async fn lookup_srv(hostname: &str) -> Result { - debug!("querying SRV for {:?}", hostname); - let hostname = hostname.trim_end_matches('.'); - services() - .globals - .dns_resolver() - .srv_lookup(hostname.to_owned()) - .await - } - - let hostnames = [format!("_matrix-fed._tcp.{hostname}."), format!("_matrix._tcp.{hostname}.")]; - - for hostname in hostnames { - match lookup_srv(&hostname).await { - Ok(result) => return Ok(handle_successful_srv(&result)), - Err(e) => handle_resolve_error(&e)?, - } - } - - Ok(None) -} - -#[allow(clippy::single_match_else)] -fn handle_resolve_error(e: &ResolveError) -> Result<()> { - use hickory_resolver::error::ResolveErrorKind; - - match *e.kind() { - ResolveErrorKind::NoRecordsFound { - .. - } => { - // Raise to debug_warn if we can find out the result wasn't from cache - debug!("{e}"); - Ok(()) - }, - _ => { - error!("DNS {e}"); - Err(Error::Err(e.to_string())) - }, - } -} - fn sign_request(dest: &ServerName, http_request: &mut http::Request>) where T: OutgoingRequest + Debug, @@ -533,139 +210,9 @@ fn validate_url(url: &Url) -> Result<()> { if let Some(url_host) = url.host_str() { if let Ok(ip) = IPAddress::parse(url_host) { trace!("Checking request URL IP {ip:?}"); - validate_ip(&ip)?; + resolve::validate_ip(&ip)?; } } Ok(()) } - -fn validate_dest(dest: &ServerName) -> Result<()> { - if dest == services().globals.server_name() { - return Err(Error::bad_config("Won't send federation request to ourselves")); - } - - if dest.is_ip_literal() || IPAddress::is_valid(dest.host()) { - validate_dest_ip_literal(dest)?; - } - - Ok(()) -} - -fn validate_dest_ip_literal(dest: &ServerName) -> Result<()> { - trace!("Destination is an IP literal, checking against IP range denylist.",); - debug_assert!( - dest.is_ip_literal() || !IPAddress::is_valid(dest.host()), - "Destination is not an IP literal." - ); - let ip = IPAddress::parse(dest.host()).map_err(|e| { - debug_error!("Failed to parse IP literal from string: {}", e); - Error::BadServerResponse("Invalid IP address") - })?; - - validate_ip(&ip)?; - - Ok(()) -} - -fn validate_ip(ip: &IPAddress) -> Result<()> { - if !services().globals.valid_cidr_range(ip) { - return Err(Error::BadServerResponse("Not allowed to send requests to this IP")); - } - - Ok(()) -} - -fn get_ip_with_port(dest_str: &str) -> Option { - if let Ok(dest) = dest_str.parse::() { - Some(FedDest::Literal(dest)) - } else if let Ok(ip_addr) = dest_str.parse::() { - Some(FedDest::Literal(SocketAddr::new(ip_addr, 8448))) - } else { - None - } -} - -fn add_port_to_hostname(dest_str: &str) -> FedDest { - let (host, port) = match dest_str.find(':') { - None => (dest_str, ":8448"), - Some(pos) => dest_str.split_at(pos), - }; - - FedDest::Named(host.to_owned(), port.to_owned()) -} - -impl FedDest { - fn into_https_string(self) -> String { - match self { - Self::Literal(addr) => format!("https://{addr}"), - Self::Named(host, port) => format!("https://{host}{port}"), - } - } - - fn into_uri_string(self) -> String { - match self { - Self::Literal(addr) => addr.to_string(), - Self::Named(host, port) => format!("{host}{port}"), - } - } - - fn hostname(&self) -> String { - match &self { - Self::Literal(addr) => addr.ip().to_string(), - Self::Named(host, _) => host.clone(), - } - } - - fn port(&self) -> Option { - match &self { - Self::Literal(addr) => Some(addr.port()), - Self::Named(_, port) => port[1..].parse().ok(), - } - } -} - -#[cfg(test)] -mod tests { - use super::{add_port_to_hostname, get_ip_with_port, FedDest}; - - #[test] - fn ips_get_default_ports() { - assert_eq!( - get_ip_with_port("1.1.1.1"), - Some(FedDest::Literal("1.1.1.1:8448".parse().unwrap())) - ); - assert_eq!( - get_ip_with_port("dead:beef::"), - Some(FedDest::Literal("[dead:beef::]:8448".parse().unwrap())) - ); - } - - #[test] - fn ips_keep_custom_ports() { - assert_eq!( - get_ip_with_port("1.1.1.1:1234"), - Some(FedDest::Literal("1.1.1.1:1234".parse().unwrap())) - ); - assert_eq!( - get_ip_with_port("[dead::beef]:8933"), - Some(FedDest::Literal("[dead::beef]:8933".parse().unwrap())) - ); - } - - #[test] - fn hostnames_get_default_ports() { - assert_eq!( - add_port_to_hostname("example.com"), - FedDest::Named(String::from("example.com"), String::from(":8448")) - ); - } - - #[test] - fn hostnames_keep_custom_ports() { - assert_eq!( - add_port_to_hostname("example.com:1337"), - FedDest::Named(String::from("example.com"), String::from(":1337")) - ); - } -}