diff --git a/Cargo.lock b/Cargo.lock index 8453335a..e956b4ea 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -426,6 +426,7 @@ dependencies = [ "threadpool", "tikv-jemallocator", "tokio", + "toml_edit 0.22.12", "tower", "tower-http", "tracing", @@ -1818,12 +1819,11 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro-crate" -version = "2.0.2" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b00f26d3400549137f92511a46ac1cd8ce37cb5598a96d382381458b992a5d24" +checksum = "7e8366a6159044a37876a2b9817124296703c586a5c92e2c53751fa06d8d43e8" dependencies = [ - "toml_datetime", - "toml_edit", + "toml_edit 0.20.2", ] [[package]] @@ -2871,14 +2871,14 @@ dependencies = [ "serde", "serde_spanned", "toml_datetime", - "toml_edit", + "toml_edit 0.20.2", ] [[package]] name = "toml_datetime" -version = "0.6.3" +version = "0.6.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7cda73e2f1397b1262d6dfdcef8aafae14d1de7748d66822d3bfeeb6d03e5e4b" +checksum = "3550f4e9685620ac18a50ed434eb3aec30db8ba93b0287467bca5826ea25baf1" dependencies = [ "serde", ] @@ -2893,7 +2893,18 @@ dependencies = [ "serde", "serde_spanned", "toml_datetime", - "winnow", + "winnow 0.5.40", +] + +[[package]] +name = "toml_edit" +version = "0.22.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3328d4f68a705b2a4498da1d580585d39a6510f98318a2cec3018a7ec61ddef" +dependencies = [ + "indexmap 2.2.5", + "toml_datetime", + "winnow 0.6.6", ] [[package]] @@ -3437,6 +3448,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "winnow" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0c976aaaa0e1f90dbb21e9587cdaf1d9679a1cde8875c0d6bd83ab96a208352" +dependencies = [ + "memchr", +] + [[package]] name = "winreg" version = "0.50.0" diff --git a/Cargo.toml b/Cargo.toml index 1ab0798f..99d66006 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -106,6 +106,8 @@ clap = { version = "4.3.0", default-features = false, features = ["std", "derive futures-util = { version = "0.3.28", default-features = false } # Used for reading the configuration from conduit.toml & environment variables figment = { version = "0.10.8", features = ["env", "toml"] } +# Used to generate the default config file +toml_edit = "0.22" # Validating urls in config url = { version = "2", features = ["serde"] } diff --git a/nix/pkgs/default/default.nix b/nix/pkgs/default/default.nix index 4577fea9..3f680bb9 100644 --- a/nix/pkgs/default/default.nix +++ b/nix/pkgs/default/default.nix @@ -54,6 +54,7 @@ let include = [ "Cargo.lock" "Cargo.toml" + "conduit-example.toml" "src" ]; }; diff --git a/src/clap.rs b/src/clap.rs index 170d2a17..68c2cca0 100644 --- a/src/clap.rs +++ b/src/clap.rs @@ -1,6 +1,6 @@ //! Integration with `clap` -use clap::Parser; +use clap::{Parser, Subcommand}; /// Returns the current version of the crate with extra info if supplied /// @@ -19,7 +19,16 @@ fn version() -> String { /// Command line arguments #[derive(Parser)] #[clap(about, version = version())] -pub struct Args {} +pub struct Args { + #[command(subcommand)] + pub command: Option, +} + +#[derive(Subcommand)] +pub enum Commands { + /// Generates a default config file + GenerateConfig, +} /// Parse command line arguments into structured data pub fn parse() -> Args { diff --git a/src/lib.rs b/src/lib.rs index 5a89f805..424387a8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,7 +3,7 @@ pub mod clap; mod config; mod database; mod service; -mod utils; +pub mod utils; // Not async due to services() being used in many closures, and async closures are not stable as of writing // This is the case for every other occurence of sync Mutex/RwLock, except for database related ones, where diff --git a/src/main.rs b/src/main.rs index 5d60a6bf..3af475d9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,7 +7,11 @@ use axum::{ Router, }; use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle}; -use conduit::api::{client_server, server_server}; +use conduit::{ + api::{client_server, server_server}, + clap::Commands, + utils::random_string, +}; use figment::{ providers::{Env, Format, Toml}, Figment, @@ -23,7 +27,12 @@ use ruma::api::{ }, IncomingRequest, }; -use tokio::signal; +use tokio::{ + fs::{try_exists, File}, + io::AsyncWriteExt, + signal, +}; +use toml_edit::DocumentMut; use tower::ServiceBuilder; use tower_http::{ cors::{self, CorsLayer}, @@ -44,100 +53,128 @@ static GLOBAL: Jemalloc = Jemalloc; #[tokio::main] async fn main() { - clap::parse(); + let cli = clap::parse(); - // Initialize config - let raw_config = - Figment::new() - .merge( - Toml::file(Env::var("CONDUIT_CONFIG").expect( - "The CONDUIT_CONFIG env var needs to be set. Example: /etc/conduit.toml", - )) - .nested(), - ) - .merge(Env::prefixed("CONDUIT_").global()); + let path = + Env::var("CONDUIT_CONFIG") + .expect("The config path must either be set via the -c/--config flag or the CONDUIT_CONFIG env var. Example: /etc/conduit.toml") + ; - let config = match raw_config.extract::() { - Ok(s) => s, - Err(e) => { - eprintln!("It looks like your config is invalid. The following error occurred: {e}"); - std::process::exit(1); + match cli.command { + Some(Commands::GenerateConfig) => { + let toml = include_str!("../conduit-example.toml"); + let mut doc = toml.parse::().expect("invalid doc"); + doc["global"]["registration_token"] = toml_edit::value(random_string(64)); + + if let Ok(true) = try_exists(path.clone()).await { + panic!("Error: file '{}' already exists", path); + // Any possible error should be caught on creation + } else { + match File::create(path).await { + Ok(mut file) => match file.write(&doc.to_string().into_bytes()).await { + Err(e) => panic!("Error writing config file: {e}"), + Ok(_) => { + println!("Successfully generated config file"); + } + }, + Err(e) => panic!("Error creating config file: {e}"), + } + } } - }; + None => { + // Initialize config + let raw_config = Figment::new() + .merge(Toml::file(path).nested()) + .merge(Env::prefixed("CONDUIT_").global()); - config.warn_deprecated(); + let config = match raw_config.extract::() { + Ok(s) => s, + Err(e) => { + eprintln!( + "It looks like your config is invalid. The following error occurred: {e}" + ); + std::process::exit(1); + } + }; - if config.allow_jaeger { - opentelemetry::global::set_text_map_propagator(opentelemetry_jaeger::Propagator::new()); - let tracer = opentelemetry_jaeger::new_agent_pipeline() - .with_auto_split_batch(true) - .with_service_name("conduit") - .install_batch(opentelemetry::runtime::Tokio) - .unwrap(); - let telemetry = tracing_opentelemetry::layer().with_tracer(tracer); + config.warn_deprecated(); - let filter_layer = match EnvFilter::try_new(&config.log) { - Ok(s) => s, - Err(e) => { - eprintln!( + if config.allow_jaeger { + opentelemetry::global::set_text_map_propagator( + opentelemetry_jaeger::Propagator::new(), + ); + let tracer = opentelemetry_jaeger::new_agent_pipeline() + .with_auto_split_batch(true) + .with_service_name("conduit") + .install_batch(opentelemetry::runtime::Tokio) + .unwrap(); + let telemetry = tracing_opentelemetry::layer().with_tracer(tracer); + + let filter_layer = match EnvFilter::try_new(&config.log) { + Ok(s) => s, + Err(e) => { + eprintln!( "It looks like your log config is invalid. The following error occurred: {e}" ); - EnvFilter::try_new("warn").unwrap() + EnvFilter::try_new("warn").unwrap() + } + }; + + let subscriber = tracing_subscriber::Registry::default() + .with(filter_layer) + .with(telemetry); + tracing::subscriber::set_global_default(subscriber).unwrap(); + } else if config.tracing_flame { + let registry = tracing_subscriber::Registry::default(); + let (flame_layer, _guard) = + tracing_flame::FlameLayer::with_file("./tracing.folded").unwrap(); + let flame_layer = flame_layer.with_empty_samples(false); + + let filter_layer = EnvFilter::new("trace,h2=off"); + + let subscriber = registry.with(filter_layer).with(flame_layer); + tracing::subscriber::set_global_default(subscriber).unwrap(); + } else { + let registry = tracing_subscriber::Registry::default(); + let fmt_layer = tracing_subscriber::fmt::Layer::new(); + let filter_layer = match EnvFilter::try_new(&config.log) { + Ok(s) => s, + Err(e) => { + eprintln!("It looks like your config is invalid. The following error occured while parsing it: {e}"); + EnvFilter::try_new("warn").unwrap() + } + }; + + let subscriber = registry.with(filter_layer).with(fmt_layer); + tracing::subscriber::set_global_default(subscriber).unwrap(); } - }; - let subscriber = tracing_subscriber::Registry::default() - .with(filter_layer) - .with(telemetry); - tracing::subscriber::set_global_default(subscriber).unwrap(); - } else if config.tracing_flame { - let registry = tracing_subscriber::Registry::default(); - let (flame_layer, _guard) = - tracing_flame::FlameLayer::with_file("./tracing.folded").unwrap(); - let flame_layer = flame_layer.with_empty_samples(false); + // This is needed for opening lots of file descriptors, which tends to + // happen more often when using RocksDB and making lots of federation + // connections at startup. The soft limit is usually 1024, and the hard + // limit is usually 512000; I've personally seen it hit >2000. + // + // * https://www.freedesktop.org/software/systemd/man/systemd.exec.html#id-1.12.2.1.17.6 + // * https://github.com/systemd/systemd/commit/0abf94923b4a95a7d89bc526efc84e7ca2b71741 + #[cfg(unix)] + maximize_fd_limit() + .expect("should be able to increase the soft limit to the hard limit"); - let filter_layer = EnvFilter::new("trace,h2=off"); + info!("Loading database"); + if let Err(error) = KeyValueDatabase::load_or_create(config).await { + error!(?error, "The database couldn't be loaded or created"); - let subscriber = registry.with(filter_layer).with(flame_layer); - tracing::subscriber::set_global_default(subscriber).unwrap(); - } else { - let registry = tracing_subscriber::Registry::default(); - let fmt_layer = tracing_subscriber::fmt::Layer::new(); - let filter_layer = match EnvFilter::try_new(&config.log) { - Ok(s) => s, - Err(e) => { - eprintln!("It looks like your config is invalid. The following error occured while parsing it: {e}"); - EnvFilter::try_new("warn").unwrap() + std::process::exit(1); + }; + let config = &services().globals.config; + + info!("Starting server"); + run_server().await.unwrap(); + + if config.allow_jaeger { + opentelemetry::global::shutdown_tracer_provider(); } - }; - - let subscriber = registry.with(filter_layer).with(fmt_layer); - tracing::subscriber::set_global_default(subscriber).unwrap(); - } - - // This is needed for opening lots of file descriptors, which tends to - // happen more often when using RocksDB and making lots of federation - // connections at startup. The soft limit is usually 1024, and the hard - // limit is usually 512000; I've personally seen it hit >2000. - // - // * https://www.freedesktop.org/software/systemd/man/systemd.exec.html#id-1.12.2.1.17.6 - // * https://github.com/systemd/systemd/commit/0abf94923b4a95a7d89bc526efc84e7ca2b71741 - #[cfg(unix)] - maximize_fd_limit().expect("should be able to increase the soft limit to the hard limit"); - - info!("Loading database"); - if let Err(error) = KeyValueDatabase::load_or_create(config).await { - error!(?error, "The database couldn't be loaded or created"); - - std::process::exit(1); - }; - let config = &services().globals.config; - - info!("Starting server"); - run_server().await.unwrap(); - - if config.allow_jaeger { - opentelemetry::global::shutdown_tracer_provider(); + } } }