Upgrade axum to 0.6
This commit is contained in:
parent
6a6f8e80f1
commit
0ded637b4a
4 changed files with 130 additions and 72 deletions
54
Cargo.lock
generated
54
Cargo.lock
generated
|
@ -89,9 +89,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "axum"
|
name = "axum"
|
||||||
version = "0.5.17"
|
version = "0.6.18"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "acee9fd5073ab6b045a275b3e709c163dd36c90685219cb21804a147b58dba43"
|
checksum = "f8175979259124331c1d7bf6586ee7e0da434155e4b2d48ec2c8386281d8df39"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"axum-core",
|
"axum-core",
|
||||||
|
@ -108,22 +108,22 @@ dependencies = [
|
||||||
"mime",
|
"mime",
|
||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
|
"rustversion",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"serde_path_to_error",
|
||||||
"serde_urlencoded",
|
"serde_urlencoded",
|
||||||
"sync_wrapper",
|
"sync_wrapper",
|
||||||
"tokio",
|
|
||||||
"tower",
|
"tower",
|
||||||
"tower-http 0.3.5",
|
|
||||||
"tower-layer",
|
"tower-layer",
|
||||||
"tower-service",
|
"tower-service",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "axum-core"
|
name = "axum-core"
|
||||||
version = "0.2.9"
|
version = "0.3.4"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "37e5939e02c56fecd5c017c37df4238c0a839fa76b7f97acdd7efb804fd181cc"
|
checksum = "759fa577a247914fd3f7f76d62972792636412fbfd634cd452f6a385a74d2d2c"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"bytes",
|
"bytes",
|
||||||
|
@ -131,6 +131,7 @@ dependencies = [
|
||||||
"http",
|
"http",
|
||||||
"http-body",
|
"http-body",
|
||||||
"mime",
|
"mime",
|
||||||
|
"rustversion",
|
||||||
"tower-layer",
|
"tower-layer",
|
||||||
"tower-service",
|
"tower-service",
|
||||||
]
|
]
|
||||||
|
@ -407,7 +408,7 @@ dependencies = [
|
||||||
"tikv-jemallocator",
|
"tikv-jemallocator",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tower",
|
"tower",
|
||||||
"tower-http 0.4.1",
|
"tower-http",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-flame",
|
"tracing-flame",
|
||||||
"tracing-opentelemetry",
|
"tracing-opentelemetry",
|
||||||
|
@ -1449,9 +1450,9 @@ checksum = "2532096657941c2fea9c289d370a250971c689d4f143798ff67113ec042024a5"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "matchit"
|
name = "matchit"
|
||||||
version = "0.5.0"
|
version = "0.7.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb"
|
checksum = "b87248edafb776e59e6ee64a79086f65890d3510f2c656c000bf2a7e8a0aea40"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "memchr"
|
name = "memchr"
|
||||||
|
@ -2363,6 +2364,12 @@ dependencies = [
|
||||||
"untrusted",
|
"untrusted",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rustversion"
|
||||||
|
version = "1.0.12"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4f3208ce4d8448b3f3e7d168a73f5e0c43a61e32930de3bceeccedb388b6bf06"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ryu"
|
name = "ryu"
|
||||||
version = "1.0.13"
|
version = "1.0.13"
|
||||||
|
@ -2467,6 +2474,15 @@ dependencies = [
|
||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde_path_to_error"
|
||||||
|
version = "0.1.11"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f7f05c1d5476066defcdfacce1f52fc3cae3af1d3089727100c02ae92e5abbe0"
|
||||||
|
dependencies = [
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_spanned"
|
name = "serde_spanned"
|
||||||
version = "0.6.3"
|
version = "0.6.3"
|
||||||
|
@ -2954,31 +2970,11 @@ dependencies = [
|
||||||
"futures-util",
|
"futures-util",
|
||||||
"pin-project",
|
"pin-project",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"tokio",
|
|
||||||
"tower-layer",
|
"tower-layer",
|
||||||
"tower-service",
|
"tower-service",
|
||||||
"tracing",
|
"tracing",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "tower-http"
|
|
||||||
version = "0.3.5"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "f873044bf02dd1e8239e9c1293ea39dad76dc594ec16185d0a1bf31d8dc8d858"
|
|
||||||
dependencies = [
|
|
||||||
"bitflags 1.3.2",
|
|
||||||
"bytes",
|
|
||||||
"futures-core",
|
|
||||||
"futures-util",
|
|
||||||
"http",
|
|
||||||
"http-body",
|
|
||||||
"http-range-header",
|
|
||||||
"pin-project-lite",
|
|
||||||
"tower",
|
|
||||||
"tower-layer",
|
|
||||||
"tower-service",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tower-http"
|
name = "tower-http"
|
||||||
version = "0.4.1"
|
version = "0.4.1"
|
||||||
|
|
|
@ -19,7 +19,7 @@ rust-version = "1.70.0"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
# Web framework
|
# Web framework
|
||||||
axum = { version = "0.5.16", default-features = false, features = ["form", "headers", "http1", "http2", "json", "matched-path"], optional = true }
|
axum = { version = "0.6.18", default-features = false, features = ["form", "headers", "http1", "http2", "json", "matched-path"], optional = true }
|
||||||
axum-server = { version = "0.5.1", features = ["tls-rustls"] }
|
axum-server = { version = "0.5.1", features = ["tls-rustls"] }
|
||||||
tower = { version = "0.4.13", features = ["util"] }
|
tower = { version = "0.4.13", features = ["util"] }
|
||||||
tower-http = { version = "0.4.1", features = ["add-extension", "cors", "sensitive-headers", "trace", "util"] }
|
tower-http = { version = "0.4.1", features = ["add-extension", "cors", "sensitive-headers", "trace", "util"] }
|
||||||
|
|
|
@ -3,18 +3,16 @@ use std::{collections::BTreeMap, iter::FromIterator, str};
|
||||||
use axum::{
|
use axum::{
|
||||||
async_trait,
|
async_trait,
|
||||||
body::{Full, HttpBody},
|
body::{Full, HttpBody},
|
||||||
extract::{
|
extract::{rejection::TypedHeaderRejectionReason, FromRequest, Path, TypedHeader},
|
||||||
rejection::TypedHeaderRejectionReason, FromRequest, Path, RequestParts, TypedHeader,
|
|
||||||
},
|
|
||||||
headers::{
|
headers::{
|
||||||
authorization::{Bearer, Credentials},
|
authorization::{Bearer, Credentials},
|
||||||
Authorization,
|
Authorization,
|
||||||
},
|
},
|
||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
BoxError,
|
BoxError, RequestExt, RequestPartsExt,
|
||||||
};
|
};
|
||||||
use bytes::{BufMut, Bytes, BytesMut};
|
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||||
use http::StatusCode;
|
use http::{Request, StatusCode};
|
||||||
use ruma::{
|
use ruma::{
|
||||||
api::{client::error::ErrorKind, AuthScheme, IncomingRequest, OutgoingResponse},
|
api::{client::error::ErrorKind, AuthScheme, IncomingRequest, OutgoingResponse},
|
||||||
CanonicalJsonValue, OwnedDeviceId, OwnedServerName, UserId,
|
CanonicalJsonValue, OwnedDeviceId, OwnedServerName, UserId,
|
||||||
|
@ -26,27 +24,44 @@ use super::{Ruma, RumaResponse};
|
||||||
use crate::{services, Error, Result};
|
use crate::{services, Error, Result};
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl<T, B> FromRequest<B> for Ruma<T>
|
impl<T, S, B> FromRequest<S, B> for Ruma<T>
|
||||||
where
|
where
|
||||||
T: IncomingRequest,
|
T: IncomingRequest,
|
||||||
B: HttpBody + Send,
|
B: HttpBody + Send + 'static,
|
||||||
B::Data: Send,
|
B::Data: Send,
|
||||||
B::Error: Into<BoxError>,
|
B::Error: Into<BoxError>,
|
||||||
{
|
{
|
||||||
type Rejection = Error;
|
type Rejection = Error;
|
||||||
|
|
||||||
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
async fn from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct QueryParams {
|
struct QueryParams {
|
||||||
access_token: Option<String>,
|
access_token: Option<String>,
|
||||||
user_id: Option<String>,
|
user_id: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
let metadata = T::METADATA;
|
let (mut parts, mut body) = match req.with_limited_body() {
|
||||||
let auth_header = Option::<TypedHeader<Authorization<Bearer>>>::from_request(req).await?;
|
Ok(limited_req) => {
|
||||||
let path_params = Path::<Vec<String>>::from_request(req).await?;
|
let (parts, body) = limited_req.into_parts();
|
||||||
|
let body = to_bytes(body)
|
||||||
|
.await
|
||||||
|
.map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?;
|
||||||
|
(parts, body)
|
||||||
|
}
|
||||||
|
Err(original_req) => {
|
||||||
|
let (parts, body) = original_req.into_parts();
|
||||||
|
let body = to_bytes(body)
|
||||||
|
.await
|
||||||
|
.map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?;
|
||||||
|
(parts, body)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let query = req.uri().query().unwrap_or_default();
|
let metadata = T::METADATA;
|
||||||
|
let auth_header: Option<TypedHeader<Authorization<Bearer>>> = parts.extract().await?;
|
||||||
|
let path_params: Path<Vec<String>> = parts.extract().await?;
|
||||||
|
|
||||||
|
let query = parts.uri.query().unwrap_or_default();
|
||||||
let query_params: QueryParams = match serde_html_form::from_str(query) {
|
let query_params: QueryParams = match serde_html_form::from_str(query) {
|
||||||
Ok(params) => params,
|
Ok(params) => params,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
@ -63,10 +78,6 @@ where
|
||||||
None => query_params.access_token.as_deref(),
|
None => query_params.access_token.as_deref(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut body = Bytes::from_request(req)
|
|
||||||
.await
|
|
||||||
.map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?;
|
|
||||||
|
|
||||||
let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok();
|
let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok();
|
||||||
|
|
||||||
let appservices = services().appservice.all().unwrap();
|
let appservices = services().appservice.all().unwrap();
|
||||||
|
@ -138,24 +149,24 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
AuthScheme::ServerSignatures => {
|
AuthScheme::ServerSignatures => {
|
||||||
let TypedHeader(Authorization(x_matrix)) =
|
let TypedHeader(Authorization(x_matrix)) = parts
|
||||||
TypedHeader::<Authorization<XMatrix>>::from_request(req)
|
.extract::<TypedHeader<Authorization<XMatrix>>>()
|
||||||
.await
|
.await
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
warn!("Missing or invalid Authorization header: {}", e);
|
warn!("Missing or invalid Authorization header: {}", e);
|
||||||
|
|
||||||
let msg = match e.reason() {
|
let msg = match e.reason() {
|
||||||
TypedHeaderRejectionReason::Missing => {
|
TypedHeaderRejectionReason::Missing => {
|
||||||
"Missing Authorization header."
|
"Missing Authorization header."
|
||||||
}
|
}
|
||||||
TypedHeaderRejectionReason::Error(_) => {
|
TypedHeaderRejectionReason::Error(_) => {
|
||||||
"Invalid X-Matrix signatures."
|
"Invalid X-Matrix signatures."
|
||||||
}
|
}
|
||||||
_ => "Unknown header-related error",
|
_ => "Unknown header-related error",
|
||||||
};
|
};
|
||||||
|
|
||||||
Error::BadRequest(ErrorKind::Forbidden, msg)
|
Error::BadRequest(ErrorKind::Forbidden, msg)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let origin_signatures = BTreeMap::from_iter([(
|
let origin_signatures = BTreeMap::from_iter([(
|
||||||
x_matrix.key.clone(),
|
x_matrix.key.clone(),
|
||||||
|
@ -170,11 +181,11 @@ where
|
||||||
let mut request_map = BTreeMap::from_iter([
|
let mut request_map = BTreeMap::from_iter([
|
||||||
(
|
(
|
||||||
"method".to_owned(),
|
"method".to_owned(),
|
||||||
CanonicalJsonValue::String(req.method().to_string()),
|
CanonicalJsonValue::String(parts.method.to_string()),
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"uri".to_owned(),
|
"uri".to_owned(),
|
||||||
CanonicalJsonValue::String(req.uri().to_string()),
|
CanonicalJsonValue::String(parts.uri.to_string()),
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"origin".to_owned(),
|
"origin".to_owned(),
|
||||||
|
@ -224,7 +235,7 @@ where
|
||||||
x_matrix.origin, e, request_map
|
x_matrix.origin, e, request_map
|
||||||
);
|
);
|
||||||
|
|
||||||
if req.uri().to_string().contains('@') {
|
if parts.uri.to_string().contains('@') {
|
||||||
warn!(
|
warn!(
|
||||||
"Request uri contained '@' character. Make sure your \
|
"Request uri contained '@' character. Make sure your \
|
||||||
reverse proxy gives Conduit the raw uri (apache: use \
|
reverse proxy gives Conduit the raw uri (apache: use \
|
||||||
|
@ -243,8 +254,8 @@ where
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut http_request = http::Request::builder().uri(req.uri()).method(req.method());
|
let mut http_request = http::Request::builder().uri(parts.uri).method(parts.method);
|
||||||
*http_request.headers_mut().unwrap() = req.headers().clone();
|
*http_request.headers_mut().unwrap() = parts.headers;
|
||||||
|
|
||||||
if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body {
|
if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body {
|
||||||
let user_id = sender_user.clone().unwrap_or_else(|| {
|
let user_id = sender_user.clone().unwrap_or_else(|| {
|
||||||
|
@ -364,3 +375,55 @@ impl<T: OutgoingResponse> IntoResponse for RumaResponse<T> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// copied from hyper under the following license:
|
||||||
|
// Copyright (c) 2014-2021 Sean McArthur
|
||||||
|
|
||||||
|
// Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
// of this software and associated documentation files (the "Software"), to deal
|
||||||
|
// in the Software without restriction, including without limitation the rights
|
||||||
|
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
// copies of the Software, and to permit persons to whom the Software is
|
||||||
|
// furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
// The above copyright notice and this permission notice shall be included in
|
||||||
|
// all copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||||
|
// THE SOFTWARE.
|
||||||
|
pub(crate) async fn to_bytes<T>(body: T) -> Result<Bytes, T::Error>
|
||||||
|
where
|
||||||
|
T: HttpBody,
|
||||||
|
{
|
||||||
|
futures_util::pin_mut!(body);
|
||||||
|
|
||||||
|
// If there's only 1 chunk, we can just return Buf::to_bytes()
|
||||||
|
let mut first = if let Some(buf) = body.data().await {
|
||||||
|
buf?
|
||||||
|
} else {
|
||||||
|
return Ok(Bytes::new());
|
||||||
|
};
|
||||||
|
|
||||||
|
let second = if let Some(buf) = body.data().await {
|
||||||
|
buf?
|
||||||
|
} else {
|
||||||
|
return Ok(first.copy_to_bytes(first.remaining()));
|
||||||
|
};
|
||||||
|
|
||||||
|
// With more than 1 buf, we gotta flatten into a Vec first.
|
||||||
|
let cap = first.remaining() + second.remaining() + body.size_hint().lower() as usize;
|
||||||
|
let mut vec = Vec::with_capacity(cap);
|
||||||
|
vec.put(first);
|
||||||
|
vec.put(second);
|
||||||
|
|
||||||
|
while let Some(buf) = body.data().await {
|
||||||
|
vec.put(buf?);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(vec.into())
|
||||||
|
}
|
||||||
|
|
|
@ -10,8 +10,7 @@
|
||||||
use std::{future::Future, io, net::SocketAddr, sync::atomic, time::Duration};
|
use std::{future::Future, io, net::SocketAddr, sync::atomic, time::Duration};
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{DefaultBodyLimit, FromRequest, MatchedPath},
|
extract::{DefaultBodyLimit, FromRequestParts, MatchedPath},
|
||||||
handler::Handler,
|
|
||||||
response::IntoResponse,
|
response::IntoResponse,
|
||||||
routing::{get, on, MethodFilter},
|
routing::{get, on, MethodFilter},
|
||||||
Router,
|
Router,
|
||||||
|
@ -421,7 +420,7 @@ fn routes() -> Router {
|
||||||
"/_matrix/client/v3/rooms/:room_id/initialSync",
|
"/_matrix/client/v3/rooms/:room_id/initialSync",
|
||||||
get(initial_sync),
|
get(initial_sync),
|
||||||
)
|
)
|
||||||
.fallback(not_found.into_service())
|
.fallback(not_found)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn shutdown_signal(handle: ServerHandle) {
|
async fn shutdown_signal(handle: ServerHandle) {
|
||||||
|
@ -505,7 +504,7 @@ macro_rules! impl_ruma_handler {
|
||||||
Fut: Future<Output = Result<Req::OutgoingResponse, E>>
|
Fut: Future<Output = Result<Req::OutgoingResponse, E>>
|
||||||
+ Send,
|
+ Send,
|
||||||
E: IntoResponse,
|
E: IntoResponse,
|
||||||
$( $ty: FromRequest<axum::body::Body> + Send + 'static, )*
|
$( $ty: FromRequestParts<()> + Send + 'static, )*
|
||||||
{
|
{
|
||||||
fn add_to_router(self, mut router: Router) -> Router {
|
fn add_to_router(self, mut router: Router) -> Router {
|
||||||
let meta = Req::METADATA;
|
let meta = Req::METADATA;
|
||||||
|
|
Loading…
Add table
Reference in a new issue