add support for listening on multiple ports

retains existing config compatibility using either:
`port = 6167`
`port = [80, 443, 8448]`

Co-authored-by: Charles Hall <charles@computer.surgery>
Signed-off-by: strawberry <strawberry@puppygock.gay>
This commit is contained in:
strawberry 2024-02-24 15:34:38 -05:00 committed by June
parent 99f7dad939
commit 67b307c75b
2 changed files with 52 additions and 12 deletions

View file

@ -5,6 +5,8 @@ use std::{
path::PathBuf, path::PathBuf,
}; };
use either::Either;
use figment::Figment; use figment::Figment;
use itertools::Itertools; use itertools::Itertools;
@ -17,15 +19,22 @@ mod proxy;
use self::proxy::ProxyConfig; use self::proxy::ProxyConfig;
#[derive(Deserialize, Clone, Debug)]
#[serde(transparent)]
pub struct ListeningPort {
#[serde(with = "either::serde_untagged")]
pub ports: Either<u16, Vec<u16>>,
}
/// all the config options for conduwuit /// all the config options for conduwuit
#[derive(Clone, Debug, Deserialize)] #[derive(Clone, Debug, Deserialize)]
pub struct Config { pub struct Config {
/// [`IpAddr`] conduwuit will listen on (can be IPv4 or IPv6) /// [`IpAddr`] conduwuit will listen on (can be IPv4 or IPv6)
#[serde(default = "default_address")] #[serde(default = "default_address")]
pub address: IpAddr, pub address: IpAddr,
/// default TCP port conduwuit will listen on /// default TCP port(s) conduwuit will listen on
#[serde(default = "default_port")] #[serde(default = "default_port")]
pub port: u16, pub port: ListeningPort,
pub tls: Option<TlsConfig>, pub tls: Option<TlsConfig>,
pub unix_socket_path: Option<PathBuf>, pub unix_socket_path: Option<PathBuf>,
#[serde(default = "default_unix_socket_perms")] #[serde(default = "default_unix_socket_perms")]
@ -400,8 +409,10 @@ fn default_address() -> IpAddr {
Ipv4Addr::LOCALHOST.into() Ipv4Addr::LOCALHOST.into()
} }
fn default_port() -> u16 { fn default_port() -> ListeningPort {
8000 ListeningPort {
ports: Either::Left(8008),
}
} }
fn default_unix_socket_perms() -> u32 { fn default_unix_socket_perms() -> u32 {

View file

@ -11,6 +11,7 @@ use axum::{
}; };
use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle}; use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle};
use conduit::api::{client_server, server_server}; use conduit::api::{client_server, server_server};
use either::Either::{Left, Right};
use figment::{ use figment::{
providers::{Env, Format, Toml}, providers::{Env, Format, Toml},
Figment, Figment,
@ -28,7 +29,7 @@ use ruma::api::{
}, },
IncomingRequest, IncomingRequest,
}; };
use tokio::{net::UnixListener, signal, sync::oneshot}; use tokio::{net::UnixListener, signal, sync::oneshot, task::JoinSet};
use tower::ServiceBuilder; use tower::ServiceBuilder;
use tower_http::{ use tower_http::{
cors::{self, CorsLayer}, cors::{self, CorsLayer},
@ -244,7 +245,24 @@ async fn main() {
async fn run_server() -> io::Result<()> { async fn run_server() -> io::Result<()> {
let config = &services().globals.config; let config = &services().globals.config;
let addr = SocketAddr::from((config.address, config.port));
let addrs = match &config.port.ports {
Left(port) => {
// Left is only 1 value, so make a vec with 1 value only
let port_vec = [port];
port_vec
.iter()
.copied()
.map(|port| SocketAddr::from((config.address, *port)))
.collect::<Vec<_>>()
}
Right(ports) => ports
.iter()
.copied()
.map(|port| SocketAddr::from((config.address, port)))
.collect::<Vec<_>>(),
};
let x_requested_with = HeaderName::from_static("x-requested-with"); let x_requested_with = HeaderName::from_static("x-requested-with");
@ -342,22 +360,33 @@ async fn run_server() -> io::Result<()> {
match &config.tls { match &config.tls {
Some(tls) => { Some(tls) => {
let conf = RustlsConfig::from_pem_file(&tls.certs, &tls.key).await?; let conf = RustlsConfig::from_pem_file(&tls.certs, &tls.key).await?;
let server = bind_rustls(addr, conf).handle(handle).serve(app);
let mut join_set = JoinSet::new();
for addr in &addrs {
join_set.spawn(
bind_rustls(*addr, conf.clone())
.handle(handle.clone())
.serve(app.clone()),
);
}
#[cfg(feature = "systemd")] #[cfg(feature = "systemd")]
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]);
info!("Listening on {}", addr); info!("Listening on {:?}", addrs);
server.await? join_set.join_next().await;
} }
None => { None => {
let server = bind(addr).handle(handle).serve(app); let mut join_set = JoinSet::new();
for addr in &addrs {
join_set.spawn(bind(*addr).handle(handle.clone()).serve(app.clone()));
}
#[cfg(feature = "systemd")] #[cfg(feature = "systemd")]
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]);
info!("Listening on {}", addr); info!("Listening on {:?}", addrs);
server.await? join_set.join_next().await;
} }
} }
} }