Rewrite query parameter parsing

This commit is contained in:
Jonas Platte 2022-02-09 12:32:18 +01:00
parent c8951a1d9c
commit 21ae63d46b
No known key found for this signature in database
GPG key ID: 7D261D771D915378

View file

@ -18,7 +18,8 @@ use ruma::{
signatures::CanonicalJsonValue, signatures::CanonicalJsonValue,
DeviceId, Outgoing, ServerName, UserId, DeviceId, Outgoing, ServerName, UserId,
}; };
use tracing::{debug, warn}; use serde::Deserialize;
use tracing::{debug, error, warn};
use super::{Ruma, RumaResponse}; use super::{Ruma, RumaResponse};
use crate::{database::DatabaseGuard, server_server, Error, Result}; use crate::{database::DatabaseGuard, server_server, Error, Result};
@ -35,18 +36,31 @@ where
type Rejection = Error; type Rejection = Error;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
#[derive(Deserialize)]
struct QueryParams {
access_token: Option<String>,
user_id: Option<String>,
}
let metadata = T::Incoming::METADATA; let metadata = T::Incoming::METADATA;
let db = DatabaseGuard::from_request(req).await?; let db = DatabaseGuard::from_request(req).await?;
let auth_header = Option::<TypedHeader<Authorization<Bearer>>>::from_request(req).await?; let auth_header = Option::<TypedHeader<Authorization<Bearer>>>::from_request(req).await?;
// FIXME: Do this more efficiently let query = req.uri().query().unwrap_or_default();
let query: BTreeMap<String, String> = let query_params: QueryParams = match ruma::serde::urlencoded::from_str(query) {
ruma::serde::urlencoded::from_str(req.uri().query().unwrap_or_default()) Ok(params) => params,
.expect("Query to string map deserialization should be fine"); Err(e) => {
error!(%query, "Failed to deserialize query parameters: {}", e);
return Err(Error::BadRequest(
ErrorKind::Unknown,
"Failed to read query parameters",
));
}
};
let token = match &auth_header { let token = match &auth_header {
Some(TypedHeader(Authorization(bearer))) => Some(bearer.token()), Some(TypedHeader(Authorization(bearer))) => Some(bearer.token()),
None => query.get("access_token").map(|tok| tok.as_str()), None => query_params.access_token.as_deref(),
}; };
let mut body = Bytes::from_request(req) let mut body = Bytes::from_request(req)
@ -67,7 +81,7 @@ where
if let Some((_id, registration)) = appservice_registration { if let Some((_id, registration)) = appservice_registration {
match metadata.authentication { match metadata.authentication {
AuthScheme::AccessToken | AuthScheme::QueryOnlyAccessToken => { AuthScheme::AccessToken | AuthScheme::QueryOnlyAccessToken => {
let user_id = query.get("user_id").map_or_else( let user_id = query_params.user_id.map_or_else(
|| { || {
UserId::parse_with_server_name( UserId::parse_with_server_name(
registration registration
@ -79,7 +93,7 @@ where
) )
.unwrap() .unwrap()
}, },
|s| UserId::parse(s.as_str()).unwrap(), |s| UserId::parse(s).unwrap(),
); );
if !db.users.exists(&user_id).unwrap() { if !db.users.exists(&user_id).unwrap() {