feat: close registration with ROCKET_REGISTRATION_DISABLED=true

This commit is contained in:
timokoesters 2020-06-06 19:02:31 +02:00
parent c85d363d71
commit 0067f49d52
No known key found for this signature in database
GPG key ID: 24DA7517711A2BA4
3 changed files with 27 additions and 13 deletions

View file

@ -117,6 +117,14 @@ pub fn register_route(
db: State<'_, Database>,
body: Ruma<register::Request>,
) -> MatrixResult<register::Response, UiaaResponse> {
if db.globals.registration_disabled() {
return MatrixResult(Err(UiaaResponse::MatrixError(Error {
kind: ErrorKind::Unknown,
message: "Registration has been disabled.".to_owned(),
status_code: http::StatusCode::FORBIDDEN,
})));
}
// Validate user id
let user_id = match UserId::parse_with_server_name(
body.username

View file

@ -7,6 +7,7 @@ pub(self) mod uiaa;
pub(self) mod users;
use directories::ProjectDirs;
use log::info;
use std::fs::remove_dir_all;
use rocket::Config;
@ -49,13 +50,10 @@ impl Database {
});
let db = sled::open(&path).unwrap();
log::info!("Opened sled database at {}", path);
info!("Opened sled database at {}", path);
Self {
globals: globals::Globals::load(
db.open_tree("global").unwrap(),
server_name.to_owned(),
),
globals: globals::Globals::load(db.open_tree("global").unwrap(), config),
users: users::Users {
userid_password: db.open_tree("userid_password").unwrap(),
userid_displayname: db.open_tree("userid_displayname").unwrap(),

View file

@ -4,13 +4,14 @@ pub const COUNTER: &str = "c";
pub struct Globals {
pub(super) globals: sled::Tree,
server_name: String,
keypair: ruma::signatures::Ed25519KeyPair,
reqwest_client: reqwest::Client,
server_name: String,
registration_disabled: bool,
}
impl Globals {
pub fn load(globals: sled::Tree, server_name: String) -> Self {
pub fn load(globals: sled::Tree, config: &rocket::Config) -> Self {
let keypair = ruma::signatures::Ed25519KeyPair::new(
&*globals
.update_and_fetch("keypair", utils::generate_keypair)
@ -22,17 +23,16 @@ impl Globals {
Self {
globals,
server_name,
keypair,
reqwest_client: reqwest::Client::new(),
server_name: config
.get_str("server_name")
.unwrap_or("localhost")
.to_owned(),
registration_disabled: config.get_bool("registration_disabled").unwrap_or(false),
}
}
/// Returns the server_name of the server.
pub fn server_name(&self) -> &str {
&self.server_name
}
/// Returns this server's keypair.
pub fn keypair(&self) -> &ruma::signatures::Ed25519KeyPair {
&self.keypair
@ -58,4 +58,12 @@ impl Globals {
.get(COUNTER)?
.map_or(0_u64, |bytes| utils::u64_from_bytes(&bytes)))
}
pub fn server_name(&self) -> &str {
&self.server_name
}
pub fn registration_disabled(&self) -> bool {
self.registration_disabled
}
}