split http serving from main.
Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
parent
5d76db8f19
commit
8ecf722abb
3 changed files with 232 additions and 222 deletions
228
src/main.rs
228
src/main.rs
|
@ -7,41 +7,22 @@ use std::os::unix::fs::PermissionsExt as _; /* not unix specific, just only for
|
|||
// are not stable as of writing This is the case for every other occurence of
|
||||
// sync Mutex/RwLock, except for database related ones
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::{any::Any, io, net::SocketAddr, sync::atomic, time::Duration};
|
||||
use std::{io, net::SocketAddr, time::Duration};
|
||||
|
||||
use api::ruma_wrapper::{Ruma, RumaResponse};
|
||||
use axum::{
|
||||
extract::{DefaultBodyLimit, MatchedPath},
|
||||
response::IntoResponse,
|
||||
Router,
|
||||
};
|
||||
use axum::Router;
|
||||
use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle};
|
||||
#[cfg(feature = "axum_dual_protocol")]
|
||||
use axum_server_dual_protocol::ServerExt;
|
||||
use config::Config;
|
||||
use database::KeyValueDatabase;
|
||||
use http::{
|
||||
header::{self, HeaderName},
|
||||
Method, StatusCode, Uri,
|
||||
};
|
||||
use ruma::api::client::{
|
||||
error::{Error as RumaError, ErrorBody, ErrorKind},
|
||||
uiaa::UiaaResponse,
|
||||
};
|
||||
use service::{pdu::PduEvent, Services};
|
||||
use tokio::{
|
||||
signal,
|
||||
sync::oneshot::{self, Sender},
|
||||
task::JoinSet,
|
||||
};
|
||||
use tower::ServiceBuilder;
|
||||
use tower_http::{
|
||||
catch_panic::CatchPanicLayer,
|
||||
cors::{self, CorsLayer},
|
||||
trace::{DefaultOnFailure, DefaultOnRequest, DefaultOnResponse, TraceLayer},
|
||||
ServiceBuilderExt as _,
|
||||
};
|
||||
use tracing::{debug, error, info, trace, warn, Level};
|
||||
use tracing::{debug, error, info, warn};
|
||||
use tracing_subscriber::{prelude::*, reload, EnvFilter, Registry};
|
||||
use utils::{
|
||||
clap,
|
||||
|
@ -52,7 +33,7 @@ pub(crate) mod alloc;
|
|||
mod api;
|
||||
mod config;
|
||||
mod database;
|
||||
mod routes;
|
||||
mod router;
|
||||
mod service;
|
||||
mod utils;
|
||||
|
||||
|
@ -66,7 +47,7 @@ pub(crate) fn services() -> &'static Services<'static> {
|
|||
.expect("SERVICES should be initialized when this is called")
|
||||
}
|
||||
|
||||
struct Server {
|
||||
pub(crate) struct Server {
|
||||
config: Config,
|
||||
|
||||
runtime: tokio::runtime::Runtime,
|
||||
|
@ -108,7 +89,7 @@ async fn async_main(server: &Server) -> Result<(), Error> {
|
|||
}
|
||||
|
||||
async fn run(server: &Server) -> io::Result<()> {
|
||||
let app = build(server).await?;
|
||||
let app = router::build(server).await?;
|
||||
let (tx, rx) = oneshot::channel::<()>();
|
||||
let handle = ServerHandle::new();
|
||||
tokio::spawn(shutdown(handle.clone(), tx));
|
||||
|
@ -286,179 +267,6 @@ async fn start(server: &Server) -> Result<(), Error> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn build(server: &Server) -> io::Result<axum::routing::IntoMakeService<Router>> {
|
||||
let base_middlewares = ServiceBuilder::new();
|
||||
#[cfg(feature = "sentry_telemetry")]
|
||||
let base_middlewares = base_middlewares.layer(sentry_tower::NewSentryLayer::<http::Request<_>>::new_from_top());
|
||||
|
||||
let x_forwarded_for = HeaderName::from_static("x-forwarded-for");
|
||||
let middlewares = base_middlewares
|
||||
.sensitive_headers([header::AUTHORIZATION])
|
||||
.sensitive_request_headers([x_forwarded_for].into())
|
||||
.layer(axum::middleware::from_fn(request_spawn))
|
||||
.layer(
|
||||
TraceLayer::new_for_http()
|
||||
.make_span_with(tracing_span::<_>)
|
||||
.on_failure(DefaultOnFailure::new().level(Level::ERROR))
|
||||
.on_request(DefaultOnRequest::new().level(Level::TRACE))
|
||||
.on_response(DefaultOnResponse::new().level(Level::DEBUG)),
|
||||
)
|
||||
.layer(axum::middleware::from_fn(request_handle))
|
||||
.layer(cors_layer(server))
|
||||
.layer(DefaultBodyLimit::max(
|
||||
server
|
||||
.config
|
||||
.max_request_size
|
||||
.try_into()
|
||||
.expect("failed to convert max request size"),
|
||||
))
|
||||
.layer(CatchPanicLayer::custom(catch_panic_layer));
|
||||
|
||||
#[cfg(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression"))]
|
||||
{
|
||||
Ok(routes::routes(&server.config)
|
||||
.layer(compression_layer(server))
|
||||
.layer(middlewares)
|
||||
.into_make_service())
|
||||
}
|
||||
#[cfg(not(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression")))]
|
||||
{
|
||||
Ok(routes::routes().layer(middlewares).into_make_service())
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, name = "spawn")]
|
||||
async fn request_spawn(
|
||||
req: http::Request<axum::body::Body>, next: axum::middleware::Next,
|
||||
) -> Result<axum::response::Response, StatusCode> {
|
||||
if services().globals.shutdown.load(atomic::Ordering::Relaxed) {
|
||||
return Err(StatusCode::SERVICE_UNAVAILABLE);
|
||||
}
|
||||
|
||||
let fut = next.run(req);
|
||||
let task = tokio::spawn(fut);
|
||||
task.await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, name = "handle")]
|
||||
async fn request_handle(
|
||||
req: http::Request<axum::body::Body>, next: axum::middleware::Next,
|
||||
) -> Result<axum::response::Response, StatusCode> {
|
||||
let method = req.method().clone();
|
||||
let uri = req.uri().clone();
|
||||
let result = next.run(req).await;
|
||||
request_result(&method, &uri, result)
|
||||
}
|
||||
|
||||
fn request_result(
|
||||
method: &Method, uri: &Uri, result: axum::response::Response,
|
||||
) -> Result<axum::response::Response, StatusCode> {
|
||||
request_result_log(method, uri, &result);
|
||||
match result.status() {
|
||||
StatusCode::METHOD_NOT_ALLOWED => request_result_403(method, uri, &result),
|
||||
_ => Ok(result),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::unnecessary_wraps)]
|
||||
fn request_result_403(
|
||||
_method: &Method, _uri: &Uri, result: &axum::response::Response,
|
||||
) -> Result<axum::response::Response, StatusCode> {
|
||||
let error = UiaaResponse::MatrixError(RumaError {
|
||||
status_code: result.status(),
|
||||
body: ErrorBody::Standard {
|
||||
kind: ErrorKind::Unrecognized,
|
||||
message: "M_UNRECOGNIZED: Method not allowed for endpoint".to_owned(),
|
||||
},
|
||||
});
|
||||
|
||||
Ok(RumaResponse(error).into_response())
|
||||
}
|
||||
|
||||
fn request_result_log(method: &Method, uri: &Uri, result: &axum::response::Response) {
|
||||
let status = result.status();
|
||||
let reason = status.canonical_reason().unwrap_or("Unknown Reason");
|
||||
let code = status.as_u16();
|
||||
if status.is_server_error() {
|
||||
error!(method = ?method, uri = ?uri, "{code} {reason}");
|
||||
} else if status.is_client_error() {
|
||||
debug_error!(method = ?method, uri = ?uri, "{code} {reason}");
|
||||
} else if status.is_redirection() {
|
||||
debug!(method = ?method, uri = ?uri, "{code} {reason}");
|
||||
} else {
|
||||
trace!(method = ?method, uri = ?uri, "{code} {reason}");
|
||||
}
|
||||
}
|
||||
|
||||
fn cors_layer(_server: &Server) -> CorsLayer {
|
||||
const METHODS: [Method; 6] = [
|
||||
Method::GET,
|
||||
Method::HEAD,
|
||||
Method::POST,
|
||||
Method::PUT,
|
||||
Method::DELETE,
|
||||
Method::OPTIONS,
|
||||
];
|
||||
|
||||
let headers: [HeaderName; 5] = [
|
||||
header::ORIGIN,
|
||||
HeaderName::from_lowercase(b"x-requested-with").unwrap(),
|
||||
header::CONTENT_TYPE,
|
||||
header::ACCEPT,
|
||||
header::AUTHORIZATION,
|
||||
];
|
||||
|
||||
CorsLayer::new()
|
||||
.allow_origin(cors::Any)
|
||||
.allow_methods(METHODS)
|
||||
.allow_headers(headers)
|
||||
.max_age(Duration::from_secs(86400))
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression"))]
|
||||
fn compression_layer(server: &Server) -> tower_http::compression::CompressionLayer {
|
||||
let mut compression_layer = tower_http::compression::CompressionLayer::new();
|
||||
|
||||
#[cfg(feature = "zstd_compression")]
|
||||
{
|
||||
if server.config.zstd_compression {
|
||||
compression_layer = compression_layer.zstd(true);
|
||||
} else {
|
||||
compression_layer = compression_layer.no_zstd();
|
||||
};
|
||||
};
|
||||
|
||||
#[cfg(feature = "gzip_compression")]
|
||||
{
|
||||
if server.config.gzip_compression {
|
||||
compression_layer = compression_layer.gzip(true);
|
||||
} else {
|
||||
compression_layer = compression_layer.no_gzip();
|
||||
};
|
||||
};
|
||||
|
||||
#[cfg(feature = "brotli_compression")]
|
||||
{
|
||||
if server.config.brotli_compression {
|
||||
compression_layer = compression_layer.br(true);
|
||||
} else {
|
||||
compression_layer = compression_layer.no_br();
|
||||
};
|
||||
};
|
||||
|
||||
compression_layer
|
||||
}
|
||||
|
||||
fn tracing_span<T>(request: &http::Request<T>) -> tracing::Span {
|
||||
let path = if let Some(path) = request.extensions().get::<MatchedPath>() {
|
||||
path.as_str()
|
||||
} else {
|
||||
request.uri().path()
|
||||
};
|
||||
|
||||
tracing::info_span!("router:", %path)
|
||||
}
|
||||
|
||||
/// Non-async initializations
|
||||
fn init(args: clap::Args) -> Result<Server, Error> {
|
||||
let config = Config::new(args.config)?;
|
||||
|
@ -692,27 +500,3 @@ fn maximize_fd_limit() -> Result<(), nix::errno::Errno> {
|
|||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
fn catch_panic_layer(err: Box<dyn Any + Send + 'static>) -> http::Response<http_body_util::Full<bytes::Bytes>> {
|
||||
let details = if let Some(s) = err.downcast_ref::<String>() {
|
||||
s.clone()
|
||||
} else if let Some(s) = err.downcast_ref::<&str>() {
|
||||
s.to_string()
|
||||
} else {
|
||||
"Unknown internal server error occurred.".to_owned()
|
||||
};
|
||||
|
||||
let body = serde_json::json!({
|
||||
"errcode": "M_UNKNOWN",
|
||||
"error": "M_UNKNOWN: Internal server error occurred",
|
||||
"details": details,
|
||||
})
|
||||
.to_string();
|
||||
|
||||
http::Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.header(header::CONTENT_TYPE, "application/json")
|
||||
.body(http_body_util::Full::from(body))
|
||||
.expect("Failed to create response for our panic catcher?")
|
||||
}
|
||||
|
|
226
src/router/mod.rs
Normal file
226
src/router/mod.rs
Normal file
|
@ -0,0 +1,226 @@
|
|||
use std::{any::Any, io, sync::atomic, time::Duration};
|
||||
|
||||
use axum::{
|
||||
extract::{DefaultBodyLimit, MatchedPath},
|
||||
response::IntoResponse,
|
||||
Router,
|
||||
};
|
||||
#[cfg(feature = "axum_dual_protocol")]
|
||||
use axum_server_dual_protocol::ServerExt;
|
||||
use http::{
|
||||
header::{self, HeaderName},
|
||||
Method, StatusCode, Uri,
|
||||
};
|
||||
use ruma::api::client::{
|
||||
error::{Error as RumaError, ErrorBody, ErrorKind},
|
||||
uiaa::UiaaResponse,
|
||||
};
|
||||
use tower::ServiceBuilder;
|
||||
use tower_http::{
|
||||
catch_panic::CatchPanicLayer,
|
||||
cors::{self, CorsLayer},
|
||||
trace::{DefaultOnFailure, DefaultOnRequest, DefaultOnResponse, TraceLayer},
|
||||
ServiceBuilderExt as _,
|
||||
};
|
||||
use tracing::{debug, error, trace, Level};
|
||||
|
||||
use super::{api::ruma_wrapper::RumaResponse, debug_error, services, utils::error::Result, Server};
|
||||
|
||||
mod routes;
|
||||
|
||||
pub(crate) async fn build(server: &Server) -> io::Result<axum::routing::IntoMakeService<Router>> {
|
||||
let base_middlewares = ServiceBuilder::new();
|
||||
#[cfg(feature = "sentry_telemetry")]
|
||||
let base_middlewares = base_middlewares.layer(sentry_tower::NewSentryLayer::<http::Request<_>>::new_from_top());
|
||||
|
||||
let x_forwarded_for = HeaderName::from_static("x-forwarded-for");
|
||||
let middlewares = base_middlewares
|
||||
.sensitive_headers([header::AUTHORIZATION])
|
||||
.sensitive_request_headers([x_forwarded_for].into())
|
||||
.layer(axum::middleware::from_fn(request_spawn))
|
||||
.layer(
|
||||
TraceLayer::new_for_http()
|
||||
.make_span_with(tracing_span::<_>)
|
||||
.on_failure(DefaultOnFailure::new().level(Level::ERROR))
|
||||
.on_request(DefaultOnRequest::new().level(Level::TRACE))
|
||||
.on_response(DefaultOnResponse::new().level(Level::DEBUG)),
|
||||
)
|
||||
.layer(axum::middleware::from_fn(request_handle))
|
||||
.layer(cors_layer(server))
|
||||
.layer(DefaultBodyLimit::max(
|
||||
server
|
||||
.config
|
||||
.max_request_size
|
||||
.try_into()
|
||||
.expect("failed to convert max request size"),
|
||||
))
|
||||
.layer(CatchPanicLayer::custom(catch_panic_layer));
|
||||
|
||||
#[cfg(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression"))]
|
||||
{
|
||||
Ok(routes::routes(&server.config)
|
||||
.layer(compression_layer(server))
|
||||
.layer(middlewares)
|
||||
.into_make_service())
|
||||
}
|
||||
#[cfg(not(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression")))]
|
||||
{
|
||||
Ok(routes::routes().layer(middlewares).into_make_service())
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, name = "spawn")]
|
||||
async fn request_spawn(
|
||||
req: http::Request<axum::body::Body>, next: axum::middleware::Next,
|
||||
) -> Result<axum::response::Response, StatusCode> {
|
||||
if services().globals.shutdown.load(atomic::Ordering::Relaxed) {
|
||||
return Err(StatusCode::SERVICE_UNAVAILABLE);
|
||||
}
|
||||
|
||||
let fut = next.run(req);
|
||||
let task = tokio::spawn(fut);
|
||||
task.await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all, name = "handle")]
|
||||
async fn request_handle(
|
||||
req: http::Request<axum::body::Body>, next: axum::middleware::Next,
|
||||
) -> Result<axum::response::Response, StatusCode> {
|
||||
let method = req.method().clone();
|
||||
let uri = req.uri().clone();
|
||||
let result = next.run(req).await;
|
||||
request_result(&method, &uri, result)
|
||||
}
|
||||
|
||||
fn request_result(
|
||||
method: &Method, uri: &Uri, result: axum::response::Response,
|
||||
) -> Result<axum::response::Response, StatusCode> {
|
||||
request_result_log(method, uri, &result);
|
||||
match result.status() {
|
||||
StatusCode::METHOD_NOT_ALLOWED => request_result_403(method, uri, &result),
|
||||
_ => Ok(result),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::unnecessary_wraps)]
|
||||
fn request_result_403(
|
||||
_method: &Method, _uri: &Uri, result: &axum::response::Response,
|
||||
) -> Result<axum::response::Response, StatusCode> {
|
||||
let error = UiaaResponse::MatrixError(RumaError {
|
||||
status_code: result.status(),
|
||||
body: ErrorBody::Standard {
|
||||
kind: ErrorKind::Unrecognized,
|
||||
message: "M_UNRECOGNIZED: Method not allowed for endpoint".to_owned(),
|
||||
},
|
||||
});
|
||||
|
||||
Ok(RumaResponse(error).into_response())
|
||||
}
|
||||
|
||||
fn request_result_log(method: &Method, uri: &Uri, result: &axum::response::Response) {
|
||||
let status = result.status();
|
||||
let reason = status.canonical_reason().unwrap_or("Unknown Reason");
|
||||
let code = status.as_u16();
|
||||
if status.is_server_error() {
|
||||
error!(method = ?method, uri = ?uri, "{code} {reason}");
|
||||
} else if status.is_client_error() {
|
||||
debug_error!(method = ?method, uri = ?uri, "{code} {reason}");
|
||||
} else if status.is_redirection() {
|
||||
debug!(method = ?method, uri = ?uri, "{code} {reason}");
|
||||
} else {
|
||||
trace!(method = ?method, uri = ?uri, "{code} {reason}");
|
||||
}
|
||||
}
|
||||
|
||||
fn cors_layer(_server: &Server) -> CorsLayer {
|
||||
const METHODS: [Method; 6] = [
|
||||
Method::GET,
|
||||
Method::HEAD,
|
||||
Method::POST,
|
||||
Method::PUT,
|
||||
Method::DELETE,
|
||||
Method::OPTIONS,
|
||||
];
|
||||
|
||||
let headers: [HeaderName; 5] = [
|
||||
header::ORIGIN,
|
||||
HeaderName::from_lowercase(b"x-requested-with").unwrap(),
|
||||
header::CONTENT_TYPE,
|
||||
header::ACCEPT,
|
||||
header::AUTHORIZATION,
|
||||
];
|
||||
|
||||
CorsLayer::new()
|
||||
.allow_origin(cors::Any)
|
||||
.allow_methods(METHODS)
|
||||
.allow_headers(headers)
|
||||
.max_age(Duration::from_secs(86400))
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression"))]
|
||||
fn compression_layer(server: &Server) -> tower_http::compression::CompressionLayer {
|
||||
let mut compression_layer = tower_http::compression::CompressionLayer::new();
|
||||
|
||||
#[cfg(feature = "zstd_compression")]
|
||||
{
|
||||
if server.config.zstd_compression {
|
||||
compression_layer = compression_layer.zstd(true);
|
||||
} else {
|
||||
compression_layer = compression_layer.no_zstd();
|
||||
};
|
||||
};
|
||||
|
||||
#[cfg(feature = "gzip_compression")]
|
||||
{
|
||||
if server.config.gzip_compression {
|
||||
compression_layer = compression_layer.gzip(true);
|
||||
} else {
|
||||
compression_layer = compression_layer.no_gzip();
|
||||
};
|
||||
};
|
||||
|
||||
#[cfg(feature = "brotli_compression")]
|
||||
{
|
||||
if server.config.brotli_compression {
|
||||
compression_layer = compression_layer.br(true);
|
||||
} else {
|
||||
compression_layer = compression_layer.no_br();
|
||||
};
|
||||
};
|
||||
|
||||
compression_layer
|
||||
}
|
||||
|
||||
fn tracing_span<T>(request: &http::Request<T>) -> tracing::Span {
|
||||
let path = if let Some(path) = request.extensions().get::<MatchedPath>() {
|
||||
path.as_str()
|
||||
} else {
|
||||
request.uri().path()
|
||||
};
|
||||
|
||||
tracing::info_span!("router:", %path)
|
||||
}
|
||||
|
||||
#[allow(clippy::needless_pass_by_value)]
|
||||
fn catch_panic_layer(err: Box<dyn Any + Send + 'static>) -> http::Response<http_body_util::Full<bytes::Bytes>> {
|
||||
let details = if let Some(s) = err.downcast_ref::<String>() {
|
||||
s.clone()
|
||||
} else if let Some(s) = err.downcast_ref::<&str>() {
|
||||
s.to_string()
|
||||
} else {
|
||||
"Unknown internal server error occurred.".to_owned()
|
||||
};
|
||||
|
||||
let body = serde_json::json!({
|
||||
"errcode": "M_UNKNOWN",
|
||||
"error": "M_UNKNOWN: Internal server error occurred",
|
||||
"details": details,
|
||||
})
|
||||
.to_string();
|
||||
|
||||
http::Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.header(header::CONTENT_TYPE, "application/json")
|
||||
.body(http_body_util::Full::from(body))
|
||||
.expect("Failed to create response for our panic catcher?")
|
||||
}
|
Loading…
Add table
Reference in a new issue