diff --git a/src/api/client_server/account.rs b/src/api/client_server/account.rs index 298c2655..68d31191 100644 --- a/src/api/client_server/account.rs +++ b/src/api/client_server/account.rs @@ -18,7 +18,9 @@ use tracing::{error, info, warn}; use super::{DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH}; use crate::{ api::client_server::{self, join_room_by_id_helper}, - service, services, utils, Error, Result, Ruma, + service, services, + utils::{self, user_id::user_is_local}, + Error, Result, Ruma, }; const RANDOM_USER_ID_LENGTH: usize = 10; @@ -40,7 +42,7 @@ pub(crate) async fn get_register_available_route( // Validate user id let user_id = UserId::parse_with_server_name(body.username.to_lowercase(), services().globals.server_name()) .ok() - .filter(|user_id| !user_id.is_historical() && user_id.server_name() == services().globals.server_name()) + .filter(|user_id| !user_id.is_historical() && user_is_local(user_id)) .ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; // Check if username is creative enough @@ -125,9 +127,7 @@ pub(crate) async fn register_route(body: Ruma) -> Result< let proposed_user_id = UserId::parse_with_server_name(username.to_lowercase(), services().globals.server_name()) .ok() - .filter(|user_id| { - !user_id.is_historical() && user_id.server_name() == services().globals.server_name() - }) + .filter(|user_id| !user_id.is_historical() && user_is_local(user_id)) .ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; if services().users.exists(&proposed_user_id)? { diff --git a/src/api/client_server/alias.rs b/src/api/client_server/alias.rs index 1fb78792..02d004d7 100644 --- a/src/api/client_server/alias.rs +++ b/src/api/client_server/alias.rs @@ -12,13 +12,13 @@ use ruma::{ }; use tracing::debug; -use crate::{debug_info, debug_warn, services, Error, Result, Ruma}; +use crate::{debug_info, debug_warn, services, utils::server_name::server_is_ours, Error, Result, Ruma}; /// # `PUT /_matrix/client/v3/directory/room/{roomAlias}` /// /// Creates a new room alias on this server. pub(crate) async fn create_alias_route(body: Ruma) -> Result { - if body.room_alias.server_name() != services().globals.server_name() { + if !server_is_ours(body.room_alias.server_name()) { return Err(Error::BadRequest(ErrorKind::InvalidParam, "Alias is from another server.")); } @@ -73,7 +73,7 @@ pub(crate) async fn create_alias_route(body: Ruma) -> /// - TODO: additional access control checks /// - TODO: Update canonical alias event pub(crate) async fn delete_alias_route(body: Ruma) -> Result { - if body.room_alias.server_name() != services().globals.server_name() { + if !server_is_ours(body.room_alias.server_name()) { return Err(Error::BadRequest(ErrorKind::InvalidParam, "Alias is from another server.")); } @@ -126,7 +126,7 @@ pub(crate) async fn get_alias_helper( room_alias: OwnedRoomAliasId, servers: Option>, ) -> Result { debug!("get_alias_helper servers: {servers:?}"); - if room_alias.server_name() != services().globals.server_name() + if !server_is_ours(room_alias.server_name()) && (!servers .as_ref() .is_some_and(|servers| servers.contains(&services().globals.server_name().to_owned())) @@ -204,7 +204,7 @@ pub(crate) async fn get_alias_helper( // room alias server first if let Some(server_index) = servers .iter() - .position(|server| server == services().globals.server_name()) + .position(|server_name| server_is_ours(server_name)) { servers.remove(server_index); servers.insert(0, services().globals.server_name().to_owned()); @@ -277,7 +277,7 @@ pub(crate) async fn get_alias_helper( // insert our server as the very first choice if in list if let Some(server_index) = servers .iter() - .position(|server| server == services().globals.server_name()) + .position(|server_name| server_is_ours(server_name)) { servers.remove(server_index); servers.insert(0, services().globals.server_name().to_owned()); diff --git a/src/api/client_server/keys.rs b/src/api/client_server/keys.rs index 4527fddf..d0206dfc 100644 --- a/src/api/client_server/keys.rs +++ b/src/api/client_server/keys.rs @@ -20,7 +20,11 @@ use serde_json::json; use tracing::debug; use super::SESSION_ID_LENGTH; -use crate::{services, utils, Error, Result, Ruma}; +use crate::{ + services, + utils::{self, user_id::user_is_local}, + Error, Result, Ruma, +}; /// # `POST /_matrix/client/r0/keys/upload` /// @@ -260,7 +264,7 @@ pub(crate) async fn get_keys_helper bool>( for (user_id, device_ids) in device_keys_input { let user_id: &UserId = user_id; - if user_id.server_name() != services().globals.server_name() { + if !user_is_local(user_id) { get_over_federation .entry(user_id.server_name()) .or_insert_with(Vec::new) diff --git a/src/api/client_server/membership.rs b/src/api/client_server/membership.rs index 64ba5dce..225a021e 100644 --- a/src/api/client_server/membership.rs +++ b/src/api/client_server/membership.rs @@ -34,7 +34,9 @@ use tracing::{debug, error, info, trace, warn}; use super::get_alias_helper; use crate::{ service::pdu::{gen_event_id_canonical_json, PduBuilder}, - services, utils, Error, PduEvent, Result, Ruma, + services, + utils::{self, server_name::server_is_ours, user_id::user_is_local}, + Error, PduEvent, Result, Ruma, }; /// # `POST /_matrix/client/r0/rooms/{roomId}/join` @@ -1088,7 +1090,7 @@ pub(crate) async fn join_room_by_id_helper( .state_cache .room_members(room_id) .filter_map(Result::ok) - .filter(|user| user.server_name() == services().globals.server_name()) + .filter(|user| user_is_local(user)) .collect::>(); let mut authorized_user: Option = None; @@ -1150,7 +1152,7 @@ pub(crate) async fn join_room_by_id_helper( if !restriction_rooms.is_empty() && servers .iter() - .any(|s| *s != services().globals.server_name()) + .any(|server_name| !server_is_ours(server_name)) { info!( "We couldn't do the join locally, maybe federation can help to satisfy the restricted join \ @@ -1303,7 +1305,7 @@ async fn make_join_request( let mut incompatible_room_version_count = 0; for remote_server in servers { - if remote_server == services().globals.server_name() { + if server_is_ours(remote_server) { continue; } info!("Asking {remote_server} for make_join ({make_join_counter})"); @@ -1436,7 +1438,7 @@ pub(crate) async fn invite_helper( )); } - if user_id.server_name() != services().globals.server_name() { + if !user_is_local(user_id) { let (pdu, pdu_json, invite_room_state) = { let mutex_state = Arc::clone( services() diff --git a/src/api/server_server.rs b/src/api/server_server.rs index 0e1da9d5..5651cbde 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -55,7 +55,9 @@ use crate::{ api::client_server::{self, claim_keys_helper, get_keys_helper}, debug_error, service::pdu::{gen_event_id_canonical_json, PduBuilder}, - services, utils, Error, PduEvent, Result, Ruma, + services, + utils::{self, user_id::user_is_local}, + Error, PduEvent, Result, Ruma, }; /// # `GET /_matrix/federation/v1/version` @@ -978,7 +980,7 @@ pub(crate) async fn create_join_event_template_route( .state_cache .room_members(&body.room_id) .filter_map(Result::ok) - .filter(|user| user.server_name() == services().globals.server_name()) + .filter(|user| user_is_local(user)) .collect(); let mut auth_user = None; diff --git a/src/database/key_value/rooms/state_cache.rs b/src/database/key_value/rooms/state_cache.rs index 99fcfe85..5c5df762 100644 --- a/src/database/key_value/rooms/state_cache.rs +++ b/src/database/key_value/rooms/state_cache.rs @@ -11,7 +11,9 @@ use tracing::error; use crate::{ database::KeyValueDatabase, service::{self, appservice::RegistrationInfo}, - services, utils, Error, Result, + services, + utils::{self, user_id::user_is_local}, + Error, Result, }; type StrippedStateEventIter<'a> = Box>)>> + 'a>; @@ -149,9 +151,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { for joined in self.room_members(room_id).filter_map(Result::ok) { joined_servers.insert(joined.server_name().to_owned()); - if joined.server_name() == services().globals.server_name() - && !services().users.is_deactivated(&joined).unwrap_or(true) - { + if user_is_local(&joined) && !services().users.is_deactivated(&joined).unwrap_or(true) { real_users.insert(joined); } joinedcount += 1; diff --git a/src/main.rs b/src/main.rs index f5da0f40..47126628 100644 --- a/src/main.rs +++ b/src/main.rs @@ -24,7 +24,6 @@ use http::{ header::{self, HeaderName}, Method, StatusCode, Uri, }; -#[cfg(unix)] use ruma::api::client::{ error::{Error as RumaError, ErrorBody, ErrorKind}, uiaa::UiaaResponse, diff --git a/src/service/admin/room/room_moderation_commands.rs b/src/service/admin/room/room_moderation_commands.rs index 8fe6a583..5c62e360 100644 --- a/src/service/admin/room/room_moderation_commands.rs +++ b/src/service/admin/room/room_moderation_commands.rs @@ -9,7 +9,9 @@ use super::RoomModerationCommand; use crate::{ api::client_server::{get_alias_helper, leave_room}, service::admin::{escape_html, Service}, - services, Result, + services, + utils::user_id::user_is_local, + Result, }; pub(crate) async fn process(command: RoomModerationCommand, body: Vec<&str>) -> Result { @@ -102,11 +104,10 @@ pub(crate) async fn process(command: RoomModerationCommand, body: Vec<&str>) -> .room_members(&room_id) .filter_map(|user| { user.ok().filter(|local_user| { - local_user.server_name() == services().globals.server_name() + user_is_local(local_user) // additional wrapped check here is to avoid adding remote users // who are in the admin room to the list of local users (would fail auth check) - && (local_user.server_name() - == services().globals.server_name() + && (user_is_local(local_user) && services() .users .is_admin(local_user) diff --git a/src/service/presence/mod.rs b/src/service/presence/mod.rs index 4c71d013..d2d1bb39 100644 --- a/src/service/presence/mod.rs +++ b/src/service/presence/mod.rs @@ -13,7 +13,11 @@ use serde::{Deserialize, Serialize}; use tokio::{sync::Mutex, time::sleep}; use tracing::{debug, error}; -use crate::{services, utils, Config, Error, Result}; +use crate::{ + services, + utils::{self, user_id::user_is_local}, + Config, Error, Result, +}; /// Represents data required to be kept in order to implement the presence /// specification. @@ -154,7 +158,7 @@ impl Service { self.db .set_presence(user_id, presence_state, currently_active, last_active_ago, status_msg)?; - if self.timeout_remote_users || user_id.server_name() == services().globals.server_name() { + if self.timeout_remote_users || user_is_local(user_id) { let timeout = match presence_state { PresenceState::Online => services().globals.config.presence_idle_timeout_s, _ => services().globals.config.presence_offline_timeout_s, diff --git a/src/service/rooms/typing/mod.rs b/src/service/rooms/typing/mod.rs index da087835..3434c55a 100644 --- a/src/service/rooms/typing/mod.rs +++ b/src/service/rooms/typing/mod.rs @@ -8,7 +8,11 @@ use ruma::{ use tokio::sync::{broadcast, RwLock}; use tracing::debug; -use crate::{services, utils, Result}; +use crate::{ + services, + utils::{self, user_id::user_is_local}, + Result, +}; pub(crate) struct Service { pub(crate) typing: RwLock>>, // u64 is unix timestamp of timeout @@ -37,7 +41,7 @@ impl Service { _ = self.typing_update_sender.send(room_id.to_owned()); // update federation - if user_id.server_name() == services().globals.server_name() { + if user_is_local(user_id) { self.federation_send(room_id, user_id, true)?; } @@ -61,7 +65,7 @@ impl Service { _ = self.typing_update_sender.send(room_id.to_owned()); // update federation - if user_id.server_name() == services().globals.server_name() { + if user_is_local(user_id) { self.federation_send(room_id, user_id, false)?; } @@ -115,7 +119,7 @@ impl Service { // update federation for user in removable { - if user.server_name() == services().globals.server_name() { + if user_is_local(&user) { self.federation_send(room_id, &user, false)?; } } @@ -154,10 +158,7 @@ impl Service { } fn federation_send(&self, room_id: &RoomId, user_id: &UserId, typing: bool) -> Result<()> { - debug_assert!( - user_id.server_name() == services().globals.server_name(), - "tried to broadcast typing status of remote user", - ); + debug_assert!(user_is_local(user_id), "tried to broadcast typing status of remote user",); if !services().globals.config.allow_outgoing_typing { return Ok(()); } diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index d18101a9..5ec7ebad 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -23,7 +23,12 @@ use ruma::{ use tracing::{debug, error, warn}; use super::{appservice, send, Destination, Msg, SendingEvent, Service}; -use crate::{service::presence::Presence, services, utils::calculate_hash, Error, PduEvent, Result}; +use crate::{ + service::presence::Presence, + services, + utils::{calculate_hash, user_id::user_is_local}, + Error, PduEvent, Result, +}; #[derive(Debug)] enum TransactionStatus { @@ -244,7 +249,7 @@ impl Service { .users .keys_changed(room_id.as_ref(), since, None) .filter_map(Result::ok) - .filter(|user_id| user_id.server_name() == services().globals.server_name()), + .filter(|user_id| user_is_local(user_id)), ); if services().globals.allow_outgoing_read_receipts() diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 6b98baa1..8b988151 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,6 +1,8 @@ pub(crate) mod clap; pub(crate) mod debug; pub(crate) mod error; +pub(crate) mod server_name; +pub(crate) mod user_id; use std::{ cmp, diff --git a/src/utils/server_name.rs b/src/utils/server_name.rs new file mode 100644 index 00000000..11303f9a --- /dev/null +++ b/src/utils/server_name.rs @@ -0,0 +1,8 @@ +//! utilities for doing/checking things with ServerName's/server_name's + +use ruma::ServerName; + +use crate::services; + +/// checks if `server_name` is ours +pub(crate) fn server_is_ours(server_name: &ServerName) -> bool { server_name == services().globals.config.server_name } diff --git a/src/utils/user_id.rs b/src/utils/user_id.rs new file mode 100644 index 00000000..ae312792 --- /dev/null +++ b/src/utils/user_id.rs @@ -0,0 +1,8 @@ +//! utilities for doing things with UserId's / usernames + +use ruma::UserId; + +use crate::services; + +/// checks if `user_id` is local to us via server_name comparison +pub(crate) fn user_is_local(user_id: &UserId) -> bool { user_id.server_name() == services().globals.config.server_name }