diff --git a/Cargo.lock b/Cargo.lock index 83c443cc..a8da7d7e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -653,7 +653,6 @@ dependencies = [ "http 1.1.0", "http-body-util", "hyper 1.4.1", - "image", "ipaddress", "jsonwebtoken", "log", @@ -666,7 +665,6 @@ dependencies = [ "sha-1", "tokio", "tracing", - "webpage", ] [[package]] @@ -802,6 +800,7 @@ dependencies = [ "tokio", "tracing", "url", + "webpage", ] [[package]] diff --git a/src/api/Cargo.toml b/src/api/Cargo.toml index 356adc1f..f537fd5f 100644 --- a/src/api/Cargo.toml +++ b/src/api/Cargo.toml @@ -47,7 +47,6 @@ hmac.workspace = true http.workspace = true http-body-util.workspace = true hyper.workspace = true -image.workspace = true ipaddress.workspace = true jsonwebtoken.workspace = true log.workspace = true @@ -60,7 +59,6 @@ serde.workspace = true sha-1.workspace = true tokio.workspace = true tracing.workspace = true -webpage.workspace = true [lints] workspace = true diff --git a/src/api/client/media.rs b/src/api/client/media.rs index f0afa290..5c83c17f 100644 --- a/src/api/client/media.rs +++ b/src/api/client/media.rs @@ -1,11 +1,11 @@ #![allow(deprecated)] -use std::{io::Cursor, time::Duration}; +use std::time::Duration; use axum::extract::State; use axum_client_ip::InsecureClientIp; use conduit::{ - debug, debug_warn, error, + debug_warn, error, utils::{ self, content_disposition::{content_disposition_type, make_content_disposition, sanitise_filename}, @@ -13,9 +13,6 @@ use conduit::{ }, warn, Error, Result, }; -use image::io::Reader as ImgReader; -use ipaddress::IPAddress; -use reqwest::Url; use ruma::api::client::{ error::{ErrorKind, RetryAfter}, media::{ @@ -24,16 +21,12 @@ use ruma::api::client::{ }, }; use service::{ - media::{FileMeta, UrlPreviewData}, + media::{FileMeta, MXC_LENGTH}, Services, }; -use webpage::HTML; use crate::{Ruma, RumaResponse}; -/// generated MXC ID (`media-id`) length -const MXC_LENGTH: usize = 32; - /// Cache control for immutable objects const CACHE_CONTROL_IMMUTABLE: &str = "public,max-age=31536000,immutable"; @@ -76,12 +69,12 @@ pub(crate) async fn get_media_preview_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let url = &body.url; - if !url_preview_allowed(&services, url) { + if !services.media.url_preview_allowed(url) { warn!(%sender_user, "URL is not allowed to be previewed: {url}"); return Err(Error::BadRequest(ErrorKind::forbidden(), "URL is not allowed to be previewed")); } - match get_url_preview(&services, url).await { + match services.media.get_url_preview(url).await { Ok(preview) => { let res = serde_json::value::to_raw_value(&preview).map_err(|e| { error!(%sender_user, "Failed to convert UrlPreviewData into a serde json value: {e}"); @@ -553,222 +546,3 @@ async fn get_remote_content( cache_control: Some(CACHE_CONTROL_IMMUTABLE.to_owned()), }) } - -async fn download_image(services: &Services, client: &reqwest::Client, url: &str) -> Result { - let image = client.get(url).send().await?.bytes().await?; - let mxc = format!("mxc://{}/{}", services.globals.server_name(), utils::random_string(MXC_LENGTH)); - - services - .media - .create(None, &mxc, None, None, &image) - .await?; - - let (width, height) = match ImgReader::new(Cursor::new(&image)).with_guessed_format() { - Err(_) => (None, None), - Ok(reader) => match reader.into_dimensions() { - Err(_) => (None, None), - Ok((width, height)) => (Some(width), Some(height)), - }, - }; - - Ok(UrlPreviewData { - image: Some(mxc), - image_size: Some(image.len()), - image_width: width, - image_height: height, - ..Default::default() - }) -} - -async fn download_html(services: &Services, client: &reqwest::Client, url: &str) -> Result { - let mut response = client.get(url).send().await?; - - let mut bytes: Vec = Vec::new(); - while let Some(chunk) = response.chunk().await? { - bytes.extend_from_slice(&chunk); - if bytes.len() > services.globals.url_preview_max_spider_size() { - debug!( - "Response body from URL {} exceeds url_preview_max_spider_size ({}), not processing the rest of the \ - response body and assuming our necessary data is in this range.", - url, - services.globals.url_preview_max_spider_size() - ); - break; - } - } - let body = String::from_utf8_lossy(&bytes); - let Ok(html) = HTML::from_string(body.to_string(), Some(url.to_owned())) else { - return Err(Error::BadRequest(ErrorKind::Unknown, "Failed to parse HTML")); - }; - - let mut data = match html.opengraph.images.first() { - None => UrlPreviewData::default(), - Some(obj) => download_image(services, client, &obj.url).await?, - }; - - let props = html.opengraph.properties; - - /* use OpenGraph title/description, but fall back to HTML if not available */ - data.title = props.get("title").cloned().or(html.title); - data.description = props.get("description").cloned().or(html.description); - - Ok(data) -} - -async fn request_url_preview(services: &Services, url: &str) -> Result { - if let Ok(ip) = IPAddress::parse(url) { - if !services.globals.valid_cidr_range(&ip) { - return Err(Error::BadServerResponse("Requesting from this address is forbidden")); - } - } - - let client = &services.client.url_preview; - let response = client.head(url).send().await?; - - if let Some(remote_addr) = response.remote_addr() { - if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) { - if !services.globals.valid_cidr_range(&ip) { - return Err(Error::BadServerResponse("Requesting from this address is forbidden")); - } - } - } - - let Some(content_type) = response - .headers() - .get(reqwest::header::CONTENT_TYPE) - .and_then(|x| x.to_str().ok()) - else { - return Err(Error::BadRequest(ErrorKind::Unknown, "Unknown Content-Type")); - }; - let data = match content_type { - html if html.starts_with("text/html") => download_html(services, client, url).await?, - img if img.starts_with("image/") => download_image(services, client, url).await?, - _ => return Err(Error::BadRequest(ErrorKind::Unknown, "Unsupported Content-Type")), - }; - - services.media.set_url_preview(url, &data).await?; - - Ok(data) -} - -async fn get_url_preview(services: &Services, url: &str) -> Result { - if let Some(preview) = services.media.get_url_preview(url).await { - return Ok(preview); - } - - // ensure that only one request is made per URL - let _request_lock = services.media.url_preview_mutex.lock(url).await; - - match services.media.get_url_preview(url).await { - Some(preview) => Ok(preview), - None => request_url_preview(services, url).await, - } -} - -fn url_preview_allowed(services: &Services, url_str: &str) -> bool { - let url: Url = match Url::parse(url_str) { - Ok(u) => u, - Err(e) => { - warn!("Failed to parse URL from a str: {}", e); - return false; - }, - }; - - if ["http", "https"] - .iter() - .all(|&scheme| scheme != url.scheme().to_lowercase()) - { - debug!("Ignoring non-HTTP/HTTPS URL to preview: {}", url); - return false; - } - - let host = match url.host_str() { - None => { - debug!("Ignoring URL preview for a URL that does not have a host (?): {}", url); - return false; - }, - Some(h) => h.to_owned(), - }; - - let allowlist_domain_contains = services.globals.url_preview_domain_contains_allowlist(); - let allowlist_domain_explicit = services.globals.url_preview_domain_explicit_allowlist(); - let denylist_domain_explicit = services.globals.url_preview_domain_explicit_denylist(); - let allowlist_url_contains = services.globals.url_preview_url_contains_allowlist(); - - if allowlist_domain_contains.contains(&"*".to_owned()) - || allowlist_domain_explicit.contains(&"*".to_owned()) - || allowlist_url_contains.contains(&"*".to_owned()) - { - debug!("Config key contains * which is allowing all URL previews. Allowing URL {}", url); - return true; - } - - if !host.is_empty() { - if denylist_domain_explicit.contains(&host) { - debug!( - "Host {} is not allowed by url_preview_domain_explicit_denylist (check 1/4)", - &host - ); - return false; - } - - if allowlist_domain_explicit.contains(&host) { - debug!("Host {} is allowed by url_preview_domain_explicit_allowlist (check 2/4)", &host); - return true; - } - - if allowlist_domain_contains - .iter() - .any(|domain_s| domain_s.contains(&host.clone())) - { - debug!("Host {} is allowed by url_preview_domain_contains_allowlist (check 3/4)", &host); - return true; - } - - if allowlist_url_contains - .iter() - .any(|url_s| url.to_string().contains(&url_s.to_string())) - { - debug!("URL {} is allowed by url_preview_url_contains_allowlist (check 4/4)", &host); - return true; - } - - // check root domain if available and if user has root domain checks - if services.globals.url_preview_check_root_domain() { - debug!("Checking root domain"); - match host.split_once('.') { - None => return false, - Some((_, root_domain)) => { - if denylist_domain_explicit.contains(&root_domain.to_owned()) { - debug!( - "Root domain {} is not allowed by url_preview_domain_explicit_denylist (check 1/3)", - &root_domain - ); - return true; - } - - if allowlist_domain_explicit.contains(&root_domain.to_owned()) { - debug!( - "Root domain {} is allowed by url_preview_domain_explicit_allowlist (check 2/3)", - &root_domain - ); - return true; - } - - if allowlist_domain_contains - .iter() - .any(|domain_s| domain_s.contains(&root_domain.to_owned())) - { - debug!( - "Root domain {} is allowed by url_preview_domain_contains_allowlist (check 3/3)", - &root_domain - ); - return true; - } - }, - } - } - } - - false -} diff --git a/src/core/mod.rs b/src/core/mod.rs index 5ed5ea15..9898243b 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -16,7 +16,7 @@ pub use error::Error; pub use info::{rustc_flags_capture, version, version::version}; pub use pdu::{PduBuilder, PduCount, PduEvent}; pub use server::Server; -pub use utils::{ctor, dtor}; +pub use utils::{ctor, dtor, implement}; pub use crate as conduit_core; diff --git a/src/core/utils/mod.rs b/src/core/utils/mod.rs index 767b65a9..1556646e 100644 --- a/src/core/utils/mod.rs +++ b/src/core/utils/mod.rs @@ -17,6 +17,7 @@ use std::cmp::{self, Ordering}; pub use ::ctor::{ctor, dtor}; pub use bytes::{increment, u64_from_bytes, u64_from_u8, u64_from_u8x8}; +pub use conduit_macros::implement; pub use debug::slice_truncated as debug_slice_truncated; pub use hash::calculate_hash; pub use html::Escape as HtmlEscape; diff --git a/src/service/Cargo.toml b/src/service/Cargo.toml index c444e5f5..c1d9889e 100644 --- a/src/service/Cargo.toml +++ b/src/service/Cargo.toml @@ -69,6 +69,7 @@ termimad.optional = true tokio.workspace = true tracing.workspace = true url.workspace = true +webpage.workspace = true [lints] workspace = true diff --git a/src/service/media/data.rs b/src/service/media/data.rs index 617ec526..70e010c2 100644 --- a/src/service/media/data.rs +++ b/src/service/media/data.rs @@ -4,7 +4,7 @@ use conduit::{debug, debug_info, utils::string_from_bytes, Error, Result}; use database::{Database, Map}; use ruma::api::client::error::ErrorKind; -use crate::media::UrlPreviewData; +use super::preview::UrlPreviewData; pub(crate) struct Data { mediaid_file: Arc, diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs index 1faa9768..ff3f3dc4 100644 --- a/src/service/media/mod.rs +++ b/src/service/media/mod.rs @@ -1,4 +1,5 @@ mod data; +mod preview; mod tests; mod thumbnail; @@ -9,13 +10,12 @@ use base64::{engine::general_purpose, Engine as _}; use conduit::{debug, debug_error, err, error, trace, utils, utils::MutexMap, Err, Result, Server}; use data::{Data, Metadata}; use ruma::{OwnedMxcUri, OwnedUserId}; -use serde::Serialize; use tokio::{ fs, io::{AsyncReadExt, AsyncWriteExt, BufReader}, }; -use crate::{globals, Dep}; +use crate::{client, globals, Dep}; #[derive(Debug)] pub struct FileMeta { @@ -24,43 +24,32 @@ pub struct FileMeta { pub content_disposition: Option, } -#[derive(Serialize, Default)] -pub struct UrlPreviewData { - #[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:title"))] - pub title: Option, - #[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:description"))] - pub description: Option, - #[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:image"))] - pub image: Option, - #[serde(skip_serializing_if = "Option::is_none", rename(serialize = "matrix:image:size"))] - pub image_size: Option, - #[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:image:width"))] - pub image_width: Option, - #[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:image:height"))] - pub image_height: Option, -} - pub struct Service { - services: Services, + url_preview_mutex: MutexMap, pub(crate) db: Data, - pub url_preview_mutex: MutexMap, + services: Services, } struct Services { server: Arc, + client: Dep, globals: Dep, } +/// generated MXC ID (`media-id`) length +pub const MXC_LENGTH: usize = 32; + #[async_trait] impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { + url_preview_mutex: MutexMap::new(), + db: Data::new(args.db), services: Services { server: args.server.clone(), + client: args.depend::("client"), globals: args.depend::("globals"), }, - db: Data::new(args.db), - url_preview_mutex: MutexMap::new(), })) } @@ -229,22 +218,6 @@ impl Service { Ok(deletion_count) } - pub async fn get_url_preview(&self, url: &str) -> Option { self.db.get_url_preview(url) } - - /// TODO: use this? - #[allow(dead_code)] - pub async fn remove_url_preview(&self, url: &str) -> Result<()> { - // TODO: also remove the downloaded image - self.db.remove_url_preview(url) - } - - pub async fn set_url_preview(&self, url: &str, data: &UrlPreviewData) -> Result<()> { - let now = SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH) - .expect("valid system time"); - self.db.set_url_preview(url, data, now) - } - pub async fn create_media_dir(&self) -> Result<()> { let dir = self.get_media_dir(); Ok(fs::create_dir_all(dir).await?) diff --git a/src/service/media/preview.rs b/src/service/media/preview.rs new file mode 100644 index 00000000..ebfd22e7 --- /dev/null +++ b/src/service/media/preview.rs @@ -0,0 +1,275 @@ +use std::{io::Cursor, time::SystemTime}; + +use conduit::{debug, utils, warn, Error, Result}; +use conduit_core::implement; +use image::ImageReader as ImgReader; +use ipaddress::IPAddress; +use ruma::api::client::error::ErrorKind; +use serde::Serialize; +use url::Url; +use webpage::HTML; + +use super::{Service, MXC_LENGTH}; + +#[derive(Serialize, Default)] +pub struct UrlPreviewData { + #[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:title"))] + pub title: Option, + #[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:description"))] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:image"))] + pub image: Option, + #[serde(skip_serializing_if = "Option::is_none", rename(serialize = "matrix:image:size"))] + pub image_size: Option, + #[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:image:width"))] + pub image_width: Option, + #[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:image:height"))] + pub image_height: Option, +} + +#[implement(Service)] +pub async fn remove_url_preview(&self, url: &str) -> Result<()> { + // TODO: also remove the downloaded image + self.db.remove_url_preview(url) +} + +#[implement(Service)] +pub async fn set_url_preview(&self, url: &str, data: &UrlPreviewData) -> Result<()> { + let now = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .expect("valid system time"); + self.db.set_url_preview(url, data, now) +} + +#[implement(Service)] +pub async fn download_image(&self, url: &str) -> Result { + let client = &self.services.client.url_preview; + let image = client.get(url).send().await?.bytes().await?; + let mxc = format!( + "mxc://{}/{}", + self.services.globals.server_name(), + utils::random_string(MXC_LENGTH) + ); + + self.create(None, &mxc, None, None, &image).await?; + + let (width, height) = match ImgReader::new(Cursor::new(&image)).with_guessed_format() { + Err(_) => (None, None), + Ok(reader) => match reader.into_dimensions() { + Err(_) => (None, None), + Ok((width, height)) => (Some(width), Some(height)), + }, + }; + + Ok(UrlPreviewData { + image: Some(mxc), + image_size: Some(image.len()), + image_width: width, + image_height: height, + ..Default::default() + }) +} + +#[implement(Service)] +pub async fn get_url_preview(&self, url: &str) -> Result { + if let Some(preview) = self.db.get_url_preview(url) { + return Ok(preview); + } + + // ensure that only one request is made per URL + let _request_lock = self.url_preview_mutex.lock(url).await; + + match self.db.get_url_preview(url) { + Some(preview) => Ok(preview), + None => self.request_url_preview(url).await, + } +} + +#[implement(Service)] +async fn request_url_preview(&self, url: &str) -> Result { + if let Ok(ip) = IPAddress::parse(url) { + if !self.services.globals.valid_cidr_range(&ip) { + return Err(Error::BadServerResponse("Requesting from this address is forbidden")); + } + } + + let client = &self.services.client.url_preview; + let response = client.head(url).send().await?; + + if let Some(remote_addr) = response.remote_addr() { + if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) { + if !self.services.globals.valid_cidr_range(&ip) { + return Err(Error::BadServerResponse("Requesting from this address is forbidden")); + } + } + } + + let Some(content_type) = response + .headers() + .get(reqwest::header::CONTENT_TYPE) + .and_then(|x| x.to_str().ok()) + else { + return Err(Error::BadRequest(ErrorKind::Unknown, "Unknown Content-Type")); + }; + let data = match content_type { + html if html.starts_with("text/html") => self.download_html(url).await?, + img if img.starts_with("image/") => self.download_image(url).await?, + _ => return Err(Error::BadRequest(ErrorKind::Unknown, "Unsupported Content-Type")), + }; + + self.set_url_preview(url, &data).await?; + + Ok(data) +} + +#[implement(Service)] +async fn download_html(&self, url: &str) -> Result { + let client = &self.services.client.url_preview; + let mut response = client.get(url).send().await?; + + let mut bytes: Vec = Vec::new(); + while let Some(chunk) = response.chunk().await? { + bytes.extend_from_slice(&chunk); + if bytes.len() > self.services.globals.url_preview_max_spider_size() { + debug!( + "Response body from URL {} exceeds url_preview_max_spider_size ({}), not processing the rest of the \ + response body and assuming our necessary data is in this range.", + url, + self.services.globals.url_preview_max_spider_size() + ); + break; + } + } + let body = String::from_utf8_lossy(&bytes); + let Ok(html) = HTML::from_string(body.to_string(), Some(url.to_owned())) else { + return Err(Error::BadRequest(ErrorKind::Unknown, "Failed to parse HTML")); + }; + + let mut data = match html.opengraph.images.first() { + None => UrlPreviewData::default(), + Some(obj) => self.download_image(&obj.url).await?, + }; + + let props = html.opengraph.properties; + + /* use OpenGraph title/description, but fall back to HTML if not available */ + data.title = props.get("title").cloned().or(html.title); + data.description = props.get("description").cloned().or(html.description); + + Ok(data) +} + +#[implement(Service)] +pub fn url_preview_allowed(&self, url_str: &str) -> bool { + let url: Url = match Url::parse(url_str) { + Ok(u) => u, + Err(e) => { + warn!("Failed to parse URL from a str: {}", e); + return false; + }, + }; + + if ["http", "https"] + .iter() + .all(|&scheme| scheme != url.scheme().to_lowercase()) + { + debug!("Ignoring non-HTTP/HTTPS URL to preview: {}", url); + return false; + } + + let host = match url.host_str() { + None => { + debug!("Ignoring URL preview for a URL that does not have a host (?): {}", url); + return false; + }, + Some(h) => h.to_owned(), + }; + + let allowlist_domain_contains = self + .services + .globals + .url_preview_domain_contains_allowlist(); + let allowlist_domain_explicit = self + .services + .globals + .url_preview_domain_explicit_allowlist(); + let denylist_domain_explicit = self.services.globals.url_preview_domain_explicit_denylist(); + let allowlist_url_contains = self.services.globals.url_preview_url_contains_allowlist(); + + if allowlist_domain_contains.contains(&"*".to_owned()) + || allowlist_domain_explicit.contains(&"*".to_owned()) + || allowlist_url_contains.contains(&"*".to_owned()) + { + debug!("Config key contains * which is allowing all URL previews. Allowing URL {}", url); + return true; + } + + if !host.is_empty() { + if denylist_domain_explicit.contains(&host) { + debug!( + "Host {} is not allowed by url_preview_domain_explicit_denylist (check 1/4)", + &host + ); + return false; + } + + if allowlist_domain_explicit.contains(&host) { + debug!("Host {} is allowed by url_preview_domain_explicit_allowlist (check 2/4)", &host); + return true; + } + + if allowlist_domain_contains + .iter() + .any(|domain_s| domain_s.contains(&host.clone())) + { + debug!("Host {} is allowed by url_preview_domain_contains_allowlist (check 3/4)", &host); + return true; + } + + if allowlist_url_contains + .iter() + .any(|url_s| url.to_string().contains(&url_s.to_string())) + { + debug!("URL {} is allowed by url_preview_url_contains_allowlist (check 4/4)", &host); + return true; + } + + // check root domain if available and if user has root domain checks + if self.services.globals.url_preview_check_root_domain() { + debug!("Checking root domain"); + match host.split_once('.') { + None => return false, + Some((_, root_domain)) => { + if denylist_domain_explicit.contains(&root_domain.to_owned()) { + debug!( + "Root domain {} is not allowed by url_preview_domain_explicit_denylist (check 1/3)", + &root_domain + ); + return true; + } + + if allowlist_domain_explicit.contains(&root_domain.to_owned()) { + debug!( + "Root domain {} is allowed by url_preview_domain_explicit_allowlist (check 2/3)", + &root_domain + ); + return true; + } + + if allowlist_domain_contains + .iter() + .any(|domain_s| domain_s.contains(&root_domain.to_owned())) + { + debug!( + "Root domain {} is allowed by url_preview_domain_contains_allowlist (check 3/3)", + &root_domain + ); + return true; + } + }, + } + } + } + + false +}