From 4dfd5a7c15c009ca7857ad3c9566b841873ef311 Mon Sep 17 00:00:00 2001 From: strawberry Date: Fri, 8 Mar 2024 00:33:24 -0500 Subject: [PATCH] add AuthScheme AccessTokenOptional in ruma_wrapper Signed-off-by: strawberry --- src/api/ruma_wrapper/axum.rs | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs index 36a7ba87..07a35fe5 100644 --- a/src/api/ruma_wrapper/axum.rs +++ b/src/api/ruma_wrapper/axum.rs @@ -83,7 +83,7 @@ where appservice_registration { match metadata.authentication { - AuthScheme::AccessToken => { + AuthScheme::AccessToken | AuthScheme::AccessTokenOptional => { let user_id = query_params.user_id.map_or_else( || { UserId::parse_with_server_name( @@ -95,7 +95,7 @@ where |s| UserId::parse(s).unwrap(), ); - if !services().users.exists(&user_id).unwrap() { + if !services().users.exists(&user_id)? { return Err(Error::BadRequest(ErrorKind::Forbidden, "User does not exist.")); } @@ -113,7 +113,7 @@ where _ => return Err(Error::BadRequest(ErrorKind::MissingToken, "Missing access token.")), }; - match services().users.find_from_token(token).unwrap() { + match services().users.find_from_token(token)? { None => { return Err(Error::BadRequest( ErrorKind::UnknownToken { @@ -127,6 +127,27 @@ where }, } }, + AuthScheme::AccessTokenOptional => { + let token = token.unwrap_or(""); + + if token.is_empty() { + (None, None, None, false) + } else { + match services().users.find_from_token(token)? { + None => { + return Err(Error::BadRequest( + ErrorKind::UnknownToken { + soft_logout: false, + }, + "Unknown access token.", + )) + }, + Some((user_id, device_id)) => { + (Some(user_id), Some(OwnedDeviceId::from(device_id)), None, false) + }, + } + } + }, AuthScheme::ServerSignatures => { let TypedHeader(Authorization(x_matrix)) = parts.extract::>>().await.map_err(|e| { @@ -219,7 +240,7 @@ where _ => return Err(Error::BadRequest(ErrorKind::MissingToken, "Missing access token.")), }; - match services().users.find_from_token(token).unwrap() { + match services().users.find_from_token(token)? { None => { return Err(Error::BadRequest( ErrorKind::UnknownToken {