diff --git a/src/api/mod.rs b/src/api/mod.rs index 00ae35cd..82b857db 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -10,8 +10,7 @@ extern crate conduit_service as service; pub(crate) use conduit::{debug_info, pdu::PduEvent, utils, Error, Result}; pub(crate) use service::services; -pub use self::router::State; -pub(crate) use self::router::{Ruma, RumaResponse}; +pub(crate) use self::router::{Ruma, RumaResponse, State}; conduit::mod_ctor! {} conduit::mod_dtor! {} diff --git a/src/api/router.rs b/src/api/router.rs index cff58c32..7d6df16c 100644 --- a/src/api/router.rs +++ b/src/api/router.rs @@ -3,7 +3,7 @@ mod auth; mod handler; mod request; mod response; -mod state; +pub mod state; use axum::{ response::IntoResponse, @@ -14,8 +14,7 @@ use conduit::{err, Server}; use http::Uri; use self::handler::RouterExt; -pub use self::state::State; -pub(super) use self::{args::Args as Ruma, response::RumaResponse}; +pub(super) use self::{args::Args as Ruma, response::RumaResponse, state::State}; use crate::{client, server}; pub fn build(router: Router, server: &Server) -> Router { diff --git a/src/api/router/state.rs b/src/api/router/state.rs index b1149d89..24163b3c 100644 --- a/src/api/router/state.rs +++ b/src/api/router/state.rs @@ -2,21 +2,78 @@ use std::{ops::Deref, sync::Arc}; use conduit_service::Services; -#[derive(Clone)] +#[derive(Clone, Copy)] pub struct State { + services: *const Services, +} + +pub struct Guard { services: Arc, } -impl State { - pub fn new(services: Arc) -> Self { - Self { - services, - } +pub fn create(services: Arc) -> (State, Guard) { + let state = State { + services: Arc::into_raw(services.clone()), + }; + + let guard = Guard { + services, + }; + + (state, guard) +} + +impl Drop for Guard { + fn drop(&mut self) { + let ptr = Arc::as_ptr(&self.services); + // SAFETY: Parity with Arc::into_raw() called in create(). This revivifies the + // Arc lost to State so it can be dropped, otherwise Services will leak. + let arc = unsafe { Arc::from_raw(ptr) }; + debug_assert!( + Arc::strong_count(&arc) > 1, + "Services usually has more than one reference and is not dropped here" + ); } } impl Deref for State { - type Target = Arc; + type Target = Services; - fn deref(&self) -> &Self::Target { &self.services } + fn deref(&self) -> &Self::Target { + deref(&self.services).expect("dereferenced Services pointer in State must not be null") + } +} + +/// SAFETY: State is a thin wrapper containing a raw const pointer to Services +/// in lieu of an Arc. Services is internally threadsafe. If State contains +/// additional fields this notice should be reevaluated. +unsafe impl Send for State {} + +/// SAFETY: State is a thin wrapper containing a raw const pointer to Services +/// in lieu of an Arc. Services is internally threadsafe. If State contains +/// additional fields this notice should be reevaluated. +unsafe impl Sync for State {} + +fn deref(services: &*const Services) -> Option<&Services> { + // SAFETY: We replaced Arc with *const Services in State. This is + // worth about 10 clones (20 reference count updates) for each request handled. + // Though this is not an incredibly large quantity, it's woefully unnecessary + // given the context as explained below; though it is not currently known to be + // a performance bottleneck, the front-line position justifies preempting it. + // + // Services is created prior to the axum/tower stack and Router, and prior + // to serving any requests through the handlers in this crate. It is then + // dropped only after all requests have completed, the listening sockets + // have been closed, axum/tower has been dropped. Thus Services is + // expected to live at least as long as any instance of State, making the + // constant updates to the prior Arc unnecessary to keep Services alive. + // + // Nevertheless if it is possible to accomplish this by annotating State + // with a lifetime to hold a reference (and be aware I have made a + // significant effort trying to make this work) this unsafety may not be + // necessary. It is either very difficult or impossible to get a + // lifetime'ed reference through Router / RumaHandler; though it is + // possible to pass a reference through axum's `with_state()` in trivial + // configurations as the only requirement of a State is Clone. + unsafe { services.as_ref() } } diff --git a/src/router/layers.rs b/src/router/layers.rs index 2b143666..7779260d 100644 --- a/src/router/layers.rs +++ b/src/router/layers.rs @@ -6,6 +6,7 @@ use axum::{ }; use axum_client_ip::SecureClientIpSource; use conduit::{error, Result, Server}; +use conduit_api::router::state::Guard; use conduit_service::Services; use http::{ header::{self, HeaderName}, @@ -35,7 +36,7 @@ const CONDUWUIT_CSP: &[&str] = &[ const CONDUWUIT_PERMISSIONS_POLICY: &[&str] = &["interest-cohort=()", "browsing-topics=()"]; -pub(crate) fn build(services: &Arc) -> Result { +pub(crate) fn build(services: &Arc) -> Result<(Router, Guard)> { let server = &services.server; let layers = ServiceBuilder::new(); @@ -85,7 +86,8 @@ pub(crate) fn build(services: &Arc) -> Result { .layer(body_limit_layer(server)) .layer(CatchPanicLayer::custom(catch_panic)); - Ok(router::build(services).layer(layers)) + let (router, guard) = router::build(services); + Ok((router.layer(layers), guard)) } #[cfg(any(feature = "zstd_compression", feature = "gzip_compression", feature = "brotli_compression"))] diff --git a/src/router/router.rs b/src/router/router.rs index e59b088b..31b3f3e4 100644 --- a/src/router/router.rs +++ b/src/router/router.rs @@ -2,19 +2,20 @@ use std::sync::Arc; use axum::{response::IntoResponse, routing::get, Router}; use conduit::Error; -use conduit_api::State; +use conduit_api::router::{state, state::Guard}; use conduit_service::Services; use http::{StatusCode, Uri}; use ruma::api::client::error::ErrorKind; -pub(crate) fn build(services: &Arc) -> Router { - let router = Router::::new(); - let state = State::new(services.clone()); - - conduit_api::router::build(router, &services.server) +pub(crate) fn build(services: &Arc) -> (Router, Guard) { + let router = Router::::new(); + let (state, guard) = state::create(services.clone()); + let router = conduit_api::router::build(router, &services.server) .route("/", get(it_works)) .fallback(not_found) - .with_state(state) + .with_state(state); + + (router, guard) } async fn not_found(_uri: Uri) -> impl IntoResponse { diff --git a/src/router/serve/mod.rs b/src/router/serve/mod.rs index 58bf2de8..9e171008 100644 --- a/src/router/serve/mod.rs +++ b/src/router/serve/mod.rs @@ -18,7 +18,7 @@ pub(super) async fn serve( let server = &services.server; let config = &server.config; let addrs = config.get_bind_addrs(); - let app = layers::build(&services)?; + let (app, _guard) = layers::build(&services)?; if cfg!(unix) && config.unix_socket_path.is_some() { unix::serve(server, app, shutdown).await