diff --git a/src/core/server.rs b/src/core/server.rs index 95eb7e01..ca62b64c 100644 --- a/src/core/server.rs +++ b/src/core/server.rs @@ -5,7 +5,7 @@ use std::{ use tokio::{runtime, sync::broadcast}; -use crate::{config::Config, log}; +use crate::{config::Config, log, Error, Result}; /// Server runtime state; public portion pub struct Server { @@ -59,6 +59,38 @@ impl Server { } } + pub fn reload(&self) -> Result<()> { + if cfg!(not(conduit_mods)) { + return Err(Error::Err("Reloading not enabled".into())); + } + + if self.reloading.swap(true, Ordering::AcqRel) { + return Err(Error::Err("Reloading already in progress".into())); + } + + if self.stopping.swap(true, Ordering::AcqRel) { + return Err(Error::Err("Shutdown already in progress".into())); + } + + self.signal("SIGINT") + } + + pub fn shutdown(&self) -> Result<()> { + if self.stopping.swap(true, Ordering::AcqRel) { + return Err(Error::Err("Shutdown already in progress".into())); + } + + self.signal("SIGTERM") + } + + pub fn signal(&self, sig: &'static str) -> Result<()> { + if let Err(e) = self.signal.send(sig) { + return Err(Error::Err(format!("Failed to send signal: {e}"))); + } + + Ok(()) + } + #[inline] pub fn runtime(&self) -> &runtime::Handle { self.runtime diff --git a/src/main/main.rs b/src/main/main.rs index 6e1bfe38..566b0d67 100644 --- a/src/main/main.rs +++ b/src/main/main.rs @@ -4,11 +4,7 @@ mod server; extern crate conduit_core as conduit; -use std::{ - cmp, - sync::{atomic::Ordering, Arc}, - time::Duration, -}; +use std::{cmp, sync::Arc, time::Duration}; use conduit::{debug_error, debug_info, error, trace, utils::available_parallelism, warn, Error, Result}; use server::Server; @@ -107,7 +103,7 @@ async fn signal(server: Arc) { use unix::SignalKind; const CONSOLE: bool = cfg!(feature = "console"); - const RELOADING: bool = cfg!(all(unix, debug_assertions, not(CONSOLE))); + const RELOADING: bool = cfg!(all(conduit_mods, not(CONSOLE))); let mut quit = unix::signal(SignalKind::quit()).expect("SIGQUIT handler"); let mut term = unix::signal(SignalKind::terminate()).expect("SIGTERM handler"); @@ -120,19 +116,17 @@ async fn signal(server: Arc) { _ = term.recv() => { sig = "SIGTERM"; }, } - // Indicate the SIGINT is requesting a hot-reload. - if RELOADING && sig == "SIGINT" { - server.server.reloading.store(true, Ordering::Release); - } - - // Indicate the signal is requesting a shutdown - if matches!(sig, "SIGQUIT" | "SIGTERM") || (!CONSOLE && sig == "SIGINT") { - server.server.stopping.store(true, Ordering::Release); - } - warn!("Received {sig}"); - if let Err(e) = server.server.signal.send(sig) { - debug_error!("signal channel: {e}"); + let result = if RELOADING && sig == "SIGINT" { + server.server.reload() + } else if matches!(sig, "SIGQUIT" | "SIGTERM") || (!CONSOLE && sig == "SIGINT") { + server.server.shutdown() + } else { + server.server.signal(sig) + }; + + if let Err(e) = result { + debug_error!(?sig, "signal: {e}"); } } } diff --git a/src/main/mods.rs b/src/main/mods.rs index 125cd54a..f39527f7 100644 --- a/src/main/mods.rs +++ b/src/main/mods.rs @@ -39,6 +39,7 @@ pub(crate) async fn run(server: &Arc, starts: bool) -> Result<(bool, boo return Err(error); } } + server.server.stopping.store(false, Ordering::Release); let run = main_mod.get::("run")?; if let Err(error) = run(&server.server).await { error!("Running server: {error}"); diff --git a/src/router/request.rs b/src/router/request.rs index 0693e671..2bdd06f8 100644 --- a/src/router/request.rs +++ b/src/router/request.rs @@ -13,7 +13,7 @@ use tracing::{debug, error, trace}; pub(crate) async fn spawn( State(server): State>, req: http::Request, next: axum::middleware::Next, ) -> Result { - if server.stopping.load(Ordering::Relaxed) { + if !server.running() { debug_warn!("unavailable pending shutdown"); return Err(StatusCode::SERVICE_UNAVAILABLE); } @@ -35,7 +35,7 @@ pub(crate) async fn spawn( pub(crate) async fn handle( State(server): State>, req: http::Request, next: axum::middleware::Next, ) -> Result { - if server.stopping.load(Ordering::Relaxed) { + if !server.running() { debug_warn!( method = %req.method(), uri = %req.uri(), diff --git a/src/router/run.rs b/src/router/run.rs index e9876ef3..fb59c797 100644 --- a/src/router/run.rs +++ b/src/router/run.rs @@ -25,7 +25,6 @@ pub(crate) async fn run(server: Arc) -> Result<(), Error> { _ = services().admin.handle.lock().await.insert(admin::handle); // Setup shutdown/signal handling - server.stopping.store(false, Ordering::Release); let handle = ServerHandle::new(); let (tx, _) = broadcast::channel::<()>(1); let sigs = server diff --git a/src/service/services.rs b/src/service/services.rs index d9805010..e1c27d52 100644 --- a/src/service/services.rs +++ b/src/service/services.rs @@ -1,6 +1,6 @@ use std::{ collections::{BTreeMap, HashMap}, - sync::{atomic, Arc, Mutex as StdMutex}, + sync::{Arc, Mutex as StdMutex}, }; use conduit::{debug_info, Result, Server}; @@ -301,8 +301,6 @@ bad_signature_ratelimiter: {bad_signature_ratelimiter} pub async fn interrupt(&self) { trace!("Interrupting services..."); - self.server.stopping.store(true, atomic::Ordering::Release); - self.sending.interrupt(); self.presence.interrupt(); self.admin.interrupt(); @@ -310,7 +308,6 @@ bad_signature_ratelimiter: {bad_signature_ratelimiter} trace!("Services interrupt complete."); } - #[tracing::instrument(skip_all)] pub async fn stop(&self) { info!("Shutting down services"); self.interrupt().await;