diff --git a/src/api/ruma_wrapper/auth.rs b/src/api/ruma_wrapper/auth.rs index 43737d20..35fd443d 100644 --- a/src/api/ruma_wrapper/auth.rs +++ b/src/api/ruma_wrapper/auth.rs @@ -1,10 +1,14 @@ use std::collections::BTreeMap; use axum::RequestPartsExt; -use axum_extra::{headers::Authorization, typed_header::TypedHeaderRejectionReason, TypedHeader}; +use axum_extra::{ + headers::{authorization::Bearer, Authorization}, + typed_header::TypedHeaderRejectionReason, + TypedHeader, +}; use http::uri::PathAndQuery; use ruma::{ - api::{client::error::ErrorKind, AuthScheme}, + api::{client::error::ErrorKind, AuthScheme, Metadata}, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId, }; use tracing::warn; @@ -20,14 +24,17 @@ enum Token { } pub(super) struct Auth { + pub(super) origin: Option, pub(super) sender_user: Option, pub(super) sender_device: Option, - pub(super) origin: Option, pub(super) appservice_info: Option, } -pub(super) async fn auth(request: &mut Request, metadata: &ruma::api::Metadata) -> Result { - let token = match &request.auth { +pub(super) async fn auth( + request: &mut Request, json_body: &Option, metadata: &Metadata, +) -> Result { + let bearer: Option>> = request.parts.extract().await?; + let token = match &bearer { Some(TypedHeader(Authorization(bearer))) => Some(bearer.token()), None => request.query.access_token.as_deref(), }; @@ -78,9 +85,9 @@ pub(super) async fn auth(request: &mut Request, metadata: &ruma::api::Metadata) (AuthScheme::AccessToken, Token::Appservice(info)) => Ok(auth_appservice(request, info)?), (AuthScheme::None | AuthScheme::AccessTokenOptional | AuthScheme::AppserviceToken, Token::Appservice(info)) => { Ok(Auth { + origin: None, sender_user: None, sender_device: None, - origin: None, appservice_info: Some(*info), }) }, @@ -91,12 +98,12 @@ pub(super) async fn auth(request: &mut Request, metadata: &ruma::api::Metadata) AuthScheme::AccessToken | AuthScheme::AccessTokenOptional | AuthScheme::None, Token::User((user_id, device_id)), ) => Ok(Auth { + origin: None, sender_user: Some(user_id), sender_device: Some(device_id), - origin: None, appservice_info: None, }), - (AuthScheme::ServerSignatures, Token::None) => Ok(auth_server(request).await?), + (AuthScheme::ServerSignatures, Token::None) => Ok(auth_server(request, json_body).await?), (AuthScheme::None | AuthScheme::AppserviceToken | AuthScheme::AccessTokenOptional, Token::None) => Ok(Auth { sender_user: None, sender_device: None, @@ -114,7 +121,7 @@ pub(super) async fn auth(request: &mut Request, metadata: &ruma::api::Metadata) } } -fn auth_appservice(request: &mut Request, info: Box) -> Result { +fn auth_appservice(request: &Request, info: Box) -> Result { let user_id = request .query .user_id @@ -139,14 +146,14 @@ fn auth_appservice(request: &mut Request, info: Box) -> Result } Ok(Auth { + origin: None, sender_user: Some(user_id), sender_device: None, - origin: None, appservice_info: Some(*info), }) } -async fn auth_server(request: &mut Request) -> Result { +async fn auth_server(request: &mut Request, json_body: &Option) -> Result { if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } @@ -167,15 +174,11 @@ async fn auth_server(request: &mut Request) -> Result { Error::BadRequest(ErrorKind::forbidden(), msg) })?; - let origin_signatures = BTreeMap::from_iter([(x_matrix.key.clone(), CanonicalJsonValue::String(x_matrix.sig))]); - - let signatures = BTreeMap::from_iter([( - x_matrix.origin.as_str().to_owned(), - CanonicalJsonValue::Object(origin_signatures), - )]); + let origin = &x_matrix.origin; + let signatures = BTreeMap::from_iter([(x_matrix.key.clone(), CanonicalJsonValue::String(x_matrix.sig))]); + let signatures = BTreeMap::from_iter([(origin.as_str().to_owned(), CanonicalJsonValue::Object(signatures))]); let server_destination = services().globals.server_name().as_str().to_owned(); - if let Some(destination) = x_matrix.destination.as_ref() { if destination != &server_destination { return Err(Error::BadRequest(ErrorKind::forbidden(), "Invalid authorization.")); @@ -197,22 +200,19 @@ async fn auth_server(request: &mut Request) -> Result { CanonicalJsonValue::String(request.parts.method.to_string()), ), ("uri".to_owned(), signature_uri), - ( - "origin".to_owned(), - CanonicalJsonValue::String(x_matrix.origin.as_str().to_owned()), - ), + ("origin".to_owned(), CanonicalJsonValue::String(origin.as_str().to_owned())), ("destination".to_owned(), CanonicalJsonValue::String(server_destination)), ("signatures".to_owned(), CanonicalJsonValue::Object(signatures)), ]); - if let Some(json_body) = &request.json { + if let Some(json_body) = json_body { request_map.insert("content".to_owned(), json_body.clone()); }; let keys_result = services() .rooms .event_handler - .fetch_signing_keys_for_server(&x_matrix.origin, vec![x_matrix.key.clone()]) + .fetch_signing_keys_for_server(origin, vec![x_matrix.key.clone()]) .await; let keys = keys_result.map_err(|e| { @@ -220,17 +220,17 @@ async fn auth_server(request: &mut Request) -> Result { Error::BadRequest(ErrorKind::forbidden(), "Failed to fetch signing keys.") })?; - let pub_key_map = BTreeMap::from_iter([(x_matrix.origin.as_str().to_owned(), keys)]); + let pub_key_map = BTreeMap::from_iter([(origin.as_str().to_owned(), keys)]); match ruma::signatures::verify_json(&pub_key_map, &request_map) { Ok(()) => Ok(Auth { + origin: Some(origin.clone()), sender_user: None, sender_device: None, - origin: Some(x_matrix.origin), appservice_info: None, }), Err(e) => { - warn!("Failed to verify json request from {}: {e}\n{request_map:?}", x_matrix.origin); + warn!("Failed to verify json request from {origin}: {e}\n{request_map:?}"); if request.parts.uri.to_string().contains('@') { warn!( diff --git a/src/api/ruma_wrapper/router.rs b/src/api/ruma_wrapper/handler.rs similarity index 100% rename from src/api/ruma_wrapper/router.rs rename to src/api/ruma_wrapper/handler.rs diff --git a/src/api/ruma_wrapper/mod.rs b/src/api/ruma_wrapper/mod.rs index a130ddd9..fb633b91 100644 --- a/src/api/ruma_wrapper/mod.rs +++ b/src/api/ruma_wrapper/mod.rs @@ -1,26 +1,69 @@ mod auth; +mod handler; mod request; -mod router; mod xmatrix; -use std::ops::Deref; +use std::{mem, ops::Deref}; +use axum::{async_trait, body::Body, extract::FromRequest}; +use bytes::{BufMut, BytesMut}; pub(super) use conduit::error::RumaResponse; -use ruma::{CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId}; +use conduit::{debug, debug_warn, trace, warn}; +use ruma::{ + api::{client::error::ErrorKind, IncomingRequest}, + CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId, +}; -pub(super) use self::router::RouterExt; -use crate::service::appservice::RegistrationInfo; +pub(super) use self::handler::RouterExt; +use self::{auth::Auth, request::Request}; +use crate::{service::appservice::RegistrationInfo, services, Error, Result}; /// Extractor for Ruma request structs pub(crate) struct Ruma { /// Request struct body pub(crate) body: T, - pub(crate) sender_user: Option, - pub(crate) sender_device: Option, - /// X-Matrix origin/server + + /// Federation server authentication: X-Matrix origin + /// None when not a federation server. pub(crate) origin: Option, - pub(crate) json_body: Option, // This is None when body is not a valid string + + /// Local user authentication: user_id. + /// None when not an authenticated local user. + pub(crate) sender_user: Option, + + /// Local user authentication: device_id. + /// None when not an authenticated local user or no device. + pub(crate) sender_device: Option, + + /// Appservice authentication; registration info. + /// None when not an appservice. pub(crate) appservice_info: Option, + + /// Parsed JSON content. + /// None when body is not a valid string + pub(crate) json_body: Option, +} + +#[async_trait] +impl FromRequest for Ruma +where + T: IncomingRequest, +{ + type Rejection = Error; + + async fn from_request(request: hyper::Request, _: &S) -> Result { + let mut request = request::from(request).await?; + let mut json_body = serde_json::from_slice::(&request.body).ok(); + let auth = auth::auth(&mut request, &json_body, &T::METADATA).await?; + Ok(Ruma { + body: make_body::(&mut request, &mut json_body, &auth)?, + origin: auth.origin, + sender_user: auth.sender_user, + sender_device: auth.sender_device, + appservice_info: auth.appservice_info, + json_body, + }) + } } impl Deref for Ruma { @@ -28,3 +71,60 @@ impl Deref for Ruma { fn deref(&self) -> &Self::Target { &self.body } } + +fn make_body(request: &mut Request, json_body: &mut Option, auth: &Auth) -> Result +where + T: IncomingRequest, +{ + let body = if let Some(CanonicalJsonValue::Object(json_body)) = json_body { + let user_id = auth.sender_user.clone().unwrap_or_else(|| { + UserId::parse_with_server_name("", services().globals.server_name()).expect("we know this is valid") + }); + + let uiaa_request = json_body + .get("auth") + .and_then(|auth| auth.as_object()) + .and_then(|auth| auth.get("session")) + .and_then(|session| session.as_str()) + .and_then(|session| { + services().uiaa.get_uiaa_request( + &user_id, + &auth.sender_device.clone().unwrap_or_else(|| "".into()), + session, + ) + }); + + if let Some(CanonicalJsonValue::Object(initial_request)) = uiaa_request { + for (key, value) in initial_request { + json_body.entry(key).or_insert(value); + } + } + + let mut buf = BytesMut::new().writer(); + serde_json::to_writer(&mut buf, &json_body).expect("value serialization can't fail"); + buf.into_inner().freeze() + } else { + mem::take(&mut request.body) + }; + + let mut http_request = hyper::Request::builder() + .uri(request.parts.uri.clone()) + .method(request.parts.method.clone()); + *http_request.headers_mut().unwrap() = request.parts.headers.clone(); + let http_request = http_request.body(body).unwrap(); + debug!( + "{:?} {:?} {:?}", + http_request.method(), + http_request.uri(), + http_request.headers() + ); + + trace!("{:?} {:?} {:?}", http_request.method(), http_request.uri(), json_body); + let body = T::try_from_http_request(http_request, &request.path).map_err(|e| { + warn!("try_from_http_request failed: {e:?}",); + debug_warn!("JSON body: {:?}", json_body); + Error::BadRequest(ErrorKind::BadJson, "Failed to deserialize request.") + })?; + + Ok(body) +} diff --git a/src/api/ruma_wrapper/request.rs b/src/api/ruma_wrapper/request.rs index b74fbd3e..59639eaa 100644 --- a/src/api/ruma_wrapper/request.rs +++ b/src/api/ruma_wrapper/request.rs @@ -1,25 +1,11 @@ -use std::{mem, str}; +use std::str; -use axum::{ - async_trait, - extract::{FromRequest, Path}, - RequestExt, RequestPartsExt, -}; -use axum_extra::{ - headers::{authorization::Bearer, Authorization}, - TypedHeader, -}; -use bytes::{BufMut, Bytes, BytesMut}; -use conduit::debug_warn; +use axum::{extract::Path, RequestExt, RequestPartsExt}; +use bytes::Bytes; use http::request::Parts; -use ruma::{ - api::{client::error::ErrorKind, IncomingRequest}, - CanonicalJsonValue, UserId, -}; +use ruma::api::client::error::ErrorKind; use serde::Deserialize; -use tracing::{debug, trace, warn}; -use super::{auth, auth::Auth, Ruma}; use crate::{services, Error, Result}; #[derive(Deserialize)] @@ -29,100 +15,17 @@ pub(super) struct QueryParams { } pub(super) struct Request { - pub(super) auth: Option>>, pub(super) path: Path>, pub(super) query: QueryParams, - pub(super) json: Option, pub(super) body: Bytes, pub(super) parts: Parts, } -#[async_trait] -impl FromRequest for Ruma -where - T: IncomingRequest, -{ - type Rejection = Error; - - async fn from_request(request: hyper::Request, _state: &S) -> Result { - let meta = T::METADATA; - let mut request: Request = extract(request).await?; - let auth: Auth = auth::auth(&mut request, &meta).await?; - let body = make_body::(&mut request, &auth)?; - Ok(Ruma { - body, - sender_user: auth.sender_user, - sender_device: auth.sender_device, - origin: auth.origin, - json_body: request.json, - appservice_info: auth.appservice_info, - }) - } -} - -fn make_body(request: &mut Request, auth: &Auth) -> Result -where - T: IncomingRequest, -{ - let body = if let Some(CanonicalJsonValue::Object(json_body)) = &mut request.json { - let user_id = auth.sender_user.clone().unwrap_or_else(|| { - UserId::parse_with_server_name("", services().globals.server_name()).expect("we know this is valid") - }); - - let uiaa_request = json_body - .get("auth") - .and_then(|auth| auth.as_object()) - .and_then(|auth| auth.get("session")) - .and_then(|session| session.as_str()) - .and_then(|session| { - services().uiaa.get_uiaa_request( - &user_id, - &auth.sender_device.clone().unwrap_or_else(|| "".into()), - session, - ) - }); - - if let Some(CanonicalJsonValue::Object(initial_request)) = uiaa_request { - for (key, value) in initial_request { - json_body.entry(key).or_insert(value); - } - } - - let mut buf = BytesMut::new().writer(); - serde_json::to_writer(&mut buf, &request.json).expect("value serialization can't fail"); - buf.into_inner().freeze() - } else { - mem::take(&mut request.body) - }; - - let mut http_request = hyper::Request::builder() - .uri(request.parts.uri.clone()) - .method(request.parts.method.clone()); - *http_request.headers_mut().unwrap() = request.parts.headers.clone(); - let http_request = http_request.body(body).unwrap(); - debug!( - "{:?} {:?} {:?}", - http_request.method(), - http_request.uri(), - http_request.headers() - ); - - trace!("{:?} {:?} {:?}", http_request.method(), http_request.uri(), request.json); - let body = T::try_from_http_request(http_request, &request.path).map_err(|e| { - warn!("try_from_http_request failed: {e:?}",); - debug_warn!("JSON body: {:?}", request.json); - Error::BadRequest(ErrorKind::BadJson, "Failed to deserialize request.") - })?; - - Ok(body) -} - -async fn extract(request: hyper::Request) -> Result { +pub(super) async fn from(request: hyper::Request) -> Result { let limited = request.with_limited_body(); let (mut parts, body) = limited.into_parts(); - let auth = parts.extract().await?; - let path = parts.extract().await?; + let path: Path> = parts.extract().await?; let query = serde_html_form::from_str(parts.uri.query().unwrap_or_default()) .map_err(|_| Error::BadRequest(ErrorKind::Unknown, "Failed to read query parameters"))?; @@ -137,13 +40,9 @@ async fn extract(request: hyper::Request) -> Result { .await .map_err(|_| Error::BadRequest(ErrorKind::TooLarge, "Request body too large"))?; - let json = serde_json::from_slice::(&body).ok(); - Ok(Request { - auth, path, query, - json, body, parts, })