diff --git a/src/config/mod.rs b/src/config/mod.rs index f5e95c42..b5e9002b 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -5,6 +5,8 @@ use std::{ path::PathBuf, }; +use either::Either; + use figment::Figment; use itertools::Itertools; @@ -17,15 +19,22 @@ mod proxy; use self::proxy::ProxyConfig; +#[derive(Deserialize, Clone, Debug)] +#[serde(transparent)] +pub struct ListeningPort { + #[serde(with = "either::serde_untagged")] + pub ports: Either>, +} + /// all the config options for conduwuit #[derive(Clone, Debug, Deserialize)] pub struct Config { /// [`IpAddr`] conduwuit will listen on (can be IPv4 or IPv6) #[serde(default = "default_address")] pub address: IpAddr, - /// default TCP port conduwuit will listen on + /// default TCP port(s) conduwuit will listen on #[serde(default = "default_port")] - pub port: u16, + pub port: ListeningPort, pub tls: Option, pub unix_socket_path: Option, #[serde(default = "default_unix_socket_perms")] @@ -400,8 +409,10 @@ fn default_address() -> IpAddr { Ipv4Addr::LOCALHOST.into() } -fn default_port() -> u16 { - 8000 +fn default_port() -> ListeningPort { + ListeningPort { + ports: Either::Left(8008), + } } fn default_unix_socket_perms() -> u32 { diff --git a/src/main.rs b/src/main.rs index fd8e4838..2026652a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -11,6 +11,7 @@ use axum::{ }; use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle}; use conduit::api::{client_server, server_server}; +use either::Either::{Left, Right}; use figment::{ providers::{Env, Format, Toml}, Figment, @@ -28,7 +29,7 @@ use ruma::api::{ }, IncomingRequest, }; -use tokio::{net::UnixListener, signal, sync::oneshot}; +use tokio::{net::UnixListener, signal, sync::oneshot, task::JoinSet}; use tower::ServiceBuilder; use tower_http::{ cors::{self, CorsLayer}, @@ -244,7 +245,24 @@ async fn main() { async fn run_server() -> io::Result<()> { let config = &services().globals.config; - let addr = SocketAddr::from((config.address, config.port)); + + let addrs = match &config.port.ports { + Left(port) => { + // Left is only 1 value, so make a vec with 1 value only + let port_vec = [port]; + + port_vec + .iter() + .copied() + .map(|port| SocketAddr::from((config.address, *port))) + .collect::>() + } + Right(ports) => ports + .iter() + .copied() + .map(|port| SocketAddr::from((config.address, port))) + .collect::>(), + }; let x_requested_with = HeaderName::from_static("x-requested-with"); @@ -342,22 +360,33 @@ async fn run_server() -> io::Result<()> { match &config.tls { Some(tls) => { let conf = RustlsConfig::from_pem_file(&tls.certs, &tls.key).await?; - let server = bind_rustls(addr, conf).handle(handle).serve(app); + + let mut join_set = JoinSet::new(); + for addr in &addrs { + join_set.spawn( + bind_rustls(*addr, conf.clone()) + .handle(handle.clone()) + .serve(app.clone()), + ); + } #[cfg(feature = "systemd")] let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); - info!("Listening on {}", addr); - server.await? + info!("Listening on {:?}", addrs); + join_set.join_next().await; } None => { - let server = bind(addr).handle(handle).serve(app); + let mut join_set = JoinSet::new(); + for addr in &addrs { + join_set.spawn(bind(*addr).handle(handle.clone()).serve(app.clone())); + } #[cfg(feature = "systemd")] let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); - info!("Listening on {}", addr); - server.await? + info!("Listening on {:?}", addrs); + join_set.join_next().await; } } }