optimize api state extractor

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-07-30 01:25:07 +00:00
parent ccef1a4c8b
commit 1f88866612
6 changed files with 81 additions and 23 deletions

View file

@ -10,8 +10,7 @@ extern crate conduit_service as service;
pub(crate) use conduit::{debug_info, pdu::PduEvent, utils, Error, Result};
pub(crate) use service::services;
pub use self::router::State;
pub(crate) use self::router::{Ruma, RumaResponse};
pub(crate) use self::router::{Ruma, RumaResponse, State};
conduit::mod_ctor! {}
conduit::mod_dtor! {}

View file

@ -3,7 +3,7 @@ mod auth;
mod handler;
mod request;
mod response;
mod state;
pub mod state;
use axum::{
response::IntoResponse,
@ -14,8 +14,7 @@ use conduit::{err, Server};
use http::Uri;
use self::handler::RouterExt;
pub use self::state::State;
pub(super) use self::{args::Args as Ruma, response::RumaResponse};
pub(super) use self::{args::Args as Ruma, response::RumaResponse, state::State};
use crate::{client, server};
pub fn build(router: Router<State>, server: &Server) -> Router<State> {

View file

@ -2,21 +2,78 @@ use std::{ops::Deref, sync::Arc};
use conduit_service::Services;
#[derive(Clone)]
#[derive(Clone, Copy)]
pub struct State {
services: *const Services,
}
pub struct Guard {
services: Arc<Services>,
}
impl State {
pub fn new(services: Arc<Services>) -> Self {
Self {
services,
}
pub fn create(services: Arc<Services>) -> (State, Guard) {
let state = State {
services: Arc::into_raw(services.clone()),
};
let guard = Guard {
services,
};
(state, guard)
}
impl Drop for Guard {
fn drop(&mut self) {
let ptr = Arc::as_ptr(&self.services);
// SAFETY: Parity with Arc::into_raw() called in create(). This revivifies the
// Arc lost to State so it can be dropped, otherwise Services will leak.
let arc = unsafe { Arc::from_raw(ptr) };
debug_assert!(
Arc::strong_count(&arc) > 1,
"Services usually has more than one reference and is not dropped here"
);
}
}
impl Deref for State {
type Target = Arc<Services>;
type Target = Services;
fn deref(&self) -> &Self::Target { &self.services }
fn deref(&self) -> &Self::Target {
deref(&self.services).expect("dereferenced Services pointer in State must not be null")
}
}
/// SAFETY: State is a thin wrapper containing a raw const pointer to Services
/// in lieu of an Arc. Services is internally threadsafe. If State contains
/// additional fields this notice should be reevaluated.
unsafe impl Send for State {}
/// SAFETY: State is a thin wrapper containing a raw const pointer to Services
/// in lieu of an Arc. Services is internally threadsafe. If State contains
/// additional fields this notice should be reevaluated.
unsafe impl Sync for State {}
fn deref(services: &*const Services) -> Option<&Services> {
// SAFETY: We replaced Arc<Services> with *const Services in State. This is
// worth about 10 clones (20 reference count updates) for each request handled.
// Though this is not an incredibly large quantity, it's woefully unnecessary
// given the context as explained below; though it is not currently known to be
// a performance bottleneck, the front-line position justifies preempting it.
//
// Services is created prior to the axum/tower stack and Router, and prior
// to serving any requests through the handlers in this crate. It is then
// dropped only after all requests have completed, the listening sockets
// have been closed, axum/tower has been dropped. Thus Services is
// expected to live at least as long as any instance of State, making the
// constant updates to the prior Arc unnecessary to keep Services alive.
//
// Nevertheless if it is possible to accomplish this by annotating State
// with a lifetime to hold a reference (and be aware I have made a
// significant effort trying to make this work) this unsafety may not be
// necessary. It is either very difficult or impossible to get a
// lifetime'ed reference through Router / RumaHandler; though it is
// possible to pass a reference through axum's `with_state()` in trivial
// configurations as the only requirement of a State is Clone.
unsafe { services.as_ref() }
}

View file

@ -6,6 +6,7 @@ use axum::{
};
use axum_client_ip::SecureClientIpSource;
use conduit::{error, Result, Server};
use conduit_api::router::state::Guard;
use conduit_service::Services;
use http::{
header::{self, HeaderName},
@ -35,7 +36,7 @@ const CONDUWUIT_CSP: &[&str] = &[
const CONDUWUIT_PERMISSIONS_POLICY: &[&str] = &["interest-cohort=()", "browsing-topics=()"];
pub(crate) fn build(services: &Arc<Services>) -> Result<Router> {
pub(crate) fn build(services: &Arc<Services>) -> Result<(Router, Guard)> {
let server = &services.server;
let layers = ServiceBuilder::new();
@ -85,7 +86,8 @@ pub(crate) fn build(services: &Arc<Services>) -> Result<Router> {
.layer(body_limit_layer(server))
.layer(CatchPanicLayer::custom(catch_panic));
Ok(router::build(services).layer(layers))
let (router, guard) = router::build(services);
Ok((router.layer(layers), guard))
}
#[cfg(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression"))]

View file

@ -2,19 +2,20 @@ use std::sync::Arc;
use axum::{response::IntoResponse, routing::get, Router};
use conduit::Error;
use conduit_api::State;
use conduit_api::router::{state, state::Guard};
use conduit_service::Services;
use http::{StatusCode, Uri};
use ruma::api::client::error::ErrorKind;
pub(crate) fn build(services: &Arc<Services>) -> Router {
let router = Router::<State>::new();
let state = State::new(services.clone());
conduit_api::router::build(router, &services.server)
pub(crate) fn build(services: &Arc<Services>) -> (Router, Guard) {
let router = Router::<state::State>::new();
let (state, guard) = state::create(services.clone());
let router = conduit_api::router::build(router, &services.server)
.route("/", get(it_works))
.fallback(not_found)
.with_state(state)
.with_state(state);
(router, guard)
}
async fn not_found(_uri: Uri) -> impl IntoResponse {

View file

@ -18,7 +18,7 @@ pub(super) async fn serve(
let server = &services.server;
let config = &server.config;
let addrs = config.get_bind_addrs();
let app = layers::build(&services)?;
let (app, _guard) = layers::build(&services)?;
if cfg!(unix) && config.unix_socket_path.is_some() {
unix::serve(server, app, shutdown).await