Rewrite query parameter parsing
This commit is contained in:
parent
c8951a1d9c
commit
21ae63d46b
1 changed files with 22 additions and 8 deletions
|
@ -18,7 +18,8 @@ use ruma::{
|
||||||
signatures::CanonicalJsonValue,
|
signatures::CanonicalJsonValue,
|
||||||
DeviceId, Outgoing, ServerName, UserId,
|
DeviceId, Outgoing, ServerName, UserId,
|
||||||
};
|
};
|
||||||
use tracing::{debug, warn};
|
use serde::Deserialize;
|
||||||
|
use tracing::{debug, error, warn};
|
||||||
|
|
||||||
use super::{Ruma, RumaResponse};
|
use super::{Ruma, RumaResponse};
|
||||||
use crate::{database::DatabaseGuard, server_server, Error, Result};
|
use crate::{database::DatabaseGuard, server_server, Error, Result};
|
||||||
|
@ -35,18 +36,31 @@ where
|
||||||
type Rejection = Error;
|
type Rejection = Error;
|
||||||
|
|
||||||
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct QueryParams {
|
||||||
|
access_token: Option<String>,
|
||||||
|
user_id: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
let metadata = T::Incoming::METADATA;
|
let metadata = T::Incoming::METADATA;
|
||||||
let db = DatabaseGuard::from_request(req).await?;
|
let db = DatabaseGuard::from_request(req).await?;
|
||||||
let auth_header = Option::<TypedHeader<Authorization<Bearer>>>::from_request(req).await?;
|
let auth_header = Option::<TypedHeader<Authorization<Bearer>>>::from_request(req).await?;
|
||||||
|
|
||||||
// FIXME: Do this more efficiently
|
let query = req.uri().query().unwrap_or_default();
|
||||||
let query: BTreeMap<String, String> =
|
let query_params: QueryParams = match ruma::serde::urlencoded::from_str(query) {
|
||||||
ruma::serde::urlencoded::from_str(req.uri().query().unwrap_or_default())
|
Ok(params) => params,
|
||||||
.expect("Query to string map deserialization should be fine");
|
Err(e) => {
|
||||||
|
error!(%query, "Failed to deserialize query parameters: {}", e);
|
||||||
|
return Err(Error::BadRequest(
|
||||||
|
ErrorKind::Unknown,
|
||||||
|
"Failed to read query parameters",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let token = match &auth_header {
|
let token = match &auth_header {
|
||||||
Some(TypedHeader(Authorization(bearer))) => Some(bearer.token()),
|
Some(TypedHeader(Authorization(bearer))) => Some(bearer.token()),
|
||||||
None => query.get("access_token").map(|tok| tok.as_str()),
|
None => query_params.access_token.as_deref(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut body = Bytes::from_request(req)
|
let mut body = Bytes::from_request(req)
|
||||||
|
@ -67,7 +81,7 @@ where
|
||||||
if let Some((_id, registration)) = appservice_registration {
|
if let Some((_id, registration)) = appservice_registration {
|
||||||
match metadata.authentication {
|
match metadata.authentication {
|
||||||
AuthScheme::AccessToken | AuthScheme::QueryOnlyAccessToken => {
|
AuthScheme::AccessToken | AuthScheme::QueryOnlyAccessToken => {
|
||||||
let user_id = query.get("user_id").map_or_else(
|
let user_id = query_params.user_id.map_or_else(
|
||||||
|| {
|
|| {
|
||||||
UserId::parse_with_server_name(
|
UserId::parse_with_server_name(
|
||||||
registration
|
registration
|
||||||
|
@ -79,7 +93,7 @@ where
|
||||||
)
|
)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
},
|
},
|
||||||
|s| UserId::parse(s.as_str()).unwrap(),
|
|s| UserId::parse(s).unwrap(),
|
||||||
);
|
);
|
||||||
|
|
||||||
if !db.users.exists(&user_id).unwrap() {
|
if !db.users.exists(&user_id).unwrap() {
|
||||||
|
|
Loading…
Add table
Reference in a new issue