diff --git a/src/core/error/err.rs b/src/core/error/err.rs index ea596644..66634f2b 100644 --- a/src/core/error/err.rs +++ b/src/core/error/err.rs @@ -49,14 +49,16 @@ macro_rules! err { $crate::$level!($($args),*); $crate::error::Error::Request( ::ruma::api::client::error::ErrorKind::forbidden(), - $crate::format_maybe!($($args),*) + $crate::format_maybe!($($args),*), + ::http::StatusCode::BAD_REQUEST ) }}; (Request(Forbidden($($args:expr),*))) => { $crate::error::Error::Request( ::ruma::api::client::error::ErrorKind::forbidden(), - $crate::format_maybe!($($args),*) + $crate::format_maybe!($($args),*), + ::http::StatusCode::BAD_REQUEST ) }; @@ -64,14 +66,16 @@ macro_rules! err { $crate::$level!($($args),*); $crate::error::Error::Request( ::ruma::api::client::error::ErrorKind::$variant, - $crate::format_maybe!($($args),*) + $crate::format_maybe!($($args),*), + ::http::StatusCode::BAD_REQUEST ) }}; (Request($variant:ident($($args:expr),*))) => { $crate::error::Error::Request( ::ruma::api::client::error::ErrorKind::$variant, - $crate::format_maybe!($($args),*) + $crate::format_maybe!($($args),*), + ::http::StatusCode::BAD_REQUEST ) }; diff --git a/src/core/error/mod.rs b/src/core/error/mod.rs index 069fe60e..1f0d56c1 100644 --- a/src/core/error/mod.rs +++ b/src/core/error/mod.rs @@ -68,7 +68,7 @@ pub enum Error { #[error("{0}: {1}")] BadRequest(ruma::api::client::error::ErrorKind, &'static str), //TODO: remove #[error("{0}: {1}")] - Request(ruma::api::client::error::ErrorKind, Cow<'static, str>), + Request(ruma::api::client::error::ErrorKind, Cow<'static, str>, http::StatusCode), #[error("from {0}: {1}")] Redaction(ruma::OwnedServerName, ruma::canonical_json::RedactionError), #[error("Remote server {0} responded with: {1}")] @@ -120,7 +120,7 @@ impl Error { match self { Self::Federation(_, error) => response::ruma_error_kind(error).clone(), - Self::BadRequest(kind, _) | Self::Request(kind, _) => kind.clone(), + Self::BadRequest(kind, ..) | Self::Request(kind, ..) => kind.clone(), _ => Unknown, } } @@ -128,7 +128,8 @@ impl Error { pub fn status_code(&self) -> http::StatusCode { match self { Self::Federation(_, ref error) | Self::RumaError(ref error) => error.status_code, - Self::BadRequest(ref kind, _) | Self::Request(ref kind, _) => response::bad_request_code(kind), + Self::Request(ref kind, _, code) => response::status_code(kind, *code), + Self::BadRequest(ref kind, ..) => response::bad_request_code(kind), Self::Conflict(_) => http::StatusCode::CONFLICT, _ => http::StatusCode::INTERNAL_SERVER_ERROR, } diff --git a/src/core/error/response.rs b/src/core/error/response.rs index 4ea76e26..7568a1c0 100644 --- a/src/core/error/response.rs +++ b/src/core/error/response.rs @@ -1,7 +1,13 @@ use bytes::BytesMut; use http::StatusCode; use http_body_util::Full; -use ruma::api::{client::uiaa::UiaaResponse, OutgoingResponse}; +use ruma::api::{ + client::{ + error::{ErrorBody, ErrorKind}, + uiaa::UiaaResponse, + }, + OutgoingResponse, +}; use super::Error; use crate::error; @@ -25,7 +31,7 @@ impl From for UiaaResponse { return Self::AuthResponse(uiaainfo); } - let body = ruma::api::client::error::ErrorBody::Standard { + let body = ErrorBody::Standard { kind: error.kind(), message: error.message(), }; @@ -37,10 +43,33 @@ impl From for UiaaResponse { } } -pub(super) fn bad_request_code(kind: &ruma::api::client::error::ErrorKind) -> StatusCode { - use ruma::api::client::error::ErrorKind::*; +pub(super) fn status_code(kind: &ErrorKind, hint: StatusCode) -> StatusCode { + if hint == StatusCode::BAD_REQUEST { + bad_request_code(kind) + } else { + hint + } +} + +pub(super) fn bad_request_code(kind: &ErrorKind) -> StatusCode { + use ErrorKind::*; match kind { + // 429 + LimitExceeded { + .. + } => StatusCode::TOO_MANY_REQUESTS, + + // 413 + TooLarge => StatusCode::PAYLOAD_TOO_LARGE, + + // 405 + Unrecognized => StatusCode::METHOD_NOT_ALLOWED, + + // 404 + NotFound => StatusCode::NOT_FOUND, + + // 403 GuestAccessForbidden | ThreepidAuthFailed | UserDeactivated @@ -52,26 +81,20 @@ pub(super) fn bad_request_code(kind: &ruma::api::client::error::ErrorKind) -> St .. } => StatusCode::FORBIDDEN, + // 401 UnknownToken { .. } | MissingToken | Unauthorized => StatusCode::UNAUTHORIZED, - LimitExceeded { - .. - } => StatusCode::TOO_MANY_REQUESTS, - - TooLarge => StatusCode::PAYLOAD_TOO_LARGE, - - NotFound | Unrecognized => StatusCode::NOT_FOUND, - + // 400 _ => StatusCode::BAD_REQUEST, } } pub(super) fn ruma_error_message(error: &ruma::api::client::error::Error) -> String { - if let ruma::api::client::error::ErrorBody::Standard { + if let ErrorBody::Standard { message, .. } = &error.body @@ -82,7 +105,6 @@ pub(super) fn ruma_error_message(error: &ruma::api::client::error::Error) -> Str format!("{error}") } -pub(super) fn ruma_error_kind(e: &ruma::api::client::error::Error) -> &ruma::api::client::error::ErrorKind { - e.error_kind() - .unwrap_or(&ruma::api::client::error::ErrorKind::Unknown) +pub(super) fn ruma_error_kind(e: &ruma::api::client::error::Error) -> &ErrorKind { + e.error_kind().unwrap_or(&ErrorKind::Unknown) } diff --git a/src/router/request.rs b/src/router/request.rs index ae739984..4d5c61e7 100644 --- a/src/router/request.rs +++ b/src/router/request.rs @@ -4,9 +4,8 @@ use axum::{ extract::State, response::{IntoResponse, Response}, }; -use conduit::{debug, debug_error, debug_warn, defer, error, trace, Error, Result, Server}; +use conduit::{debug, debug_error, debug_warn, defer, err, error, trace, Result, Server}; use http::{Method, StatusCode, Uri}; -use ruma::api::client::error::{Error as RumaError, ErrorBody, ErrorKind}; #[tracing::instrument(skip_all, level = "debug")] pub(crate) async fn spawn( @@ -58,33 +57,13 @@ pub(crate) async fn handle( trace!(active, finished, "leave"); }}; - let method = req.method().clone(); let uri = req.uri().clone(); + let method = req.method().clone(); let result = next.run(req).await; handle_result(&method, &uri, result) } fn handle_result(method: &Method, uri: &Uri, result: Response) -> Result { - handle_result_log(method, uri, &result); - match result.status() { - StatusCode::METHOD_NOT_ALLOWED => handle_result_405(method, uri, &result), - _ => Ok(result), - } -} - -fn handle_result_405(_method: &Method, _uri: &Uri, result: &Response) -> Result { - let error = Error::RumaError(RumaError { - status_code: result.status(), - body: ErrorBody::Standard { - kind: ErrorKind::Unrecognized, - message: "M_UNRECOGNIZED: Method not allowed for endpoint".to_owned(), - }, - }); - - Ok(error.into_response()) -} - -fn handle_result_log(method: &Method, uri: &Uri, result: &Response) { let status = result.status(); let reason = status.canonical_reason().unwrap_or("Unknown Reason"); let code = status.as_u16(); @@ -97,4 +76,10 @@ fn handle_result_log(method: &Method, uri: &Uri, result: &Response) { } else { trace!(method = ?method, uri = ?uri, "{code} {reason}"); } + + if status == StatusCode::METHOD_NOT_ALLOWED { + return Ok(err!(Request(Unrecognized("Method Not Allowed"))).into_response()); + } + + Ok(result) } diff --git a/src/router/router.rs b/src/router/router.rs index da31ffea..2f6d5c46 100644 --- a/src/router/router.rs +++ b/src/router/router.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use axum::{response::IntoResponse, routing::get, Router}; use conduit::{Error, Server}; use conduit_service as service; -use http::Uri; +use http::{StatusCode, Uri}; use ruma::api::client::error::ErrorKind; extern crate conduit_api as api; @@ -19,7 +19,7 @@ pub(crate) fn build(server: &Arc) -> Router { } async fn not_found(_uri: Uri) -> impl IntoResponse { - Error::BadRequest(ErrorKind::Unrecognized, "Unrecognized request") + Error::Request(ErrorKind::Unrecognized, "Not Found".into(), StatusCode::NOT_FOUND) } async fn it_works() -> &'static str { "hewwo from conduwuit woof!" }