From 8ecf722abb81442ff097141d8dfbc2d318c4056d Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Mon, 29 Apr 2024 17:41:15 -0700 Subject: [PATCH] split http serving from main. Signed-off-by: Jason Volk --- src/main.rs | 228 +------------------------------------ src/router/mod.rs | 226 ++++++++++++++++++++++++++++++++++++ src/{ => router}/routes.rs | 0 3 files changed, 232 insertions(+), 222 deletions(-) create mode 100644 src/router/mod.rs rename src/{ => router}/routes.rs (100%) diff --git a/src/main.rs b/src/main.rs index 47126628..81b71435 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,41 +7,22 @@ use std::os::unix::fs::PermissionsExt as _; /* not unix specific, just only for // are not stable as of writing This is the case for every other occurence of // sync Mutex/RwLock, except for database related ones use std::sync::{Arc, RwLock}; -use std::{any::Any, io, net::SocketAddr, sync::atomic, time::Duration}; +use std::{io, net::SocketAddr, time::Duration}; use api::ruma_wrapper::{Ruma, RumaResponse}; -use axum::{ - extract::{DefaultBodyLimit, MatchedPath}, - response::IntoResponse, - Router, -}; +use axum::Router; use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle}; #[cfg(feature = "axum_dual_protocol")] use axum_server_dual_protocol::ServerExt; use config::Config; use database::KeyValueDatabase; -use http::{ - header::{self, HeaderName}, - Method, StatusCode, Uri, -}; -use ruma::api::client::{ - error::{Error as RumaError, ErrorBody, ErrorKind}, - uiaa::UiaaResponse, -}; use service::{pdu::PduEvent, Services}; use tokio::{ signal, sync::oneshot::{self, Sender}, task::JoinSet, }; -use tower::ServiceBuilder; -use tower_http::{ - catch_panic::CatchPanicLayer, - cors::{self, CorsLayer}, - trace::{DefaultOnFailure, DefaultOnRequest, DefaultOnResponse, TraceLayer}, - ServiceBuilderExt as _, -}; -use tracing::{debug, error, info, trace, warn, Level}; +use tracing::{debug, error, info, warn}; use tracing_subscriber::{prelude::*, reload, EnvFilter, Registry}; use utils::{ clap, @@ -52,7 +33,7 @@ pub(crate) mod alloc; mod api; mod config; mod database; -mod routes; +mod router; mod service; mod utils; @@ -66,7 +47,7 @@ pub(crate) fn services() -> &'static Services<'static> { .expect("SERVICES should be initialized when this is called") } -struct Server { +pub(crate) struct Server { config: Config, runtime: tokio::runtime::Runtime, @@ -108,7 +89,7 @@ async fn async_main(server: &Server) -> Result<(), Error> { } async fn run(server: &Server) -> io::Result<()> { - let app = build(server).await?; + let app = router::build(server).await?; let (tx, rx) = oneshot::channel::<()>(); let handle = ServerHandle::new(); tokio::spawn(shutdown(handle.clone(), tx)); @@ -286,179 +267,6 @@ async fn start(server: &Server) -> Result<(), Error> { Ok(()) } -async fn build(server: &Server) -> io::Result> { - let base_middlewares = ServiceBuilder::new(); - #[cfg(feature = "sentry_telemetry")] - let base_middlewares = base_middlewares.layer(sentry_tower::NewSentryLayer::>::new_from_top()); - - let x_forwarded_for = HeaderName::from_static("x-forwarded-for"); - let middlewares = base_middlewares - .sensitive_headers([header::AUTHORIZATION]) - .sensitive_request_headers([x_forwarded_for].into()) - .layer(axum::middleware::from_fn(request_spawn)) - .layer( - TraceLayer::new_for_http() - .make_span_with(tracing_span::<_>) - .on_failure(DefaultOnFailure::new().level(Level::ERROR)) - .on_request(DefaultOnRequest::new().level(Level::TRACE)) - .on_response(DefaultOnResponse::new().level(Level::DEBUG)), - ) - .layer(axum::middleware::from_fn(request_handle)) - .layer(cors_layer(server)) - .layer(DefaultBodyLimit::max( - server - .config - .max_request_size - .try_into() - .expect("failed to convert max request size"), - )) - .layer(CatchPanicLayer::custom(catch_panic_layer)); - - #[cfg(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression"))] - { - Ok(routes::routes(&server.config) - .layer(compression_layer(server)) - .layer(middlewares) - .into_make_service()) - } - #[cfg(not(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression")))] - { - Ok(routes::routes().layer(middlewares).into_make_service()) - } -} - -#[tracing::instrument(skip_all, name = "spawn")] -async fn request_spawn( - req: http::Request, next: axum::middleware::Next, -) -> Result { - if services().globals.shutdown.load(atomic::Ordering::Relaxed) { - return Err(StatusCode::SERVICE_UNAVAILABLE); - } - - let fut = next.run(req); - let task = tokio::spawn(fut); - task.await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) -} - -#[tracing::instrument(skip_all, name = "handle")] -async fn request_handle( - req: http::Request, next: axum::middleware::Next, -) -> Result { - let method = req.method().clone(); - let uri = req.uri().clone(); - let result = next.run(req).await; - request_result(&method, &uri, result) -} - -fn request_result( - method: &Method, uri: &Uri, result: axum::response::Response, -) -> Result { - request_result_log(method, uri, &result); - match result.status() { - StatusCode::METHOD_NOT_ALLOWED => request_result_403(method, uri, &result), - _ => Ok(result), - } -} - -#[allow(clippy::unnecessary_wraps)] -fn request_result_403( - _method: &Method, _uri: &Uri, result: &axum::response::Response, -) -> Result { - let error = UiaaResponse::MatrixError(RumaError { - status_code: result.status(), - body: ErrorBody::Standard { - kind: ErrorKind::Unrecognized, - message: "M_UNRECOGNIZED: Method not allowed for endpoint".to_owned(), - }, - }); - - Ok(RumaResponse(error).into_response()) -} - -fn request_result_log(method: &Method, uri: &Uri, result: &axum::response::Response) { - let status = result.status(); - let reason = status.canonical_reason().unwrap_or("Unknown Reason"); - let code = status.as_u16(); - if status.is_server_error() { - error!(method = ?method, uri = ?uri, "{code} {reason}"); - } else if status.is_client_error() { - debug_error!(method = ?method, uri = ?uri, "{code} {reason}"); - } else if status.is_redirection() { - debug!(method = ?method, uri = ?uri, "{code} {reason}"); - } else { - trace!(method = ?method, uri = ?uri, "{code} {reason}"); - } -} - -fn cors_layer(_server: &Server) -> CorsLayer { - const METHODS: [Method; 6] = [ - Method::GET, - Method::HEAD, - Method::POST, - Method::PUT, - Method::DELETE, - Method::OPTIONS, - ]; - - let headers: [HeaderName; 5] = [ - header::ORIGIN, - HeaderName::from_lowercase(b"x-requested-with").unwrap(), - header::CONTENT_TYPE, - header::ACCEPT, - header::AUTHORIZATION, - ]; - - CorsLayer::new() - .allow_origin(cors::Any) - .allow_methods(METHODS) - .allow_headers(headers) - .max_age(Duration::from_secs(86400)) -} - -#[cfg(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression"))] -fn compression_layer(server: &Server) -> tower_http::compression::CompressionLayer { - let mut compression_layer = tower_http::compression::CompressionLayer::new(); - - #[cfg(feature = "zstd_compression")] - { - if server.config.zstd_compression { - compression_layer = compression_layer.zstd(true); - } else { - compression_layer = compression_layer.no_zstd(); - }; - }; - - #[cfg(feature = "gzip_compression")] - { - if server.config.gzip_compression { - compression_layer = compression_layer.gzip(true); - } else { - compression_layer = compression_layer.no_gzip(); - }; - }; - - #[cfg(feature = "brotli_compression")] - { - if server.config.brotli_compression { - compression_layer = compression_layer.br(true); - } else { - compression_layer = compression_layer.no_br(); - }; - }; - - compression_layer -} - -fn tracing_span(request: &http::Request) -> tracing::Span { - let path = if let Some(path) = request.extensions().get::() { - path.as_str() - } else { - request.uri().path() - }; - - tracing::info_span!("router:", %path) -} - /// Non-async initializations fn init(args: clap::Args) -> Result { let config = Config::new(args.config)?; @@ -692,27 +500,3 @@ fn maximize_fd_limit() -> Result<(), nix::errno::Errno> { Ok(()) } - -#[allow(clippy::needless_pass_by_value)] -fn catch_panic_layer(err: Box) -> http::Response> { - let details = if let Some(s) = err.downcast_ref::() { - s.clone() - } else if let Some(s) = err.downcast_ref::<&str>() { - s.to_string() - } else { - "Unknown internal server error occurred.".to_owned() - }; - - let body = serde_json::json!({ - "errcode": "M_UNKNOWN", - "error": "M_UNKNOWN: Internal server error occurred", - "details": details, - }) - .to_string(); - - http::Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .header(header::CONTENT_TYPE, "application/json") - .body(http_body_util::Full::from(body)) - .expect("Failed to create response for our panic catcher?") -} diff --git a/src/router/mod.rs b/src/router/mod.rs new file mode 100644 index 00000000..a601e76b --- /dev/null +++ b/src/router/mod.rs @@ -0,0 +1,226 @@ +use std::{any::Any, io, sync::atomic, time::Duration}; + +use axum::{ + extract::{DefaultBodyLimit, MatchedPath}, + response::IntoResponse, + Router, +}; +#[cfg(feature = "axum_dual_protocol")] +use axum_server_dual_protocol::ServerExt; +use http::{ + header::{self, HeaderName}, + Method, StatusCode, Uri, +}; +use ruma::api::client::{ + error::{Error as RumaError, ErrorBody, ErrorKind}, + uiaa::UiaaResponse, +}; +use tower::ServiceBuilder; +use tower_http::{ + catch_panic::CatchPanicLayer, + cors::{self, CorsLayer}, + trace::{DefaultOnFailure, DefaultOnRequest, DefaultOnResponse, TraceLayer}, + ServiceBuilderExt as _, +}; +use tracing::{debug, error, trace, Level}; + +use super::{api::ruma_wrapper::RumaResponse, debug_error, services, utils::error::Result, Server}; + +mod routes; + +pub(crate) async fn build(server: &Server) -> io::Result> { + let base_middlewares = ServiceBuilder::new(); + #[cfg(feature = "sentry_telemetry")] + let base_middlewares = base_middlewares.layer(sentry_tower::NewSentryLayer::>::new_from_top()); + + let x_forwarded_for = HeaderName::from_static("x-forwarded-for"); + let middlewares = base_middlewares + .sensitive_headers([header::AUTHORIZATION]) + .sensitive_request_headers([x_forwarded_for].into()) + .layer(axum::middleware::from_fn(request_spawn)) + .layer( + TraceLayer::new_for_http() + .make_span_with(tracing_span::<_>) + .on_failure(DefaultOnFailure::new().level(Level::ERROR)) + .on_request(DefaultOnRequest::new().level(Level::TRACE)) + .on_response(DefaultOnResponse::new().level(Level::DEBUG)), + ) + .layer(axum::middleware::from_fn(request_handle)) + .layer(cors_layer(server)) + .layer(DefaultBodyLimit::max( + server + .config + .max_request_size + .try_into() + .expect("failed to convert max request size"), + )) + .layer(CatchPanicLayer::custom(catch_panic_layer)); + + #[cfg(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression"))] + { + Ok(routes::routes(&server.config) + .layer(compression_layer(server)) + .layer(middlewares) + .into_make_service()) + } + #[cfg(not(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression")))] + { + Ok(routes::routes().layer(middlewares).into_make_service()) + } +} + +#[tracing::instrument(skip_all, name = "spawn")] +async fn request_spawn( + req: http::Request, next: axum::middleware::Next, +) -> Result { + if services().globals.shutdown.load(atomic::Ordering::Relaxed) { + return Err(StatusCode::SERVICE_UNAVAILABLE); + } + + let fut = next.run(req); + let task = tokio::spawn(fut); + task.await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) +} + +#[tracing::instrument(skip_all, name = "handle")] +async fn request_handle( + req: http::Request, next: axum::middleware::Next, +) -> Result { + let method = req.method().clone(); + let uri = req.uri().clone(); + let result = next.run(req).await; + request_result(&method, &uri, result) +} + +fn request_result( + method: &Method, uri: &Uri, result: axum::response::Response, +) -> Result { + request_result_log(method, uri, &result); + match result.status() { + StatusCode::METHOD_NOT_ALLOWED => request_result_403(method, uri, &result), + _ => Ok(result), + } +} + +#[allow(clippy::unnecessary_wraps)] +fn request_result_403( + _method: &Method, _uri: &Uri, result: &axum::response::Response, +) -> Result { + let error = UiaaResponse::MatrixError(RumaError { + status_code: result.status(), + body: ErrorBody::Standard { + kind: ErrorKind::Unrecognized, + message: "M_UNRECOGNIZED: Method not allowed for endpoint".to_owned(), + }, + }); + + Ok(RumaResponse(error).into_response()) +} + +fn request_result_log(method: &Method, uri: &Uri, result: &axum::response::Response) { + let status = result.status(); + let reason = status.canonical_reason().unwrap_or("Unknown Reason"); + let code = status.as_u16(); + if status.is_server_error() { + error!(method = ?method, uri = ?uri, "{code} {reason}"); + } else if status.is_client_error() { + debug_error!(method = ?method, uri = ?uri, "{code} {reason}"); + } else if status.is_redirection() { + debug!(method = ?method, uri = ?uri, "{code} {reason}"); + } else { + trace!(method = ?method, uri = ?uri, "{code} {reason}"); + } +} + +fn cors_layer(_server: &Server) -> CorsLayer { + const METHODS: [Method; 6] = [ + Method::GET, + Method::HEAD, + Method::POST, + Method::PUT, + Method::DELETE, + Method::OPTIONS, + ]; + + let headers: [HeaderName; 5] = [ + header::ORIGIN, + HeaderName::from_lowercase(b"x-requested-with").unwrap(), + header::CONTENT_TYPE, + header::ACCEPT, + header::AUTHORIZATION, + ]; + + CorsLayer::new() + .allow_origin(cors::Any) + .allow_methods(METHODS) + .allow_headers(headers) + .max_age(Duration::from_secs(86400)) +} + +#[cfg(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression"))] +fn compression_layer(server: &Server) -> tower_http::compression::CompressionLayer { + let mut compression_layer = tower_http::compression::CompressionLayer::new(); + + #[cfg(feature = "zstd_compression")] + { + if server.config.zstd_compression { + compression_layer = compression_layer.zstd(true); + } else { + compression_layer = compression_layer.no_zstd(); + }; + }; + + #[cfg(feature = "gzip_compression")] + { + if server.config.gzip_compression { + compression_layer = compression_layer.gzip(true); + } else { + compression_layer = compression_layer.no_gzip(); + }; + }; + + #[cfg(feature = "brotli_compression")] + { + if server.config.brotli_compression { + compression_layer = compression_layer.br(true); + } else { + compression_layer = compression_layer.no_br(); + }; + }; + + compression_layer +} + +fn tracing_span(request: &http::Request) -> tracing::Span { + let path = if let Some(path) = request.extensions().get::() { + path.as_str() + } else { + request.uri().path() + }; + + tracing::info_span!("router:", %path) +} + +#[allow(clippy::needless_pass_by_value)] +fn catch_panic_layer(err: Box) -> http::Response> { + let details = if let Some(s) = err.downcast_ref::() { + s.clone() + } else if let Some(s) = err.downcast_ref::<&str>() { + s.to_string() + } else { + "Unknown internal server error occurred.".to_owned() + }; + + let body = serde_json::json!({ + "errcode": "M_UNKNOWN", + "error": "M_UNKNOWN: Internal server error occurred", + "details": details, + }) + .to_string(); + + http::Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .header(header::CONTENT_TYPE, "application/json") + .body(http_body_util::Full::from(body)) + .expect("Failed to create response for our panic catcher?") +} diff --git a/src/routes.rs b/src/router/routes.rs similarity index 100% rename from src/routes.rs rename to src/router/routes.rs