diff --git a/Cargo.lock b/Cargo.lock index da06c5ad..e0841e44 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -536,6 +536,7 @@ dependencies = [ "reqwest", "ring", "ruma", + "ruma-identifiers-validation", "rusqlite", "rust-rocksdb", "sd-notify", diff --git a/Cargo.toml b/Cargo.toml index f818d56f..84976146 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -275,6 +275,10 @@ features = [ "unstable-extensible-events", ] +[dependencies.ruma-identifiers-validation] +git = "https://github.com/girlbossceo/ruma" +branch = "conduwuit-changes" + [dependencies.hickory-resolver] version = "0.24.1" default-features = false diff --git a/src/service/sending/send.rs b/src/service/sending/send.rs index f5cd0649..726ceffe 100644 --- a/src/service/sending/send.rs +++ b/src/service/sending/send.rs @@ -4,7 +4,6 @@ use std::{ net::{IpAddr, SocketAddr}, }; -use futures_util::TryFutureExt; use hickory_resolver::{error::ResolveError, lookup::SrvLookup}; use http::{header::AUTHORIZATION, HeaderValue}; use ipaddress::IPAddress; @@ -15,9 +14,9 @@ use ruma::{ }, OwnedServerName, ServerName, }; -use tracing::{debug, trace, warn}; +use tracing::{debug, error, trace, warn}; -use crate::{services, Error, Result}; +use crate::{debug_error, debug_warn, services, Error, Result}; /// Wraps either an literal IP address plus port, or a hostname plus complement /// (colon-plus-port if it was specified). @@ -63,11 +62,11 @@ where trace!("Preparing to send request"); validate_destination(destination)?; - let actual = get_actual_destination(destination).await; + let actual = get_actual_destination(destination).await?; let mut http_request = req .try_into_http_request::>(&actual.string, SendAccessToken::IfRequired(""), &[MatrixVersion::V1_5]) .map_err(|e| { - warn!("Failed to find destination {}: {}", actual.string, e); + debug_warn!("Failed to find destination {}: {}", actual.string, e); Error::BadServerResponse("Invalid destination") })?; @@ -96,8 +95,6 @@ where T: OutgoingRequest + Debug, { trace!("Received response from {} for {} with {}", actual.string, url, response.url()); - validate_response(&response)?; - let status = response.status(); let mut http_response_builder = http::Response::builder() .status(status) @@ -111,7 +108,7 @@ where trace!("Waiting for response body"); let body = response.bytes().await.unwrap_or_else(|e| { - debug!("server error {}", e); + debug_error!("server error {}", e); Vec::new().into() }); // TODO: handle timeout @@ -153,27 +150,27 @@ where // we do not need to log that servers in a room are dead, this is normal in // public rooms and just spams the logs. if e.is_timeout() { - debug!("Timed out sending request to {}: {}", actual.string, e,); + debug_error!("timeout {}: {}", actual.host, e); } else if e.is_connect() { - debug!("Failed to connect to {}: {}", actual.string, e); + debug_error!("connect {}: {}", actual.host, e); } else if e.is_redirect() { - debug!( + debug_error!( method = ?method, url = ?url, final_url = ?e.url(), - "Redirect loop sending request to {}: {}", - actual.string, + "Redirect loop {}: {}", + actual.host, e, ); } else { - debug!("Could not send request to {}: {}", actual.string, e); + debug_error!("{}: {}", actual.host, e); } Err(e.into()) } #[tracing::instrument(skip_all, name = "resolve")] -async fn get_actual_destination(server_name: &ServerName) -> ActualDestination { +async fn get_actual_destination(server_name: &ServerName) -> Result { let cached; let cached_result = services() .globals @@ -188,23 +185,23 @@ async fn get_actual_destination(server_name: &ServerName) -> ActualDestination { result } else { cached = false; - resolve_actual_destination(server_name).await + resolve_actual_destination(server_name).await? }; let string = destination.clone().into_https_string(); - ActualDestination { + Ok(ActualDestination { destination, host, string, cached, - } + }) } /// Returns: `actual_destination`, host header /// Implemented according to the specification at /// Numbers in comments below refer to bullet points in linked section of /// specification -async fn resolve_actual_destination(destination: &'_ ServerName) -> (FedDest, String) { +async fn resolve_actual_destination(destination: &'_ ServerName) -> Result<(FedDest, String)> { trace!("Finding actual destination for {destination}"); let destination_str = destination.as_str().to_owned(); let mut hostname = destination_str.clone(); @@ -218,12 +215,12 @@ async fn resolve_actual_destination(destination: &'_ ServerName) -> (FedDest, St debug!("2: Hostname with included port"); let (host, port) = destination_str.split_at(pos); - query_and_cache_override(host, host, port.parse::().unwrap_or(8448)).await; + query_and_cache_override(host, host, port.parse::().unwrap_or(8448)).await?; FedDest::Named(host.to_owned(), port.to_owned()) } else { trace!("Requesting well known for {destination}"); - if let Some(delegated_hostname) = request_well_known(destination.as_str()).await { + if let Some(delegated_hostname) = request_well_known(destination.as_str()).await? { debug!("3: A .well-known file is available"); hostname = add_port_to_hostname(&delegated_hostname).into_uri_string(); match get_ip_with_port(&delegated_hostname) { @@ -233,12 +230,12 @@ async fn resolve_actual_destination(destination: &'_ ServerName) -> (FedDest, St debug!("3.2: Hostname with port in .well-known file"); let (host, port) = delegated_hostname.split_at(pos); - query_and_cache_override(host, host, port.parse::().unwrap_or(8448)).await; + query_and_cache_override(host, host, port.parse::().unwrap_or(8448)).await?; FedDest::Named(host.to_owned(), port.to_owned()) } else { trace!("Delegated hostname has no port in this branch"); - if let Some(hostname_override) = query_srv_record(&delegated_hostname).await { + if let Some(hostname_override) = query_srv_record(&delegated_hostname).await? { debug!("3.3: SRV lookup successful"); let force_port = hostname_override.port(); @@ -247,7 +244,7 @@ async fn resolve_actual_destination(destination: &'_ ServerName) -> (FedDest, St &hostname_override.hostname(), force_port.unwrap_or(8448), ) - .await; + .await?; if let Some(port) = force_port { FedDest::Named(delegated_hostname, format!(":{port}")) @@ -256,7 +253,7 @@ async fn resolve_actual_destination(destination: &'_ ServerName) -> (FedDest, St } } else { debug!("3.4: No SRV records, just use the hostname from .well-known"); - query_and_cache_override(&delegated_hostname, &delegated_hostname, 8448).await; + query_and_cache_override(&delegated_hostname, &delegated_hostname, 8448).await?; add_port_to_hostname(&delegated_hostname) } } @@ -264,12 +261,12 @@ async fn resolve_actual_destination(destination: &'_ ServerName) -> (FedDest, St } } else { trace!("4: No .well-known or an error occured"); - if let Some(hostname_override) = query_srv_record(&destination_str).await { + if let Some(hostname_override) = query_srv_record(&destination_str).await? { debug!("4: No .well-known; SRV record found"); let force_port = hostname_override.port(); query_and_cache_override(&hostname, &hostname_override.hostname(), force_port.unwrap_or(8448)) - .await; + .await?; if let Some(port) = force_port { FedDest::Named(hostname.clone(), format!(":{port}")) @@ -278,7 +275,7 @@ async fn resolve_actual_destination(destination: &'_ ServerName) -> (FedDest, St } } else { debug!("4: No .well-known; 5: No SRV record found"); - query_and_cache_override(&destination_str, &destination_str, 8448).await; + query_and_cache_override(&destination_str, &destination_str, 8448).await?; add_port_to_hostname(&destination_str) } } @@ -300,10 +297,65 @@ async fn resolve_actual_destination(destination: &'_ ServerName) -> (FedDest, St }; debug!("Actual destination: {actual_destination:?} hostname: {hostname:?}"); - (actual_destination, hostname.into_uri_string()) + Ok((actual_destination, hostname.into_uri_string())) } -async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u16) { +async fn request_well_known(destination: &str) -> Result> { + if !services() + .globals + .resolver + .overrides + .read() + .unwrap() + .contains_key(destination) + { + query_and_cache_override(destination, destination, 8448).await?; + } + + let response = services() + .globals + .client + .well_known + .get(&format!("https://{destination}/.well-known/matrix/server")) + .send() + .await; + + trace!("Well known response: {:?}", response); + if let Err(e) = &response { + debug!("Well known error: {e:?}"); + return Ok(None); + } + + let response = response?; + if !response.status().is_success() { + debug!("Well known response not 2XX"); + return Ok(None); + } + + let text = response.text().await?; + trace!("Well known response text: {:?}", text); + if text.len() >= 12288 { + debug!("Well known response contains junk"); + return Ok(None); + } + + let body: serde_json::Value = serde_json::from_str(&text).unwrap_or_default(); + + let m_server = body + .get("m.server") + .unwrap_or(&serde_json::Value::Null) + .as_str() + .unwrap_or_default(); + + if ruma_identifiers_validation::server_name::validate(m_server).is_err() { + debug!("Well known response content missing or invalid"); + return Ok(None); + } + + Ok(Some(m_server.to_owned())) +} + +async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u16) -> Result<()> { match services() .globals .dns_resolver() @@ -319,14 +371,17 @@ async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u1 .write() .unwrap() .insert(overname.to_owned(), (override_ip.iter().collect(), port)); + + Ok(()) }, Err(e) => { debug!("Got {:?} for {:?} to override {:?}", e.kind(), hostname, overname); + handle_resolve_error(&e) }, } } -async fn query_srv_record(hostname: &'_ str) -> Option { +async fn query_srv_record(hostname: &'_ str) -> Result> { fn handle_successful_srv(srv: &SrvLookup) -> Option { srv.iter().next().map(|result| { FedDest::Named( @@ -346,61 +401,34 @@ async fn query_srv_record(hostname: &'_ str) -> Option { .await } - let first_hostname = format!("_matrix-fed._tcp.{hostname}."); - let second_hostname = format!("_matrix._tcp.{hostname}."); + let hostnames = [format!("_matrix-fed._tcp.{hostname}."), format!("_matrix._tcp.{hostname}.")]; - lookup_srv(&first_hostname) - .or_else(|_| { - trace!("Querying deprecated _matrix SRV record for host {:?}", hostname); - lookup_srv(&second_hostname) - }) - .and_then(|srv_lookup| async move { Ok(handle_successful_srv(&srv_lookup)) }) - .await - .ok() - .flatten() + for hostname in hostnames { + match lookup_srv(&hostname).await { + Ok(result) => return Ok(handle_successful_srv(&result)), + Err(e) => handle_resolve_error(&e)?, + } + } + + Ok(None) } -async fn request_well_known(destination: &str) -> Option { - if !services() - .globals - .resolver - .overrides - .read() - .unwrap() - .contains_key(destination) - { - query_and_cache_override(destination, destination, 8448).await; +#[allow(clippy::single_match_else)] +fn handle_resolve_error(e: &ResolveError) -> Result<()> { + use hickory_resolver::error::ResolveErrorKind; + + match *e.kind() { + ResolveErrorKind::Io { + .. + } => { + debug_error!("DNS IO error: {e}"); + Err(Error::Error(e.to_string())) + }, + _ => { + debug!("DNS protocol error: {e}"); + Ok(()) + }, } - - let response = services() - .globals - .client - .well_known - .get(&format!("https://{destination}/.well-known/matrix/server")) - .send() - .await; - - trace!("Well known response: {:?}", response); - if let Err(e) = &response { - debug!("Well known error: {e:?}"); - return None; - } - - let text = response.ok()?.text().await; - trace!("Well known response text: {:?}", text); - - if text.as_ref().ok()?.len() > 10000 { - debug!( - "Well known response for destination '{destination}' exceeded past 10000 characters, assuming no \ - well-known." - ); - return None; - } - - let body: serde_json::Value = serde_json::from_str(&text.ok()?).ok()?; - trace!("serde_json body of well known text: {}", body); - - Some(body.get("m.server")?.as_str()?.to_owned()) } fn sign_request(destination: &ServerName, http_request: &mut http::Request>) @@ -466,17 +494,6 @@ where } } -fn validate_response(response: &reqwest::Response) -> Result<()> { - if let Some(remote_addr) = response.remote_addr() { - if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) { - trace!("Checking response destination's IP"); - validate_ip(&ip)?; - } - } - - Ok(()) -} - fn validate_url(url: &reqwest::Url) -> Result<()> { if let Some(url_host) = url.host_str() { if let Ok(ip) = IPAddress::parse(url_host) { @@ -509,7 +526,7 @@ fn validate_destination_ip_literal(destination: &ServerName) -> Result<()> { debug!("Destination is an IP literal, checking against IP range denylist.",); let ip = IPAddress::parse(destination.host()).map_err(|e| { - warn!("Failed to parse IP literal from string: {}", e); + debug_warn!("Failed to parse IP literal from string: {}", e); Error::BadServerResponse("Invalid IP address") })?;