de-global services for services

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-07-18 06:37:47 +00:00
parent 992c0a1e58
commit 010e4ee35a
85 changed files with 2480 additions and 1887 deletions

View file

@ -15,7 +15,7 @@ use ruma::{
events::room::message::RoomMessageEventContent,
CanonicalJsonObject, EventId, OwnedRoomOrAliasId, RoomId, RoomVersionId, ServerName,
};
use service::{rooms::event_handler::parse_incoming_pdu, services, PduEvent};
use service::services;
use tokio::sync::RwLock;
use tracing_subscriber::EnvFilter;
@ -189,7 +189,10 @@ pub(super) async fn get_remote_pdu(
debug!("Attempting to parse PDU: {:?}", &response.pdu);
let parsed_pdu = {
let parsed_result = parse_incoming_pdu(&response.pdu);
let parsed_result = services()
.rooms
.event_handler
.parse_incoming_pdu(&response.pdu);
let (event_id, value, room_id) = match parsed_result {
Ok(t) => t,
Err(e) => {
@ -510,7 +513,7 @@ pub(super) async fn force_set_room_state_from_server(
let mut events = Vec::with_capacity(remote_state_response.pdus.len());
for pdu in remote_state_response.pdus.clone() {
events.push(match parse_incoming_pdu(&pdu) {
events.push(match services().rooms.event_handler.parse_incoming_pdu(&pdu) {
Ok(t) => t,
Err(e) => {
warn!("Could not parse PDU, ignoring: {e}");

View file

@ -1,3 +1,4 @@
#![recursion_limit = "168"]
#![allow(clippy::wildcard_imports)]
pub(crate) mod appservice;

View file

@ -22,7 +22,11 @@ pub(crate) async fn create_alias_route(
) -> Result<create_alias::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
service::rooms::alias::appservice_checks(&body.room_alias, &body.appservice_info).await?;
services
.rooms
.alias
.appservice_checks(&body.room_alias, &body.appservice_info)
.await?;
// this isn't apart of alias_checks or delete alias route because we should
// allow removing forbidden room aliases
@ -61,7 +65,11 @@ pub(crate) async fn delete_alias_route(
) -> Result<delete_alias::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
service::rooms::alias::appservice_checks(&body.room_alias, &body.appservice_info).await?;
services
.rooms
.alias
.appservice_checks(&body.room_alias, &body.appservice_info)
.await?;
if services
.rooms

View file

@ -43,7 +43,6 @@ use crate::{
service::{
pdu::{gen_event_id_canonical_json, PduBuilder},
rooms::state::RoomMutexGuard,
sending::convert_to_outgoing_federation_event,
server_is_ours, user_is_local, Services,
},
Ruma,
@ -791,7 +790,9 @@ async fn join_room_by_id_helper_remote(
federation::membership::create_join_event::v2::Request {
room_id: room_id.to_owned(),
event_id: event_id.to_owned(),
pdu: convert_to_outgoing_federation_event(join_event.clone()),
pdu: services
.sending
.convert_to_outgoing_federation_event(join_event.clone()),
omit_members: false,
},
)
@ -1203,7 +1204,9 @@ async fn join_room_by_id_helper_local(
federation::membership::create_join_event::v2::Request {
room_id: room_id.to_owned(),
event_id: event_id.to_owned(),
pdu: convert_to_outgoing_federation_event(join_event.clone()),
pdu: services
.sending
.convert_to_outgoing_federation_event(join_event.clone()),
omit_members: false,
},
)
@ -1431,7 +1434,9 @@ pub(crate) async fn invite_helper(
room_id: room_id.to_owned(),
event_id: (*pdu.event_id).to_owned(),
room_version: room_version_id.clone(),
event: convert_to_outgoing_federation_event(pdu_json.clone()),
event: services
.sending
.convert_to_outgoing_federation_event(pdu_json.clone()),
invite_room_state,
via: services.rooms.state_cache.servers_route_via(room_id).ok(),
},
@ -1763,7 +1768,9 @@ async fn remote_leave_room(services: &Services, user_id: &UserId, room_id: &Room
federation::membership::create_leave_event::v2::Request {
room_id: room_id.to_owned(),
event_id,
pdu: convert_to_outgoing_federation_event(leave_event.clone()),
pdu: services
.sending
.convert_to_outgoing_federation_event(leave_event.clone()),
},
)
.await?;

View file

@ -475,8 +475,6 @@ async fn handle_left_room(
async fn process_presence_updates(
services: &Services, presence_updates: &mut HashMap<OwnedUserId, PresenceEvent>, since: u64, syncing_user: &UserId,
) -> Result<()> {
use crate::service::presence::Presence;
// Take presence updates
for (user_id, _, presence_bytes) in services.presence.presence_since(since) {
if !services
@ -487,7 +485,9 @@ async fn process_presence_updates(
continue;
}
let presence_event = Presence::from_json_bytes_to_event(&presence_bytes, &user_id)?;
let presence_event = services
.presence
.from_json_bytes_to_event(&presence_bytes, &user_id)?;
match presence_updates.entry(user_id) {
Entry::Vacant(slot) => {
slot.insert(presence_event);

View file

@ -1,3 +1,5 @@
#![recursion_limit = "160"]
pub mod client;
pub mod router;
pub mod server;

View file

@ -4,7 +4,6 @@ use ruma::{
api::{client::error::ErrorKind, federation::backfill::get_backfill},
uint, user_id, MilliSecondsSinceUnixEpoch,
};
use service::sending::convert_to_outgoing_federation_event;
use crate::Ruma;
@ -67,7 +66,7 @@ pub(crate) async fn get_backfill_route(
})
.map(|(_, pdu)| services.rooms.timeline.get_pdu_json(&pdu.event_id))
.filter_map(|r| r.ok().flatten())
.map(convert_to_outgoing_federation_event)
.map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu))
.collect();
Ok(get_backfill::v1::Response {

View file

@ -4,7 +4,6 @@ use ruma::{
api::{client::error::ErrorKind, federation::event::get_event},
MilliSecondsSinceUnixEpoch, RoomId,
};
use service::sending::convert_to_outgoing_federation_event;
use crate::Ruma;
@ -50,6 +49,6 @@ pub(crate) async fn get_event_route(
Ok(get_event::v1::Response {
origin: services.globals.server_name().to_owned(),
origin_server_ts: MilliSecondsSinceUnixEpoch::now(),
pdu: convert_to_outgoing_federation_event(event),
pdu: services.sending.convert_to_outgoing_federation_event(event),
})
}

View file

@ -6,7 +6,6 @@ use ruma::{
api::{client::error::ErrorKind, federation::authorization::get_event_authorization},
RoomId,
};
use service::sending::convert_to_outgoing_federation_event;
use crate::Ruma;
@ -60,7 +59,7 @@ pub(crate) async fn get_event_authorization_route(
Ok(get_event_authorization::v1::Response {
auth_chain: auth_chain_ids
.filter_map(|id| services.rooms.timeline.get_pdu_json(&id).ok()?)
.map(convert_to_outgoing_federation_event)
.map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu))
.collect(),
})
}

View file

@ -4,7 +4,6 @@ use ruma::{
api::{client::error::ErrorKind, federation::event::get_missing_events},
OwnedEventId, RoomId,
};
use service::sending::convert_to_outgoing_federation_event;
use crate::Ruma;
@ -82,7 +81,7 @@ pub(crate) async fn get_missing_events_route(
)
.map_err(|_| Error::bad_database("Invalid prev_events in event in database."))?,
);
events.push(convert_to_outgoing_federation_event(pdu));
events.push(services.sending.convert_to_outgoing_federation_event(pdu));
}
i = i.saturating_add(1);
}

View file

@ -7,7 +7,7 @@ use ruma::{
serde::JsonObject,
CanonicalJsonValue, EventId, OwnedUserId,
};
use service::{sending::convert_to_outgoing_federation_event, server_is_ours};
use service::server_is_ours;
use crate::Ruma;
@ -174,6 +174,8 @@ pub(crate) async fn create_invite_route(
}
Ok(create_invite::v2::Response {
event: convert_to_outgoing_federation_event(signed_event),
event: services
.sending
.convert_to_outgoing_federation_event(signed_event),
})
}

View file

@ -21,7 +21,6 @@ use ruma::{
use tokio::sync::RwLock;
use crate::{
service::rooms::event_handler::parse_incoming_pdu,
services::Services,
utils::{self},
Error, Result, Ruma,
@ -89,7 +88,7 @@ async fn handle_pdus(
) -> Result<ResolvedMap> {
let mut parsed_pdus = Vec::with_capacity(body.pdus.len());
for pdu in &body.pdus {
parsed_pdus.push(match parse_incoming_pdu(pdu) {
parsed_pdus.push(match services.rooms.event_handler.parse_incoming_pdu(pdu) {
Ok(t) => t,
Err(e) => {
debug_warn!("Could not parse PDU: {e}");

View file

@ -13,9 +13,7 @@ use ruma::{
CanonicalJsonValue, OwnedServerName, OwnedUserId, RoomId, ServerName,
};
use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
use service::{
pdu::gen_event_id_canonical_json, sending::convert_to_outgoing_federation_event, user_is_local, Services,
};
use service::{pdu::gen_event_id_canonical_json, user_is_local, Services};
use tokio::sync::RwLock;
use tracing::warn;
@ -186,12 +184,12 @@ async fn create_join_event(
Ok(create_join_event::v1::RoomState {
auth_chain: auth_chain_ids
.filter_map(|id| services.rooms.timeline.get_pdu_json(&id).ok().flatten())
.map(convert_to_outgoing_federation_event)
.map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu))
.collect(),
state: state_ids
.iter()
.filter_map(|(_, id)| services.rooms.timeline.get_pdu_json(id).ok().flatten())
.map(convert_to_outgoing_federation_event)
.map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu))
.collect(),
// Event field is required if the room version supports restricted join rules.
event: Some(

View file

@ -3,7 +3,6 @@ use std::sync::Arc;
use axum::extract::State;
use conduit::{Error, Result};
use ruma::api::{client::error::ErrorKind, federation::event::get_room_state};
use service::sending::convert_to_outgoing_federation_event;
use crate::Ruma;
@ -44,7 +43,11 @@ pub(crate) async fn get_room_state_route(
.state_full_ids(shortstatehash)
.await?
.into_values()
.map(|id| convert_to_outgoing_federation_event(services.rooms.timeline.get_pdu_json(&id).unwrap().unwrap()))
.map(|id| {
services
.sending
.convert_to_outgoing_federation_event(services.rooms.timeline.get_pdu_json(&id).unwrap().unwrap())
})
.collect();
let auth_chain_ids = services
@ -61,7 +64,7 @@ pub(crate) async fn get_room_state_route(
.timeline
.get_pdu_json(&id)
.ok()?
.map(convert_to_outgoing_federation_event)
.map(|pdu| services.sending.convert_to_outgoing_federation_event(pdu))
})
.collect(),
pdus,

View file

@ -1,3 +1,5 @@
#![recursion_limit = "160"]
mod layers;
mod request;
mod router;

View file

@ -1,7 +1,7 @@
use std::{collections::HashMap, sync::Arc};
use conduit::{utils, warn, Error, Result};
use database::{Database, Map};
use database::Map;
use ruma::{
api::client::error::ErrorKind,
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType},
@ -9,18 +9,27 @@ use ruma::{
RoomId, UserId,
};
use crate::services;
use crate::{globals, Dep};
pub(super) struct Data {
roomuserdataid_accountdata: Arc<Map>,
roomusertype_roomuserdataid: Arc<Map>,
services: Services,
}
struct Services {
globals: Dep<globals::Service>,
}
impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self {
pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self {
roomuserdataid_accountdata: db["roomuserdataid_accountdata"].clone(),
roomusertype_roomuserdataid: db["roomusertype_roomuserdataid"].clone(),
services: Services {
globals: args.depend::<globals::Service>("globals"),
},
}
}
@ -40,7 +49,7 @@ impl Data {
prefix.push(0xFF);
let mut roomuserdataid = prefix.clone();
roomuserdataid.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
roomuserdataid.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes());
roomuserdataid.push(0xFF);
roomuserdataid.extend_from_slice(event_type.to_string().as_bytes());

View file

@ -17,7 +17,7 @@ pub struct Service {
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db: Data::new(args.db),
db: Data::new(&args),
}))
}

View file

@ -11,10 +11,11 @@ use rustyline_async::{Readline, ReadlineError, ReadlineEvent};
use termimad::MadSkin;
use tokio::task::JoinHandle;
use crate::services;
use crate::{admin, Dep};
pub struct Console {
server: Arc<Server>,
admin: Dep<admin::Service>,
worker_join: Mutex<Option<JoinHandle<()>>>,
input_abort: Mutex<Option<AbortHandle>>,
command_abort: Mutex<Option<AbortHandle>>,
@ -29,6 +30,7 @@ impl Console {
pub(super) fn new(args: &crate::Args<'_>) -> Arc<Self> {
Arc::new(Self {
server: args.server.clone(),
admin: args.depend::<admin::Service>("admin"),
worker_join: None.into(),
input_abort: None.into(),
command_abort: None.into(),
@ -116,7 +118,8 @@ impl Console {
let _suppression = log::Suppress::new(&self.server);
let (mut readline, _writer) = Readline::new(PROMPT.to_owned())?;
readline.set_tab_completer(Self::tab_complete);
let self_ = Arc::clone(self);
readline.set_tab_completer(move |line| self_.tab_complete(line));
self.set_history(&mut readline);
let future = readline.readline();
@ -154,7 +157,7 @@ impl Console {
}
async fn process(self: Arc<Self>, line: String) {
match services().admin.command_in_place(line, None).await {
match self.admin.command_in_place(line, None).await {
Ok(Some(content)) => self.output(content).await,
Err(e) => error!("processing command: {e}"),
_ => (),
@ -184,9 +187,8 @@ impl Console {
history.truncate(HISTORY_LIMIT);
}
fn tab_complete(line: &str) -> String {
services()
.admin
fn tab_complete(&self, line: &str) -> String {
self.admin
.complete_command(line)
.unwrap_or_else(|| line.to_owned())
}

View file

@ -25,13 +25,14 @@ impl super::Service {
return Ok(());
};
let state_lock = self.state.mutex.lock(&room_id).await;
let state_lock = self.services.state.mutex.lock(&room_id).await;
// Use the server user to grant the new admin's power level
let server_user = &self.globals.server_user;
let server_user = &self.services.globals.server_user;
// Invite and join the real user
self.timeline
self.services
.timeline
.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomMember,
@ -55,7 +56,8 @@ impl super::Service {
&state_lock,
)
.await?;
self.timeline
self.services
.timeline
.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomMember,
@ -85,7 +87,8 @@ impl super::Service {
users.insert(server_user.clone(), 100.into());
users.insert(user_id.to_owned(), 100.into());
self.timeline
self.services
.timeline
.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomPowerLevels,
@ -105,12 +108,12 @@ impl super::Service {
.await?;
// Send welcome message
self.timeline.build_and_append_pdu(
self.services.timeline.build_and_append_pdu(
PduBuilder {
event_type: TimelineEventType::RoomMessage,
content: to_raw_value(&RoomMessageEventContent::text_html(
format!("## Thank you for trying out conduwuit!\n\nconduwuit is a fork of upstream Conduit which is in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.\n\nHelpful links:\n> Git and Documentation: https://github.com/girlbossceo/conduwuit\n> Report issues: https://github.com/girlbossceo/conduwuit/issues\n\nFor a list of available commands, send the following message in this room: `@conduit:{}: --help`\n\nHere are some rooms you can join (by typing the command):\n\nconduwuit room (Ask questions and get notified on updates):\n`/join #conduwuit:puppygock.gay`", self.globals.server_name()),
format!("<h2>Thank you for trying out conduwuit!</h2>\n<p>conduwuit is a fork of upstream Conduit which is in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.</p>\n<p>Helpful links:</p>\n<blockquote>\n<p>Git and Documentation: https://github.com/girlbossceo/conduwuit<br>Report issues: https://github.com/girlbossceo/conduwuit/issues</p>\n</blockquote>\n<p>For a list of available commands, send the following message in this room: <code>@conduit:{}: --help</code></p>\n<p>Here are some rooms you can join (by typing the command):</p>\n<p>conduwuit room (Ask questions and get notified on updates):<br><code>/join #conduwuit:puppygock.gay</code></p>\n", self.globals.server_name()),
format!("## Thank you for trying out conduwuit!\n\nconduwuit is a fork of upstream Conduit which is in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.\n\nHelpful links:\n> Git and Documentation: https://github.com/girlbossceo/conduwuit\n> Report issues: https://github.com/girlbossceo/conduwuit/issues\n\nFor a list of available commands, send the following message in this room: `@conduit:{}: --help`\n\nHere are some rooms you can join (by typing the command):\n\nconduwuit room (Ask questions and get notified on updates):\n`/join #conduwuit:puppygock.gay`", self.services.globals.server_name()),
format!("<h2>Thank you for trying out conduwuit!</h2>\n<p>conduwuit is a fork of upstream Conduit which is in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.</p>\n<p>Helpful links:</p>\n<blockquote>\n<p>Git and Documentation: https://github.com/girlbossceo/conduwuit<br>Report issues: https://github.com/girlbossceo/conduwuit/issues</p>\n</blockquote>\n<p>For a list of available commands, send the following message in this room: <code>@conduit:{}: --help</code></p>\n<p>Here are some rooms you can join (by typing the command):</p>\n<p>conduwuit room (Ask questions and get notified on updates):<br><code>/join #conduwuit:puppygock.gay</code></p>\n", self.services.globals.server_name()),
))
.expect("event is valid, we just created it"),
unsigned: None,

View file

@ -22,15 +22,10 @@ use ruma::{
use serde_json::value::to_raw_value;
use tokio::sync::{Mutex, RwLock};
use crate::{globals, rooms, rooms::state::RoomMutexGuard, user_is_local};
use crate::{globals, rooms, rooms::state::RoomMutexGuard, user_is_local, Dep};
pub struct Service {
server: Arc<Server>,
globals: Arc<globals::Service>,
alias: Arc<rooms::alias::Service>,
timeline: Arc<rooms::timeline::Service>,
state: Arc<rooms::state::Service>,
state_cache: Arc<rooms::state_cache::Service>,
services: Services,
sender: Sender<Command>,
receiver: Mutex<Receiver<Command>>,
pub handle: RwLock<Option<Handler>>,
@ -39,6 +34,15 @@ pub struct Service {
pub console: Arc<console::Console>,
}
struct Services {
server: Arc<Server>,
globals: Dep<globals::Service>,
alias: Dep<rooms::alias::Service>,
timeline: Dep<rooms::timeline::Service>,
state: Dep<rooms::state::Service>,
state_cache: Dep<rooms::state_cache::Service>,
}
#[derive(Debug)]
pub struct Command {
pub command: String,
@ -58,12 +62,14 @@ impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
let (sender, receiver) = loole::bounded(COMMAND_QUEUE_LIMIT);
Ok(Arc::new(Self {
server: args.server.clone(),
globals: args.require_service::<globals::Service>("globals"),
alias: args.require_service::<rooms::alias::Service>("rooms::alias"),
timeline: args.require_service::<rooms::timeline::Service>("rooms::timeline"),
state: args.require_service::<rooms::state::Service>("rooms::state"),
state_cache: args.require_service::<rooms::state_cache::Service>("rooms::state_cache"),
services: Services {
server: args.server.clone(),
globals: args.depend::<globals::Service>("globals"),
alias: args.depend::<rooms::alias::Service>("rooms::alias"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
state: args.depend::<rooms::state::Service>("rooms::state"),
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
},
sender,
receiver: Mutex::new(receiver),
handle: RwLock::new(None),
@ -75,7 +81,7 @@ impl crate::Service for Service {
async fn worker(self: Arc<Self>) -> Result<()> {
let receiver = self.receiver.lock().await;
let mut signals = self.server.signal.subscribe();
let mut signals = self.services.server.signal.subscribe();
loop {
tokio::select! {
command = receiver.recv_async() => match command {
@ -116,7 +122,7 @@ impl Service {
pub async fn send_message(&self, message_content: RoomMessageEventContent) {
if let Ok(Some(room_id)) = self.get_admin_room() {
let user_id = &self.globals.server_user;
let user_id = &self.services.globals.server_user;
self.respond_to_room(message_content, &room_id, user_id)
.await;
}
@ -176,7 +182,7 @@ impl Service {
/// Checks whether a given user is an admin of this server
pub async fn user_is_admin(&self, user_id: &UserId) -> Result<bool> {
if let Ok(Some(admin_room)) = self.get_admin_room() {
self.state_cache.is_joined(user_id, &admin_room)
self.services.state_cache.is_joined(user_id, &admin_room)
} else {
Ok(false)
}
@ -187,10 +193,15 @@ impl Service {
/// Errors are propagated from the database, and will have None if there is
/// no admin room
pub fn get_admin_room(&self) -> Result<Option<OwnedRoomId>> {
if let Some(room_id) = self.alias.resolve_local_alias(&self.globals.admin_alias)? {
if let Some(room_id) = self
.services
.alias
.resolve_local_alias(&self.services.globals.admin_alias)?
{
if self
.services
.state_cache
.is_joined(&self.globals.server_user, &room_id)?
.is_joined(&self.services.globals.server_user, &room_id)?
{
return Ok(Some(room_id));
}
@ -207,12 +218,12 @@ impl Service {
return;
};
let Ok(Some(pdu)) = self.timeline.get_pdu(&in_reply_to.event_id) else {
let Ok(Some(pdu)) = self.services.timeline.get_pdu(&in_reply_to.event_id) else {
return;
};
let response_sender = if self.is_admin_room(&pdu.room_id) {
&self.globals.server_user
&self.services.globals.server_user
} else {
&pdu.sender
};
@ -229,7 +240,7 @@ impl Service {
"sender is not admin"
);
let state_lock = self.state.mutex.lock(room_id).await;
let state_lock = self.services.state.mutex.lock(room_id).await;
let response_pdu = PduBuilder {
event_type: TimelineEventType::RoomMessage,
content: to_raw_value(&content).expect("event is valid, we just created it"),
@ -239,6 +250,7 @@ impl Service {
};
if let Err(e) = self
.services
.timeline
.build_and_append_pdu(response_pdu, user_id, room_id, &state_lock)
.await
@ -266,7 +278,8 @@ impl Service {
redacts: None,
};
self.timeline
self.services
.timeline
.build_and_append_pdu(response_pdu, user_id, room_id, state_lock)
.await?;
@ -279,7 +292,7 @@ impl Service {
let is_public_escape = is_escape && body.trim_start_matches('\\').starts_with("!admin");
// Admin command with public echo (in admin room)
let server_user = &self.globals.server_user;
let server_user = &self.services.globals.server_user;
let is_public_prefix = body.starts_with("!admin") || body.starts_with(server_user.as_str());
// Expected backward branch
@ -293,7 +306,7 @@ impl Service {
}
// Check if server-side command-escape is disabled by configuration
if is_public_escape && !self.globals.config.admin_escape_commands {
if is_public_escape && !self.services.globals.config.admin_escape_commands {
return false;
}
@ -309,7 +322,7 @@ impl Service {
// This will evaluate to false if the emergency password is set up so that
// the administrator can execute commands as conduit
let emergency_password_set = self.globals.emergency_password().is_some();
let emergency_password_set = self.services.globals.emergency_password().is_some();
let from_server = pdu.sender == *server_user && !emergency_password_set;
if from_server && self.is_admin_room(&pdu.room_id) {
return false;

View file

@ -12,7 +12,7 @@ use ruma::{
};
use tokio::sync::RwLock;
use crate::services;
use crate::{sending, Dep};
/// Compiled regular expressions for a namespace
#[derive(Clone, Debug)]
@ -118,9 +118,14 @@ impl TryFrom<Registration> for RegistrationInfo {
pub struct Service {
pub db: Data,
services: Services,
registration_info: RwLock<BTreeMap<String, RegistrationInfo>>,
}
struct Services {
sending: Dep<sending::Service>,
}
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
let mut registration_info = BTreeMap::new();
@ -138,6 +143,9 @@ impl crate::Service for Service {
Ok(Arc::new(Self {
db,
services: Services {
sending: args.depend::<sending::Service>("sending"),
},
registration_info: RwLock::new(registration_info),
}))
}
@ -178,7 +186,9 @@ impl Service {
// deletes all active requests for the appservice if there are any so we stop
// sending to the URL
services().sending.cleanup_events(service_name.to_owned())?;
self.services
.sending
.cleanup_events(service_name.to_owned())?;
Ok(())
}

View file

@ -18,7 +18,7 @@ pub struct Service {
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
let config = &args.server.config;
let resolver = args.require_service::<resolver::Service>("resolver");
let resolver = args.require::<resolver::Service>("resolver");
Ok(Arc::new(Self {
default: base(config)

View file

@ -0,0 +1,83 @@
use std::sync::Arc;
use async_trait::async_trait;
use conduit::{error, warn, Result};
use ruma::{
events::{push_rules::PushRulesEventContent, GlobalAccountDataEvent, GlobalAccountDataEventType},
push::Ruleset,
};
use crate::{account_data, globals, users, Dep};
pub struct Service {
services: Services,
}
struct Services {
account_data: Dep<account_data::Service>,
globals: Dep<globals::Service>,
users: Dep<users::Service>,
}
#[async_trait]
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
services: Services {
account_data: args.depend::<account_data::Service>("account_data"),
globals: args.depend::<globals::Service>("globals"),
users: args.depend::<users::Service>("users"),
},
}))
}
async fn worker(self: Arc<Self>) -> Result<()> {
self.set_emergency_access()
.inspect_err(|e| error!("Could not set the configured emergency password for the conduit user: {e}"))?;
Ok(())
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
impl Service {
/// Sets the emergency password and push rules for the @conduit account in
/// case emergency password is set
fn set_emergency_access(&self) -> Result<bool> {
let conduit_user = &self.services.globals.server_user;
self.services
.users
.set_password(conduit_user, self.services.globals.emergency_password().as_deref())?;
let (ruleset, pwd_set) = match self.services.globals.emergency_password() {
Some(_) => (Ruleset::server_default(conduit_user), true),
None => (Ruleset::new(), false),
};
self.services.account_data.update(
None,
conduit_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(&GlobalAccountDataEvent {
content: PushRulesEventContent {
global: ruleset,
},
})
.expect("to json value always works"),
)?;
if pwd_set {
warn!(
"The server account emergency password is set! Please unset it as soon as you finish admin account \
recovery! You will be logged out of the server service account when you finish."
);
} else {
// logs out any users still in the server service account and removes sessions
self.services.users.deactivate_account(conduit_user)?;
}
Ok(pwd_set)
}
}

View file

@ -3,7 +3,7 @@ use std::{
sync::{Arc, RwLock},
};
use conduit::{trace, utils, Error, Result};
use conduit::{trace, utils, Error, Result, Server};
use database::{Database, Map};
use futures_util::{stream::FuturesUnordered, StreamExt};
use ruma::{
@ -12,7 +12,7 @@ use ruma::{
DeviceId, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, ServerName, UserId,
};
use crate::services;
use crate::{rooms, Dep};
pub struct Data {
global: Arc<Map>,
@ -28,14 +28,23 @@ pub struct Data {
server_signingkeys: Arc<Map>,
readreceiptid_readreceipt: Arc<Map>,
userid_lastonetimekeyupdate: Arc<Map>,
pub(super) db: Arc<Database>,
counter: RwLock<u64>,
pub(super) db: Arc<Database>,
services: Services,
}
struct Services {
server: Arc<Server>,
short: Dep<rooms::short::Service>,
state_cache: Dep<rooms::state_cache::Service>,
typing: Dep<rooms::typing::Service>,
}
const COUNTER: &[u8] = b"c";
impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self {
pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self {
global: db["global"].clone(),
todeviceid_events: db["todeviceid_events"].clone(),
@ -50,8 +59,14 @@ impl Data {
server_signingkeys: db["server_signingkeys"].clone(),
readreceiptid_readreceipt: db["readreceiptid_readreceipt"].clone(),
userid_lastonetimekeyupdate: db["userid_lastonetimekeyupdate"].clone(),
db: db.clone(),
counter: RwLock::new(Self::stored_count(&db["global"]).expect("initialized global counter")),
db: args.db.clone(),
services: Services {
server: args.server.clone(),
short: args.depend::<rooms::short::Service>("rooms::short"),
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
typing: args.depend::<rooms::typing::Service>("rooms::typing"),
},
}
}
@ -118,14 +133,14 @@ impl Data {
futures.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix));
// Events for rooms we are in
for room_id in services()
.rooms
for room_id in self
.services
.state_cache
.rooms_joined(user_id)
.filter_map(Result::ok)
{
let short_roomid = services()
.rooms
let short_roomid = self
.services
.short
.get_shortroomid(&room_id)
.ok()
@ -143,7 +158,7 @@ impl Data {
// EDUs
futures.push(Box::pin(async move {
let _result = services().rooms.typing.wait_for_update(&room_id).await;
let _result = self.services.typing.wait_for_update(&room_id).await;
}));
futures.push(self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix));
@ -176,12 +191,12 @@ impl Data {
futures.push(self.userid_lastonetimekeyupdate.watch_prefix(&userid_bytes));
futures.push(Box::pin(async move {
while services().server.running() {
let _result = services().server.signal.subscribe().recv().await;
while self.services.server.running() {
let _result = self.services.server.signal.subscribe().recv().await;
}
}));
if !services().server.running() {
if !self.services.server.running() {
return Ok(());
}

View file

@ -1,54 +0,0 @@
use conduit::Result;
use ruma::{
events::{push_rules::PushRulesEventContent, GlobalAccountDataEvent, GlobalAccountDataEventType},
push::Ruleset,
};
use tracing::{error, warn};
use crate::services;
/// Set emergency access for the conduit user
pub(crate) fn init_emergency_access() {
if let Err(e) = set_emergency_access() {
error!("Could not set the configured emergency password for the conduit user: {e}");
}
}
/// Sets the emergency password and push rules for the @conduit account in case
/// emergency password is set
fn set_emergency_access() -> Result<bool> {
let conduit_user = &services().globals.server_user;
services()
.users
.set_password(conduit_user, services().globals.emergency_password().as_deref())?;
let (ruleset, pwd_set) = match services().globals.emergency_password() {
Some(_) => (Ruleset::server_default(conduit_user), true),
None => (Ruleset::new(), false),
};
services().account_data.update(
None,
conduit_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(&GlobalAccountDataEvent {
content: PushRulesEventContent {
global: ruleset,
},
})
.expect("to json value always works"),
)?;
if pwd_set {
warn!(
"The server account emergency password is set! Please unset it as soon as you finish admin account \
recovery! You will be logged out of the server service account when you finish."
);
} else {
// logs out any users still in the server service account and removes sessions
services().users.deactivate_account(conduit_user)?;
}
Ok(pwd_set)
}

View file

@ -10,7 +10,6 @@ use std::{
};
use conduit::{debug, debug_info, debug_warn, error, info, utils, warn, Config, Error, Result};
use database::Database;
use itertools::Itertools;
use ruma::{
events::{push_rules::PushRulesEvent, room::member::MembershipState, GlobalAccountDataEventType},
@ -18,7 +17,7 @@ use ruma::{
EventId, OwnedRoomId, RoomId, UserId,
};
use crate::services;
use crate::Services;
/// The current schema version.
/// - If database is opened at greater version we reject with error. The
@ -28,13 +27,13 @@ use crate::services;
/// equal or lesser version. These are expected to be backward-compatible.
const DATABASE_VERSION: u64 = 13;
pub(crate) async fn migrations(db: &Arc<Database>, config: &Config) -> Result<()> {
pub(crate) async fn migrations(services: &Services) -> Result<()> {
// Matrix resource ownership is based on the server name; changing it
// requires recreating the database from scratch.
if services().users.count()? > 0 {
let conduit_user = &services().globals.server_user;
if services.users.count()? > 0 {
let conduit_user = &services.globals.server_user;
if !services().users.exists(conduit_user)? {
if !services.users.exists(conduit_user)? {
error!("The {} server user does not exist, and the database is not new.", conduit_user);
return Err(Error::bad_database(
"Cannot reuse an existing database after changing the server name, please delete the old one first.",
@ -42,15 +41,18 @@ pub(crate) async fn migrations(db: &Arc<Database>, config: &Config) -> Result<()
}
}
if services().users.count()? > 0 {
migrate(db, config).await
if services.users.count()? > 0 {
migrate(services).await
} else {
fresh(db, config).await
fresh(services).await
}
}
async fn fresh(db: &Arc<Database>, config: &Config) -> Result<()> {
services()
async fn fresh(services: &Services) -> Result<()> {
let db = &services.db;
let config = &services.server.config;
services
.globals
.db
.bump_database_version(DATABASE_VERSION)?;
@ -70,97 +72,100 @@ async fn fresh(db: &Arc<Database>, config: &Config) -> Result<()> {
}
/// Apply any migrations
async fn migrate(db: &Arc<Database>, config: &Config) -> Result<()> {
if services().globals.db.database_version()? < 1 {
db_lt_1(db, config).await?;
async fn migrate(services: &Services) -> Result<()> {
let db = &services.db;
let config = &services.server.config;
if services.globals.db.database_version()? < 1 {
db_lt_1(services).await?;
}
if services().globals.db.database_version()? < 2 {
db_lt_2(db, config).await?;
if services.globals.db.database_version()? < 2 {
db_lt_2(services).await?;
}
if services().globals.db.database_version()? < 3 {
db_lt_3(db, config).await?;
if services.globals.db.database_version()? < 3 {
db_lt_3(services).await?;
}
if services().globals.db.database_version()? < 4 {
db_lt_4(db, config).await?;
if services.globals.db.database_version()? < 4 {
db_lt_4(services).await?;
}
if services().globals.db.database_version()? < 5 {
db_lt_5(db, config).await?;
if services.globals.db.database_version()? < 5 {
db_lt_5(services).await?;
}
if services().globals.db.database_version()? < 6 {
db_lt_6(db, config).await?;
if services.globals.db.database_version()? < 6 {
db_lt_6(services).await?;
}
if services().globals.db.database_version()? < 7 {
db_lt_7(db, config).await?;
if services.globals.db.database_version()? < 7 {
db_lt_7(services).await?;
}
if services().globals.db.database_version()? < 8 {
db_lt_8(db, config).await?;
if services.globals.db.database_version()? < 8 {
db_lt_8(services).await?;
}
if services().globals.db.database_version()? < 9 {
db_lt_9(db, config).await?;
if services.globals.db.database_version()? < 9 {
db_lt_9(services).await?;
}
if services().globals.db.database_version()? < 10 {
db_lt_10(db, config).await?;
if services.globals.db.database_version()? < 10 {
db_lt_10(services).await?;
}
if services().globals.db.database_version()? < 11 {
db_lt_11(db, config).await?;
if services.globals.db.database_version()? < 11 {
db_lt_11(services).await?;
}
if services().globals.db.database_version()? < 12 {
db_lt_12(db, config).await?;
if services.globals.db.database_version()? < 12 {
db_lt_12(services).await?;
}
// This migration can be reused as-is anytime the server-default rules are
// updated.
if services().globals.db.database_version()? < 13 {
db_lt_13(db, config).await?;
if services.globals.db.database_version()? < 13 {
db_lt_13(services).await?;
}
if db["global"].get(b"feat_sha256_media")?.is_none() {
migrate_sha256_media(db, config).await?;
migrate_sha256_media(services).await?;
} else if config.media_startup_check {
checkup_sha256_media(db, config).await?;
checkup_sha256_media(services).await?;
}
if db["global"]
.get(b"fix_bad_double_separator_in_state_cache")?
.is_none()
{
fix_bad_double_separator_in_state_cache(db, config).await?;
fix_bad_double_separator_in_state_cache(services).await?;
}
if db["global"]
.get(b"retroactively_fix_bad_data_from_roomuserid_joined")?
.is_none()
{
retroactively_fix_bad_data_from_roomuserid_joined(db, config).await?;
retroactively_fix_bad_data_from_roomuserid_joined(services).await?;
}
assert_eq!(
services().globals.db.database_version().unwrap(),
services.globals.db.database_version().unwrap(),
DATABASE_VERSION,
"Failed asserting local database version {} is equal to known latest conduwuit database version {}",
services().globals.db.database_version().unwrap(),
services.globals.db.database_version().unwrap(),
DATABASE_VERSION,
);
{
let patterns = services().globals.forbidden_usernames();
let patterns = services.globals.forbidden_usernames();
if !patterns.is_empty() {
for user_id in services()
for user_id in services
.users
.iter()
.filter_map(Result::ok)
.filter(|user| !services().users.is_deactivated(user).unwrap_or(true))
.filter(|user| !services.users.is_deactivated(user).unwrap_or(true))
.filter(|user| user.server_name() == config.server_name)
{
let matches = patterns.matches(user_id.localpart());
@ -179,11 +184,11 @@ async fn migrate(db: &Arc<Database>, config: &Config) -> Result<()> {
}
{
let patterns = services().globals.forbidden_alias_names();
let patterns = services.globals.forbidden_alias_names();
if !patterns.is_empty() {
for address in services().rooms.metadata.iter_ids() {
for address in services.rooms.metadata.iter_ids() {
let room_id = address?;
let room_aliases = services().rooms.alias.local_aliases_for_room(&room_id);
let room_aliases = services.rooms.alias.local_aliases_for_room(&room_id);
for room_alias_result in room_aliases {
let room_alias = room_alias_result?;
let matches = patterns.matches(room_alias.alias());
@ -211,7 +216,9 @@ async fn migrate(db: &Arc<Database>, config: &Config) -> Result<()> {
Ok(())
}
async fn db_lt_1(db: &Arc<Database>, _config: &Config) -> Result<()> {
async fn db_lt_1(services: &Services) -> Result<()> {
let db = &services.db;
let roomserverids = &db["roomserverids"];
let serverroomids = &db["serverroomids"];
for (roomserverid, _) in roomserverids.iter() {
@ -228,12 +235,14 @@ async fn db_lt_1(db: &Arc<Database>, _config: &Config) -> Result<()> {
serverroomids.insert(&serverroomid, &[])?;
}
services().globals.db.bump_database_version(1)?;
services.globals.db.bump_database_version(1)?;
info!("Migration: 0 -> 1 finished");
Ok(())
}
async fn db_lt_2(db: &Arc<Database>, _config: &Config) -> Result<()> {
async fn db_lt_2(services: &Services) -> Result<()> {
let db = &services.db;
// We accidentally inserted hashed versions of "" into the db instead of just ""
let userid_password = &db["roomserverids"];
for (userid, password) in userid_password.iter() {
@ -245,12 +254,14 @@ async fn db_lt_2(db: &Arc<Database>, _config: &Config) -> Result<()> {
}
}
services().globals.db.bump_database_version(2)?;
services.globals.db.bump_database_version(2)?;
info!("Migration: 1 -> 2 finished");
Ok(())
}
async fn db_lt_3(db: &Arc<Database>, _config: &Config) -> Result<()> {
async fn db_lt_3(services: &Services) -> Result<()> {
let db = &services.db;
// Move media to filesystem
let mediaid_file = &db["mediaid_file"];
for (key, content) in mediaid_file.iter() {
@ -259,41 +270,45 @@ async fn db_lt_3(db: &Arc<Database>, _config: &Config) -> Result<()> {
}
#[allow(deprecated)]
let path = services().media.get_media_file(&key);
let path = services.media.get_media_file(&key);
let mut file = fs::File::create(path)?;
file.write_all(&content)?;
mediaid_file.insert(&key, &[])?;
}
services().globals.db.bump_database_version(3)?;
services.globals.db.bump_database_version(3)?;
info!("Migration: 2 -> 3 finished");
Ok(())
}
async fn db_lt_4(_db: &Arc<Database>, config: &Config) -> Result<()> {
// Add federated users to services() as deactivated
for our_user in services().users.iter() {
async fn db_lt_4(services: &Services) -> Result<()> {
let config = &services.server.config;
// Add federated users to services as deactivated
for our_user in services.users.iter() {
let our_user = our_user?;
if services().users.is_deactivated(&our_user)? {
if services.users.is_deactivated(&our_user)? {
continue;
}
for room in services().rooms.state_cache.rooms_joined(&our_user) {
for user in services().rooms.state_cache.room_members(&room?) {
for room in services.rooms.state_cache.rooms_joined(&our_user) {
for user in services.rooms.state_cache.room_members(&room?) {
let user = user?;
if user.server_name() != config.server_name {
info!(?user, "Migration: creating user");
services().users.create(&user, None)?;
services.users.create(&user, None)?;
}
}
}
}
services().globals.db.bump_database_version(4)?;
services.globals.db.bump_database_version(4)?;
info!("Migration: 3 -> 4 finished");
Ok(())
}
async fn db_lt_5(db: &Arc<Database>, _config: &Config) -> Result<()> {
async fn db_lt_5(services: &Services) -> Result<()> {
let db = &services.db;
// Upgrade user data store
let roomuserdataid_accountdata = &db["roomuserdataid_accountdata"];
let roomusertype_roomuserdataid = &db["roomusertype_roomuserdataid"];
@ -312,26 +327,30 @@ async fn db_lt_5(db: &Arc<Database>, _config: &Config) -> Result<()> {
roomusertype_roomuserdataid.insert(&key, &roomuserdataid)?;
}
services().globals.db.bump_database_version(5)?;
services.globals.db.bump_database_version(5)?;
info!("Migration: 4 -> 5 finished");
Ok(())
}
async fn db_lt_6(db: &Arc<Database>, _config: &Config) -> Result<()> {
async fn db_lt_6(services: &Services) -> Result<()> {
let db = &services.db;
// Set room member count
let roomid_shortstatehash = &db["roomid_shortstatehash"];
for (roomid, _) in roomid_shortstatehash.iter() {
let string = utils::string_from_bytes(&roomid).unwrap();
let room_id = <&RoomId>::try_from(string.as_str()).unwrap();
services().rooms.state_cache.update_joined_count(room_id)?;
services.rooms.state_cache.update_joined_count(room_id)?;
}
services().globals.db.bump_database_version(6)?;
services.globals.db.bump_database_version(6)?;
info!("Migration: 5 -> 6 finished");
Ok(())
}
async fn db_lt_7(db: &Arc<Database>, _config: &Config) -> Result<()> {
async fn db_lt_7(services: &Services) -> Result<()> {
let db = &services.db;
// Upgrade state store
let mut last_roomstates: HashMap<OwnedRoomId, u64> = HashMap::new();
let mut current_sstatehash: Option<u64> = None;
@ -347,7 +366,7 @@ async fn db_lt_7(db: &Arc<Database>, _config: &Config) -> Result<()> {
let states_parents = last_roomsstatehash.map_or_else(
|| Ok(Vec::new()),
|&last_roomsstatehash| {
services()
services
.rooms
.state_compressor
.load_shortstatehash_info(last_roomsstatehash)
@ -371,7 +390,7 @@ async fn db_lt_7(db: &Arc<Database>, _config: &Config) -> Result<()> {
(current_state, HashSet::new())
};
services().rooms.state_compressor.save_state_from_diff(
services.rooms.state_compressor.save_state_from_diff(
current_sstatehash,
Arc::new(statediffnew),
Arc::new(statediffremoved),
@ -380,7 +399,7 @@ async fn db_lt_7(db: &Arc<Database>, _config: &Config) -> Result<()> {
)?;
/*
let mut tmp = services().rooms.load_shortstatehash_info(&current_sstatehash)?;
let mut tmp = services.rooms.load_shortstatehash_info(&current_sstatehash)?;
let state = tmp.pop().unwrap();
println!(
"{}\t{}{:?}: {:?} + {:?} - {:?}",
@ -425,12 +444,7 @@ async fn db_lt_7(db: &Arc<Database>, _config: &Config) -> Result<()> {
let event_id = shorteventid_eventid.get(&seventid).unwrap().unwrap();
let string = utils::string_from_bytes(&event_id).unwrap();
let event_id = <&EventId>::try_from(string.as_str()).unwrap();
let pdu = services()
.rooms
.timeline
.get_pdu(event_id)
.unwrap()
.unwrap();
let pdu = services.rooms.timeline.get_pdu(event_id).unwrap().unwrap();
if Some(&pdu.room_id) != current_room.as_ref() {
current_room = Some(pdu.room_id.clone());
@ -451,12 +465,14 @@ async fn db_lt_7(db: &Arc<Database>, _config: &Config) -> Result<()> {
)?;
}
services().globals.db.bump_database_version(7)?;
services.globals.db.bump_database_version(7)?;
info!("Migration: 6 -> 7 finished");
Ok(())
}
async fn db_lt_8(db: &Arc<Database>, _config: &Config) -> Result<()> {
async fn db_lt_8(services: &Services) -> Result<()> {
let db = &services.db;
let roomid_shortstatehash = &db["roomid_shortstatehash"];
let roomid_shortroomid = &db["roomid_shortroomid"];
let pduid_pdu = &db["pduid_pdu"];
@ -464,7 +480,7 @@ async fn db_lt_8(db: &Arc<Database>, _config: &Config) -> Result<()> {
// Generate short room ids for all rooms
for (room_id, _) in roomid_shortstatehash.iter() {
let shortroomid = services().globals.next_count()?.to_be_bytes();
let shortroomid = services.globals.next_count()?.to_be_bytes();
roomid_shortroomid.insert(&room_id, &shortroomid)?;
info!("Migration: 8");
}
@ -517,12 +533,14 @@ async fn db_lt_8(db: &Arc<Database>, _config: &Config) -> Result<()> {
eventid_pduid.insert_batch(batch2.iter().map(database::KeyVal::from))?;
services().globals.db.bump_database_version(8)?;
services.globals.db.bump_database_version(8)?;
info!("Migration: 7 -> 8 finished");
Ok(())
}
async fn db_lt_9(db: &Arc<Database>, _config: &Config) -> Result<()> {
async fn db_lt_9(services: &Services) -> Result<()> {
let db = &services.db;
let tokenids = &db["tokenids"];
let roomid_shortroomid = &db["roomid_shortroomid"];
@ -574,12 +592,14 @@ async fn db_lt_9(db: &Arc<Database>, _config: &Config) -> Result<()> {
tokenids.remove(&key)?;
}
services().globals.db.bump_database_version(9)?;
services.globals.db.bump_database_version(9)?;
info!("Migration: 8 -> 9 finished");
Ok(())
}
async fn db_lt_10(db: &Arc<Database>, _config: &Config) -> Result<()> {
async fn db_lt_10(services: &Services) -> Result<()> {
let db = &services.db;
let statekey_shortstatekey = &db["statekey_shortstatekey"];
let shortstatekey_statekey = &db["shortstatekey_statekey"];
@ -589,28 +609,30 @@ async fn db_lt_10(db: &Arc<Database>, _config: &Config) -> Result<()> {
}
// Force E2EE device list updates so we can send them over federation
for user_id in services().users.iter().filter_map(Result::ok) {
services().users.mark_device_key_update(&user_id)?;
for user_id in services.users.iter().filter_map(Result::ok) {
services.users.mark_device_key_update(&user_id)?;
}
services().globals.db.bump_database_version(10)?;
services.globals.db.bump_database_version(10)?;
info!("Migration: 9 -> 10 finished");
Ok(())
}
#[allow(unreachable_code)]
async fn db_lt_11(_db: &Arc<Database>, _config: &Config) -> Result<()> {
todo!("Dropping a column to clear data is not implemented yet.");
async fn db_lt_11(services: &Services) -> Result<()> {
error!("Dropping a column to clear data is not implemented yet.");
//let userdevicesessionid_uiaarequest = &db["userdevicesessionid_uiaarequest"];
//userdevicesessionid_uiaarequest.clear()?;
services().globals.db.bump_database_version(11)?;
services.globals.db.bump_database_version(11)?;
info!("Migration: 10 -> 11 finished");
Ok(())
}
async fn db_lt_12(_db: &Arc<Database>, config: &Config) -> Result<()> {
for username in services().users.list_local_users()? {
async fn db_lt_12(services: &Services) -> Result<()> {
let config = &services.server.config;
for username in services.users.list_local_users()? {
let user = match UserId::parse_with_server_name(username.clone(), &config.server_name) {
Ok(u) => u,
Err(e) => {
@ -619,7 +641,7 @@ async fn db_lt_12(_db: &Arc<Database>, config: &Config) -> Result<()> {
},
};
let raw_rules_list = services()
let raw_rules_list = services
.account_data
.get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into())
.unwrap()
@ -664,7 +686,7 @@ async fn db_lt_12(_db: &Arc<Database>, config: &Config) -> Result<()> {
}
}
services().account_data.update(
services.account_data.update(
None,
&user,
GlobalAccountDataEventType::PushRules.to_string().into(),
@ -672,13 +694,15 @@ async fn db_lt_12(_db: &Arc<Database>, config: &Config) -> Result<()> {
)?;
}
services().globals.db.bump_database_version(12)?;
services.globals.db.bump_database_version(12)?;
info!("Migration: 11 -> 12 finished");
Ok(())
}
async fn db_lt_13(_db: &Arc<Database>, config: &Config) -> Result<()> {
for username in services().users.list_local_users()? {
async fn db_lt_13(services: &Services) -> Result<()> {
let config = &services.server.config;
for username in services.users.list_local_users()? {
let user = match UserId::parse_with_server_name(username.clone(), &config.server_name) {
Ok(u) => u,
Err(e) => {
@ -687,7 +711,7 @@ async fn db_lt_13(_db: &Arc<Database>, config: &Config) -> Result<()> {
},
};
let raw_rules_list = services()
let raw_rules_list = services
.account_data
.get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into())
.unwrap()
@ -701,7 +725,7 @@ async fn db_lt_13(_db: &Arc<Database>, config: &Config) -> Result<()> {
.global
.update_with_server_default(user_default_rules);
services().account_data.update(
services.account_data.update(
None,
&user,
GlobalAccountDataEventType::PushRules.to_string().into(),
@ -709,7 +733,7 @@ async fn db_lt_13(_db: &Arc<Database>, config: &Config) -> Result<()> {
)?;
}
services().globals.db.bump_database_version(13)?;
services.globals.db.bump_database_version(13)?;
info!("Migration: 12 -> 13 finished");
Ok(())
}
@ -717,15 +741,17 @@ async fn db_lt_13(_db: &Arc<Database>, config: &Config) -> Result<()> {
/// Migrates a media directory from legacy base64 file names to sha2 file names.
/// All errors are fatal. Upon success the database is keyed to not perform this
/// again.
async fn migrate_sha256_media(db: &Arc<Database>, _config: &Config) -> Result<()> {
async fn migrate_sha256_media(services: &Services) -> Result<()> {
let db = &services.db;
warn!("Migrating legacy base64 file names to sha256 file names");
let mediaid_file = &db["mediaid_file"];
// Move old media files to new names
let mut changes = Vec::<(PathBuf, PathBuf)>::new();
for (key, _) in mediaid_file.iter() {
let old = services().media.get_media_file_b64(&key);
let new = services().media.get_media_file_sha256(&key);
let old = services.media.get_media_file_b64(&key);
let new = services.media.get_media_file_sha256(&key);
debug!(?key, ?old, ?new, num = changes.len(), "change");
changes.push((old, new));
}
@ -739,8 +765,8 @@ async fn migrate_sha256_media(db: &Arc<Database>, _config: &Config) -> Result<()
// Apply fix from when sha256_media was backward-incompat and bumped the schema
// version from 13 to 14. For users satisfying these conditions we can go back.
if services().globals.db.database_version()? == 14 && DATABASE_VERSION == 13 {
services().globals.db.bump_database_version(13)?;
if services.globals.db.database_version()? == 14 && DATABASE_VERSION == 13 {
services.globals.db.bump_database_version(13)?;
}
db["global"].insert(b"feat_sha256_media", &[])?;
@ -752,14 +778,16 @@ async fn migrate_sha256_media(db: &Arc<Database>, _config: &Config) -> Result<()
/// - Going back and forth to non-sha256 legacy binaries (e.g. upstream).
/// - Deletion of artifacts in the media directory which will then fall out of
/// sync with the database.
async fn checkup_sha256_media(db: &Arc<Database>, config: &Config) -> Result<()> {
async fn checkup_sha256_media(services: &Services) -> Result<()> {
use crate::media::encode_key;
debug!("Checking integrity of media directory");
let db = &services.db;
let media = &services.media;
let config = &services.server.config;
let mediaid_file = &db["mediaid_file"];
let mediaid_user = &db["mediaid_user"];
let dbs = (mediaid_file, mediaid_user);
let media = &services().media;
let timer = Instant::now();
let dir = media.get_media_dir();
@ -791,6 +819,7 @@ async fn handle_media_check(
new_path: &OsStr, old_path: &OsStr,
) -> Result<()> {
use crate::media::encode_key;
let (mediaid_file, mediaid_user) = dbs;
let old_exists = files.contains(old_path);
@ -827,8 +856,10 @@ async fn handle_media_check(
Ok(())
}
async fn fix_bad_double_separator_in_state_cache(db: &Arc<Database>, _config: &Config) -> Result<()> {
async fn fix_bad_double_separator_in_state_cache(services: &Services) -> Result<()> {
warn!("Fixing bad double separator in state_cache roomuserid_joined");
let db = &services.db;
let roomuserid_joined = &db["roomuserid_joined"];
let _cork = db.cork_and_sync();
@ -864,11 +895,13 @@ async fn fix_bad_double_separator_in_state_cache(db: &Arc<Database>, _config: &C
Ok(())
}
async fn retroactively_fix_bad_data_from_roomuserid_joined(db: &Arc<Database>, _config: &Config) -> Result<()> {
async fn retroactively_fix_bad_data_from_roomuserid_joined(services: &Services) -> Result<()> {
warn!("Retroactively fixing bad data from broken roomuserid_joined");
let db = &services.db;
let _cork = db.cork_and_sync();
let room_ids = services()
let room_ids = services
.rooms
.metadata
.iter_ids()
@ -878,7 +911,7 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(db: &Arc<Database>, _
for room_id in room_ids.clone() {
debug_info!("Fixing room {room_id}");
let users_in_room = services()
let users_in_room = services
.rooms
.state_cache
.room_members(&room_id)
@ -888,7 +921,7 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(db: &Arc<Database>, _
let joined_members = users_in_room
.iter()
.filter(|user_id| {
services()
services
.rooms
.state_accessor
.get_member(&room_id, user_id)
@ -900,7 +933,7 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(db: &Arc<Database>, _
let non_joined_members = users_in_room
.iter()
.filter(|user_id| {
services()
services
.rooms
.state_accessor
.get_member(&room_id, user_id)
@ -913,7 +946,7 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(db: &Arc<Database>, _
for user_id in joined_members {
debug_info!("User is joined, marking as joined");
services()
services
.rooms
.state_cache
.mark_as_joined(user_id, &room_id)?;
@ -921,10 +954,7 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(db: &Arc<Database>, _
for user_id in non_joined_members {
debug_info!("User is left or banned, marking as left");
services()
.rooms
.state_cache
.mark_as_left(user_id, &room_id)?;
services.rooms.state_cache.mark_as_left(user_id, &room_id)?;
}
}
@ -933,7 +963,7 @@ async fn retroactively_fix_bad_data_from_roomuserid_joined(db: &Arc<Database>, _
"Updating joined count for room {room_id} to fix servers in room after correcting membership states"
);
services().rooms.state_cache.update_joined_count(&room_id)?;
services.rooms.state_cache.update_joined_count(&room_id)?;
}
db.db.cleanup()?;

View file

@ -1,5 +1,4 @@
mod data;
mod emerg_access;
pub(super) mod migrations;
use std::{
@ -9,7 +8,6 @@ use std::{
time::Instant,
};
use async_trait::async_trait;
use conduit::{error, trace, Config, Result};
use data::Data;
use ipaddress::IPAddress;
@ -43,11 +41,10 @@ pub struct Service {
type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries
#[async_trait]
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
let db = Data::new(&args);
let config = &args.server.config;
let db = Data::new(args.db);
let keypair = db.load_keypair();
let keypair = match keypair {
@ -104,19 +101,13 @@ impl crate::Service for Service {
.supported_room_versions()
.contains(&s.config.default_room_version)
{
error!(config=?s.config.default_room_version, fallback=?crate::config::default_default_room_version(), "Room version in config isn't supported, falling back to default version");
s.config.default_room_version = crate::config::default_default_room_version();
error!(config=?s.config.default_room_version, fallback=?conduit::config::default_default_room_version(), "Room version in config isn't supported, falling back to default version");
s.config.default_room_version = conduit::config::default_default_room_version();
};
Ok(Arc::new(s))
}
async fn worker(self: Arc<Self>) -> Result<()> {
emerg_access::init_emergency_access();
Ok(())
}
fn memory_usage(&self, out: &mut dyn Write) -> Result<()> {
let bad_event_ratelimiter = self
.bad_event_ratelimiter

View file

@ -1,7 +1,7 @@
use std::{collections::BTreeMap, sync::Arc};
use conduit::{utils, Error, Result};
use database::{Database, Map};
use database::Map;
use ruma::{
api::client::{
backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
@ -11,25 +11,34 @@ use ruma::{
OwnedRoomId, RoomId, UserId,
};
use crate::services;
use crate::{globals, Dep};
pub(super) struct Data {
backupid_algorithm: Arc<Map>,
backupid_etag: Arc<Map>,
backupkeyid_backup: Arc<Map>,
services: Services,
}
struct Services {
globals: Dep<globals::Service>,
}
impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self {
pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self {
backupid_algorithm: db["backupid_algorithm"].clone(),
backupid_etag: db["backupid_etag"].clone(),
backupkeyid_backup: db["backupkeyid_backup"].clone(),
services: Services {
globals: args.depend::<globals::Service>("globals"),
},
}
}
pub(super) fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> {
let version = services().globals.next_count()?.to_string();
let version = self.services.globals.next_count()?.to_string();
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
@ -40,7 +49,7 @@ impl Data {
&serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"),
)?;
self.backupid_etag
.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
.insert(&key, &self.services.globals.next_count()?.to_be_bytes())?;
Ok(version)
}
@ -75,7 +84,7 @@ impl Data {
self.backupid_algorithm
.insert(&key, backup_metadata.json().get().as_bytes())?;
self.backupid_etag
.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
.insert(&key, &self.services.globals.next_count()?.to_be_bytes())?;
Ok(version.to_owned())
}
@ -152,7 +161,7 @@ impl Data {
}
self.backupid_etag
.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
.insert(&key, &self.services.globals.next_count()?.to_be_bytes())?;
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());

View file

@ -17,7 +17,7 @@ pub struct Service {
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db: Data::new(args.db),
db: Data::new(&args),
}))
}

View file

@ -8,13 +8,13 @@ use tokio::{
time::sleep,
};
use crate::{service::Service, Services};
use crate::{service, service::Service, Services};
pub(crate) struct Manager {
manager: Mutex<Option<JoinHandle<Result<()>>>>,
workers: Mutex<Workers>,
server: Arc<Server>,
services: &'static Services,
service: Arc<service::Map>,
}
type Workers = JoinSet<WorkerResult>;
@ -29,7 +29,7 @@ impl Manager {
manager: Mutex::new(None),
workers: Mutex::new(JoinSet::new()),
server: services.server.clone(),
services: crate::services(),
service: services.service.clone(),
})
}
@ -53,9 +53,19 @@ impl Manager {
.spawn(async move { self_.worker().await }),
);
// we can't hold the lock during the iteration with start_worker so the values
// are snapshotted here
let services: Vec<Arc<dyn Service>> = self
.service
.read()
.expect("locked for reading")
.values()
.map(|v| v.0.clone())
.collect();
debug!("Starting service workers...");
for (service, ..) in self.services.service.values() {
self.start_worker(&mut workers, service).await?;
for service in services {
self.start_worker(&mut workers, &service).await?;
}
Ok(())

View file

@ -1,10 +1,10 @@
use std::sync::Arc;
use conduit::{debug, debug_info, Error, Result};
use conduit::{debug, debug_info, utils::string_from_bytes, Error, Result};
use database::{Database, Map};
use ruma::api::client::error::ErrorKind;
use crate::{media::UrlPreviewData, utils::string_from_bytes};
use crate::media::UrlPreviewData;
pub(crate) struct Data {
mediaid_file: Arc<Map>,

View file

@ -15,7 +15,7 @@ use tokio::{
io::{AsyncReadExt, AsyncWriteExt, BufReader},
};
use crate::services;
use crate::{globals, Dep};
#[derive(Debug)]
pub struct FileMeta {
@ -41,16 +41,24 @@ pub struct UrlPreviewData {
}
pub struct Service {
server: Arc<Server>,
services: Services,
pub(crate) db: Data,
pub url_preview_mutex: MutexMap<String, ()>,
}
struct Services {
server: Arc<Server>,
globals: Dep<globals::Service>,
}
#[async_trait]
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
server: args.server.clone(),
services: Services {
server: args.server.clone(),
globals: args.depend::<globals::Service>("globals"),
},
db: Data::new(args.db),
url_preview_mutex: MutexMap::new(),
}))
@ -164,7 +172,7 @@ impl Service {
debug!("Parsed MXC key to URL: {mxc_s}");
let mxc = OwnedMxcUri::from(mxc_s);
if mxc.server_name() == Ok(services().globals.server_name()) {
if mxc.server_name() == Ok(self.services.globals.server_name()) {
debug!("Ignoring local media MXC: {mxc}");
// ignore our own MXC URLs as this would be local media.
continue;
@ -246,7 +254,7 @@ impl Service {
let legacy_rm = fs::remove_file(&legacy);
let (file_rm, legacy_rm) = tokio::join!(file_rm, legacy_rm);
if let Err(e) = legacy_rm {
if self.server.config.media_compat_file_link {
if self.services.server.config.media_compat_file_link {
debug_error!(?key, ?legacy, "Failed to remove legacy media symlink: {e}");
}
}
@ -259,7 +267,7 @@ impl Service {
debug!(?key, ?path, "Creating media file");
let file = fs::File::create(&path).await?;
if self.server.config.media_compat_file_link {
if self.services.server.config.media_compat_file_link {
let legacy = self.get_media_file_b64(key);
if let Err(e) = fs::symlink(&path, &legacy).await {
debug_error!(
@ -304,7 +312,7 @@ impl Service {
#[must_use]
pub fn get_media_dir(&self) -> PathBuf {
let mut r = PathBuf::new();
r.push(self.server.config.database_path.clone());
r.push(self.services.server.config.database_path.clone());
r.push("media");
r
}

View file

@ -1,3 +1,4 @@
#![recursion_limit = "160"]
#![allow(refining_impl_trait)]
mod manager;
@ -8,6 +9,7 @@ pub mod account_data;
pub mod admin;
pub mod appservice;
pub mod client;
pub mod emergency;
pub mod globals;
pub mod key_backups;
pub mod media;
@ -26,8 +28,8 @@ extern crate conduit_database as database;
use std::sync::{Arc, RwLock};
pub(crate) use conduit::{config, debug_error, debug_warn, utils, Error, Result, Server};
pub use conduit::{pdu, PduBuilder, PduCount, PduEvent};
use conduit::{Result, Server};
use database::Database;
pub(crate) use service::{Args, Dep, Service};

View file

@ -1,21 +1,32 @@
use std::sync::Arc;
use conduit::{debug_warn, utils, Error, Result};
use database::{Database, Map};
use database::Map;
use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, UInt, UserId};
use crate::{presence::Presence, services};
use crate::{globals, presence::Presence, users, Dep};
pub struct Data {
presenceid_presence: Arc<Map>,
userid_presenceid: Arc<Map>,
services: Services,
}
struct Services {
globals: Dep<globals::Service>,
users: Dep<users::Service>,
}
impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self {
pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self {
presenceid_presence: db["presenceid_presence"].clone(),
userid_presenceid: db["userid_presenceid"].clone(),
services: Services {
globals: args.depend::<globals::Service>("globals"),
users: args.depend::<users::Service>("users"),
},
}
}
@ -28,7 +39,10 @@ impl Data {
self.presenceid_presence
.get(&key)?
.map(|presence_bytes| -> Result<(u64, PresenceEvent)> {
Ok((count, Presence::from_json_bytes(&presence_bytes)?.to_presence_event(user_id)?))
Ok((
count,
Presence::from_json_bytes(&presence_bytes)?.to_presence_event(user_id, &self.services.users)?,
))
})
.transpose()
} else {
@ -80,7 +94,7 @@ impl Data {
last_active_ts,
status_msg,
);
let count = services().globals.next_count()?;
let count = self.services.globals.next_count()?;
let key = presenceid_key(count, user_id);
self.presenceid_presence

View file

@ -3,8 +3,7 @@ mod data;
use std::{sync::Arc, time::Duration};
use async_trait::async_trait;
use conduit::{checked, debug, error, utils, Error, Result};
use data::Data;
use conduit::{checked, debug, error, utils, Error, Result, Server};
use futures_util::{stream::FuturesUnordered, StreamExt};
use ruma::{
events::presence::{PresenceEvent, PresenceEventContent},
@ -14,7 +13,8 @@ use ruma::{
use serde::{Deserialize, Serialize};
use tokio::{sync::Mutex, time::sleep};
use crate::{services, user_is_local};
use self::data::Data;
use crate::{user_is_local, users, Dep};
/// Represents data required to be kept in order to implement the presence
/// specification.
@ -37,11 +37,6 @@ impl Presence {
}
}
pub fn from_json_bytes_to_event(bytes: &[u8], user_id: &UserId) -> Result<PresenceEvent> {
let presence = Self::from_json_bytes(bytes)?;
presence.to_presence_event(user_id)
}
pub fn from_json_bytes(bytes: &[u8]) -> Result<Self> {
serde_json::from_slice(bytes).map_err(|_| Error::bad_database("Invalid presence data in database"))
}
@ -51,7 +46,7 @@ impl Presence {
}
/// Creates a PresenceEvent from available data.
pub fn to_presence_event(&self, user_id: &UserId) -> Result<PresenceEvent> {
pub fn to_presence_event(&self, user_id: &UserId, users: &users::Service) -> Result<PresenceEvent> {
let now = utils::millis_since_unix_epoch();
let last_active_ago = if self.currently_active {
None
@ -66,14 +61,15 @@ impl Presence {
status_msg: self.status_msg.clone(),
currently_active: Some(self.currently_active),
last_active_ago,
displayname: services().users.displayname(user_id)?,
avatar_url: services().users.avatar_url(user_id)?,
displayname: users.displayname(user_id)?,
avatar_url: users.avatar_url(user_id)?,
},
})
}
}
pub struct Service {
services: Services,
pub db: Data,
pub timer_sender: loole::Sender<(OwnedUserId, Duration)>,
timer_receiver: Mutex<loole::Receiver<(OwnedUserId, Duration)>>,
@ -82,6 +78,11 @@ pub struct Service {
offline_timeout: u64,
}
struct Services {
server: Arc<Server>,
users: Dep<users::Service>,
}
#[async_trait]
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
@ -90,7 +91,11 @@ impl crate::Service for Service {
let offline_timeout_s = config.presence_offline_timeout_s;
let (timer_sender, timer_receiver) = loole::unbounded();
Ok(Arc::new(Self {
db: Data::new(args.db),
services: Services {
server: args.server.clone(),
users: args.depend::<users::Service>("users"),
},
db: Data::new(&args),
timer_sender,
timer_receiver: Mutex::new(timer_receiver),
timeout_remote_users: config.presence_timeout_remote_users,
@ -182,8 +187,8 @@ impl Service {
if self.timeout_remote_users || user_is_local(user_id) {
let timeout = match presence_state {
PresenceState::Online => services().globals.config.presence_idle_timeout_s,
_ => services().globals.config.presence_offline_timeout_s,
PresenceState::Online => self.services.server.config.presence_idle_timeout_s,
_ => self.services.server.config.presence_offline_timeout_s,
};
self.timer_sender
@ -210,6 +215,11 @@ impl Service {
self.db.presence_since(since)
}
pub fn from_json_bytes_to_event(&self, bytes: &[u8], user_id: &UserId) -> Result<PresenceEvent> {
let presence = Presence::from_json_bytes(bytes)?;
presence.to_presence_event(user_id, &self.services.users)
}
fn process_presence_timer(&self, user_id: &OwnedUserId) -> Result<()> {
let mut presence_state = PresenceState::Offline;
let mut last_active_ago = None;

View file

@ -3,8 +3,7 @@ mod data;
use std::{fmt::Debug, mem, sync::Arc};
use bytes::BytesMut;
use conduit::{debug_info, info, trace, warn, Error, Result};
use data::Data;
use conduit::{debug_info, info, trace, utils::string_from_bytes, warn, Error, PduEvent, Result};
use ipaddress::IPAddress;
use ruma::{
api::{
@ -23,15 +22,32 @@ use ruma::{
uint, RoomId, UInt, UserId,
};
use crate::{services, PduEvent};
use self::data::Data;
use crate::{client, globals, rooms, users, Dep};
pub struct Service {
services: Services,
db: Data,
}
struct Services {
globals: Dep<globals::Service>,
client: Dep<client::Service>,
state_accessor: Dep<rooms::state_accessor::Service>,
state_cache: Dep<rooms::state_cache::Service>,
users: Dep<users::Service>,
}
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
services: Services {
globals: args.depend::<globals::Service>("globals"),
client: args.depend::<client::Service>("client"),
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
users: args.depend::<users::Service>("users"),
},
db: Data::new(args.db),
}))
}
@ -62,7 +78,7 @@ impl Service {
{
const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_0];
let dest = dest.replace(services().globals.notification_push_path(), "");
let dest = dest.replace(self.services.globals.notification_push_path(), "");
trace!("Push gateway destination: {dest}");
let http_request = request
@ -78,13 +94,13 @@ impl Service {
if let Some(url_host) = reqwest_request.url().host_str() {
trace!("Checking request URL for IP");
if let Ok(ip) = IPAddress::parse(url_host) {
if !services().globals.valid_cidr_range(&ip) {
if !self.services.globals.valid_cidr_range(&ip) {
return Err(Error::BadServerResponse("Not allowed to send requests to this IP"));
}
}
}
let response = services().client.pusher.execute(reqwest_request).await;
let response = self.services.client.pusher.execute(reqwest_request).await;
match response {
Ok(mut response) => {
@ -93,7 +109,7 @@ impl Service {
trace!("Checking response destination's IP");
if let Some(remote_addr) = response.remote_addr() {
if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) {
if !services().globals.valid_cidr_range(&ip) {
if !self.services.globals.valid_cidr_range(&ip) {
return Err(Error::BadServerResponse("Not allowed to send requests to this IP"));
}
}
@ -114,7 +130,7 @@ impl Service {
if !status.is_success() {
info!("Push gateway {dest} returned unsuccessful HTTP response ({status})");
debug_info!("Push gateway response body: {:?}", crate::utils::string_from_bytes(&body));
debug_info!("Push gateway response body: {:?}", string_from_bytes(&body));
return Err(Error::BadServerResponse("Push gateway returned unsuccessful response"));
}
@ -143,8 +159,8 @@ impl Service {
let mut notify = None;
let mut tweaks = Vec::new();
let power_levels: RoomPowerLevelsEventContent = services()
.rooms
let power_levels: RoomPowerLevelsEventContent = self
.services
.state_accessor
.room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")?
.map(|ev| {
@ -195,15 +211,15 @@ impl Service {
let ctx = PushConditionRoomCtx {
room_id: room_id.to_owned(),
member_count: UInt::try_from(
services()
.rooms
self.services
.state_cache
.room_joined_count(room_id)?
.unwrap_or(1),
)
.unwrap_or_else(|_| uint!(0)),
user_id: user.to_owned(),
user_display_name: services()
user_display_name: self
.services
.users
.displayname(user)?
.unwrap_or_else(|| user.localpart().to_owned()),
@ -263,9 +279,9 @@ impl Service {
notifi.user_is_target = event.state_key.as_deref() == Some(event.sender.as_str());
}
notifi.sender_display_name = services().users.displayname(&event.sender)?;
notifi.sender_display_name = self.services.users.displayname(&event.sender)?;
notifi.room_name = services().rooms.state_accessor.get_name(&event.room_id)?;
notifi.room_name = self.services.state_accessor.get_name(&event.room_id)?;
self.send_request(&http.url, send_event_notification::v1::Request::new(notifi))
.await?;

View file

@ -9,12 +9,9 @@ use hickory_resolver::{error::ResolveError, lookup::SrvLookup};
use ipaddress::IPAddress;
use ruma::ServerName;
use crate::{
resolver::{
cache::{CachedDest, CachedOverride},
fed::{add_port_to_hostname, get_ip_with_port, FedDest},
},
services,
use crate::resolver::{
cache::{CachedDest, CachedOverride},
fed::{add_port_to_hostname, get_ip_with_port, FedDest},
};
#[derive(Clone, Debug)]
@ -40,7 +37,7 @@ impl super::Service {
result
} else {
cached = false;
validate_dest(server_name)?;
self.validate_dest(server_name)?;
self.resolve_actual_dest(server_name, true).await?
};
@ -188,7 +185,8 @@ impl super::Service {
self.query_and_cache_override(dest, dest, 8448).await?;
}
let response = services()
let response = self
.services
.client
.well_known
.get(&format!("https://{dest}/.well-known/matrix/server"))
@ -245,19 +243,14 @@ impl super::Service {
#[tracing::instrument(skip_all, name = "ip")]
async fn query_and_cache_override(&self, overname: &'_ str, hostname: &'_ str, port: u16) -> Result<()> {
match services()
.resolver
.raw()
.lookup_ip(hostname.to_owned())
.await
{
Err(e) => handle_resolve_error(&e),
match self.raw().lookup_ip(hostname.to_owned()).await {
Err(e) => Self::handle_resolve_error(&e),
Ok(override_ip) => {
if hostname != overname {
debug_info!("{overname:?} overriden by {hostname:?}");
}
services().resolver.set_cached_override(
self.set_cached_override(
overname.to_owned(),
CachedOverride {
ips: override_ip.iter().collect(),
@ -295,62 +288,62 @@ impl super::Service {
for hostname in hostnames {
match lookup_srv(self.raw(), &hostname).await {
Ok(result) => return Ok(handle_successful_srv(&result)),
Err(e) => handle_resolve_error(&e)?,
Err(e) => Self::handle_resolve_error(&e)?,
}
}
Ok(None)
}
}
#[allow(clippy::single_match_else)]
fn handle_resolve_error(e: &ResolveError) -> Result<()> {
use hickory_resolver::error::ResolveErrorKind;
#[allow(clippy::single_match_else)]
fn handle_resolve_error(e: &ResolveError) -> Result<()> {
use hickory_resolver::error::ResolveErrorKind;
match *e.kind() {
ResolveErrorKind::NoRecordsFound {
..
} => {
// Raise to debug_warn if we can find out the result wasn't from cache
debug!("{e}");
Ok(())
},
_ => Err!(error!("DNS {e}")),
}
}
fn validate_dest(dest: &ServerName) -> Result<()> {
if dest == services().globals.server_name() {
return Err!("Won't send federation request to ourselves");
match *e.kind() {
ResolveErrorKind::NoRecordsFound {
..
} => {
// Raise to debug_warn if we can find out the result wasn't from cache
debug!("{e}");
Ok(())
},
_ => Err!(error!("DNS {e}")),
}
}
if dest.is_ip_literal() || IPAddress::is_valid(dest.host()) {
validate_dest_ip_literal(dest)?;
fn validate_dest(&self, dest: &ServerName) -> Result<()> {
if dest == self.services.server.config.server_name {
return Err!("Won't send federation request to ourselves");
}
if dest.is_ip_literal() || IPAddress::is_valid(dest.host()) {
self.validate_dest_ip_literal(dest)?;
}
Ok(())
}
Ok(())
}
fn validate_dest_ip_literal(&self, dest: &ServerName) -> Result<()> {
trace!("Destination is an IP literal, checking against IP range denylist.",);
debug_assert!(
dest.is_ip_literal() || !IPAddress::is_valid(dest.host()),
"Destination is not an IP literal."
);
let ip = IPAddress::parse(dest.host()).map_err(|e| {
debug_error!("Failed to parse IP literal from string: {}", e);
Error::BadServerResponse("Invalid IP address")
})?;
fn validate_dest_ip_literal(dest: &ServerName) -> Result<()> {
trace!("Destination is an IP literal, checking against IP range denylist.",);
debug_assert!(
dest.is_ip_literal() || !IPAddress::is_valid(dest.host()),
"Destination is not an IP literal."
);
let ip = IPAddress::parse(dest.host()).map_err(|e| {
debug_error!("Failed to parse IP literal from string: {}", e);
Error::BadServerResponse("Invalid IP address")
})?;
self.validate_ip(&ip)?;
validate_ip(&ip)?;
Ok(())
}
pub(crate) fn validate_ip(ip: &IPAddress) -> Result<()> {
if !services().globals.valid_cidr_range(ip) {
return Err(Error::BadServerResponse("Not allowed to send requests to this IP"));
Ok(())
}
Ok(())
pub(crate) fn validate_ip(&self, ip: &IPAddress) -> Result<()> {
if !self.services.globals.valid_cidr_range(ip) {
return Err(Error::BadServerResponse("Not allowed to send requests to this IP"));
}
Ok(())
}
}

View file

@ -5,11 +5,10 @@ use std::{
time::SystemTime,
};
use conduit::trace;
use conduit::{trace, utils::rand};
use ruma::{OwnedServerName, ServerName};
use super::fed::FedDest;
use crate::utils::rand;
pub struct Cache {
pub destinations: RwLock<WellKnownMap>, // actual_destination, host

View file

@ -6,14 +6,22 @@ mod tests;
use std::{fmt::Write, sync::Arc};
use conduit::Result;
use conduit::{Result, Server};
use hickory_resolver::TokioAsyncResolver;
use self::{cache::Cache, dns::Resolver};
use crate::{client, globals, Dep};
pub struct Service {
pub cache: Arc<Cache>,
pub resolver: Arc<Resolver>,
services: Services,
}
struct Services {
server: Arc<Server>,
client: Dep<client::Service>,
globals: Dep<globals::Service>,
}
impl crate::Service for Service {
@ -23,6 +31,11 @@ impl crate::Service for Service {
Ok(Arc::new(Self {
cache: cache.clone(),
resolver: Resolver::build(args.server, cache)?,
services: Services {
server: args.server.clone(),
client: args.depend::<client::Service>("client"),
globals: args.depend::<globals::Service>("globals"),
},
}))
}

View file

@ -1,23 +1,32 @@
use std::sync::Arc;
use conduit::{utils, Error, Result};
use database::{Database, Map};
use database::Map;
use ruma::{api::client::error::ErrorKind, OwnedRoomAliasId, OwnedRoomId, OwnedUserId, RoomAliasId, RoomId, UserId};
use crate::services;
use crate::{globals, Dep};
pub(super) struct Data {
alias_userid: Arc<Map>,
alias_roomid: Arc<Map>,
aliasid_alias: Arc<Map>,
services: Services,
}
struct Services {
globals: Dep<globals::Service>,
}
impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self {
pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self {
alias_userid: db["alias_userid"].clone(),
alias_roomid: db["alias_roomid"].clone(),
aliasid_alias: db["aliasid_alias"].clone(),
services: Services {
globals: args.depend::<globals::Service>("globals"),
},
}
}
@ -31,7 +40,7 @@ impl Data {
let mut aliasid = room_id.as_bytes().to_vec();
aliasid.push(0xFF);
aliasid.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
aliasid.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes());
self.aliasid_alias.insert(&aliasid, alias.as_bytes())?;
Ok(())

View file

@ -4,9 +4,8 @@ mod remote;
use std::sync::Arc;
use conduit::{err, Error, Result};
use data::Data;
use ruma::{
api::{appservice, client::error::ErrorKind},
api::client::error::ErrorKind,
events::{
room::power_levels::{RoomPowerLevels, RoomPowerLevelsEventContent},
StateEventType,
@ -14,16 +13,33 @@ use ruma::{
OwnedRoomAliasId, OwnedRoomId, OwnedServerName, RoomAliasId, RoomId, RoomOrAliasId, UserId,
};
use crate::{appservice::RegistrationInfo, server_is_ours, services};
use self::data::Data;
use crate::{admin, appservice, appservice::RegistrationInfo, globals, rooms, sending, server_is_ours, Dep};
pub struct Service {
db: Data,
services: Services,
}
struct Services {
admin: Dep<admin::Service>,
appservice: Dep<appservice::Service>,
globals: Dep<globals::Service>,
sending: Dep<sending::Service>,
state_accessor: Dep<rooms::state_accessor::Service>,
}
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db: Data::new(args.db),
db: Data::new(&args),
services: Services {
admin: args.depend::<admin::Service>("admin"),
appservice: args.depend::<appservice::Service>("appservice"),
globals: args.depend::<globals::Service>("globals"),
sending: args.depend::<sending::Service>("sending"),
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
},
}))
}
@ -33,7 +49,7 @@ impl crate::Service for Service {
impl Service {
#[tracing::instrument(skip(self))]
pub fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()> {
if alias == services().globals.admin_alias && user_id != services().globals.server_user {
if alias == self.services.globals.admin_alias && user_id != self.services.globals.server_user {
Err(Error::BadRequest(
ErrorKind::forbidden(),
"Only the server user can set this alias",
@ -72,10 +88,10 @@ impl Service {
if !server_is_ours(room_alias.server_name())
&& (!servers
.as_ref()
.is_some_and(|servers| servers.contains(&services().globals.server_name().to_owned()))
.is_some_and(|servers| servers.contains(&self.services.globals.server_name().to_owned()))
|| servers.as_ref().is_none())
{
return remote::resolve(room_alias, servers).await;
return self.remote_resolve(room_alias, servers).await;
}
let room_id: Option<OwnedRoomId> = match self.resolve_local_alias(room_alias)? {
@ -111,7 +127,7 @@ impl Service {
return Err(Error::BadRequest(ErrorKind::NotFound, "Alias not found."));
};
let server_user = &services().globals.server_user;
let server_user = &self.services.globals.server_user;
// The creator of an alias can remove it
if self
@ -119,7 +135,7 @@ impl Service {
.who_created_alias(alias)?
.is_some_and(|user| user == user_id)
// Server admins can remove any local alias
|| services().admin.user_is_admin(user_id).await?
|| self.services.admin.user_is_admin(user_id).await?
// Always allow the server service account to remove the alias, since there may not be an admin room
|| server_user == user_id
{
@ -127,8 +143,7 @@ impl Service {
// Checking whether the user is able to change canonical aliases of the
// room
} else if let Some(event) =
services()
.rooms
self.services
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomPowerLevels, "")?
{
@ -140,8 +155,7 @@ impl Service {
// If there is no power levels event, only the room creator can change
// canonical aliases
} else if let Some(event) =
services()
.rooms
self.services
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomCreate, "")?
{
@ -152,14 +166,16 @@ impl Service {
}
async fn resolve_appservice_alias(&self, room_alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> {
for appservice in services().appservice.read().await.values() {
use ruma::api::appservice::query::query_room_alias;
for appservice in self.services.appservice.read().await.values() {
if appservice.aliases.is_match(room_alias.as_str())
&& matches!(
services()
self.services
.sending
.send_appservice_request(
appservice.registration.clone(),
appservice::query::query_room_alias::v1::Request {
query_room_alias::v1::Request {
room_alias: room_alias.to_owned(),
},
)
@ -167,10 +183,7 @@ impl Service {
Ok(Some(_opt_result))
) {
return Ok(Some(
services()
.rooms
.alias
.resolve_local_alias(room_alias)?
self.resolve_local_alias(room_alias)?
.ok_or_else(|| err!(Request(NotFound("Room does not exist."))))?,
));
}
@ -178,20 +191,27 @@ impl Service {
Ok(None)
}
}
pub async fn appservice_checks(room_alias: &RoomAliasId, appservice_info: &Option<RegistrationInfo>) -> Result<()> {
if !server_is_ours(room_alias.server_name()) {
return Err(Error::BadRequest(ErrorKind::InvalidParam, "Alias is from another server."));
}
if let Some(ref info) = appservice_info {
if !info.aliases.is_match(room_alias.as_str()) {
return Err(Error::BadRequest(ErrorKind::Exclusive, "Room alias is not in namespace."));
pub async fn appservice_checks(
&self, room_alias: &RoomAliasId, appservice_info: &Option<RegistrationInfo>,
) -> Result<()> {
if !server_is_ours(room_alias.server_name()) {
return Err(Error::BadRequest(ErrorKind::InvalidParam, "Alias is from another server."));
}
} else if services().appservice.is_exclusive_alias(room_alias).await {
return Err(Error::BadRequest(ErrorKind::Exclusive, "Room alias reserved by appservice."));
}
Ok(())
if let Some(ref info) = appservice_info {
if !info.aliases.is_match(room_alias.as_str()) {
return Err(Error::BadRequest(ErrorKind::Exclusive, "Room alias is not in namespace."));
}
} else if self
.services
.appservice
.is_exclusive_alias(room_alias)
.await
{
return Err(Error::BadRequest(ErrorKind::Exclusive, "Room alias reserved by appservice."));
}
Ok(())
}
}

View file

@ -1,71 +1,75 @@
use conduit::{debug, debug_info, debug_warn, Error, Result};
use conduit::{debug, debug_warn, Error, Result};
use ruma::{
api::{client::error::ErrorKind, federation},
OwnedRoomId, OwnedServerName, RoomAliasId,
};
use crate::services;
impl super::Service {
pub(super) async fn remote_resolve(
&self, room_alias: &RoomAliasId, servers: Option<&Vec<OwnedServerName>>,
) -> Result<(OwnedRoomId, Option<Vec<OwnedServerName>>)> {
debug!(?room_alias, ?servers, "resolve");
pub(super) async fn resolve(
room_alias: &RoomAliasId, servers: Option<&Vec<OwnedServerName>>,
) -> Result<(OwnedRoomId, Option<Vec<OwnedServerName>>)> {
debug!(?room_alias, ?servers, "resolve");
let mut response = self
.services
.sending
.send_federation_request(
room_alias.server_name(),
federation::query::get_room_information::v1::Request {
room_alias: room_alias.to_owned(),
},
)
.await;
let mut response = services()
.sending
.send_federation_request(
room_alias.server_name(),
federation::query::get_room_information::v1::Request {
room_alias: room_alias.to_owned(),
},
)
.await;
debug!("room alias server_name get_alias_helper response: {response:?}");
debug!("room alias server_name get_alias_helper response: {response:?}");
if let Err(ref e) = response {
debug_warn!(
"Server {} of the original room alias failed to assist in resolving room alias: {e}",
room_alias.server_name(),
);
}
if let Err(ref e) = response {
debug_info!(
"Server {} of the original room alias failed to assist in resolving room alias: {e}",
room_alias.server_name()
);
}
if response.as_ref().is_ok_and(|resp| resp.servers.is_empty()) || response.as_ref().is_err() {
if let Some(servers) = servers {
for server in servers {
response = self
.services
.sending
.send_federation_request(
server,
federation::query::get_room_information::v1::Request {
room_alias: room_alias.to_owned(),
},
)
.await;
debug!("Got response from server {server} for room aliases: {response:?}");
if response.as_ref().is_ok_and(|resp| resp.servers.is_empty()) || response.as_ref().is_err() {
if let Some(servers) = servers {
for server in servers {
response = services()
.sending
.send_federation_request(
server,
federation::query::get_room_information::v1::Request {
room_alias: room_alias.to_owned(),
},
)
.await;
debug!("Got response from server {server} for room aliases: {response:?}");
if let Ok(ref response) = response {
if !response.servers.is_empty() {
break;
if let Ok(ref response) = response {
if !response.servers.is_empty() {
break;
}
debug_warn!(
"Server {server} responded with room aliases, but was empty? Response: {response:?}"
);
}
debug_warn!("Server {server} responded with room aliases, but was empty? Response: {response:?}");
}
}
}
if let Ok(response) = response {
let room_id = response.room_id;
let mut pre_servers = response.servers;
// since the room alis server responded, insert it into the list
pre_servers.push(room_alias.server_name().into());
return Ok((room_id, Some(pre_servers)));
}
Err(Error::BadRequest(
ErrorKind::NotFound,
"No servers could assist in resolving the room alias",
))
}
if let Ok(response) = response {
let room_id = response.room_id;
let mut pre_servers = response.servers;
// since the room alis server responded, insert it into the list
pre_servers.push(room_alias.server_name().into());
return Ok((room_id, Some(pre_servers)));
}
Err(Error::BadRequest(
ErrorKind::NotFound,
"No servers could assist in resolving the room alias",
))
}

View file

@ -3,8 +3,8 @@ use std::{
sync::{Arc, Mutex},
};
use conduit::{utils, utils::math::usize_from_f64, Result, Server};
use database::{Database, Map};
use conduit::{utils, utils::math::usize_from_f64, Result};
use database::Map;
use lru_cache::LruCache;
pub(super) struct Data {
@ -13,8 +13,9 @@ pub(super) struct Data {
}
impl Data {
pub(super) fn new(server: &Arc<Server>, db: &Arc<Database>) -> Self {
let config = &server.config;
pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
let config = &args.server.config;
let cache_size = f64::from(config.auth_chain_cache_capacity);
let cache_size = usize_from_f64(cache_size * config.cache_capacity_modifier).expect("valid cache size");
Self {

View file

@ -6,19 +6,29 @@ use std::{
};
use conduit::{debug, error, trace, validated, warn, Err, Result};
use data::Data;
use ruma::{EventId, RoomId};
use crate::services;
use self::data::Data;
use crate::{rooms, Dep};
pub struct Service {
services: Services,
db: Data,
}
struct Services {
short: Dep<rooms::short::Service>,
timeline: Dep<rooms::timeline::Service>,
}
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db: Data::new(args.server, args.db),
services: Services {
short: args.depend::<rooms::short::Service>("rooms::short"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
},
db: Data::new(&args),
}))
}
@ -27,7 +37,7 @@ impl crate::Service for Service {
impl Service {
pub async fn event_ids_iter<'a>(
&self, room_id: &RoomId, starting_events_: Vec<Arc<EventId>>,
&'a self, room_id: &RoomId, starting_events_: Vec<Arc<EventId>>,
) -> Result<impl Iterator<Item = Arc<EventId>> + 'a> {
let mut starting_events: Vec<&EventId> = Vec::with_capacity(starting_events_.len());
for starting_event in &starting_events_ {
@ -38,7 +48,7 @@ impl Service {
.get_auth_chain(room_id, &starting_events)
.await?
.into_iter()
.filter_map(move |sid| services().rooms.short.get_eventid_from_short(sid).ok()))
.filter_map(move |sid| self.services.short.get_eventid_from_short(sid).ok()))
}
#[tracing::instrument(skip_all, name = "auth_chain")]
@ -48,8 +58,8 @@ impl Service {
let started = std::time::Instant::now();
let mut buckets = [BUCKET; NUM_BUCKETS];
for (i, &short) in services()
.rooms
for (i, &short) in self
.services
.short
.multi_get_or_create_shorteventid(starting_events)?
.iter()
@ -140,7 +150,7 @@ impl Service {
while let Some(event_id) = todo.pop() {
trace!(?event_id, "processing auth event");
match services().rooms.timeline.get_pdu(&event_id) {
match self.services.timeline.get_pdu(&event_id) {
Ok(Some(pdu)) => {
if pdu.room_id != room_id {
return Err!(Request(Forbidden(
@ -150,10 +160,7 @@ impl Service {
)));
}
for auth_event in &pdu.auth_events {
let sauthevent = services()
.rooms
.short
.get_or_create_shorteventid(auth_event)?;
let sauthevent = self.services.short.get_or_create_shorteventid(auth_event)?;
if found.insert(sauthevent) {
trace!(?event_id, ?auth_event, "adding auth event to processing queue");

View file

@ -2,10 +2,10 @@ mod data;
use std::sync::Arc;
use data::Data;
use conduit::Result;
use ruma::{OwnedRoomId, RoomId};
use crate::Result;
use self::data::Data;
pub struct Service {
db: Data,

View file

@ -10,12 +10,11 @@ use std::{
};
use conduit::{
debug, debug_error, debug_info, err, error, info, trace,
debug, debug_error, debug_info, err, error, info, pdu, trace,
utils::{math::continue_exponential_backoff_secs, MutexMap},
warn, Error, Result,
warn, Error, PduEvent, Result,
};
use futures_util::Future;
pub use parse_incoming_pdu::parse_incoming_pdu;
use ruma::{
api::{
client::error::ErrorKind,
@ -36,13 +35,28 @@ use ruma::{
use tokio::sync::RwLock;
use super::state_compressor::CompressedStateEvent;
use crate::{pdu, services, PduEvent};
use crate::{globals, rooms, sending, Dep};
pub struct Service {
services: Services,
pub federation_handletime: StdRwLock<HandleTimeMap>,
pub mutex_federation: RoomMutexMap,
}
struct Services {
globals: Dep<globals::Service>,
sending: Dep<sending::Service>,
auth_chain: Dep<rooms::auth_chain::Service>,
metadata: Dep<rooms::metadata::Service>,
outlier: Dep<rooms::outlier::Service>,
pdu_metadata: Dep<rooms::pdu_metadata::Service>,
short: Dep<rooms::short::Service>,
state: Dep<rooms::state::Service>,
state_accessor: Dep<rooms::state_accessor::Service>,
state_compressor: Dep<rooms::state_compressor::Service>,
timeline: Dep<rooms::timeline::Service>,
}
type RoomMutexMap = MutexMap<OwnedRoomId, ()>;
type HandleTimeMap = HashMap<OwnedRoomId, (OwnedEventId, Instant)>;
@ -55,8 +69,21 @@ type AsyncRecursiveCanonicalJsonResult<'a> =
AsyncRecursiveType<'a, Result<(Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>)>>;
impl crate::Service for Service {
fn build(_args: crate::Args<'_>) -> Result<Arc<Self>> {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
services: Services {
globals: args.depend::<globals::Service>("globals"),
sending: args.depend::<sending::Service>("sending"),
auth_chain: args.depend::<rooms::auth_chain::Service>("rooms::auth_chain"),
metadata: args.depend::<rooms::metadata::Service>("rooms::metadata"),
outlier: args.depend::<rooms::outlier::Service>("rooms::outlier"),
pdu_metadata: args.depend::<rooms::pdu_metadata::Service>("rooms::pdu_metadata"),
short: args.depend::<rooms::short::Service>("rooms::short"),
state: args.depend::<rooms::state::Service>("rooms::state"),
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_compressor: args.depend::<rooms::state_compressor::Service>("rooms::state_compressor"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
},
federation_handletime: HandleTimeMap::new().into(),
mutex_federation: RoomMutexMap::new(),
}))
@ -114,17 +141,17 @@ impl Service {
pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
) -> Result<Option<Vec<u8>>> {
// 1. Skip the PDU if we already have it as a timeline event
if let Some(pdu_id) = services().rooms.timeline.get_pdu_id(event_id)? {
if let Some(pdu_id) = self.services.timeline.get_pdu_id(event_id)? {
return Ok(Some(pdu_id.to_vec()));
}
// 1.1 Check the server is in the room
if !services().rooms.metadata.exists(room_id)? {
if !self.services.metadata.exists(room_id)? {
return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server"));
}
// 1.2 Check if the room is disabled
if services().rooms.metadata.is_disabled(room_id)? {
if self.services.metadata.is_disabled(room_id)? {
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"Federation of this room is currently disabled on this server.",
@ -147,8 +174,8 @@ impl Service {
self.acl_check(sender.server_name(), room_id)?;
// Fetch create event
let create_event = services()
.rooms
let create_event = self
.services
.state_accessor
.room_state_get(room_id, &StateEventType::RoomCreate, "")?
.ok_or_else(|| Error::bad_database("Failed to find create event in db."))?;
@ -156,8 +183,8 @@ impl Service {
// Procure the room version
let room_version_id = Self::get_room_version_id(&create_event)?;
let first_pdu_in_room = services()
.rooms
let first_pdu_in_room = self
.services
.timeline
.first_pdu_in_room(room_id)?
.ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?;
@ -208,7 +235,8 @@ impl Service {
Ok(()) => continue,
Err(e) => {
warn!("Prev event {} failed: {}", prev_id, e);
match services()
match self
.services
.globals
.bad_event_ratelimiter
.write()
@ -258,7 +286,7 @@ impl Service {
create_event: &Arc<PduEvent>, first_pdu_in_room: &Arc<PduEvent>, prev_id: &EventId,
) -> Result<()> {
// Check for disabled again because it might have changed
if services().rooms.metadata.is_disabled(room_id)? {
if self.services.metadata.is_disabled(room_id)? {
debug!(
"Federaton of room {room_id} is currently disabled on this server. Request by origin {origin} and \
event ID {event_id}"
@ -269,7 +297,8 @@ impl Service {
));
}
if let Some((time, tries)) = services()
if let Some((time, tries)) = self
.services
.globals
.bad_event_ratelimiter
.read()
@ -349,7 +378,7 @@ impl Service {
};
// Skip the PDU if it is redacted and we already have it as an outlier event
if services().rooms.timeline.get_pdu_json(event_id)?.is_some() {
if self.services.timeline.get_pdu_json(event_id)?.is_some() {
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Event was redacted and we already knew about it",
@ -401,7 +430,7 @@ impl Service {
// Build map of auth events
let mut auth_events = HashMap::with_capacity(incoming_pdu.auth_events.len());
for id in &incoming_pdu.auth_events {
let Some(auth_event) = services().rooms.timeline.get_pdu(id)? else {
let Some(auth_event) = self.services.timeline.get_pdu(id)? else {
warn!("Could not find auth event {}", id);
continue;
};
@ -454,8 +483,7 @@ impl Service {
trace!("Validation successful.");
// 7. Persist the event as an outlier.
services()
.rooms
self.services
.outlier
.add_pdu_outlier(&incoming_pdu.event_id, &val)?;
@ -470,12 +498,12 @@ impl Service {
origin: &ServerName, room_id: &RoomId, pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
) -> Result<Option<Vec<u8>>> {
// Skip the PDU if we already have it as a timeline event
if let Ok(Some(pduid)) = services().rooms.timeline.get_pdu_id(&incoming_pdu.event_id) {
if let Ok(Some(pduid)) = self.services.timeline.get_pdu_id(&incoming_pdu.event_id) {
return Ok(Some(pduid.to_vec()));
}
if services()
.rooms
if self
.services
.pdu_metadata
.is_event_soft_failed(&incoming_pdu.event_id)?
{
@ -521,14 +549,13 @@ impl Service {
&incoming_pdu,
None::<PduEvent>, // TODO: third party invite
|k, s| {
services()
.rooms
self.services
.short
.get_shortstatekey(&k.to_string().into(), s)
.ok()
.flatten()
.and_then(|shortstatekey| state_at_incoming_event.get(&shortstatekey))
.and_then(|event_id| services().rooms.timeline.get_pdu(event_id).ok().flatten())
.and_then(|event_id| self.services.timeline.get_pdu(event_id).ok().flatten())
},
)
.map_err(|_e| Error::BadRequest(ErrorKind::forbidden(), "Auth check failed."))?;
@ -541,7 +568,7 @@ impl Service {
}
debug!("Gathering auth events");
let auth_events = services().rooms.state.get_auth_events(
let auth_events = self.services.state.get_auth_events(
room_id,
&incoming_pdu.kind,
&incoming_pdu.sender,
@ -562,7 +589,7 @@ impl Service {
&& match room_version_id {
V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => {
if let Some(redact_id) = &incoming_pdu.redacts {
!services().rooms.state_accessor.user_can_redact(
!self.services.state_accessor.user_can_redact(
redact_id,
&incoming_pdu.sender,
&incoming_pdu.room_id,
@ -577,7 +604,7 @@ impl Service {
.map_err(|_| Error::bad_database("Invalid content in redaction pdu."))?;
if let Some(redact_id) = &content.redacts {
!services().rooms.state_accessor.user_can_redact(
!self.services.state_accessor.user_can_redact(
redact_id,
&incoming_pdu.sender,
&incoming_pdu.room_id,
@ -594,12 +621,12 @@ impl Service {
// We start looking at current room state now, so lets lock the room
trace!("Locking the room");
let state_lock = services().rooms.state.mutex.lock(room_id).await;
let state_lock = self.services.state.mutex.lock(room_id).await;
// Now we calculate the set of extremities this room has after the incoming
// event has been applied. We start with the previous extremities (aka leaves)
trace!("Calculating extremities");
let mut extremities = services().rooms.state.get_forward_extremities(room_id)?;
let mut extremities = self.services.state.get_forward_extremities(room_id)?;
trace!("Calculated {} extremities", extremities.len());
// Remove any forward extremities that are referenced by this incoming event's
@ -609,22 +636,13 @@ impl Service {
}
// Only keep those extremities were not referenced yet
extremities.retain(|id| {
!matches!(
services()
.rooms
.pdu_metadata
.is_event_referenced(room_id, id),
Ok(true)
)
});
extremities.retain(|id| !matches!(self.services.pdu_metadata.is_event_referenced(room_id, id), Ok(true)));
debug!("Retained {} extremities. Compressing state", extremities.len());
let state_ids_compressed = Arc::new(
state_at_incoming_event
.iter()
.map(|(shortstatekey, id)| {
services()
.rooms
self.services
.state_compressor
.compress_state_event(*shortstatekey, id)
})
@ -637,8 +655,8 @@ impl Service {
// We also add state after incoming event to the fork states
let mut state_after = state_at_incoming_event.clone();
if let Some(state_key) = &incoming_pdu.state_key {
let shortstatekey = services()
.rooms
let shortstatekey = self
.services
.short
.get_or_create_shortstatekey(&incoming_pdu.kind.to_string().into(), state_key)?;
@ -651,13 +669,12 @@ impl Service {
// Set the new room state to the resolved state
debug!("Forcing new room state");
let (sstatehash, new, removed) = services()
.rooms
let (sstatehash, new, removed) = self
.services
.state_compressor
.save_state(room_id, new_room_state)?;
services()
.rooms
self.services
.state
.force_state(room_id, sstatehash, new, removed, &state_lock)
.await?;
@ -667,8 +684,7 @@ impl Service {
// if not soft fail it
if soft_fail {
debug!("Soft failing event");
services()
.rooms
self.services
.timeline
.append_incoming_pdu(
&incoming_pdu,
@ -682,8 +698,7 @@ impl Service {
// Soft fail, we keep the event as an outlier but don't add it to the timeline
warn!("Event was soft failed: {:?}", incoming_pdu);
services()
.rooms
self.services
.pdu_metadata
.mark_event_soft_failed(&incoming_pdu.event_id)?;
@ -696,8 +711,8 @@ impl Service {
// Now that the event has passed all auth it is added into the timeline.
// We use the `state_at_event` instead of `state_after` so we accurately
// represent the state for this event.
let pdu_id = services()
.rooms
let pdu_id = self
.services
.timeline
.append_incoming_pdu(
&incoming_pdu,
@ -723,14 +738,14 @@ impl Service {
&self, room_id: &RoomId, room_version_id: &RoomVersionId, incoming_state: HashMap<u64, Arc<EventId>>,
) -> Result<Arc<HashSet<CompressedStateEvent>>> {
debug!("Loading current room state ids");
let current_sstatehash = services()
.rooms
let current_sstatehash = self
.services
.state
.get_room_shortstatehash(room_id)?
.expect("every room has state");
let current_state_ids = services()
.rooms
let current_state_ids = self
.services
.state_accessor
.state_full_ids(current_sstatehash)
.await?;
@ -740,8 +755,7 @@ impl Service {
let mut auth_chain_sets = Vec::with_capacity(fork_states.len());
for state in &fork_states {
auth_chain_sets.push(
services()
.rooms
self.services
.auth_chain
.event_ids_iter(room_id, state.iter().map(|(_, id)| id.clone()).collect())
.await?
@ -755,8 +769,7 @@ impl Service {
.map(|map| {
map.into_iter()
.filter_map(|(k, id)| {
services()
.rooms
self.services
.short
.get_statekey_from_short(k)
.map(|(ty, st_key)| ((ty.to_string().into(), st_key), id))
@ -766,11 +779,11 @@ impl Service {
})
.collect();
let lock = services().globals.stateres_mutex.lock();
let lock = self.services.globals.stateres_mutex.lock();
debug!("Resolving state");
let state_resolve = state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| {
let res = services().rooms.timeline.get_pdu(id);
let res = self.services.timeline.get_pdu(id);
if let Err(e) = &res {
error!("Failed to fetch event: {}", e);
}
@ -793,12 +806,11 @@ impl Service {
let new_room_state = state
.into_iter()
.map(|((event_type, state_key), event_id)| {
let shortstatekey = services()
.rooms
let shortstatekey = self
.services
.short
.get_or_create_shortstatekey(&event_type.to_string().into(), &state_key)?;
services()
.rooms
self.services
.state_compressor
.compress_state_event(shortstatekey, &event_id)
})
@ -814,15 +826,14 @@ impl Service {
&self, incoming_pdu: &Arc<PduEvent>,
) -> Result<Option<HashMap<u64, Arc<EventId>>>> {
let prev_event = &*incoming_pdu.prev_events[0];
let prev_event_sstatehash = services()
.rooms
let prev_event_sstatehash = self
.services
.state_accessor
.pdu_shortstatehash(prev_event)?;
let state = if let Some(shortstatehash) = prev_event_sstatehash {
Some(
services()
.rooms
self.services
.state_accessor
.state_full_ids(shortstatehash)
.await,
@ -833,8 +844,8 @@ impl Service {
if let Some(Ok(mut state)) = state {
debug!("Using cached state");
let prev_pdu = services()
.rooms
let prev_pdu = self
.services
.timeline
.get_pdu(prev_event)
.ok()
@ -842,8 +853,8 @@ impl Service {
.ok_or_else(|| Error::bad_database("Could not find prev event, but we know the state."))?;
if let Some(state_key) = &prev_pdu.state_key {
let shortstatekey = services()
.rooms
let shortstatekey = self
.services
.short
.get_or_create_shortstatekey(&prev_pdu.kind.to_string().into(), state_key)?;
@ -866,13 +877,13 @@ impl Service {
let mut okay = true;
for prev_eventid in &incoming_pdu.prev_events {
let Ok(Some(prev_event)) = services().rooms.timeline.get_pdu(prev_eventid) else {
let Ok(Some(prev_event)) = self.services.timeline.get_pdu(prev_eventid) else {
okay = false;
break;
};
let Ok(Some(sstatehash)) = services()
.rooms
let Ok(Some(sstatehash)) = self
.services
.state_accessor
.pdu_shortstatehash(prev_eventid)
else {
@ -891,15 +902,15 @@ impl Service {
let mut auth_chain_sets = Vec::with_capacity(extremity_sstatehashes.len());
for (sstatehash, prev_event) in extremity_sstatehashes {
let mut leaf_state: HashMap<_, _> = services()
.rooms
let mut leaf_state: HashMap<_, _> = self
.services
.state_accessor
.state_full_ids(sstatehash)
.await?;
if let Some(state_key) = &prev_event.state_key {
let shortstatekey = services()
.rooms
let shortstatekey = self
.services
.short
.get_or_create_shortstatekey(&prev_event.kind.to_string().into(), state_key)?;
leaf_state.insert(shortstatekey, Arc::from(&*prev_event.event_id));
@ -910,7 +921,7 @@ impl Service {
let mut starting_events = Vec::with_capacity(leaf_state.len());
for (k, id) in leaf_state {
if let Ok((ty, st_key)) = services().rooms.short.get_statekey_from_short(k) {
if let Ok((ty, st_key)) = self.services.short.get_statekey_from_short(k) {
// FIXME: Undo .to_string().into() when StateMap
// is updated to use StateEventType
state.insert((ty.to_string().into(), st_key), id.clone());
@ -921,8 +932,7 @@ impl Service {
}
auth_chain_sets.push(
services()
.rooms
self.services
.auth_chain
.event_ids_iter(room_id, starting_events)
.await?
@ -932,9 +942,9 @@ impl Service {
fork_states.push(state);
}
let lock = services().globals.stateres_mutex.lock();
let lock = self.services.globals.stateres_mutex.lock();
let result = state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| {
let res = services().rooms.timeline.get_pdu(id);
let res = self.services.timeline.get_pdu(id);
if let Err(e) = &res {
error!("Failed to fetch event: {}", e);
}
@ -947,8 +957,8 @@ impl Service {
new_state
.into_iter()
.map(|((event_type, state_key), event_id)| {
let shortstatekey = services()
.rooms
let shortstatekey = self
.services
.short
.get_or_create_shortstatekey(&event_type.to_string().into(), &state_key)?;
Ok((shortstatekey, event_id))
@ -974,7 +984,8 @@ impl Service {
pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, event_id: &EventId,
) -> Result<Option<HashMap<u64, Arc<EventId>>>> {
debug!("Fetching state ids");
match services()
match self
.services
.sending
.send_federation_request(
origin,
@ -1004,8 +1015,8 @@ impl Service {
.clone()
.ok_or_else(|| Error::bad_database("Found non-state pdu in state events."))?;
let shortstatekey = services()
.rooms
let shortstatekey = self
.services
.short
.get_or_create_shortstatekey(&pdu.kind.to_string().into(), &state_key)?;
@ -1022,8 +1033,8 @@ impl Service {
}
// The original create event must still be in the state
let create_shortstatekey = services()
.rooms
let create_shortstatekey = self
.services
.short
.get_shortstatekey(&StateEventType::RoomCreate, "")?
.expect("Room exists");
@ -1056,7 +1067,8 @@ impl Service {
) -> AsyncRecursiveCanonicalJsonVec<'a> {
Box::pin(async move {
let back_off = |id| async {
match services()
match self
.services
.globals
.bad_event_ratelimiter
.write()
@ -1075,7 +1087,7 @@ impl Service {
// a. Look in the main timeline (pduid_pdu tree)
// b. Look at outlier pdu tree
// (get_pdu_json checks both)
if let Ok(Some(local_pdu)) = services().rooms.timeline.get_pdu(id) {
if let Ok(Some(local_pdu)) = self.services.timeline.get_pdu(id) {
trace!("Found {} in db", id);
events_with_auth_events.push((id, Some(local_pdu), vec![]));
continue;
@ -1089,7 +1101,8 @@ impl Service {
let mut events_all = HashSet::with_capacity(todo_auth_events.len());
let mut i: u64 = 0;
while let Some(next_id) = todo_auth_events.pop() {
if let Some((time, tries)) = services()
if let Some((time, tries)) = self
.services
.globals
.bad_event_ratelimiter
.read()
@ -1114,13 +1127,14 @@ impl Service {
tokio::task::yield_now().await;
}
if let Ok(Some(_)) = services().rooms.timeline.get_pdu(&next_id) {
if let Ok(Some(_)) = self.services.timeline.get_pdu(&next_id) {
trace!("Found {} in db", next_id);
continue;
}
debug!("Fetching {} over federation.", next_id);
match services()
match self
.services
.sending
.send_federation_request(
origin,
@ -1195,7 +1209,8 @@ impl Service {
pdus.push((local_pdu, None));
}
for (next_id, value) in events_in_reverse_order.iter().rev() {
if let Some((time, tries)) = services()
if let Some((time, tries)) = self
.services
.globals
.bad_event_ratelimiter
.read()
@ -1244,8 +1259,8 @@ impl Service {
let mut eventid_info = HashMap::new();
let mut todo_outlier_stack: Vec<Arc<EventId>> = initial_set;
let first_pdu_in_room = services()
.rooms
let first_pdu_in_room = self
.services
.timeline
.first_pdu_in_room(room_id)?
.ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?;
@ -1267,19 +1282,18 @@ impl Service {
{
Self::check_room_id(room_id, &pdu)?;
if amount > services().globals.max_fetch_prev_events() {
if amount > self.services.globals.max_fetch_prev_events() {
// Max limit reached
debug!(
"Max prev event limit reached! Limit: {}",
services().globals.max_fetch_prev_events()
self.services.globals.max_fetch_prev_events()
);
graph.insert(prev_event_id.clone(), HashSet::new());
continue;
}
if let Some(json) = json_opt.or_else(|| {
services()
.rooms
self.services
.outlier
.get_outlier_pdu_json(&prev_event_id)
.ok()
@ -1335,8 +1349,7 @@ impl Service {
#[tracing::instrument(skip_all)]
pub fn acl_check(&self, server_name: &ServerName, room_id: &RoomId) -> Result<()> {
let acl_event = if let Some(acl) =
services()
.rooms
self.services
.state_accessor
.room_state_get(room_id, &StateEventType::RoomServerAcl, "")?
{

View file

@ -1,29 +1,28 @@
use conduit::{Err, Error, Result};
use conduit::{pdu::gen_event_id_canonical_json, warn, Err, Error, Result};
use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, OwnedEventId, OwnedRoomId, RoomId};
use serde_json::value::RawValue as RawJsonValue;
use tracing::warn;
use crate::{pdu::gen_event_id_canonical_json, services};
impl super::Service {
pub fn parse_incoming_pdu(&self, pdu: &RawJsonValue) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> {
let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| {
warn!("Error parsing incoming event {pdu:?}: {e:?}");
Error::BadServerResponse("Invalid PDU in server response")
})?;
pub fn parse_incoming_pdu(pdu: &RawJsonValue) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> {
let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| {
warn!("Error parsing incoming event {pdu:?}: {e:?}");
Error::BadServerResponse("Invalid PDU in server response")
})?;
let room_id: OwnedRoomId = value
.get("room_id")
.and_then(|id| RoomId::parse(id.as_str()?).ok())
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid room id in pdu"))?;
let room_id: OwnedRoomId = value
.get("room_id")
.and_then(|id| RoomId::parse(id.as_str()?).ok())
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid room id in pdu"))?;
let Ok(room_version_id) = self.services.state.get_room_version(&room_id) else {
return Err!("Server is not in room {room_id}");
};
let Ok(room_version_id) = services().rooms.state.get_room_version(&room_id) else {
return Err!("Server is not in room {room_id}");
};
let Ok((event_id, value)) = gen_event_id_canonical_json(pdu, &room_version_id) else {
// Event could not be converted to canonical json
return Err!(Request(InvalidParam("Could not convert event to canonical json.")));
};
let Ok((event_id, value)) = gen_event_id_canonical_json(pdu, &room_version_id) else {
// Event could not be converted to canonical json
return Err!(Request(InvalidParam("Could not convert event to canonical json.")));
};
Ok((event_id, value, room_id))
Ok((event_id, value, room_id))
}
}

View file

@ -3,7 +3,7 @@ use std::{
time::{Duration, SystemTime},
};
use conduit::{debug, error, info, trace, warn};
use conduit::{debug, error, info, trace, warn, Error, Result};
use futures_util::{stream::FuturesUnordered, StreamExt};
use ruma::{
api::federation::{
@ -21,8 +21,6 @@ use ruma::{
use serde_json::value::RawValue as RawJsonValue;
use tokio::sync::{RwLock, RwLockWriteGuard};
use crate::{services, Error, Result};
impl super::Service {
pub async fn fetch_required_signing_keys<'a, E>(
&'a self, events: E, pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
@ -147,7 +145,8 @@ impl super::Service {
debug!("Loading signing keys for {}", origin);
let result: BTreeMap<_, _> = services()
let result: BTreeMap<_, _> = self
.services
.globals
.signing_keys_for(origin)?
.into_iter()
@ -171,9 +170,10 @@ impl super::Service {
&self, mut servers: BTreeMap<OwnedServerName, BTreeMap<OwnedServerSigningKeyId, QueryCriteria>>,
pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
) -> Result<()> {
for server in services().globals.trusted_servers() {
for server in self.services.globals.trusted_servers() {
debug!("Asking batch signing keys from trusted server {}", server);
match services()
match self
.services
.sending
.send_federation_request(
server,
@ -199,7 +199,8 @@ impl super::Service {
// TODO: Check signature from trusted server?
servers.remove(&k.server_name);
let result = services()
let result = self
.services
.globals
.db
.add_signing_key(&k.server_name, k.clone())?
@ -234,7 +235,7 @@ impl super::Service {
.into_keys()
.map(|server| async move {
(
services()
self.services
.sending
.send_federation_request(&server, get_server_keys::v2::Request::new())
.await,
@ -248,7 +249,8 @@ impl super::Service {
if let (Ok(get_keys_response), origin) = result {
debug!("Result is from {origin}");
if let Ok(key) = get_keys_response.server_key.deserialize() {
let result: BTreeMap<_, _> = services()
let result: BTreeMap<_, _> = self
.services
.globals
.db
.add_signing_key(&origin, key)?
@ -297,7 +299,7 @@ impl super::Service {
return Ok(());
}
if services().globals.query_trusted_key_servers_first() {
if self.services.globals.query_trusted_key_servers_first() {
info!(
"query_trusted_key_servers_first is set to true, querying notary trusted key servers first for \
homeserver signing keys."
@ -349,7 +351,8 @@ impl super::Service {
) -> Result<BTreeMap<String, Base64>> {
let contains_all_ids = |keys: &BTreeMap<String, Base64>| signature_ids.iter().all(|id| keys.contains_key(id));
let mut result: BTreeMap<_, _> = services()
let mut result: BTreeMap<_, _> = self
.services
.globals
.signing_keys_for(origin)?
.into_iter()
@ -362,15 +365,16 @@ impl super::Service {
}
// i didnt split this out into their own functions because it's relatively small
if services().globals.query_trusted_key_servers_first() {
if self.services.globals.query_trusted_key_servers_first() {
info!(
"query_trusted_key_servers_first is set to true, querying notary trusted servers first for {origin} \
keys"
);
for server in services().globals.trusted_servers() {
for server in self.services.globals.trusted_servers() {
debug!("Asking notary server {server} for {origin}'s signing key");
if let Some(server_keys) = services()
if let Some(server_keys) = self
.services
.sending
.send_federation_request(
server,
@ -394,7 +398,10 @@ impl super::Service {
}) {
debug!("Got signing keys: {:?}", server_keys);
for k in server_keys {
services().globals.db.add_signing_key(origin, k.clone())?;
self.services
.globals
.db
.add_signing_key(origin, k.clone())?;
result.extend(
k.verify_keys
.into_iter()
@ -414,14 +421,15 @@ impl super::Service {
}
debug!("Asking {origin} for their signing keys over federation");
if let Some(server_key) = services()
if let Some(server_key) = self
.services
.sending
.send_federation_request(origin, get_server_keys::v2::Request::new())
.await
.ok()
.and_then(|resp| resp.server_key.deserialize().ok())
{
services()
self.services
.globals
.db
.add_signing_key(origin, server_key.clone())?;
@ -447,14 +455,15 @@ impl super::Service {
info!("query_trusted_key_servers_first is set to false, querying {origin} first");
debug!("Asking {origin} for their signing keys over federation");
if let Some(server_key) = services()
if let Some(server_key) = self
.services
.sending
.send_federation_request(origin, get_server_keys::v2::Request::new())
.await
.ok()
.and_then(|resp| resp.server_key.deserialize().ok())
{
services()
self.services
.globals
.db
.add_signing_key(origin, server_key.clone())?;
@ -477,9 +486,10 @@ impl super::Service {
}
}
for server in services().globals.trusted_servers() {
for server in self.services.globals.trusted_servers() {
debug!("Asking notary server {server} for {origin}'s signing key");
if let Some(server_keys) = services()
if let Some(server_keys) = self
.services
.sending
.send_federation_request(
server,
@ -503,7 +513,10 @@ impl super::Service {
}) {
debug!("Got signing keys: {:?}", server_keys);
for k in server_keys {
services().globals.db.add_signing_key(origin, k.clone())?;
self.services
.globals
.db
.add_signing_key(origin, k.clone())?;
result.extend(
k.verify_keys
.into_iter()

View file

@ -6,10 +6,10 @@ use std::{
sync::{Arc, Mutex},
};
use data::Data;
use conduit::{PduCount, Result};
use ruma::{DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId};
use crate::{PduCount, Result};
use self::data::Data;
pub struct Service {
db: Data,

View file

@ -1,30 +1,39 @@
use std::sync::Arc;
use conduit::{error, utils, Error, Result};
use database::{Database, Map};
use database::Map;
use ruma::{OwnedRoomId, RoomId};
use crate::services;
use crate::{rooms, Dep};
pub(super) struct Data {
disabledroomids: Arc<Map>,
bannedroomids: Arc<Map>,
roomid_shortroomid: Arc<Map>,
pduid_pdu: Arc<Map>,
services: Services,
}
struct Services {
short: Dep<rooms::short::Service>,
}
impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self {
pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self {
disabledroomids: db["disabledroomids"].clone(),
bannedroomids: db["bannedroomids"].clone(),
roomid_shortroomid: db["roomid_shortroomid"].clone(),
pduid_pdu: db["pduid_pdu"].clone(),
services: Services {
short: args.depend::<rooms::short::Service>("rooms::short"),
},
}
}
pub(super) fn exists(&self, room_id: &RoomId) -> Result<bool> {
let prefix = match services().rooms.short.get_shortroomid(room_id)? {
let prefix = match self.services.short.get_shortroomid(room_id)? {
Some(b) => b.to_be_bytes().to_vec(),
None => return Ok(false),
};

View file

@ -3,9 +3,10 @@ mod data;
use std::sync::Arc;
use conduit::Result;
use data::Data;
use ruma::{OwnedRoomId, RoomId};
use self::data::Data;
pub struct Service {
db: Data,
}
@ -13,7 +14,7 @@ pub struct Service {
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db: Data::new(args.db),
db: Data::new(&args),
}))
}

View file

@ -33,13 +33,13 @@ pub struct Service {
pub read_receipt: Arc<read_receipt::Service>,
pub search: Arc<search::Service>,
pub short: Arc<short::Service>,
pub spaces: Arc<spaces::Service>,
pub state: Arc<state::Service>,
pub state_accessor: Arc<state_accessor::Service>,
pub state_cache: Arc<state_cache::Service>,
pub state_compressor: Arc<state_compressor::Service>,
pub timeline: Arc<timeline::Service>,
pub threads: Arc<threads::Service>,
pub timeline: Arc<timeline::Service>,
pub typing: Arc<typing::Service>,
pub spaces: Arc<spaces::Service>,
pub user: Arc<user::Service>,
}

View file

@ -1,26 +1,35 @@
use std::{mem::size_of, sync::Arc};
use conduit::{utils, Error, Result};
use database::{Database, Map};
use conduit::{utils, Error, PduCount, PduEvent, Result};
use database::Map;
use ruma::{EventId, RoomId, UserId};
use crate::{services, PduCount, PduEvent};
use crate::{rooms, Dep};
pub(super) struct Data {
tofrom_relation: Arc<Map>,
referencedevents: Arc<Map>,
softfailedeventids: Arc<Map>,
services: Services,
}
struct Services {
timeline: Dep<rooms::timeline::Service>,
}
type PdusIterItem = Result<(PduCount, PduEvent)>;
type PdusIterator<'a> = Box<dyn Iterator<Item = PdusIterItem> + 'a>;
impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self {
pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self {
tofrom_relation: db["tofrom_relation"].clone(),
referencedevents: db["referencedevents"].clone(),
softfailedeventids: db["softfailedeventids"].clone(),
services: Services {
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
},
}
}
@ -57,8 +66,8 @@ impl Data {
let mut pduid = shortroomid.to_be_bytes().to_vec();
pduid.extend_from_slice(&from.to_be_bytes());
let mut pdu = services()
.rooms
let mut pdu = self
.services
.timeline
.get_pdu_from_id(&pduid)?
.ok_or_else(|| Error::bad_database("Pdu in tofrom_relation is invalid."))?;

View file

@ -2,8 +2,7 @@ mod data;
use std::sync::Arc;
use conduit::Result;
use data::Data;
use conduit::{PduCount, PduEvent, Result};
use ruma::{
api::{client::relations::get_relating_events, Direction},
events::{relation::RelationType, TimelineEventType},
@ -11,12 +10,20 @@ use ruma::{
};
use serde::Deserialize;
use crate::{services, PduCount, PduEvent};
use self::data::Data;
use crate::{rooms, Dep};
pub struct Service {
services: Services,
db: Data,
}
struct Services {
short: Dep<rooms::short::Service>,
state_accessor: Dep<rooms::state_accessor::Service>,
timeline: Dep<rooms::timeline::Service>,
}
#[derive(Clone, Debug, Deserialize)]
struct ExtractRelType {
rel_type: RelationType,
@ -30,7 +37,12 @@ struct ExtractRelatesToEventId {
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db: Data::new(args.db),
services: Services {
short: args.depend::<rooms::short::Service>("rooms::short"),
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
},
db: Data::new(&args),
}))
}
@ -101,8 +113,7 @@ impl Service {
})
.take(limit)
.filter(|(_, pdu)| {
services()
.rooms
self.services
.state_accessor
.user_can_see_event(sender_user, room_id, &pdu.event_id)
.unwrap_or(false)
@ -147,8 +158,7 @@ impl Service {
})
.take(limit)
.filter(|(_, pdu)| {
services()
.rooms
self.services
.state_accessor
.user_can_see_event(sender_user, room_id, &pdu.event_id)
.unwrap_or(false)
@ -180,10 +190,10 @@ impl Service {
pub fn relations_until<'a>(
&'a self, user_id: &'a UserId, room_id: &'a RoomId, target: &'a EventId, until: PduCount, max_depth: u8,
) -> Result<Vec<(PduCount, PduEvent)>> {
let room_id = services().rooms.short.get_or_create_shortroomid(room_id)?;
let room_id = self.services.short.get_or_create_shortroomid(room_id)?;
#[allow(unknown_lints)]
#[allow(clippy::manual_unwrap_or_default)]
let target = match services().rooms.timeline.get_pdu_count(target)? {
let target = match self.services.timeline.get_pdu_count(target)? {
Some(PduCount::Normal(c)) => c,
// TODO: Support backfilled relations
_ => 0, // This will result in an empty iterator

View file

@ -1,14 +1,14 @@
use std::{mem::size_of, sync::Arc};
use conduit::{utils, Error, Result};
use database::{Database, Map};
use database::Map;
use ruma::{
events::{receipt::ReceiptEvent, AnySyncEphemeralRoomEvent},
serde::Raw,
CanonicalJsonObject, OwnedUserId, RoomId, UserId,
};
use crate::services;
use crate::{globals, Dep};
type AnySyncEphemeralRoomEventIter<'a> =
Box<dyn Iterator<Item = Result<(OwnedUserId, u64, Raw<AnySyncEphemeralRoomEvent>)>> + 'a>;
@ -16,15 +16,24 @@ type AnySyncEphemeralRoomEventIter<'a> =
pub(super) struct Data {
roomuserid_privateread: Arc<Map>,
roomuserid_lastprivatereadupdate: Arc<Map>,
services: Services,
readreceiptid_readreceipt: Arc<Map>,
}
struct Services {
globals: Dep<globals::Service>,
}
impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self {
pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self {
roomuserid_privateread: db["roomuserid_privateread"].clone(),
roomuserid_lastprivatereadupdate: db["roomuserid_lastprivatereadupdate"].clone(),
readreceiptid_readreceipt: db["readreceiptid_readreceipt"].clone(),
services: Services {
globals: args.depend::<globals::Service>("globals"),
},
}
}
@ -51,7 +60,7 @@ impl Data {
}
let mut room_latest_id = prefix;
room_latest_id.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
room_latest_id.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes());
room_latest_id.push(0xFF);
room_latest_id.extend_from_slice(user_id.as_bytes());
@ -108,7 +117,7 @@ impl Data {
.insert(&key, &count.to_be_bytes())?;
self.roomuserid_lastprivatereadupdate
.insert(&key, &services().globals.next_count()?.to_be_bytes())
.insert(&key, &self.services.globals.next_count()?.to_be_bytes())
}
pub(super) fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {

View file

@ -6,16 +6,24 @@ use conduit::Result;
use data::Data;
use ruma::{events::receipt::ReceiptEvent, serde::Raw, OwnedUserId, RoomId, UserId};
use crate::services;
use crate::{sending, Dep};
pub struct Service {
services: Services,
db: Data,
}
struct Services {
sending: Dep<sending::Service>,
}
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db: Data::new(args.db),
services: Services {
sending: args.depend::<sending::Service>("sending"),
},
db: Data::new(&args),
}))
}
@ -26,7 +34,7 @@ impl Service {
/// Replaces the previous read receipt.
pub fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) -> Result<()> {
self.db.readreceipt_update(user_id, room_id, event)?;
services().sending.flush_room(room_id)?;
self.services.sending.flush_room(room_id)?;
Ok(())
}

View file

@ -1,21 +1,30 @@
use std::sync::Arc;
use conduit::{utils, Result};
use database::{Database, Map};
use database::Map;
use ruma::RoomId;
use crate::services;
use crate::{rooms, Dep};
type SearchPdusResult<'a> = Result<Option<(Box<dyn Iterator<Item = Vec<u8>> + 'a>, Vec<String>)>>;
pub(super) struct Data {
tokenids: Arc<Map>,
services: Services,
}
struct Services {
short: Dep<rooms::short::Service>,
}
impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self {
pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self {
tokenids: db["tokenids"].clone(),
services: Services {
short: args.depend::<rooms::short::Service>("rooms::short"),
},
}
}
@ -51,8 +60,8 @@ impl Data {
}
pub(super) fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> {
let prefix = services()
.rooms
let prefix = self
.services
.short
.get_shortroomid(room_id)?
.expect("room exists")

View file

@ -13,7 +13,7 @@ pub struct Service {
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db: Data::new(args.db),
db: Data::new(&args),
}))
}

View file

@ -1,10 +1,10 @@
use std::sync::Arc;
use conduit::{utils, warn, Error, Result};
use database::{Database, Map};
use database::Map;
use ruma::{events::StateEventType, EventId, RoomId};
use crate::services;
use crate::{globals, Dep};
pub(super) struct Data {
eventid_shorteventid: Arc<Map>,
@ -13,10 +13,16 @@ pub(super) struct Data {
shortstatekey_statekey: Arc<Map>,
roomid_shortroomid: Arc<Map>,
statehash_shortstatehash: Arc<Map>,
services: Services,
}
struct Services {
globals: Dep<globals::Service>,
}
impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self {
pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self {
eventid_shorteventid: db["eventid_shorteventid"].clone(),
shorteventid_eventid: db["shorteventid_eventid"].clone(),
@ -24,6 +30,9 @@ impl Data {
shortstatekey_statekey: db["shortstatekey_statekey"].clone(),
roomid_shortroomid: db["roomid_shortroomid"].clone(),
statehash_shortstatehash: db["statehash_shortstatehash"].clone(),
services: Services {
globals: args.depend::<globals::Service>("globals"),
},
}
}
@ -31,7 +40,7 @@ impl Data {
let short = if let Some(shorteventid) = self.eventid_shorteventid.get(event_id.as_bytes())? {
utils::u64_from_bytes(&shorteventid).map_err(|_| Error::bad_database("Invalid shorteventid in db."))?
} else {
let shorteventid = services().globals.next_count()?;
let shorteventid = self.services.globals.next_count()?;
self.eventid_shorteventid
.insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?;
self.shorteventid_eventid
@ -59,7 +68,7 @@ impl Data {
utils::u64_from_bytes(short).map_err(|_| Error::bad_database("Invalid shorteventid in db."))?,
),
None => {
let short = services().globals.next_count()?;
let short = self.services.globals.next_count()?;
self.eventid_shorteventid
.insert(keys[i], &short.to_be_bytes())?;
self.shorteventid_eventid
@ -98,7 +107,7 @@ impl Data {
let short = if let Some(shortstatekey) = self.statekey_shortstatekey.get(&statekey_vec)? {
utils::u64_from_bytes(&shortstatekey).map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?
} else {
let shortstatekey = services().globals.next_count()?;
let shortstatekey = self.services.globals.next_count()?;
self.statekey_shortstatekey
.insert(&statekey_vec, &shortstatekey.to_be_bytes())?;
self.shortstatekey_statekey
@ -158,7 +167,7 @@ impl Data {
true,
)
} else {
let shortstatehash = services().globals.next_count()?;
let shortstatehash = self.services.globals.next_count()?;
self.statehash_shortstatehash
.insert(state_hash, &shortstatehash.to_be_bytes())?;
(shortstatehash, false)
@ -176,7 +185,7 @@ impl Data {
Ok(if let Some(short) = self.roomid_shortroomid.get(room_id.as_bytes())? {
utils::u64_from_bytes(&short).map_err(|_| Error::bad_database("Invalid shortroomid in db."))?
} else {
let short = services().globals.next_count()?;
let short = self.services.globals.next_count()?;
self.roomid_shortroomid
.insert(room_id.as_bytes(), &short.to_be_bytes())?;
short

View file

@ -3,9 +3,10 @@ mod data;
use std::sync::Arc;
use conduit::Result;
use data::Data;
use ruma::{events::StateEventType, EventId, RoomId};
use self::data::Data;
pub struct Service {
db: Data,
}
@ -13,7 +14,7 @@ pub struct Service {
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db: Data::new(args.db),
db: Data::new(&args),
}))
}

View file

@ -28,7 +28,7 @@ use ruma::{
};
use tokio::sync::Mutex;
use crate::services;
use crate::{rooms, sending, Dep};
pub struct CachedSpaceHierarchySummary {
summary: SpaceHierarchyParentSummary,
@ -119,42 +119,18 @@ enum Identifier<'a> {
}
pub struct Service {
services: Services,
pub roomid_spacehierarchy_cache: Mutex<LruCache<OwnedRoomId, Option<CachedSpaceHierarchySummary>>>,
}
// Here because cannot implement `From` across ruma-federation-api and
// ruma-client-api types
impl From<CachedSpaceHierarchySummary> for SpaceHierarchyRoomsChunk {
fn from(value: CachedSpaceHierarchySummary) -> Self {
let SpaceHierarchyParentSummary {
canonical_alias,
name,
num_joined_members,
room_id,
topic,
world_readable,
guest_can_join,
avatar_url,
join_rule,
room_type,
children_state,
..
} = value.summary;
Self {
canonical_alias,
name,
num_joined_members,
room_id,
topic,
world_readable,
guest_can_join,
avatar_url,
join_rule,
room_type,
children_state,
}
}
struct Services {
state_accessor: Dep<rooms::state_accessor::Service>,
state_cache: Dep<rooms::state_cache::Service>,
state: Dep<rooms::state::Service>,
short: Dep<rooms::short::Service>,
event_handler: Dep<rooms::event_handler::Service>,
timeline: Dep<rooms::timeline::Service>,
sending: Dep<sending::Service>,
}
impl crate::Service for Service {
@ -163,6 +139,15 @@ impl crate::Service for Service {
let cache_size = f64::from(config.roomid_spacehierarchy_cache_capacity);
let cache_size = cache_size * config.cache_capacity_modifier;
Ok(Arc::new(Self {
services: Services {
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
state: args.depend::<rooms::state::Service>("rooms::state"),
short: args.depend::<rooms::short::Service>("rooms::short"),
event_handler: args.depend::<rooms::event_handler::Service>("rooms::event_handler"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
sending: args.depend::<sending::Service>("sending"),
},
roomid_spacehierarchy_cache: Mutex::new(LruCache::new(usize_from_f64(cache_size)?)),
}))
}
@ -226,7 +211,7 @@ impl Service {
.as_ref()
{
return Ok(if let Some(cached) = cached {
if is_accessable_child(
if self.is_accessible_child(
current_room,
&cached.summary.join_rule,
&identifier,
@ -242,8 +227,8 @@ impl Service {
}
Ok(
if let Some(children_pdus) = get_stripped_space_child_events(current_room).await? {
let summary = Self::get_room_summary(current_room, children_pdus, &identifier);
if let Some(children_pdus) = self.get_stripped_space_child_events(current_room).await? {
let summary = self.get_room_summary(current_room, children_pdus, &identifier);
if let Ok(summary) = summary {
self.roomid_spacehierarchy_cache.lock().await.insert(
current_room.clone(),
@ -269,7 +254,8 @@ impl Service {
) -> Result<Option<SummaryAccessibility>> {
for server in via {
debug_info!("Asking {server} for /hierarchy");
let Ok(response) = services()
let Ok(response) = self
.services
.sending
.send_federation_request(
server,
@ -325,7 +311,10 @@ impl Service {
avatar_url,
join_rule,
room_type,
children_state: get_stripped_space_child_events(&room_id).await?.unwrap(),
children_state: self
.get_stripped_space_child_events(&room_id)
.await?
.unwrap(),
allowed_room_ids,
}
},
@ -333,7 +322,7 @@ impl Service {
);
}
}
if is_accessable_child(
if self.is_accessible_child(
current_room,
&response.room.join_rule,
&Identifier::UserId(user_id),
@ -370,12 +359,13 @@ impl Service {
}
fn get_room_summary(
current_room: &OwnedRoomId, children_state: Vec<Raw<HierarchySpaceChildEvent>>, identifier: &Identifier<'_>,
&self, current_room: &OwnedRoomId, children_state: Vec<Raw<HierarchySpaceChildEvent>>,
identifier: &Identifier<'_>,
) -> Result<SpaceHierarchyParentSummary, Error> {
let room_id: &RoomId = current_room;
let join_rule = services()
.rooms
let join_rule = self
.services
.state_accessor
.room_state_get(room_id, &StateEventType::RoomJoinRules, "")?
.map(|s| {
@ -386,12 +376,12 @@ impl Service {
.transpose()?
.unwrap_or(JoinRule::Invite);
let allowed_room_ids = services()
.rooms
let allowed_room_ids = self
.services
.state_accessor
.allowed_room_ids(join_rule.clone());
if !is_accessable_child(current_room, &join_rule.clone().into(), identifier, &allowed_room_ids) {
if !self.is_accessible_child(current_room, &join_rule.clone().into(), identifier, &allowed_room_ids) {
debug!("User is not allowed to see room {room_id}");
// This error will be caught later
return Err(Error::BadRequest(ErrorKind::forbidden(), "User is not allowed to see the room"));
@ -400,18 +390,18 @@ impl Service {
let join_rule = join_rule.into();
Ok(SpaceHierarchyParentSummary {
canonical_alias: services()
.rooms
canonical_alias: self
.services
.state_accessor
.get_canonical_alias(room_id)
.unwrap_or(None),
name: services()
.rooms
name: self
.services
.state_accessor
.get_name(room_id)
.unwrap_or(None),
num_joined_members: services()
.rooms
num_joined_members: self
.services
.state_cache
.room_joined_count(room_id)
.unwrap_or_default()
@ -422,22 +412,22 @@ impl Service {
.try_into()
.expect("user count should not be that big"),
room_id: room_id.to_owned(),
topic: services()
.rooms
topic: self
.services
.state_accessor
.get_room_topic(room_id)
.unwrap_or(None),
world_readable: services().rooms.state_accessor.is_world_readable(room_id)?,
guest_can_join: services().rooms.state_accessor.guest_can_join(room_id)?,
avatar_url: services()
.rooms
world_readable: self.services.state_accessor.is_world_readable(room_id)?,
guest_can_join: self.services.state_accessor.guest_can_join(room_id)?,
avatar_url: self
.services
.state_accessor
.get_avatar(room_id)?
.into_option()
.unwrap_or_default()
.url,
join_rule,
room_type: services().rooms.state_accessor.get_room_type(room_id)?,
room_type: self.services.state_accessor.get_room_type(room_id)?,
children_state,
allowed_room_ids,
})
@ -487,7 +477,7 @@ impl Service {
.into_iter()
.rev()
.skip_while(|(room, _)| {
if let Ok(short) = services().rooms.short.get_shortroomid(room)
if let Ok(short) = self.services.short.get_shortroomid(room)
{
short.as_ref() != short_room_ids.get(parents.len())
} else {
@ -541,7 +531,7 @@ impl Service {
let mut short_room_ids = vec![];
for room in parents {
short_room_ids.push(services().rooms.short.get_or_create_shortroomid(&room)?);
short_room_ids.push(self.services.short.get_or_create_shortroomid(&room)?);
}
Some(
@ -559,128 +549,152 @@ impl Service {
rooms: results,
})
}
}
fn next_room_to_traverse(
stack: &mut Vec<Vec<(OwnedRoomId, Vec<OwnedServerName>)>>, parents: &mut VecDeque<OwnedRoomId>,
) -> Option<(OwnedRoomId, Vec<OwnedServerName>)> {
while stack.last().map_or(false, Vec::is_empty) {
stack.pop();
parents.pop_back();
/// Simply returns the stripped m.space.child events of a room
async fn get_stripped_space_child_events(
&self, room_id: &RoomId,
) -> Result<Option<Vec<Raw<HierarchySpaceChildEvent>>>, Error> {
let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? else {
return Ok(None);
};
let state = self
.services
.state_accessor
.state_full_ids(current_shortstatehash)
.await?;
let mut children_pdus = Vec::new();
for (key, id) in state {
let (event_type, state_key) = self.services.short.get_statekey_from_short(key)?;
if event_type != StateEventType::SpaceChild {
continue;
}
let pdu = self
.services
.timeline
.get_pdu(&id)?
.ok_or_else(|| Error::bad_database("Event in space state not found"))?;
if serde_json::from_str::<SpaceChildEventContent>(pdu.content.get())
.ok()
.map(|c| c.via)
.map_or(true, |v| v.is_empty())
{
continue;
}
if OwnedRoomId::try_from(state_key).is_ok() {
children_pdus.push(pdu.to_stripped_spacechild_state_event());
}
}
Ok(Some(children_pdus))
}
stack.last_mut().and_then(Vec::pop)
}
/// With the given identifier, checks if a room is accessable
fn is_accessible_child(
&self, current_room: &OwnedRoomId, join_rule: &SpaceRoomJoinRule, identifier: &Identifier<'_>,
allowed_room_ids: &Vec<OwnedRoomId>,
) -> bool {
// Note: unwrap_or_default for bool means false
match identifier {
Identifier::ServerName(server_name) => {
let room_id: &RoomId = current_room;
/// Simply returns the stripped m.space.child events of a room
async fn get_stripped_space_child_events(
room_id: &RoomId,
) -> Result<Option<Vec<Raw<HierarchySpaceChildEvent>>>, Error> {
let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? else {
return Ok(None);
};
let state = services()
.rooms
.state_accessor
.state_full_ids(current_shortstatehash)
.await?;
let mut children_pdus = Vec::new();
for (key, id) in state {
let (event_type, state_key) = services().rooms.short.get_statekey_from_short(key)?;
if event_type != StateEventType::SpaceChild {
continue;
}
let pdu = services()
.rooms
.timeline
.get_pdu(&id)?
.ok_or_else(|| Error::bad_database("Event in space state not found"))?;
if serde_json::from_str::<SpaceChildEventContent>(pdu.content.get())
.ok()
.map(|c| c.via)
.map_or(true, |v| v.is_empty())
{
continue;
}
if OwnedRoomId::try_from(state_key).is_ok() {
children_pdus.push(pdu.to_stripped_spacechild_state_event());
}
}
Ok(Some(children_pdus))
}
/// With the given identifier, checks if a room is accessable
fn is_accessable_child(
current_room: &OwnedRoomId, join_rule: &SpaceRoomJoinRule, identifier: &Identifier<'_>,
allowed_room_ids: &Vec<OwnedRoomId>,
) -> bool {
// Note: unwrap_or_default for bool means false
match identifier {
Identifier::ServerName(server_name) => {
let room_id: &RoomId = current_room;
// Checks if ACLs allow for the server to participate
if services()
.rooms
.event_handler
.acl_check(server_name, room_id)
.is_err()
{
return false;
}
},
Identifier::UserId(user_id) => {
if services()
.rooms
.state_cache
.is_joined(user_id, current_room)
.unwrap_or_default()
|| services()
.rooms
.state_cache
.is_invited(user_id, current_room)
.unwrap_or_default()
{
return true;
}
},
} // Takes care of join rules
match join_rule {
SpaceRoomJoinRule::Restricted => {
for room in allowed_room_ids {
match identifier {
Identifier::UserId(user) => {
if services()
.rooms
.state_cache
.is_joined(user, room)
.unwrap_or_default()
{
return true;
}
},
Identifier::ServerName(server) => {
if services()
.rooms
.state_cache
.server_in_room(server, room)
.unwrap_or_default()
{
return true;
}
},
// Checks if ACLs allow for the server to participate
if self
.services
.event_handler
.acl_check(server_name, room_id)
.is_err()
{
return false;
}
}
false
},
SpaceRoomJoinRule::Public | SpaceRoomJoinRule::Knock | SpaceRoomJoinRule::KnockRestricted => true,
// Invite only, Private, or Custom join rule
_ => false,
},
Identifier::UserId(user_id) => {
if self
.services
.state_cache
.is_joined(user_id, current_room)
.unwrap_or_default()
|| self
.services
.state_cache
.is_invited(user_id, current_room)
.unwrap_or_default()
{
return true;
}
},
} // Takes care of join rules
match join_rule {
SpaceRoomJoinRule::Restricted => {
for room in allowed_room_ids {
match identifier {
Identifier::UserId(user) => {
if self
.services
.state_cache
.is_joined(user, room)
.unwrap_or_default()
{
return true;
}
},
Identifier::ServerName(server) => {
if self
.services
.state_cache
.server_in_room(server, room)
.unwrap_or_default()
{
return true;
}
},
}
}
false
},
SpaceRoomJoinRule::Public | SpaceRoomJoinRule::Knock | SpaceRoomJoinRule::KnockRestricted => true,
// Invite only, Private, or Custom join rule
_ => false,
}
}
}
// Here because cannot implement `From` across ruma-federation-api and
// ruma-client-api types
impl From<CachedSpaceHierarchySummary> for SpaceHierarchyRoomsChunk {
fn from(value: CachedSpaceHierarchySummary) -> Self {
let SpaceHierarchyParentSummary {
canonical_alias,
name,
num_joined_members,
room_id,
topic,
world_readable,
guest_can_join,
avatar_url,
join_rule,
room_type,
children_state,
..
} = value.summary;
Self {
canonical_alias,
name,
num_joined_members,
room_id,
topic,
world_readable,
guest_can_join,
avatar_url,
join_rule,
room_type,
children_state,
}
}
}
@ -736,3 +750,14 @@ fn get_parent_children_via(
})
.collect()
}
fn next_room_to_traverse(
stack: &mut Vec<Vec<(OwnedRoomId, Vec<OwnedServerName>)>>, parents: &mut VecDeque<OwnedRoomId>,
) -> Option<(OwnedRoomId, Vec<OwnedServerName>)> {
while stack.last().map_or(false, Vec::is_empty) {
stack.pop();
parents.pop_back();
}
stack.last_mut().and_then(Vec::pop)
}

View file

@ -8,7 +8,7 @@ use std::{
use conduit::{
utils::{calculate_hash, MutexMap, MutexMapGuard},
warn, Error, Result,
warn, Error, PduEvent, Result,
};
use data::Data;
use ruma::{
@ -23,19 +23,39 @@ use ruma::{
};
use super::state_compressor::CompressedStateEvent;
use crate::{services, PduEvent};
use crate::{globals, rooms, Dep};
pub struct Service {
services: Services,
db: Data,
pub mutex: RoomMutexMap,
}
struct Services {
globals: Dep<globals::Service>,
short: Dep<rooms::short::Service>,
spaces: Dep<rooms::spaces::Service>,
state_cache: Dep<rooms::state_cache::Service>,
state_accessor: Dep<rooms::state_accessor::Service>,
state_compressor: Dep<rooms::state_compressor::Service>,
timeline: Dep<rooms::timeline::Service>,
}
type RoomMutexMap = MutexMap<OwnedRoomId, ()>;
pub type RoomMutexGuard = MutexMapGuard<OwnedRoomId, ()>;
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
services: Services {
globals: args.depend::<globals::Service>("globals"),
short: args.depend::<rooms::short::Service>("rooms::short"),
spaces: args.depend::<rooms::spaces::Service>("rooms::spaces"),
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
state_compressor: args.depend::<rooms::state_compressor::Service>("rooms::state_compressor"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
},
db: Data::new(args.db),
mutex: RoomMutexMap::new(),
}))
@ -62,14 +82,13 @@ impl Service {
state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
) -> Result<()> {
for event_id in statediffnew.iter().filter_map(|new| {
services()
.rooms
self.services
.state_compressor
.parse_compressed_state_event(new)
.ok()
.map(|(_, id)| id)
}) {
let Some(pdu) = services().rooms.timeline.get_pdu_json(&event_id)? else {
let Some(pdu) = self.services.timeline.get_pdu_json(&event_id)? else {
continue;
};
@ -94,7 +113,7 @@ impl Service {
continue;
};
services().rooms.state_cache.update_membership(
self.services.state_cache.update_membership(
room_id,
&user_id,
membership_event,
@ -105,8 +124,7 @@ impl Service {
)?;
},
TimelineEventType::SpaceChild => {
services()
.rooms
self.services
.spaces
.roomid_spacehierarchy_cache
.lock()
@ -117,7 +135,7 @@ impl Service {
}
}
services().rooms.state_cache.update_joined_count(room_id)?;
self.services.state_cache.update_joined_count(room_id)?;
self.db
.set_room_state(room_id, shortstatehash, state_lock)?;
@ -133,10 +151,7 @@ impl Service {
pub fn set_event_state(
&self, event_id: &EventId, room_id: &RoomId, state_ids_compressed: Arc<HashSet<CompressedStateEvent>>,
) -> Result<u64> {
let shorteventid = services()
.rooms
.short
.get_or_create_shorteventid(event_id)?;
let shorteventid = self.services.short.get_or_create_shorteventid(event_id)?;
let previous_shortstatehash = self.db.get_room_shortstatehash(room_id)?;
@ -147,20 +162,15 @@ impl Service {
.collect::<Vec<_>>(),
);
let (shortstatehash, already_existed) = services()
.rooms
let (shortstatehash, already_existed) = self
.services
.short
.get_or_create_shortstatehash(&state_hash)?;
if !already_existed {
let states_parents = previous_shortstatehash.map_or_else(
|| Ok(Vec::new()),
|p| {
services()
.rooms
.state_compressor
.load_shortstatehash_info(p)
},
|p| self.services.state_compressor.load_shortstatehash_info(p),
)?;
let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() {
@ -179,7 +189,7 @@ impl Service {
} else {
(state_ids_compressed, Arc::new(HashSet::new()))
};
services().rooms.state_compressor.save_state_from_diff(
self.services.state_compressor.save_state_from_diff(
shortstatehash,
statediffnew,
statediffremoved,
@ -199,8 +209,8 @@ impl Service {
/// to `stateid_pduid` and adds the incoming event to `eventid_statehash`.
#[tracing::instrument(skip(self, new_pdu), level = "debug")]
pub fn append_to_state(&self, new_pdu: &PduEvent) -> Result<u64> {
let shorteventid = services()
.rooms
let shorteventid = self
.services
.short
.get_or_create_shorteventid(&new_pdu.event_id)?;
@ -214,21 +224,16 @@ impl Service {
let states_parents = previous_shortstatehash.map_or_else(
|| Ok(Vec::new()),
#[inline]
|p| {
services()
.rooms
.state_compressor
.load_shortstatehash_info(p)
},
|p| self.services.state_compressor.load_shortstatehash_info(p),
)?;
let shortstatekey = services()
.rooms
let shortstatekey = self
.services
.short
.get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key)?;
let new = services()
.rooms
let new = self
.services
.state_compressor
.compress_state_event(shortstatekey, &new_pdu.event_id)?;
@ -246,7 +251,7 @@ impl Service {
}
// TODO: statehash with deterministic inputs
let shortstatehash = services().globals.next_count()?;
let shortstatehash = self.services.globals.next_count()?;
let mut statediffnew = HashSet::new();
statediffnew.insert(new);
@ -256,7 +261,7 @@ impl Service {
statediffremoved.insert(*replaces);
}
services().rooms.state_compressor.save_state_from_diff(
self.services.state_compressor.save_state_from_diff(
shortstatehash,
Arc::new(statediffnew),
Arc::new(statediffremoved),
@ -275,22 +280,20 @@ impl Service {
let mut state = Vec::new();
// Add recommended events
if let Some(e) =
services()
.rooms
self.services
.state_accessor
.room_state_get(&invite_event.room_id, &StateEventType::RoomCreate, "")?
{
state.push(e.to_stripped_state_event());
}
if let Some(e) =
services()
.rooms
self.services
.state_accessor
.room_state_get(&invite_event.room_id, &StateEventType::RoomJoinRules, "")?
{
state.push(e.to_stripped_state_event());
}
if let Some(e) = services().rooms.state_accessor.room_state_get(
if let Some(e) = self.services.state_accessor.room_state_get(
&invite_event.room_id,
&StateEventType::RoomCanonicalAlias,
"",
@ -298,22 +301,20 @@ impl Service {
state.push(e.to_stripped_state_event());
}
if let Some(e) =
services()
.rooms
self.services
.state_accessor
.room_state_get(&invite_event.room_id, &StateEventType::RoomAvatar, "")?
{
state.push(e.to_stripped_state_event());
}
if let Some(e) =
services()
.rooms
self.services
.state_accessor
.room_state_get(&invite_event.room_id, &StateEventType::RoomName, "")?
{
state.push(e.to_stripped_state_event());
}
if let Some(e) = services().rooms.state_accessor.room_state_get(
if let Some(e) = self.services.state_accessor.room_state_get(
&invite_event.room_id,
&StateEventType::RoomMember,
invite_event.sender.as_str(),
@ -339,8 +340,8 @@ impl Service {
/// Returns the room's version.
#[tracing::instrument(skip(self), level = "debug")]
pub fn get_room_version(&self, room_id: &RoomId) -> Result<RoomVersionId> {
let create_event = services()
.rooms
let create_event = self
.services
.state_accessor
.room_state_get(room_id, &StateEventType::RoomCreate, "")?;
@ -393,8 +394,7 @@ impl Service {
let mut sauthevents = auth_events
.into_iter()
.filter_map(|(event_type, state_key)| {
services()
.rooms
self.services
.short
.get_shortstatekey(&event_type.to_string().into(), &state_key)
.ok()
@ -403,8 +403,8 @@ impl Service {
})
.collect::<HashMap<_, _>>();
let full_state = services()
.rooms
let full_state = self
.services
.state_compressor
.load_shortstatehash_info(shortstatehash)?
.pop()
@ -414,16 +414,14 @@ impl Service {
Ok(full_state
.iter()
.filter_map(|compressed| {
services()
.rooms
self.services
.state_compressor
.parse_compressed_state_event(compressed)
.ok()
})
.filter_map(|(shortstatekey, event_id)| sauthevents.remove(&shortstatekey).map(|k| (k, event_id)))
.filter_map(|(k, event_id)| {
services()
.rooms
self.services
.timeline
.get_pdu(&event_id)
.ok()

View file

@ -1,28 +1,43 @@
use std::{collections::HashMap, sync::Arc};
use conduit::{utils, Error, Result};
use database::{Database, Map};
use conduit::{utils, Error, PduEvent, Result};
use database::Map;
use ruma::{events::StateEventType, EventId, RoomId};
use crate::{services, PduEvent};
use crate::{rooms, Dep};
pub(super) struct Data {
eventid_shorteventid: Arc<Map>,
shorteventid_shortstatehash: Arc<Map>,
services: Services,
}
struct Services {
short: Dep<rooms::short::Service>,
state: Dep<rooms::state::Service>,
state_compressor: Dep<rooms::state_compressor::Service>,
timeline: Dep<rooms::timeline::Service>,
}
impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self {
pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self {
eventid_shorteventid: db["eventid_shorteventid"].clone(),
shorteventid_shortstatehash: db["shorteventid_shortstatehash"].clone(),
services: Services {
short: args.depend::<rooms::short::Service>("rooms::short"),
state: args.depend::<rooms::state::Service>("rooms::state"),
state_compressor: args.depend::<rooms::state_compressor::Service>("rooms::state_compressor"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
},
}
}
#[allow(unused_qualifications)] // async traits
pub(super) async fn state_full_ids(&self, shortstatehash: u64) -> Result<HashMap<u64, Arc<EventId>>> {
let full_state = services()
.rooms
let full_state = self
.services
.state_compressor
.load_shortstatehash_info(shortstatehash)?
.pop()
@ -31,8 +46,8 @@ impl Data {
let mut result = HashMap::new();
let mut i: u8 = 0;
for compressed in full_state.iter() {
let parsed = services()
.rooms
let parsed = self
.services
.state_compressor
.parse_compressed_state_event(compressed)?;
result.insert(parsed.0, parsed.1);
@ -49,8 +64,8 @@ impl Data {
pub(super) async fn state_full(
&self, shortstatehash: u64,
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
let full_state = services()
.rooms
let full_state = self
.services
.state_compressor
.load_shortstatehash_info(shortstatehash)?
.pop()
@ -60,11 +75,11 @@ impl Data {
let mut result = HashMap::new();
let mut i: u8 = 0;
for compressed in full_state.iter() {
let (_, eventid) = services()
.rooms
let (_, eventid) = self
.services
.state_compressor
.parse_compressed_state_event(compressed)?;
if let Some(pdu) = services().rooms.timeline.get_pdu(&eventid)? {
if let Some(pdu) = self.services.timeline.get_pdu(&eventid)? {
result.insert(
(
pdu.kind.to_string().into(),
@ -92,15 +107,15 @@ impl Data {
pub(super) fn state_get_id(
&self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<EventId>>> {
let Some(shortstatekey) = services()
.rooms
let Some(shortstatekey) = self
.services
.short
.get_shortstatekey(event_type, state_key)?
else {
return Ok(None);
};
let full_state = services()
.rooms
let full_state = self
.services
.state_compressor
.load_shortstatehash_info(shortstatehash)?
.pop()
@ -110,8 +125,7 @@ impl Data {
.iter()
.find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))
.and_then(|compressed| {
services()
.rooms
self.services
.state_compressor
.parse_compressed_state_event(compressed)
.ok()
@ -125,7 +139,7 @@ impl Data {
&self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<PduEvent>>> {
self.state_get_id(shortstatehash, event_type, state_key)?
.map_or(Ok(None), |event_id| services().rooms.timeline.get_pdu(&event_id))
.map_or(Ok(None), |event_id| self.services.timeline.get_pdu(&event_id))
}
/// Returns the state hash for this pdu.
@ -149,7 +163,7 @@ impl Data {
pub(super) async fn room_state_full(
&self, room_id: &RoomId,
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
if let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? {
self.state_full(current_shortstatehash).await
} else {
Ok(HashMap::new())
@ -161,7 +175,7 @@ impl Data {
pub(super) fn room_state_get_id(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<EventId>>> {
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
if let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? {
self.state_get_id(current_shortstatehash, event_type, state_key)
} else {
Ok(None)
@ -173,7 +187,7 @@ impl Data {
pub(super) fn room_state_get(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<PduEvent>>> {
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
if let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? {
self.state_get(current_shortstatehash, event_type, state_key)
} else {
Ok(None)

View file

@ -6,7 +6,7 @@ use std::{
sync::{Arc, Mutex as StdMutex, Mutex},
};
use conduit::{err, error, utils::math::usize_from_f64, warn, Error, Result};
use conduit::{err, error, pdu::PduBuilder, utils::math::usize_from_f64, warn, Error, PduEvent, Result};
use data::Data;
use lru_cache::LruCache;
use ruma::{
@ -33,14 +33,20 @@ use ruma::{
};
use serde_json::value::to_raw_value;
use crate::{pdu::PduBuilder, rooms::state::RoomMutexGuard, services, PduEvent};
use crate::{rooms, rooms::state::RoomMutexGuard, Dep};
pub struct Service {
services: Services,
db: Data,
pub server_visibility_cache: Mutex<LruCache<(OwnedServerName, u64), bool>>,
pub user_visibility_cache: Mutex<LruCache<(OwnedUserId, u64), bool>>,
}
struct Services {
state_cache: Dep<rooms::state_cache::Service>,
timeline: Dep<rooms::timeline::Service>,
}
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
let config = &args.server.config;
@ -50,7 +56,11 @@ impl crate::Service for Service {
f64::from(config.user_visibility_cache_capacity) * config.cache_capacity_modifier;
Ok(Arc::new(Self {
db: Data::new(args.db),
services: Services {
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
},
db: Data::new(&args),
server_visibility_cache: StdMutex::new(LruCache::new(usize_from_f64(server_visibility_cache_capacity)?)),
user_visibility_cache: StdMutex::new(LruCache::new(usize_from_f64(user_visibility_cache_capacity)?)),
}))
@ -164,8 +174,8 @@ impl Service {
})
.unwrap_or(HistoryVisibility::Shared);
let mut current_server_members = services()
.rooms
let mut current_server_members = self
.services
.state_cache
.room_members(room_id)
.filter_map(Result::ok)
@ -212,7 +222,7 @@ impl Service {
return Ok(*visibility);
}
let currently_member = services().rooms.state_cache.is_joined(user_id, room_id)?;
let currently_member = self.services.state_cache.is_joined(user_id, room_id)?;
let history_visibility = self
.state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")?
@ -258,7 +268,7 @@ impl Service {
/// the room's history_visibility at that event's state.
#[tracing::instrument(skip(self, user_id, room_id))]
pub fn user_can_see_state_events(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let currently_member = services().rooms.state_cache.is_joined(user_id, room_id)?;
let currently_member = self.services.state_cache.is_joined(user_id, room_id)?;
let history_visibility = self
.room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "")?
@ -342,8 +352,8 @@ impl Service {
redacts: None,
};
Ok(services()
.rooms
Ok(self
.services
.timeline
.create_hash_and_sign_event(new_event, sender, room_id, state_lock)
.is_ok())
@ -413,7 +423,7 @@ impl Service {
// Falling back on m.room.create to judge power level
if let Some(pdu) = self.room_state_get(room_id, &StateEventType::RoomCreate, "")? {
Ok(pdu.sender == sender
|| if let Ok(Some(pdu)) = services().rooms.timeline.get_pdu(redacts) {
|| if let Ok(Some(pdu)) = self.services.timeline.get_pdu(redacts) {
pdu.sender == sender
} else {
false
@ -430,7 +440,7 @@ impl Service {
.map(|event: RoomPowerLevels| {
event.user_can_redact_event_of_other(sender)
|| event.user_can_redact_own_event(sender)
&& if let Ok(Some(pdu)) = services().rooms.timeline.get_pdu(redacts) {
&& if let Ok(Some(pdu)) = self.services.timeline.get_pdu(redacts) {
if federation {
pdu.sender.server_name() == sender.server_name()
} else {

View file

@ -4,7 +4,7 @@ use std::{
};
use conduit::{utils, Error, Result};
use database::{Database, Map};
use database::Map;
use itertools::Itertools;
use ruma::{
events::{AnyStrippedStateEvent, AnySyncStateEvent},
@ -12,44 +12,55 @@ use ruma::{
OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId,
};
use crate::{appservice::RegistrationInfo, services, user_is_local};
use crate::{appservice::RegistrationInfo, globals, user_is_local, users, Dep};
type StrippedStateEventIter<'a> = Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>)>> + 'a>;
type AnySyncStateEventIter<'a> = Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnySyncStateEvent>>)>> + 'a>;
type AppServiceInRoomCache = RwLock<HashMap<OwnedRoomId, HashMap<String, bool>>>;
pub(super) struct Data {
userroomid_joined: Arc<Map>,
roomuserid_joined: Arc<Map>,
userroomid_invitestate: Arc<Map>,
roomuserid_invitecount: Arc<Map>,
userroomid_leftstate: Arc<Map>,
roomuserid_leftcount: Arc<Map>,
roomid_inviteviaservers: Arc<Map>,
roomuseroncejoinedids: Arc<Map>,
roomid_joinedcount: Arc<Map>,
roomid_invitedcount: Arc<Map>,
roomserverids: Arc<Map>,
serverroomids: Arc<Map>,
pub(super) appservice_in_room_cache: AppServiceInRoomCache,
roomid_invitedcount: Arc<Map>,
roomid_inviteviaservers: Arc<Map>,
roomid_joinedcount: Arc<Map>,
roomserverids: Arc<Map>,
roomuserid_invitecount: Arc<Map>,
roomuserid_joined: Arc<Map>,
roomuserid_leftcount: Arc<Map>,
roomuseroncejoinedids: Arc<Map>,
serverroomids: Arc<Map>,
userroomid_invitestate: Arc<Map>,
userroomid_joined: Arc<Map>,
userroomid_leftstate: Arc<Map>,
services: Services,
}
struct Services {
globals: Dep<globals::Service>,
users: Dep<users::Service>,
}
impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self {
pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self {
userroomid_joined: db["userroomid_joined"].clone(),
roomuserid_joined: db["roomuserid_joined"].clone(),
userroomid_invitestate: db["userroomid_invitestate"].clone(),
roomuserid_invitecount: db["roomuserid_invitecount"].clone(),
userroomid_leftstate: db["userroomid_leftstate"].clone(),
roomuserid_leftcount: db["roomuserid_leftcount"].clone(),
roomid_inviteviaservers: db["roomid_inviteviaservers"].clone(),
roomuseroncejoinedids: db["roomuseroncejoinedids"].clone(),
roomid_joinedcount: db["roomid_joinedcount"].clone(),
roomid_invitedcount: db["roomid_invitedcount"].clone(),
roomserverids: db["roomserverids"].clone(),
serverroomids: db["serverroomids"].clone(),
appservice_in_room_cache: RwLock::new(HashMap::new()),
roomid_invitedcount: db["roomid_invitedcount"].clone(),
roomid_inviteviaservers: db["roomid_inviteviaservers"].clone(),
roomid_joinedcount: db["roomid_joinedcount"].clone(),
roomserverids: db["roomserverids"].clone(),
roomuserid_invitecount: db["roomuserid_invitecount"].clone(),
roomuserid_joined: db["roomuserid_joined"].clone(),
roomuserid_leftcount: db["roomuserid_leftcount"].clone(),
roomuseroncejoinedids: db["roomuseroncejoinedids"].clone(),
serverroomids: db["serverroomids"].clone(),
userroomid_invitestate: db["userroomid_invitestate"].clone(),
userroomid_joined: db["userroomid_joined"].clone(),
userroomid_leftstate: db["userroomid_leftstate"].clone(),
services: Services {
globals: args.depend::<globals::Service>("globals"),
users: args.depend::<users::Service>("users"),
},
}
}
@ -100,7 +111,7 @@ impl Data {
&serde_json::to_vec(&last_state.unwrap_or_default()).expect("state to bytes always works"),
)?;
self.roomuserid_invitecount
.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?;
.insert(&roomuser_id, &self.services.globals.next_count()?.to_be_bytes())?;
self.userroomid_joined.remove(&userroom_id)?;
self.roomuserid_joined.remove(&roomuser_id)?;
self.userroomid_leftstate.remove(&userroom_id)?;
@ -144,7 +155,7 @@ impl Data {
&serde_json::to_vec(&Vec::<Raw<AnySyncStateEvent>>::new()).unwrap(),
)?; // TODO
self.roomuserid_leftcount
.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?;
.insert(&roomuser_id, &self.services.globals.next_count()?.to_be_bytes())?;
self.userroomid_joined.remove(&userroom_id)?;
self.roomuserid_joined.remove(&roomuser_id)?;
self.userroomid_invitestate.remove(&userroom_id)?;
@ -228,7 +239,7 @@ impl Data {
} else {
let bridge_user_id = UserId::parse_with_server_name(
appservice.registration.sender_localpart.as_str(),
services().globals.server_name(),
self.services.globals.server_name(),
)
.ok();
@ -356,7 +367,7 @@ impl Data {
) -> Box<dyn Iterator<Item = OwnedUserId> + 'a> {
Box::new(
self.local_users_in_room(room_id)
.filter(|user| !services().users.is_deactivated(user).unwrap_or(true)),
.filter(|user| !self.services.users.is_deactivated(user).unwrap_or(true)),
)
}

View file

@ -21,16 +21,28 @@ use ruma::{
OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId,
};
use crate::{appservice::RegistrationInfo, services, user_is_local};
use crate::{account_data, appservice::RegistrationInfo, rooms, user_is_local, users, Dep};
pub struct Service {
services: Services,
db: Data,
}
struct Services {
account_data: Dep<account_data::Service>,
state_accessor: Dep<rooms::state_accessor::Service>,
users: Dep<users::Service>,
}
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db: Data::new(args.db),
services: Services {
account_data: args.depend::<account_data::Service>("account_data"),
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
users: args.depend::<users::Service>("users"),
},
db: Data::new(&args),
}))
}
@ -54,18 +66,18 @@ impl Service {
// update
#[allow(clippy::collapsible_if)]
if !user_is_local(user_id) {
if !services().users.exists(user_id)? {
services().users.create(user_id, None)?;
if !self.services.users.exists(user_id)? {
self.services.users.create(user_id, None)?;
}
/*
// Try to update our local copy of the user if ours does not match
if ((services().users.displayname(user_id)? != membership_event.displayname)
|| (services().users.avatar_url(user_id)? != membership_event.avatar_url)
|| (services().users.blurhash(user_id)? != membership_event.blurhash))
if ((self.services.users.displayname(user_id)? != membership_event.displayname)
|| (self.services.users.avatar_url(user_id)? != membership_event.avatar_url)
|| (self.services.users.blurhash(user_id)? != membership_event.blurhash))
&& (membership != MembershipState::Leave)
{
let response = services()
let response = self.services
.sending
.send_federation_request(
user_id.server_name(),
@ -76,9 +88,9 @@ impl Service {
)
.await;
services().users.set_displayname(user_id, response.displayname.clone()).await?;
services().users.set_avatar_url(user_id, response.avatar_url).await?;
services().users.set_blurhash(user_id, response.blurhash).await?;
self.services.users.set_displayname(user_id, response.displayname.clone()).await?;
self.services.users.set_avatar_url(user_id, response.avatar_url).await?;
self.services.users.set_blurhash(user_id, response.blurhash).await?;
};
*/
}
@ -91,8 +103,8 @@ impl Service {
self.db.mark_as_once_joined(user_id, room_id)?;
// Check if the room has a predecessor
if let Some(predecessor) = services()
.rooms
if let Some(predecessor) = self
.services
.state_accessor
.room_state_get(room_id, &StateEventType::RoomCreate, "")?
.and_then(|create| serde_json::from_str(create.content.get()).ok())
@ -124,21 +136,23 @@ impl Service {
// .ok();
// Copy old tags to new room
if let Some(tag_event) = services()
if let Some(tag_event) = self
.services
.account_data
.get(Some(&predecessor.room_id), user_id, RoomAccountDataEventType::Tag)?
.map(|event| {
serde_json::from_str(event.get())
.map_err(|e| err!(Database(warn!("Invalid account data event in db: {e:?}"))))
}) {
services()
self.services
.account_data
.update(Some(room_id), user_id, RoomAccountDataEventType::Tag, &tag_event?)
.ok();
};
// Copy direct chat flag
if let Some(direct_event) = services()
if let Some(direct_event) = self
.services
.account_data
.get(None, user_id, GlobalAccountDataEventType::Direct.to_string().into())?
.map(|event| {
@ -156,7 +170,7 @@ impl Service {
}
if room_ids_updated {
services().account_data.update(
self.services.account_data.update(
None,
user_id,
GlobalAccountDataEventType::Direct.to_string().into(),
@ -171,7 +185,8 @@ impl Service {
},
MembershipState::Invite => {
// We want to know if the sender is ignored by the receiver
let is_ignored = services()
let is_ignored = self
.services
.account_data
.get(
None, // Ignored users are in global account data
@ -393,8 +408,8 @@ impl Service {
/// See <https://spec.matrix.org/v1.10/appendices/#routing>
#[tracing::instrument(skip(self))]
pub fn servers_route_via(&self, room_id: &RoomId) -> Result<Vec<OwnedServerName>> {
let most_powerful_user_server = services()
.rooms
let most_powerful_user_server = self
.services
.state_accessor
.room_state_get(room_id, &StateEventType::RoomPowerLevels, "")?
.map(|pdu| {

View file

@ -13,7 +13,7 @@ use lru_cache::LruCache;
use ruma::{EventId, RoomId};
use self::data::StateDiff;
use crate::services;
use crate::{rooms, Dep};
type StateInfoLruCache = Mutex<
LruCache<
@ -48,16 +48,25 @@ pub type CompressedStateEvent = [u8; 2 * size_of::<u64>()];
pub struct Service {
db: Data,
services: Services,
pub stateinfo_cache: StateInfoLruCache,
}
struct Services {
short: Dep<rooms::short::Service>,
state: Dep<rooms::state::Service>,
}
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
let config = &args.server.config;
let cache_capacity = f64::from(config.stateinfo_cache_capacity) * config.cache_capacity_modifier;
Ok(Arc::new(Self {
db: Data::new(args.db),
services: Services {
short: args.depend::<rooms::short::Service>("rooms::short"),
state: args.depend::<rooms::state::Service>("rooms::state"),
},
stateinfo_cache: StdMutex::new(LruCache::new(usize_from_f64(cache_capacity)?)),
}))
}
@ -124,8 +133,8 @@ impl Service {
pub fn compress_state_event(&self, shortstatekey: u64, event_id: &EventId) -> Result<CompressedStateEvent> {
let mut v = shortstatekey.to_be_bytes().to_vec();
v.extend_from_slice(
&services()
.rooms
&self
.services
.short
.get_or_create_shorteventid(event_id)?
.to_be_bytes(),
@ -138,7 +147,7 @@ impl Service {
pub fn parse_compressed_state_event(&self, compressed_event: &CompressedStateEvent) -> Result<(u64, Arc<EventId>)> {
Ok((
utils::u64_from_bytes(&compressed_event[0..size_of::<u64>()]).expect("bytes have right length"),
services().rooms.short.get_eventid_from_short(
self.services.short.get_eventid_from_short(
utils::u64_from_bytes(&compressed_event[size_of::<u64>()..]).expect("bytes have right length"),
)?,
))
@ -282,7 +291,7 @@ impl Service {
pub fn save_state(
&self, room_id: &RoomId, new_state_ids_compressed: Arc<HashSet<CompressedStateEvent>>,
) -> HashSetCompressStateEvent {
let previous_shortstatehash = services().rooms.state.get_room_shortstatehash(room_id)?;
let previous_shortstatehash = self.services.state.get_room_shortstatehash(room_id)?;
let state_hash = utils::calculate_hash(
&new_state_ids_compressed
@ -291,8 +300,8 @@ impl Service {
.collect::<Vec<_>>(),
);
let (new_shortstatehash, already_existed) = services()
.rooms
let (new_shortstatehash, already_existed) = self
.services
.short
.get_or_create_shortstatehash(&state_hash)?;

View file

@ -1,29 +1,40 @@
use std::{mem::size_of, sync::Arc};
use conduit::{checked, utils, Error, Result};
use database::{Database, Map};
use conduit::{checked, utils, Error, PduEvent, Result};
use database::Map;
use ruma::{api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId, UserId};
use crate::{services, PduEvent};
use crate::{rooms, Dep};
type PduEventIterResult<'a> = Result<Box<dyn Iterator<Item = Result<(u64, PduEvent)>> + 'a>>;
pub(super) struct Data {
threadid_userids: Arc<Map>,
services: Services,
}
struct Services {
short: Dep<rooms::short::Service>,
timeline: Dep<rooms::timeline::Service>,
}
impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self {
pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self {
threadid_userids: db["threadid_userids"].clone(),
services: Services {
short: args.depend::<rooms::short::Service>("rooms::short"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
},
}
}
pub(super) fn threads_until<'a>(
&'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, _include: &'a IncludeThreads,
) -> PduEventIterResult<'a> {
let prefix = services()
.rooms
let prefix = self
.services
.short
.get_shortroomid(room_id)?
.expect("room exists")
@ -40,8 +51,8 @@ impl Data {
.map(move |(pduid, _users)| {
let count = utils::u64_from_bytes(&pduid[(size_of::<u64>())..])
.map_err(|_| Error::bad_database("Invalid pduid in threadid_userids."))?;
let mut pdu = services()
.rooms
let mut pdu = self
.services
.timeline
.get_pdu_from_id(&pduid)?
.ok_or_else(|| Error::bad_database("Invalid pduid reference in threadid_userids"))?;

View file

@ -2,7 +2,7 @@ mod data;
use std::{collections::BTreeMap, sync::Arc};
use conduit::{Error, Result};
use conduit::{Error, PduEvent, Result};
use data::Data;
use ruma::{
api::client::{error::ErrorKind, threads::get_threads::v1::IncludeThreads},
@ -11,16 +11,24 @@ use ruma::{
};
use serde_json::json;
use crate::{services, PduEvent};
use crate::{rooms, Dep};
pub struct Service {
services: Services,
db: Data,
}
struct Services {
timeline: Dep<rooms::timeline::Service>,
}
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db: Data::new(args.db),
services: Services {
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
},
db: Data::new(&args),
}))
}
@ -35,22 +43,22 @@ impl Service {
}
pub fn add_to_thread(&self, root_event_id: &EventId, pdu: &PduEvent) -> Result<()> {
let root_id = &services()
.rooms
let root_id = self
.services
.timeline
.get_pdu_id(root_event_id)?
.ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Invalid event id in thread message"))?;
let root_pdu = services()
.rooms
let root_pdu = self
.services
.timeline
.get_pdu_from_id(root_id)?
.get_pdu_from_id(&root_id)?
.ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Thread root pdu not found"))?;
let mut root_pdu_json = services()
.rooms
let mut root_pdu_json = self
.services
.timeline
.get_pdu_json_from_id(root_id)?
.get_pdu_json_from_id(&root_id)?
.ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Thread root pdu not found"))?;
if let CanonicalJsonValue::Object(unsigned) = root_pdu_json
@ -93,20 +101,19 @@ impl Service {
);
}
services()
.rooms
self.services
.timeline
.replace_pdu(root_id, &root_pdu_json, &root_pdu)?;
.replace_pdu(&root_id, &root_pdu_json, &root_pdu)?;
}
let mut users = Vec::new();
if let Some(userids) = self.db.get_participants(root_id)? {
if let Some(userids) = self.db.get_participants(&root_id)? {
users.extend_from_slice(&userids);
} else {
users.push(root_pdu.sender);
}
users.push(pdu.sender.clone());
self.db.update_participants(root_id, &users)
self.db.update_participants(&root_id, &users)
}
}

View file

@ -4,19 +4,25 @@ use std::{
sync::{Arc, Mutex},
};
use conduit::{checked, error, utils, Error, Result};
use conduit::{checked, error, utils, Error, PduCount, PduEvent, Result};
use database::{Database, Map};
use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId};
use crate::{services, PduCount, PduEvent};
use crate::{rooms, Dep};
pub(super) struct Data {
eventid_outlierpdu: Arc<Map>,
eventid_pduid: Arc<Map>,
pduid_pdu: Arc<Map>,
eventid_outlierpdu: Arc<Map>,
userroomid_notificationcount: Arc<Map>,
userroomid_highlightcount: Arc<Map>,
userroomid_notificationcount: Arc<Map>,
pub(super) lasttimelinecount_cache: LastTimelineCountCache,
pub(super) db: Arc<Database>,
services: Services,
}
struct Services {
short: Dep<rooms::short::Service>,
}
type PdusIterItem = Result<(PduCount, PduEvent)>;
@ -24,14 +30,19 @@ type PdusIterator<'a> = Box<dyn Iterator<Item = PdusIterItem> + 'a>;
type LastTimelineCountCache = Mutex<HashMap<OwnedRoomId, PduCount>>;
impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self {
pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self {
eventid_outlierpdu: db["eventid_outlierpdu"].clone(),
eventid_pduid: db["eventid_pduid"].clone(),
pduid_pdu: db["pduid_pdu"].clone(),
eventid_outlierpdu: db["eventid_outlierpdu"].clone(),
userroomid_notificationcount: db["userroomid_notificationcount"].clone(),
userroomid_highlightcount: db["userroomid_highlightcount"].clone(),
userroomid_notificationcount: db["userroomid_notificationcount"].clone(),
lasttimelinecount_cache: Mutex::new(HashMap::new()),
db: args.db.clone(),
services: Services {
short: args.depend::<rooms::short::Service>("rooms::short"),
},
}
}
@ -210,7 +221,7 @@ impl Data {
/// happened before the event with id `until` in reverse-chronological
/// order.
pub(super) fn pdus_until(&self, user_id: &UserId, room_id: &RoomId, until: PduCount) -> Result<PdusIterator<'_>> {
let (prefix, current) = count_to_id(room_id, until, 1, true)?;
let (prefix, current) = self.count_to_id(room_id, until, 1, true)?;
let user_id = user_id.to_owned();
@ -232,7 +243,7 @@ impl Data {
}
pub(super) fn pdus_after(&self, user_id: &UserId, room_id: &RoomId, from: PduCount) -> Result<PdusIterator<'_>> {
let (prefix, current) = count_to_id(room_id, from, 1, false)?;
let (prefix, current) = self.count_to_id(room_id, from, 1, false)?;
let user_id = user_id.to_owned();
@ -277,6 +288,41 @@ impl Data {
.increment_batch(highlights_batch.iter().map(Vec::as_slice))?;
Ok(())
}
pub(super) fn count_to_id(
&self, room_id: &RoomId, count: PduCount, offset: u64, subtract: bool,
) -> Result<(Vec<u8>, Vec<u8>)> {
let prefix = self
.services
.short
.get_shortroomid(room_id)?
.ok_or_else(|| Error::bad_database("Looked for bad shortroomid in timeline"))?
.to_be_bytes()
.to_vec();
let mut pdu_id = prefix.clone();
// +1 so we don't send the base event
let count_raw = match count {
PduCount::Normal(x) => {
if subtract {
x.saturating_sub(offset)
} else {
x.saturating_add(offset)
}
},
PduCount::Backfilled(x) => {
pdu_id.extend_from_slice(&0_u64.to_be_bytes());
let num = u64::MAX.saturating_sub(x);
if subtract {
num.saturating_sub(offset)
} else {
num.saturating_add(offset)
}
},
};
pdu_id.extend_from_slice(&count_raw.to_be_bytes());
Ok((prefix, pdu_id))
}
}
/// Returns the `count` of this pdu's id.
@ -294,38 +340,3 @@ pub(super) fn pdu_count(pdu_id: &[u8]) -> Result<PduCount> {
Ok(PduCount::Normal(last_u64))
}
}
pub(super) fn count_to_id(
room_id: &RoomId, count: PduCount, offset: u64, subtract: bool,
) -> Result<(Vec<u8>, Vec<u8>)> {
let prefix = services()
.rooms
.short
.get_shortroomid(room_id)?
.ok_or_else(|| Error::bad_database("Looked for bad shortroomid in timeline"))?
.to_be_bytes()
.to_vec();
let mut pdu_id = prefix.clone();
// +1 so we don't send the base event
let count_raw = match count {
PduCount::Normal(x) => {
if subtract {
x.saturating_sub(offset)
} else {
x.saturating_add(offset)
}
},
PduCount::Backfilled(x) => {
pdu_id.extend_from_slice(&0_u64.to_be_bytes());
let num = u64::MAX.saturating_sub(x);
if subtract {
num.saturating_sub(offset)
} else {
num.saturating_add(offset)
}
},
};
pdu_id.extend_from_slice(&count_raw.to_be_bytes());
Ok((prefix, pdu_id))
}

View file

@ -7,11 +7,12 @@ use std::{
};
use conduit::{
debug, error, info, utils,
debug, error, info,
pdu::{EventHash, PduBuilder, PduCount, PduEvent},
utils,
utils::{MutexMap, MutexMapGuard},
validated, warn, Error, Result,
validated, warn, Error, Result, Server,
};
use data::Data;
use itertools::Itertools;
use ruma::{
api::{client::error::ErrorKind, federation},
@ -37,11 +38,10 @@ use serde::Deserialize;
use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
use tokio::sync::RwLock;
use self::data::Data;
use crate::{
appservice::NamespaceRegex,
pdu::{EventHash, PduBuilder},
rooms::{event_handler::parse_incoming_pdu, state_compressor::CompressedStateEvent},
server_is_ours, services, PduCount, PduEvent,
account_data, admin, appservice, appservice::NamespaceRegex, globals, pusher, rooms,
rooms::state_compressor::CompressedStateEvent, sending, server_is_ours, Dep,
};
// Update Relationships
@ -67,17 +67,61 @@ struct ExtractBody {
}
pub struct Service {
services: Services,
db: Data,
pub mutex_insert: RoomMutexMap,
}
struct Services {
server: Arc<Server>,
account_data: Dep<account_data::Service>,
appservice: Dep<appservice::Service>,
admin: Dep<admin::Service>,
alias: Dep<rooms::alias::Service>,
globals: Dep<globals::Service>,
short: Dep<rooms::short::Service>,
state: Dep<rooms::state::Service>,
state_cache: Dep<rooms::state_cache::Service>,
state_accessor: Dep<rooms::state_accessor::Service>,
pdu_metadata: Dep<rooms::pdu_metadata::Service>,
read_receipt: Dep<rooms::read_receipt::Service>,
sending: Dep<sending::Service>,
user: Dep<rooms::user::Service>,
pusher: Dep<pusher::Service>,
threads: Dep<rooms::threads::Service>,
search: Dep<rooms::search::Service>,
spaces: Dep<rooms::spaces::Service>,
event_handler: Dep<rooms::event_handler::Service>,
}
type RoomMutexMap = MutexMap<OwnedRoomId, ()>;
pub type RoomMutexGuard = MutexMapGuard<OwnedRoomId, ()>;
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db: Data::new(args.db),
services: Services {
server: args.server.clone(),
account_data: args.depend::<account_data::Service>("account_data"),
appservice: args.depend::<appservice::Service>("appservice"),
admin: args.depend::<admin::Service>("admin"),
alias: args.depend::<rooms::alias::Service>("rooms::alias"),
globals: args.depend::<globals::Service>("globals"),
short: args.depend::<rooms::short::Service>("rooms::short"),
state: args.depend::<rooms::state::Service>("rooms::state"),
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
pdu_metadata: args.depend::<rooms::pdu_metadata::Service>("rooms::pdu_metadata"),
read_receipt: args.depend::<rooms::read_receipt::Service>("rooms::read_receipt"),
sending: args.depend::<sending::Service>("sending"),
user: args.depend::<rooms::user::Service>("rooms::user"),
pusher: args.depend::<pusher::Service>("pusher"),
threads: args.depend::<rooms::threads::Service>("rooms::threads"),
search: args.depend::<rooms::search::Service>("rooms::search"),
spaces: args.depend::<rooms::spaces::Service>("rooms::spaces"),
event_handler: args.depend::<rooms::event_handler::Service>("rooms::event_handler"),
},
db: Data::new(&args),
mutex_insert: RoomMutexMap::new(),
}))
}
@ -217,10 +261,10 @@ impl Service {
state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
) -> Result<Vec<u8>> {
// Coalesce database writes for the remainder of this scope.
let _cork = services().db.cork_and_flush();
let _cork = self.db.db.cork_and_flush();
let shortroomid = services()
.rooms
let shortroomid = self
.services
.short
.get_shortroomid(&pdu.room_id)?
.expect("room exists");
@ -233,14 +277,14 @@ impl Service {
.entry("unsigned".to_owned())
.or_insert_with(|| CanonicalJsonValue::Object(BTreeMap::default()))
{
if let Some(shortstatehash) = services()
.rooms
if let Some(shortstatehash) = self
.services
.state_accessor
.pdu_shortstatehash(&pdu.event_id)
.unwrap()
{
if let Some(prev_state) = services()
.rooms
if let Some(prev_state) = self
.services
.state_accessor
.state_get(shortstatehash, &pdu.kind.to_string().into(), state_key)
.unwrap()
@ -270,30 +314,26 @@ impl Service {
}
// We must keep track of all events that have been referenced.
services()
.rooms
self.services
.pdu_metadata
.mark_as_referenced(&pdu.room_id, &pdu.prev_events)?;
services()
.rooms
self.services
.state
.set_forward_extremities(&pdu.room_id, leaves, state_lock)?;
let insert_lock = self.mutex_insert.lock(&pdu.room_id).await;
let count1 = services().globals.next_count()?;
let count1 = self.services.globals.next_count()?;
// Mark as read first so the sending client doesn't get a notification even if
// appending fails
services()
.rooms
self.services
.read_receipt
.private_read_set(&pdu.room_id, &pdu.sender, count1)?;
services()
.rooms
self.services
.user
.reset_notification_counts(&pdu.sender, &pdu.room_id)?;
let count2 = services().globals.next_count()?;
let count2 = self.services.globals.next_count()?;
let mut pdu_id = shortroomid.to_be_bytes().to_vec();
pdu_id.extend_from_slice(&count2.to_be_bytes());
@ -303,8 +343,8 @@ impl Service {
drop(insert_lock);
// See if the event matches any known pushers
let power_levels: RoomPowerLevelsEventContent = services()
.rooms
let power_levels: RoomPowerLevelsEventContent = self
.services
.state_accessor
.room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")?
.map(|ev| {
@ -319,8 +359,8 @@ impl Service {
let mut notifies = Vec::new();
let mut highlights = Vec::new();
let mut push_target = services()
.rooms
let mut push_target = self
.services
.state_cache
.active_local_users_in_room(&pdu.room_id)
.collect_vec();
@ -341,7 +381,8 @@ impl Service {
continue;
}
let rules_for_user = services()
let rules_for_user = self
.services
.account_data
.get(None, user, GlobalAccountDataEventType::PushRules.to_string().into())?
.map(|event| {
@ -357,7 +398,7 @@ impl Service {
let mut notify = false;
for action in
services()
self.services
.pusher
.get_actions(user, &rules_for_user, &power_levels, &sync_pdu, &pdu.room_id)?
{
@ -378,8 +419,10 @@ impl Service {
highlights.push(user.clone());
}
for push_key in services().pusher.get_pushkeys(user) {
services().sending.send_pdu_push(&pdu_id, user, push_key?)?;
for push_key in self.services.pusher.get_pushkeys(user) {
self.services
.sending
.send_pdu_push(&pdu_id, user, push_key?)?;
}
}
@ -390,11 +433,11 @@ impl Service {
TimelineEventType::RoomRedaction => {
use RoomVersionId::*;
let room_version_id = services().rooms.state.get_room_version(&pdu.room_id)?;
let room_version_id = self.services.state.get_room_version(&pdu.room_id)?;
match room_version_id {
V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => {
if let Some(redact_id) = &pdu.redacts {
if services().rooms.state_accessor.user_can_redact(
if self.services.state_accessor.user_can_redact(
redact_id,
&pdu.sender,
&pdu.room_id,
@ -412,7 +455,7 @@ impl Service {
})?;
if let Some(redact_id) = &content.redacts {
if services().rooms.state_accessor.user_can_redact(
if self.services.state_accessor.user_can_redact(
redact_id,
&pdu.sender,
&pdu.room_id,
@ -433,8 +476,7 @@ impl Service {
},
TimelineEventType::SpaceChild => {
if let Some(_state_key) = &pdu.state_key {
services()
.rooms
self.services
.spaces
.roomid_spacehierarchy_cache
.lock()
@ -455,7 +497,7 @@ impl Service {
let invite_state = match content.membership {
MembershipState::Invite => {
let state = services().rooms.state.calculate_invite_state(pdu)?;
let state = self.services.state.calculate_invite_state(pdu)?;
Some(state)
},
_ => None,
@ -463,7 +505,7 @@ impl Service {
// Update our membership info, we do this here incase a user is invited
// and immediately leaves we need the DB to record the invite event for auth
services().rooms.state_cache.update_membership(
self.services.state_cache.update_membership(
&pdu.room_id,
&target_user_id,
content,
@ -479,13 +521,12 @@ impl Service {
.map_err(|_| Error::bad_database("Invalid content in pdu."))?;
if let Some(body) = content.body {
services()
.rooms
self.services
.search
.index_pdu(shortroomid, &pdu_id, &body)?;
if services().admin.is_admin_command(pdu, &body).await {
services()
if self.services.admin.is_admin_command(pdu, &body).await {
self.services
.admin
.command(body, Some((*pdu.event_id).into()))
.await;
@ -497,8 +538,7 @@ impl Service {
if let Ok(content) = serde_json::from_str::<ExtractRelatesToEventId>(pdu.content.get()) {
if let Some(related_pducount) = self.get_pdu_count(&content.relates_to.event_id)? {
services()
.rooms
self.services
.pdu_metadata
.add_relation(PduCount::Normal(count2), related_pducount)?;
}
@ -512,29 +552,25 @@ impl Service {
// We need to do it again here, because replies don't have
// event_id as a top level field
if let Some(related_pducount) = self.get_pdu_count(&in_reply_to.event_id)? {
services()
.rooms
self.services
.pdu_metadata
.add_relation(PduCount::Normal(count2), related_pducount)?;
}
},
Relation::Thread(thread) => {
services()
.rooms
.threads
.add_to_thread(&thread.event_id, pdu)?;
self.services.threads.add_to_thread(&thread.event_id, pdu)?;
},
_ => {}, // TODO: Aggregate other types
}
}
for appservice in services().appservice.read().await.values() {
if services()
.rooms
for appservice in self.services.appservice.read().await.values() {
if self
.services
.state_cache
.appservice_in_room(&pdu.room_id, appservice)?
{
services()
self.services
.sending
.send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?;
continue;
@ -550,7 +586,7 @@ impl Service {
{
let appservice_uid = appservice.registration.sender_localpart.as_str();
if state_key_uid == appservice_uid {
services()
self.services
.sending
.send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?;
continue;
@ -567,8 +603,7 @@ impl Service {
.map_or(false, |state_key| users.is_match(state_key))
};
let matching_aliases = |aliases: &NamespaceRegex| {
services()
.rooms
self.services
.alias
.local_aliases_for_room(&pdu.room_id)
.filter_map(Result::ok)
@ -579,7 +614,7 @@ impl Service {
|| appservice.rooms.is_match(pdu.room_id.as_str())
|| matching_users(&appservice.users)
{
services()
self.services
.sending
.send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?;
}
@ -603,8 +638,8 @@ impl Service {
redacts,
} = pdu_builder;
let prev_events: Vec<_> = services()
.rooms
let prev_events: Vec<_> = self
.services
.state
.get_forward_extremities(room_id)?
.into_iter()
@ -612,28 +647,23 @@ impl Service {
.collect();
// If there was no create event yet, assume we are creating a room
let room_version_id = services()
.rooms
.state
.get_room_version(room_id)
.or_else(|_| {
if event_type == TimelineEventType::RoomCreate {
let content = serde_json::from_str::<RoomCreateEventContent>(content.get())
.expect("Invalid content in RoomCreate pdu.");
Ok(content.room_version)
} else {
Err(Error::InconsistentRoomState(
"non-create event for room of unknown version",
room_id.to_owned(),
))
}
})?;
let room_version_id = self.services.state.get_room_version(room_id).or_else(|_| {
if event_type == TimelineEventType::RoomCreate {
let content = serde_json::from_str::<RoomCreateEventContent>(content.get())
.expect("Invalid content in RoomCreate pdu.");
Ok(content.room_version)
} else {
Err(Error::InconsistentRoomState(
"non-create event for room of unknown version",
room_id.to_owned(),
))
}
})?;
let room_version = RoomVersion::new(&room_version_id).expect("room version is supported");
let auth_events =
services()
.rooms
self.services
.state
.get_auth_events(room_id, &event_type, sender, state_key.as_deref(), &content)?;
@ -649,8 +679,7 @@ impl Service {
if let Some(state_key) = &state_key {
if let Some(prev_pdu) =
services()
.rooms
self.services
.state_accessor
.room_state_get(room_id, &event_type.to_string().into(), state_key)?
{
@ -730,12 +759,12 @@ impl Service {
// Add origin because synapse likes that (and it's required in the spec)
pdu_json.insert(
"origin".to_owned(),
to_canonical_value(services().globals.server_name()).expect("server name is a valid CanonicalJsonValue"),
to_canonical_value(self.services.globals.server_name()).expect("server name is a valid CanonicalJsonValue"),
);
match ruma::signatures::hash_and_sign_event(
services().globals.server_name().as_str(),
services().globals.keypair(),
self.services.globals.server_name().as_str(),
self.services.globals.keypair(),
&mut pdu_json,
&room_version_id,
) {
@ -763,8 +792,8 @@ impl Service {
);
// Generate short event id
let _shorteventid = services()
.rooms
let _shorteventid = self
.services
.short
.get_or_create_shorteventid(&pdu.event_id)?;
@ -783,7 +812,7 @@ impl Service {
state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
) -> Result<Arc<EventId>> {
let (pdu, pdu_json) = self.create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock)?;
if let Some(admin_room) = services().admin.get_admin_room()? {
if let Some(admin_room) = self.services.admin.get_admin_room()? {
if admin_room == room_id {
match pdu.event_type() {
TimelineEventType::RoomEncryption => {
@ -798,7 +827,7 @@ impl Service {
.state_key()
.filter(|v| v.starts_with('@'))
.unwrap_or(sender.as_str());
let server_user = &services().globals.server_user.to_string();
let server_user = &self.services.globals.server_user.to_string();
let content = serde_json::from_str::<RoomMemberEventContent>(pdu.content.get())
.map_err(|_| Error::bad_database("Invalid content in pdu"))?;
@ -812,8 +841,8 @@ impl Service {
));
}
let count = services()
.rooms
let count = self
.services
.state_cache
.room_members(room_id)
.filter_map(Result::ok)
@ -837,8 +866,8 @@ impl Service {
));
}
let count = services()
.rooms
let count = self
.services
.state_cache
.room_members(room_id)
.filter_map(Result::ok)
@ -861,15 +890,14 @@ impl Service {
// If redaction event is not authorized, do not append it to the timeline
if pdu.kind == TimelineEventType::RoomRedaction {
use RoomVersionId::*;
match services().rooms.state.get_room_version(&pdu.room_id)? {
match self.services.state.get_room_version(&pdu.room_id)? {
V1 | V2 | V3 | V4 | V5 | V6 | V7 | V8 | V9 | V10 => {
if let Some(redact_id) = &pdu.redacts {
if !services().rooms.state_accessor.user_can_redact(
redact_id,
&pdu.sender,
&pdu.room_id,
false,
)? {
if !self
.services
.state_accessor
.user_can_redact(redact_id, &pdu.sender, &pdu.room_id, false)?
{
return Err(Error::BadRequest(ErrorKind::forbidden(), "User cannot redact this event."));
}
};
@ -879,12 +907,11 @@ impl Service {
.map_err(|_| Error::bad_database("Invalid content in redaction pdu."))?;
if let Some(redact_id) = &content.redacts {
if !services().rooms.state_accessor.user_can_redact(
redact_id,
&pdu.sender,
&pdu.room_id,
false,
)? {
if !self
.services
.state_accessor
.user_can_redact(redact_id, &pdu.sender, &pdu.room_id, false)?
{
return Err(Error::BadRequest(ErrorKind::forbidden(), "User cannot redact this event."));
}
}
@ -895,7 +922,7 @@ impl Service {
// We append to state before appending the pdu, so we don't have a moment in
// time with the pdu without it's state. This is okay because append_pdu can't
// fail.
let statehashid = services().rooms.state.append_to_state(&pdu)?;
let statehashid = self.services.state.append_to_state(&pdu)?;
let pdu_id = self
.append_pdu(
@ -910,13 +937,12 @@ impl Service {
// We set the room state after inserting the pdu, so that we never have a moment
// in time where events in the current room state do not exist
services()
.rooms
self.services
.state
.set_room_state(room_id, statehashid, state_lock)?;
let mut servers: HashSet<OwnedServerName> = services()
.rooms
let mut servers: HashSet<OwnedServerName> = self
.services
.state_cache
.room_servers(room_id)
.filter_map(Result::ok)
@ -936,9 +962,9 @@ impl Service {
// Remove our server from the server list since it will be added to it by
// room_servers() and/or the if statement above
servers.remove(services().globals.server_name());
servers.remove(self.services.globals.server_name());
services()
self.services
.sending
.send_pdu_servers(servers.into_iter(), &pdu_id)?;
@ -960,18 +986,15 @@ impl Service {
// We append to state before appending the pdu, so we don't have a moment in
// time with the pdu without it's state. This is okay because append_pdu can't
// fail.
services()
.rooms
self.services
.state
.set_event_state(&pdu.event_id, &pdu.room_id, state_ids_compressed)?;
if soft_fail {
services()
.rooms
self.services
.pdu_metadata
.mark_as_referenced(&pdu.room_id, &pdu.prev_events)?;
services()
.rooms
self.services
.state
.set_forward_extremities(&pdu.room_id, new_room_leaves, state_lock)?;
return Ok(None);
@ -1022,14 +1045,13 @@ impl Service {
if let Ok(content) = serde_json::from_str::<ExtractBody>(pdu.content.get()) {
if let Some(body) = content.body {
services()
.rooms
self.services
.search
.deindex_pdu(shortroomid, &pdu_id, &body)?;
}
}
let room_version_id = services().rooms.state.get_room_version(&pdu.room_id)?;
let room_version_id = self.services.state.get_room_version(&pdu.room_id)?;
pdu.redact(room_version_id, reason)?;
@ -1058,8 +1080,8 @@ impl Service {
return Ok(());
}
let power_levels: RoomPowerLevelsEventContent = services()
.rooms
let power_levels: RoomPowerLevelsEventContent = self
.services
.state_accessor
.room_state_get(room_id, &StateEventType::RoomPowerLevels, "")?
.map(|ev| {
@ -1077,8 +1099,8 @@ impl Service {
}
});
let room_alias_servers = services()
.rooms
let room_alias_servers = self
.services
.alias
.local_aliases_for_room(room_id)
.filter_map(|alias| {
@ -1090,14 +1112,13 @@ impl Service {
let servers = room_mods
.chain(room_alias_servers)
.chain(services().globals.config.trusted_servers.clone())
.chain(self.services.server.config.trusted_servers.clone())
.filter(|server_name| {
if server_is_ours(server_name) {
return false;
}
services()
.rooms
self.services
.state_cache
.server_in_room(server_name, room_id)
.unwrap_or(false)
@ -1105,7 +1126,8 @@ impl Service {
for backfill_server in servers {
info!("Asking {backfill_server} for backfill");
let response = services()
let response = self
.services
.sending
.send_federation_request(
&backfill_server,
@ -1141,11 +1163,11 @@ impl Service {
&self, origin: &ServerName, pdu: Box<RawJsonValue>,
pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
) -> Result<()> {
let (event_id, value, room_id) = parse_incoming_pdu(&pdu)?;
let (event_id, value, room_id) = self.services.event_handler.parse_incoming_pdu(&pdu)?;
// Lock so we cannot backfill the same pdu twice at the same time
let mutex_lock = services()
.rooms
let mutex_lock = self
.services
.event_handler
.mutex_federation
.lock(&room_id)
@ -1158,14 +1180,12 @@ impl Service {
return Ok(());
}
services()
.rooms
self.services
.event_handler
.fetch_required_signing_keys([&value], pub_key_map)
.await?;
services()
.rooms
self.services
.event_handler
.handle_incoming_pdu(origin, &room_id, &event_id, value, false, pub_key_map)
.await?;
@ -1173,8 +1193,8 @@ impl Service {
let value = self.get_pdu_json(&event_id)?.expect("We just created it");
let pdu = self.get_pdu(&event_id)?.expect("We just created it");
let shortroomid = services()
.rooms
let shortroomid = self
.services
.short
.get_shortroomid(&room_id)?
.expect("room exists");
@ -1182,7 +1202,7 @@ impl Service {
let insert_lock = self.mutex_insert.lock(&room_id).await;
let max = u64::MAX;
let count = services().globals.next_count()?;
let count = self.services.globals.next_count()?;
let mut pdu_id = shortroomid.to_be_bytes().to_vec();
pdu_id.extend_from_slice(&0_u64.to_be_bytes());
pdu_id.extend_from_slice(&(validated!(max - count)?).to_be_bytes());
@ -1197,8 +1217,7 @@ impl Service {
.map_err(|_| Error::bad_database("Invalid content in pdu."))?;
if let Some(body) = content.body {
services()
.rooms
self.services
.search
.index_pdu(shortroomid, &pdu_id, &body)?;
}

View file

@ -1,6 +1,6 @@
use std::{collections::BTreeMap, sync::Arc};
use conduit::{debug_info, trace, utils, Result};
use conduit::{debug_info, trace, utils, Result, Server};
use ruma::{
api::federation::transactions::edu::{Edu, TypingContent},
events::SyncEphemeralRoomEvent,
@ -8,19 +8,31 @@ use ruma::{
};
use tokio::sync::{broadcast, RwLock};
use crate::{services, user_is_local};
use crate::{globals, sending, user_is_local, Dep};
pub struct Service {
pub typing: RwLock<BTreeMap<OwnedRoomId, BTreeMap<OwnedUserId, u64>>>, // u64 is unix timestamp of timeout
pub last_typing_update: RwLock<BTreeMap<OwnedRoomId, u64>>, /* timestamp of the last change to
* typing
* users */
server: Arc<Server>,
services: Services,
/// u64 is unix timestamp of timeout
pub typing: RwLock<BTreeMap<OwnedRoomId, BTreeMap<OwnedUserId, u64>>>,
/// timestamp of the last change to typing users
pub last_typing_update: RwLock<BTreeMap<OwnedRoomId, u64>>,
pub typing_update_sender: broadcast::Sender<OwnedRoomId>,
}
struct Services {
globals: Dep<globals::Service>,
sending: Dep<sending::Service>,
}
impl crate::Service for Service {
fn build(_args: crate::Args<'_>) -> Result<Arc<Self>> {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
server: args.server.clone(),
services: Services {
globals: args.depend::<globals::Service>("globals"),
sending: args.depend::<sending::Service>("sending"),
},
typing: RwLock::new(BTreeMap::new()),
last_typing_update: RwLock::new(BTreeMap::new()),
typing_update_sender: broadcast::channel(100).0,
@ -45,14 +57,14 @@ impl Service {
self.last_typing_update
.write()
.await
.insert(room_id.to_owned(), services().globals.next_count()?);
.insert(room_id.to_owned(), self.services.globals.next_count()?);
if self.typing_update_sender.send(room_id.to_owned()).is_err() {
trace!("receiver found what it was looking for and is no longer interested");
}
// update federation
if user_is_local(user_id) {
Self::federation_send(room_id, user_id, true)?;
self.federation_send(room_id, user_id, true)?;
}
Ok(())
@ -71,14 +83,14 @@ impl Service {
self.last_typing_update
.write()
.await
.insert(room_id.to_owned(), services().globals.next_count()?);
.insert(room_id.to_owned(), self.services.globals.next_count()?);
if self.typing_update_sender.send(room_id.to_owned()).is_err() {
trace!("receiver found what it was looking for and is no longer interested");
}
// update federation
if user_is_local(user_id) {
Self::federation_send(room_id, user_id, false)?;
self.federation_send(room_id, user_id, false)?;
}
Ok(())
@ -126,7 +138,7 @@ impl Service {
self.last_typing_update
.write()
.await
.insert(room_id.to_owned(), services().globals.next_count()?);
.insert(room_id.to_owned(), self.services.globals.next_count()?);
if self.typing_update_sender.send(room_id.to_owned()).is_err() {
trace!("receiver found what it was looking for and is no longer interested");
}
@ -134,7 +146,7 @@ impl Service {
// update federation
for user in removable {
if user_is_local(&user) {
Self::federation_send(room_id, &user, false)?;
self.federation_send(room_id, &user, false)?;
}
}
}
@ -171,15 +183,15 @@ impl Service {
})
}
fn federation_send(room_id: &RoomId, user_id: &UserId, typing: bool) -> Result<()> {
fn federation_send(&self, room_id: &RoomId, user_id: &UserId, typing: bool) -> Result<()> {
debug_assert!(user_is_local(user_id), "tried to broadcast typing status of remote user",);
if !services().globals.config.allow_outgoing_typing {
if !self.server.config.allow_outgoing_typing {
return Ok(());
}
let edu = Edu::Typing(TypingContent::new(room_id.to_owned(), user_id.to_owned(), typing));
services()
self.services
.sending
.send_edu_room(room_id, serde_json::to_vec(&edu).expect("Serialized Edu::Typing"))?;

View file

@ -1,10 +1,10 @@
use std::sync::Arc;
use conduit::{utils, Error, Result};
use database::{Database, Map};
use database::Map;
use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId};
use crate::services;
use crate::{globals, rooms, Dep};
pub(super) struct Data {
userroomid_notificationcount: Arc<Map>,
@ -12,16 +12,27 @@ pub(super) struct Data {
roomuserid_lastnotificationread: Arc<Map>,
roomsynctoken_shortstatehash: Arc<Map>,
userroomid_joined: Arc<Map>,
services: Services,
}
struct Services {
globals: Dep<globals::Service>,
short: Dep<rooms::short::Service>,
}
impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self {
pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self {
userroomid_notificationcount: db["userroomid_notificationcount"].clone(),
userroomid_highlightcount: db["userroomid_highlightcount"].clone(),
roomuserid_lastnotificationread: db["userroomid_highlightcount"].clone(), //< NOTE: known bug from conduit
roomsynctoken_shortstatehash: db["roomsynctoken_shortstatehash"].clone(),
userroomid_joined: db["userroomid_joined"].clone(),
services: Services {
globals: args.depend::<globals::Service>("globals"),
short: args.depend::<rooms::short::Service>("rooms::short"),
},
}
}
@ -39,7 +50,7 @@ impl Data {
.insert(&userroom_id, &0_u64.to_be_bytes())?;
self.roomuserid_lastnotificationread
.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?;
.insert(&roomuser_id, &self.services.globals.next_count()?.to_be_bytes())?;
Ok(())
}
@ -87,8 +98,8 @@ impl Data {
pub(super) fn associate_token_shortstatehash(
&self, room_id: &RoomId, token: u64, shortstatehash: u64,
) -> Result<()> {
let shortroomid = services()
.rooms
let shortroomid = self
.services
.short
.get_shortroomid(room_id)?
.expect("room exists");
@ -101,8 +112,8 @@ impl Data {
}
pub(super) fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> {
let shortroomid = services()
.rooms
let shortroomid = self
.services
.short
.get_shortroomid(room_id)?
.expect("room exists");

View file

@ -3,9 +3,10 @@ mod data;
use std::sync::Arc;
use conduit::Result;
use data::Data;
use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId};
use self::data::Data;
pub struct Service {
db: Data,
}
@ -13,7 +14,7 @@ pub struct Service {
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db: Data::new(args.db),
db: Data::new(&args),
}))
}

View file

@ -1,16 +1,17 @@
use std::{fmt::Debug, mem};
use bytes::BytesMut;
use conduit::{debug_error, trace, utils, warn, Error, Result};
use reqwest::Client;
use ruma::api::{appservice::Registration, IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken};
use tracing::{trace, warn};
use crate::{debug_error, services, utils, Error, Result};
/// Sends a request to an appservice
///
/// Only returns Ok(None) if there is no url specified in the appservice
/// registration file
pub(crate) async fn send_request<T>(registration: Registration, request: T) -> Result<Option<T::IncomingResponse>>
pub(crate) async fn send_request<T>(
client: &Client, registration: Registration, request: T,
) -> Result<Option<T::IncomingResponse>>
where
T: OutgoingRequest + Debug + Send,
{
@ -48,15 +49,10 @@ where
let reqwest_request = reqwest::Request::try_from(http_request)?;
let mut response = services()
.client
.appservice
.execute(reqwest_request)
.await
.map_err(|e| {
warn!("Could not send request to appservice \"{}\" at {dest}: {e}", registration.id);
e
})?;
let mut response = client.execute(reqwest_request).await.map_err(|e| {
warn!("Could not send request to appservice \"{}\" at {dest}: {e}", registration.id);
e
})?;
// reqwest::Response -> http::Response conversion
let status = response.status();

View file

@ -5,7 +5,7 @@ use database::{Database, Map};
use ruma::{ServerName, UserId};
use super::{Destination, SendingEvent};
use crate::services;
use crate::{globals, Dep};
type OutgoingSendingIter<'a> = Box<dyn Iterator<Item = Result<(Vec<u8>, Destination, SendingEvent)>> + 'a>;
type SendingEventIter<'a> = Box<dyn Iterator<Item = Result<(Vec<u8>, SendingEvent)>> + 'a>;
@ -15,15 +15,24 @@ pub struct Data {
servernameevent_data: Arc<Map>,
servername_educount: Arc<Map>,
pub(super) db: Arc<Database>,
services: Services,
}
struct Services {
globals: Dep<globals::Service>,
}
impl Data {
pub(super) fn new(db: Arc<Database>) -> Self {
pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self {
servercurrentevent_data: db["servercurrentevent_data"].clone(),
servernameevent_data: db["servernameevent_data"].clone(),
servername_educount: db["servername_educount"].clone(),
db,
db: args.db.clone(),
services: Services {
globals: args.depend::<globals::Service>("globals"),
},
}
}
@ -78,7 +87,7 @@ impl Data {
if let SendingEvent::Pdu(value) = &event {
key.extend_from_slice(value);
} else {
key.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
key.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes());
}
let value = if let SendingEvent::Edu(value) = &event {
&**value

View file

@ -6,26 +6,39 @@ mod sender;
use std::{fmt::Debug, sync::Arc};
use async_trait::async_trait;
use conduit::{err, Result, Server};
use conduit::{err, warn, Result, Server};
use ruma::{
api::{appservice::Registration, OutgoingRequest},
OwnedServerName, OwnedUserId, RoomId, ServerName, UserId,
};
pub use sender::convert_to_outgoing_federation_event;
use tokio::sync::Mutex;
use tracing::warn;
use crate::{server_is_ours, services};
use crate::{account_data, client, globals, presence, pusher, resolver, rooms, server_is_ours, users, Dep};
pub struct Service {
pub db: data::Data,
server: Arc<Server>,
/// The state for a given state hash.
services: Services,
pub db: data::Data,
sender: loole::Sender<Msg>,
receiver: Mutex<loole::Receiver<Msg>>,
}
struct Services {
client: Dep<client::Service>,
globals: Dep<globals::Service>,
resolver: Dep<resolver::Service>,
state: Dep<rooms::state::Service>,
state_cache: Dep<rooms::state_cache::Service>,
user: Dep<rooms::user::Service>,
users: Dep<users::Service>,
presence: Dep<presence::Service>,
read_receipt: Dep<rooms::read_receipt::Service>,
timeline: Dep<rooms::timeline::Service>,
account_data: Dep<account_data::Service>,
appservice: Dep<crate::appservice::Service>,
pusher: Dep<pusher::Service>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
struct Msg {
dest: Destination,
@ -53,8 +66,23 @@ impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
let (sender, receiver) = loole::unbounded();
Ok(Arc::new(Self {
db: data::Data::new(args.db.clone()),
server: args.server.clone(),
services: Services {
client: args.depend::<client::Service>("client"),
globals: args.depend::<globals::Service>("globals"),
resolver: args.depend::<resolver::Service>("resolver"),
state: args.depend::<rooms::state::Service>("rooms::state"),
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
user: args.depend::<rooms::user::Service>("rooms::user"),
users: args.depend::<users::Service>("users"),
presence: args.depend::<presence::Service>("presence"),
read_receipt: args.depend::<rooms::read_receipt::Service>("rooms::read_receipt"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
account_data: args.depend::<account_data::Service>("account_data"),
appservice: args.depend::<crate::appservice::Service>("appservice"),
pusher: args.depend::<pusher::Service>("pusher"),
},
db: data::Data::new(&args),
sender,
receiver: Mutex::new(receiver),
}))
@ -103,8 +131,8 @@ impl Service {
#[tracing::instrument(skip(self, room_id, pdu_id), level = "debug")]
pub fn send_pdu_room(&self, room_id: &RoomId, pdu_id: &[u8]) -> Result<()> {
let servers = services()
.rooms
let servers = self
.services
.state_cache
.room_servers(room_id)
.filter_map(Result::ok)
@ -152,8 +180,8 @@ impl Service {
#[tracing::instrument(skip(self, room_id, serialized), level = "debug")]
pub fn send_edu_room(&self, room_id: &RoomId, serialized: Vec<u8>) -> Result<()> {
let servers = services()
.rooms
let servers = self
.services
.state_cache
.room_servers(room_id)
.filter_map(Result::ok)
@ -189,8 +217,8 @@ impl Service {
#[tracing::instrument(skip(self, room_id), level = "debug")]
pub fn flush_room(&self, room_id: &RoomId) -> Result<()> {
let servers = services()
.rooms
let servers = self
.services
.state_cache
.room_servers(room_id)
.filter_map(Result::ok)
@ -213,13 +241,13 @@ impl Service {
Ok(())
}
#[tracing::instrument(skip(self, request), name = "request")]
#[tracing::instrument(skip_all, name = "request")]
pub async fn send_federation_request<T>(&self, dest: &ServerName, request: T) -> Result<T::IncomingResponse>
where
T: OutgoingRequest + Debug + Send,
{
let client = &services().client.federation;
send::send(client, dest, request).await
let client = &self.services.client.federation;
self.send(client, dest, request).await
}
/// Sends a request to an appservice
@ -232,7 +260,8 @@ impl Service {
where
T: OutgoingRequest + Debug + Send,
{
appservice::send_request(registration, request).await
let client = &self.services.client.appservice;
appservice::send_request(client, registration, request).await
}
/// Cleanup event data

View file

@ -1,6 +1,8 @@
use std::{fmt::Debug, mem};
use conduit::Err;
use conduit::{
debug, debug_error, debug_warn, error::inspect_debug_log, trace, utils::string::EMPTY, Err, Error, Result,
};
use http::{header::AUTHORIZATION, HeaderValue};
use ipaddress::IPAddress;
use reqwest::{Client, Method, Request, Response, Url};
@ -13,75 +15,91 @@ use ruma::{
server_util::authorization::XMatrix,
ServerName,
};
use tracing::{debug, trace};
use crate::{
debug_error, debug_warn, resolver,
globals, resolver,
resolver::{actual::ActualDest, cache::CachedDest},
services, Error, Result,
};
#[tracing::instrument(skip_all, name = "send")]
pub async fn send<T>(client: &Client, dest: &ServerName, req: T) -> Result<T::IncomingResponse>
where
T: OutgoingRequest + Debug + Send,
{
if !services().globals.allow_federation() {
return Err!(Config("allow_federation", "Federation is disabled."));
impl super::Service {
#[tracing::instrument(skip(self, client, req), name = "send")]
pub async fn send<T>(&self, client: &Client, dest: &ServerName, req: T) -> Result<T::IncomingResponse>
where
T: OutgoingRequest + Debug + Send,
{
if !self.server.config.allow_federation {
return Err!(Config("allow_federation", "Federation is disabled."));
}
let actual = self.services.resolver.get_actual_dest(dest).await?;
let request = self.prepare::<T>(dest, &actual, req).await?;
self.execute::<T>(dest, &actual, request, client).await
}
let actual = services().resolver.get_actual_dest(dest).await?;
let request = prepare::<T>(dest, &actual, req).await?;
execute::<T>(client, dest, &actual, request).await
}
async fn execute<T>(
&self, dest: &ServerName, actual: &ActualDest, request: Request, client: &Client,
) -> Result<T::IncomingResponse>
where
T: OutgoingRequest + Debug + Send,
{
let url = request.url().clone();
let method = request.method().clone();
async fn execute<T>(
client: &Client, dest: &ServerName, actual: &ActualDest, request: Request,
) -> Result<T::IncomingResponse>
where
T: OutgoingRequest + Debug + Send,
{
let method = request.method().clone();
let url = request.url().clone();
debug!(
method = ?method,
url = ?url,
"Sending request",
);
match client.execute(request).await {
Ok(response) => handle_response::<T>(dest, actual, &method, &url, response).await,
Err(e) => handle_error::<T>(dest, actual, &method, &url, e),
debug!(?method, ?url, "Sending request");
match client.execute(request).await {
Ok(response) => handle_response::<T>(&self.services.resolver, dest, actual, &method, &url, response).await,
Err(error) => handle_error::<T>(dest, actual, &method, &url, error),
}
}
}
async fn prepare<T>(dest: &ServerName, actual: &ActualDest, req: T) -> Result<Request>
where
T: OutgoingRequest + Debug + Send,
{
const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_5];
async fn prepare<T>(&self, dest: &ServerName, actual: &ActualDest, req: T) -> Result<Request>
where
T: OutgoingRequest + Debug + Send,
{
const VERSIONS: [MatrixVersion; 1] = [MatrixVersion::V1_5];
const SATIR: SendAccessToken<'_> = SendAccessToken::IfRequired(EMPTY);
trace!("Preparing request");
trace!("Preparing request");
let mut http_request = req
.try_into_http_request::<Vec<u8>>(&actual.string, SATIR, &VERSIONS)
.map_err(|_| Error::BadServerResponse("Invalid destination"))?;
let mut http_request = req
.try_into_http_request::<Vec<u8>>(&actual.string, SendAccessToken::IfRequired(""), &VERSIONS)
.map_err(|_e| Error::BadServerResponse("Invalid destination"))?;
sign_request::<T>(&self.services.globals, dest, &mut http_request);
sign_request::<T>(dest, &mut http_request);
let request = Request::try_from(http_request)?;
self.validate_url(request.url())?;
let request = Request::try_from(http_request)?;
validate_url(request.url())?;
Ok(request)
}
Ok(request)
fn validate_url(&self, url: &Url) -> Result<()> {
if let Some(url_host) = url.host_str() {
if let Ok(ip) = IPAddress::parse(url_host) {
trace!("Checking request URL IP {ip:?}");
self.services.resolver.validate_ip(&ip)?;
}
}
Ok(())
}
}
async fn handle_response<T>(
dest: &ServerName, actual: &ActualDest, method: &Method, url: &Url, mut response: Response,
resolver: &resolver::Service, dest: &ServerName, actual: &ActualDest, method: &Method, url: &Url,
mut response: Response,
) -> Result<T::IncomingResponse>
where
T: OutgoingRequest + Debug + Send,
{
trace!("Received response from {} for {} with {}", actual.string, url, response.url());
let status = response.status();
trace!(
?status, ?method,
request_url = ?url,
response_url = ?response.url(),
"Received response from {}",
actual.string,
);
let mut http_response_builder = http::Response::builder()
.status(status)
.version(response.version());
@ -92,11 +110,13 @@ where
.expect("http::response::Builder is usable"),
);
trace!("Waiting for response body");
let body = response.bytes().await.unwrap_or_else(|e| {
debug_error!("server error {}", e);
Vec::new().into()
}); // TODO: handle timeout
// TODO: handle timeout
trace!("Waiting for response body...");
let body = response
.bytes()
.await
.inspect_err(inspect_debug_log)
.unwrap_or_else(|_| Vec::new().into());
let http_response = http_response_builder
.body(body)
@ -109,7 +129,7 @@ where
let response = T::IncomingResponse::try_from_http_response(http_response);
if response.is_ok() && !actual.cached {
services().resolver.set_cached_destination(
resolver.set_cached_destination(
dest.to_owned(),
CachedDest {
dest: actual.dest.clone(),
@ -120,7 +140,7 @@ where
}
match response {
Err(_e) => Err(Error::BadServerResponse("Server returned bad 200 response.")),
Err(_) => Err(Error::BadServerResponse("Server returned bad 200 response.")),
Ok(response) => Ok(response),
}
}
@ -150,7 +170,7 @@ where
Err(e.into())
}
fn sign_request<T>(dest: &ServerName, http_request: &mut http::Request<Vec<u8>>)
fn sign_request<T>(globals: &globals::Service, dest: &ServerName, http_request: &mut http::Request<Vec<u8>>)
where
T: OutgoingRequest + Debug + Send,
{
@ -172,16 +192,12 @@ where
.to_string()
.into(),
);
req_map.insert("origin".to_owned(), services().globals.server_name().as_str().into());
req_map.insert("origin".to_owned(), globals.server_name().as_str().into());
req_map.insert("destination".to_owned(), dest.as_str().into());
let mut req_json = serde_json::from_value(req_map.into()).expect("valid JSON is valid BTreeMap");
ruma::signatures::sign_json(
services().globals.server_name().as_str(),
services().globals.keypair(),
&mut req_json,
)
.expect("our request json is what ruma expects");
ruma::signatures::sign_json(globals.server_name().as_str(), globals.keypair(), &mut req_json)
.expect("our request json is what ruma expects");
let req_json: serde_json::Map<String, serde_json::Value> =
serde_json::from_slice(&serde_json::to_vec(&req_json).unwrap()).unwrap();
@ -207,24 +223,8 @@ where
http_request.headers_mut().insert(
AUTHORIZATION,
HeaderValue::from(&XMatrix::new(
services().globals.config.server_name.clone(),
dest.to_owned(),
key,
sig,
)),
HeaderValue::from(&XMatrix::new(globals.config.server_name.clone(), dest.to_owned(), key, sig)),
);
}
}
}
fn validate_url(url: &Url) -> Result<()> {
if let Some(url_host) = url.host_str() {
if let Ok(ip) = IPAddress::parse(url_host) {
trace!("Checking request URL IP {ip:?}");
resolver::actual::validate_ip(&ip)?;
}
}
Ok(())
}

View file

@ -6,7 +6,11 @@ use std::{
};
use base64::{engine::general_purpose, Engine as _};
use conduit::{debug, debug_warn, error, trace, utils::math::continue_exponential_backoff_secs, warn};
use conduit::{
debug, debug_warn, error, trace,
utils::{calculate_hash, math::continue_exponential_backoff_secs},
warn, Error, Result,
};
use federation::transactions::send_transaction_message;
use futures_util::{future::BoxFuture, stream::FuturesUnordered, StreamExt};
use ruma::{
@ -24,8 +28,8 @@ use ruma::{
use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
use tokio::time::sleep_until;
use super::{appservice, send, Destination, Msg, SendingEvent, Service};
use crate::{presence::Presence, services, user_is_local, utils::calculate_hash, Error, Result};
use super::{appservice, Destination, Msg, SendingEvent, Service};
use crate::user_is_local;
#[derive(Debug)]
enum TransactionStatus {
@ -69,8 +73,8 @@ impl Service {
Ok(())
}
fn handle_response(
&self, response: SendingResult, futures: &SendingFutures<'_>, statuses: &mut CurTransactionStatus,
fn handle_response<'a>(
&'a self, response: SendingResult, futures: &SendingFutures<'a>, statuses: &mut CurTransactionStatus,
) {
match response {
Ok(dest) => self.handle_response_ok(&dest, futures, statuses),
@ -91,8 +95,8 @@ impl Service {
});
}
fn handle_response_ok(
&self, dest: &Destination, futures: &SendingFutures<'_>, statuses: &mut CurTransactionStatus,
fn handle_response_ok<'a>(
&'a self, dest: &Destination, futures: &SendingFutures<'a>, statuses: &mut CurTransactionStatus,
) {
let _cork = self.db.db.cork();
self.db
@ -113,24 +117,24 @@ impl Service {
.mark_as_active(&new_events)
.expect("marked as active");
let new_events_vec = new_events.into_iter().map(|(event, _)| event).collect();
futures.push(Box::pin(send_events(dest.clone(), new_events_vec)));
futures.push(Box::pin(self.send_events(dest.clone(), new_events_vec)));
} else {
statuses.remove(dest);
}
}
fn handle_request(&self, msg: Msg, futures: &SendingFutures<'_>, statuses: &mut CurTransactionStatus) {
fn handle_request<'a>(&'a self, msg: Msg, futures: &SendingFutures<'a>, statuses: &mut CurTransactionStatus) {
let iv = vec![(msg.event, msg.queue_id)];
if let Ok(Some(events)) = self.select_events(&msg.dest, iv, statuses) {
if !events.is_empty() {
futures.push(Box::pin(send_events(msg.dest, events)));
futures.push(Box::pin(self.send_events(msg.dest, events)));
} else {
statuses.remove(&msg.dest);
}
}
}
async fn finish_responses(&self, futures: &mut SendingFutures<'_>, statuses: &mut CurTransactionStatus) {
async fn finish_responses<'a>(&'a self, futures: &mut SendingFutures<'a>, statuses: &mut CurTransactionStatus) {
let now = Instant::now();
let timeout = Duration::from_millis(CLEANUP_TIMEOUT_MS);
let deadline = now.checked_add(timeout).unwrap_or(now);
@ -148,7 +152,7 @@ impl Service {
debug_warn!("Leaving with {} unfinished requests...", futures.len());
}
fn initial_requests(&self, futures: &SendingFutures<'_>, statuses: &mut CurTransactionStatus) {
fn initial_requests<'a>(&'a self, futures: &SendingFutures<'a>, statuses: &mut CurTransactionStatus) {
let keep = usize::try_from(self.server.config.startup_netburst_keep).unwrap_or(usize::MAX);
let mut txns = HashMap::<Destination, Vec<SendingEvent>>::new();
for (key, dest, event) in self.db.active_requests().filter_map(Result::ok) {
@ -166,12 +170,12 @@ impl Service {
for (dest, events) in txns {
if self.server.config.startup_netburst && !events.is_empty() {
statuses.insert(dest.clone(), TransactionStatus::Running);
futures.push(Box::pin(send_events(dest.clone(), events)));
futures.push(Box::pin(self.send_events(dest.clone(), events)));
}
}
}
#[tracing::instrument(skip_all)]
#[tracing::instrument(skip_all, level = "debug")]
fn select_events(
&self,
dest: &Destination,
@ -218,7 +222,7 @@ impl Service {
Ok(Some(events))
}
#[tracing::instrument(skip_all)]
#[tracing::instrument(skip_all, level = "debug")]
fn select_events_current(&self, dest: Destination, statuses: &mut CurTransactionStatus) -> Result<(bool, bool)> {
let (mut allow, mut retry) = (true, false);
statuses
@ -244,7 +248,7 @@ impl Service {
Ok((allow, retry))
}
#[tracing::instrument(skip_all)]
#[tracing::instrument(skip_all, level = "debug")]
fn select_edus(&self, server_name: &ServerName) -> Result<(Vec<Vec<u8>>, u64)> {
// u64: count of last edu
let since = self.db.get_latest_educount(server_name)?;
@ -252,11 +256,11 @@ impl Service {
let mut max_edu_count = since;
let mut device_list_changes = HashSet::new();
for room_id in services().rooms.state_cache.server_rooms(server_name) {
for room_id in self.services.state_cache.server_rooms(server_name) {
let room_id = room_id?;
// Look for device list updates in this room
device_list_changes.extend(
services()
self.services
.users
.keys_changed(room_id.as_ref(), since, None)
.filter_map(Result::ok)
@ -264,7 +268,7 @@ impl Service {
);
if self.server.config.allow_outgoing_read_receipts
&& !select_edus_receipts(&room_id, since, &mut max_edu_count, &mut events)?
&& !self.select_edus_receipts(&room_id, since, &mut max_edu_count, &mut events)?
{
break;
}
@ -287,381 +291,390 @@ impl Service {
}
if self.server.config.allow_outgoing_presence {
select_edus_presence(server_name, since, &mut max_edu_count, &mut events)?;
self.select_edus_presence(server_name, since, &mut max_edu_count, &mut events)?;
}
Ok((events, max_edu_count))
}
}
/// Look for presence
fn select_edus_presence(
server_name: &ServerName, since: u64, max_edu_count: &mut u64, events: &mut Vec<Vec<u8>>,
) -> Result<bool> {
// Look for presence updates for this server
let mut presence_updates = Vec::new();
for (user_id, count, presence_bytes) in services().presence.presence_since(since) {
*max_edu_count = cmp::max(count, *max_edu_count);
/// Look for presence
fn select_edus_presence(
&self, server_name: &ServerName, since: u64, max_edu_count: &mut u64, events: &mut Vec<Vec<u8>>,
) -> Result<bool> {
// Look for presence updates for this server
let mut presence_updates = Vec::new();
for (user_id, count, presence_bytes) in self.services.presence.presence_since(since) {
*max_edu_count = cmp::max(count, *max_edu_count);
if !user_is_local(&user_id) {
continue;
}
if !user_is_local(&user_id) {
continue;
}
if !services()
.rooms
.state_cache
.server_sees_user(server_name, &user_id)?
{
continue;
}
if !self
.services
.state_cache
.server_sees_user(server_name, &user_id)?
{
continue;
}
let presence_event = Presence::from_json_bytes_to_event(&presence_bytes, &user_id)?;
presence_updates.push(PresenceUpdate {
user_id,
presence: presence_event.content.presence,
currently_active: presence_event.content.currently_active.unwrap_or(false),
last_active_ago: presence_event
.content
.last_active_ago
.unwrap_or_else(|| uint!(0)),
status_msg: presence_event.content.status_msg,
});
if presence_updates.len() >= SELECT_EDU_LIMIT {
break;
}
}
if !presence_updates.is_empty() {
let presence_content = Edu::Presence(PresenceContent::new(presence_updates));
events.push(serde_json::to_vec(&presence_content).expect("PresenceEvent can be serialized"));
}
Ok(true)
}
/// Look for read receipts in this room
fn select_edus_receipts(
room_id: &RoomId, since: u64, max_edu_count: &mut u64, events: &mut Vec<Vec<u8>>,
) -> Result<bool> {
for r in services()
.rooms
.read_receipt
.readreceipts_since(room_id, since)
{
let (user_id, count, read_receipt) = r?;
*max_edu_count = cmp::max(count, *max_edu_count);
if !user_is_local(&user_id) {
continue;
}
let event = serde_json::from_str(read_receipt.json().get())
.map_err(|_| Error::bad_database("Invalid edu event in read_receipts."))?;
let federation_event = if let AnySyncEphemeralRoomEvent::Receipt(r) = event {
let mut read = BTreeMap::new();
let (event_id, mut receipt) = r
.content
.0
.into_iter()
.next()
.expect("we only use one event per read receipt");
let receipt = receipt
.remove(&ReceiptType::Read)
.expect("our read receipts always set this")
.remove(&user_id)
.expect("our read receipts always have the user here");
read.insert(
let presence_event = self
.services
.presence
.from_json_bytes_to_event(&presence_bytes, &user_id)?;
presence_updates.push(PresenceUpdate {
user_id,
ReceiptData {
data: receipt.clone(),
event_ids: vec![event_id.clone()],
},
);
presence: presence_event.content.presence,
currently_active: presence_event.content.currently_active.unwrap_or(false),
last_active_ago: presence_event
.content
.last_active_ago
.unwrap_or_else(|| uint!(0)),
status_msg: presence_event.content.status_msg,
});
let receipt_map = ReceiptMap {
read,
};
let mut receipts = BTreeMap::new();
receipts.insert(room_id.to_owned(), receipt_map);
Edu::Receipt(ReceiptContent {
receipts,
})
} else {
Error::bad_database("Invalid event type in read_receipts");
continue;
};
events.push(serde_json::to_vec(&federation_event).expect("json can be serialized"));
if events.len() >= SELECT_EDU_LIMIT {
return Ok(false);
}
}
Ok(true)
}
async fn send_events(dest: Destination, events: Vec<SendingEvent>) -> SendingResult {
//debug_assert!(!events.is_empty(), "sending empty transaction");
match dest {
Destination::Normal(ref server) => send_events_dest_normal(&dest, server, events).await,
Destination::Appservice(ref id) => send_events_dest_appservice(&dest, id, events).await,
Destination::Push(ref userid, ref pushkey) => send_events_dest_push(&dest, userid, pushkey, events).await,
}
}
#[tracing::instrument(skip(dest, events))]
async fn send_events_dest_appservice(dest: &Destination, id: &str, events: Vec<SendingEvent>) -> SendingResult {
let mut pdu_jsons = Vec::new();
for event in &events {
match event {
SendingEvent::Pdu(pdu_id) => {
pdu_jsons.push(
services()
.rooms
.timeline
.get_pdu_from_id(pdu_id)
.map_err(|e| (dest.clone(), e))?
.ok_or_else(|| {
(
dest.clone(),
Error::bad_database("[Appservice] Event in servernameevent_data not found in db."),
)
})?
.to_room_event(),
);
},
SendingEvent::Edu(_) | SendingEvent::Flush => {
// Appservices don't need EDUs (?) and flush only;
// no new content
},
}
}
//debug_assert!(!pdu_jsons.is_empty(), "sending empty transaction");
match appservice::send_request(
services()
.appservice
.get_registration(id)
.await
.ok_or_else(|| {
(
dest.clone(),
Error::bad_database("[Appservice] Could not load registration from db."),
)
})?,
ruma::api::appservice::event::push_events::v1::Request {
events: pdu_jsons,
txn_id: (&*general_purpose::URL_SAFE_NO_PAD.encode(calculate_hash(
&events
.iter()
.map(|e| match e {
SendingEvent::Edu(b) | SendingEvent::Pdu(b) => &**b,
SendingEvent::Flush => &[],
})
.collect::<Vec<_>>(),
)))
.into(),
},
)
.await
{
Ok(_) => Ok(dest.clone()),
Err(e) => Err((dest.clone(), e)),
}
}
#[tracing::instrument(skip(dest, events))]
async fn send_events_dest_push(
dest: &Destination, userid: &OwnedUserId, pushkey: &str, events: Vec<SendingEvent>,
) -> SendingResult {
let mut pdus = Vec::new();
for event in &events {
match event {
SendingEvent::Pdu(pdu_id) => {
pdus.push(
services()
.rooms
.timeline
.get_pdu_from_id(pdu_id)
.map_err(|e| (dest.clone(), e))?
.ok_or_else(|| {
(
dest.clone(),
Error::bad_database("[Push] Event in servernameevent_data not found in db."),
)
})?,
);
},
SendingEvent::Edu(_) | SendingEvent::Flush => {
// Push gateways don't need EDUs (?) and flush only;
// no new content
},
}
}
for pdu in pdus {
// Redacted events are not notification targets (we don't send push for them)
if let Some(unsigned) = &pdu.unsigned {
if let Ok(unsigned) = serde_json::from_str::<serde_json::Value>(unsigned.get()) {
if unsigned.get("redacted_because").is_some() {
continue;
}
if presence_updates.len() >= SELECT_EDU_LIMIT {
break;
}
}
let Some(pusher) = services()
.pusher
.get_pusher(userid, pushkey)
.map_err(|e| (dest.clone(), e))?
else {
continue;
};
if !presence_updates.is_empty() {
let presence_content = Edu::Presence(PresenceContent::new(presence_updates));
events.push(serde_json::to_vec(&presence_content).expect("PresenceEvent can be serialized"));
}
let rules_for_user = services()
.account_data
.get(None, userid, GlobalAccountDataEventType::PushRules.to_string().into())
.unwrap_or_default()
.and_then(|event| serde_json::from_str::<PushRulesEvent>(event.get()).ok())
.map_or_else(|| push::Ruleset::server_default(userid), |ev: PushRulesEvent| ev.content.global);
let unread: UInt = services()
.rooms
.user
.notification_count(userid, &pdu.room_id)
.map_err(|e| (dest.clone(), e))?
.try_into()
.expect("notification count can't go that high");
let _response = services()
.pusher
.send_push_notice(userid, unread, &pusher, rules_for_user, &pdu)
.await
.map(|_response| dest.clone())
.map_err(|e| (dest.clone(), e));
Ok(true)
}
Ok(dest.clone())
}
/// Look for read receipts in this room
fn select_edus_receipts(
&self, room_id: &RoomId, since: u64, max_edu_count: &mut u64, events: &mut Vec<Vec<u8>>,
) -> Result<bool> {
for r in self
.services
.read_receipt
.readreceipts_since(room_id, since)
{
let (user_id, count, read_receipt) = r?;
*max_edu_count = cmp::max(count, *max_edu_count);
#[tracing::instrument(skip(dest, events), name = "")]
async fn send_events_dest_normal(
dest: &Destination, server: &OwnedServerName, events: Vec<SendingEvent>,
) -> SendingResult {
let mut pdu_jsons = Vec::with_capacity(
events
.iter()
.filter(|event| matches!(event, SendingEvent::Pdu(_)))
.count(),
);
let mut edu_jsons = Vec::with_capacity(
events
.iter()
.filter(|event| matches!(event, SendingEvent::Edu(_)))
.count(),
);
if !user_is_local(&user_id) {
continue;
}
for event in &events {
match event {
SendingEvent::Pdu(pdu_id) => pdu_jsons.push(convert_to_outgoing_federation_event(
// TODO: check room version and remove event_id if needed
services()
.rooms
.timeline
.get_pdu_json_from_id(pdu_id)
.map_err(|e| (dest.clone(), e))?
.ok_or_else(|| {
error!(?dest, ?server, ?pdu_id, "event not found");
(
dest.clone(),
Error::bad_database("[Normal] Event in servernameevent_data not found in db."),
)
})?,
)),
SendingEvent::Edu(edu) => {
if let Ok(raw) = serde_json::from_slice(edu) {
edu_jsons.push(raw);
}
},
SendingEvent::Flush => {
// flush only; no new content
let event = serde_json::from_str(read_receipt.json().get())
.map_err(|_| Error::bad_database("Invalid edu event in read_receipts."))?;
let federation_event = if let AnySyncEphemeralRoomEvent::Receipt(r) = event {
let mut read = BTreeMap::new();
let (event_id, mut receipt) = r
.content
.0
.into_iter()
.next()
.expect("we only use one event per read receipt");
let receipt = receipt
.remove(&ReceiptType::Read)
.expect("our read receipts always set this")
.remove(&user_id)
.expect("our read receipts always have the user here");
read.insert(
user_id,
ReceiptData {
data: receipt.clone(),
event_ids: vec![event_id.clone()],
},
);
let receipt_map = ReceiptMap {
read,
};
let mut receipts = BTreeMap::new();
receipts.insert(room_id.to_owned(), receipt_map);
Edu::Receipt(ReceiptContent {
receipts,
})
} else {
Error::bad_database("Invalid event type in read_receipts");
continue;
};
events.push(serde_json::to_vec(&federation_event).expect("json can be serialized"));
if events.len() >= SELECT_EDU_LIMIT {
return Ok(false);
}
}
Ok(true)
}
async fn send_events(&self, dest: Destination, events: Vec<SendingEvent>) -> SendingResult {
//debug_assert!(!events.is_empty(), "sending empty transaction");
match dest {
Destination::Normal(ref server) => self.send_events_dest_normal(&dest, server, events).await,
Destination::Appservice(ref id) => self.send_events_dest_appservice(&dest, id, events).await,
Destination::Push(ref userid, ref pushkey) => {
self.send_events_dest_push(&dest, userid, pushkey, events)
.await
},
}
}
//debug_assert!(pdu_jsons.len() + edu_jsons.len() > 0, "sending empty
// transaction");
send::send(
&services().client.sender,
server,
send_transaction_message::v1::Request {
origin: services().server.config.server_name.clone(),
#[tracing::instrument(skip(self, dest, events), name = "appservice")]
async fn send_events_dest_appservice(
&self, dest: &Destination, id: &str, events: Vec<SendingEvent>,
) -> SendingResult {
let mut pdu_jsons = Vec::new();
for event in &events {
match event {
SendingEvent::Pdu(pdu_id) => {
pdu_jsons.push(
self.services
.timeline
.get_pdu_from_id(pdu_id)
.map_err(|e| (dest.clone(), e))?
.ok_or_else(|| {
(
dest.clone(),
Error::bad_database("[Appservice] Event in servernameevent_data not found in db."),
)
})?
.to_room_event(),
);
},
SendingEvent::Edu(_) | SendingEvent::Flush => {
// Appservices don't need EDUs (?) and flush only;
// no new content
},
}
}
//debug_assert!(!pdu_jsons.is_empty(), "sending empty transaction");
let client = &self.services.client.appservice;
match appservice::send_request(
client,
self.services
.appservice
.get_registration(id)
.await
.ok_or_else(|| {
(
dest.clone(),
Error::bad_database("[Appservice] Could not load registration from db."),
)
})?,
ruma::api::appservice::event::push_events::v1::Request {
events: pdu_jsons,
txn_id: (&*general_purpose::URL_SAFE_NO_PAD.encode(calculate_hash(
&events
.iter()
.map(|e| match e {
SendingEvent::Edu(b) | SendingEvent::Pdu(b) => &**b,
SendingEvent::Flush => &[],
})
.collect::<Vec<_>>(),
)))
.into(),
},
)
.await
{
Ok(_) => Ok(dest.clone()),
Err(e) => Err((dest.clone(), e)),
}
}
#[tracing::instrument(skip(self, dest, events), name = "push")]
async fn send_events_dest_push(
&self, dest: &Destination, userid: &OwnedUserId, pushkey: &str, events: Vec<SendingEvent>,
) -> SendingResult {
let mut pdus = Vec::new();
for event in &events {
match event {
SendingEvent::Pdu(pdu_id) => {
pdus.push(
self.services
.timeline
.get_pdu_from_id(pdu_id)
.map_err(|e| (dest.clone(), e))?
.ok_or_else(|| {
(
dest.clone(),
Error::bad_database("[Push] Event in servernameevent_data not found in db."),
)
})?,
);
},
SendingEvent::Edu(_) | SendingEvent::Flush => {
// Push gateways don't need EDUs (?) and flush only;
// no new content
},
}
}
for pdu in pdus {
// Redacted events are not notification targets (we don't send push for them)
if let Some(unsigned) = &pdu.unsigned {
if let Ok(unsigned) = serde_json::from_str::<serde_json::Value>(unsigned.get()) {
if unsigned.get("redacted_because").is_some() {
continue;
}
}
}
let Some(pusher) = self
.services
.pusher
.get_pusher(userid, pushkey)
.map_err(|e| (dest.clone(), e))?
else {
continue;
};
let rules_for_user = self
.services
.account_data
.get(None, userid, GlobalAccountDataEventType::PushRules.to_string().into())
.unwrap_or_default()
.and_then(|event| serde_json::from_str::<PushRulesEvent>(event.get()).ok())
.map_or_else(|| push::Ruleset::server_default(userid), |ev: PushRulesEvent| ev.content.global);
let unread: UInt = self
.services
.user
.notification_count(userid, &pdu.room_id)
.map_err(|e| (dest.clone(), e))?
.try_into()
.expect("notification count can't go that high");
let _response = self
.services
.pusher
.send_push_notice(userid, unread, &pusher, rules_for_user, &pdu)
.await
.map(|_response| dest.clone())
.map_err(|e| (dest.clone(), e));
}
Ok(dest.clone())
}
#[tracing::instrument(skip(self, dest, events), name = "", level = "debug")]
async fn send_events_dest_normal(
&self, dest: &Destination, server: &OwnedServerName, events: Vec<SendingEvent>,
) -> SendingResult {
let mut pdu_jsons = Vec::with_capacity(
events
.iter()
.filter(|event| matches!(event, SendingEvent::Pdu(_)))
.count(),
);
let mut edu_jsons = Vec::with_capacity(
events
.iter()
.filter(|event| matches!(event, SendingEvent::Edu(_)))
.count(),
);
for event in &events {
match event {
// TODO: check room version and remove event_id if needed
SendingEvent::Pdu(pdu_id) => pdu_jsons.push(
self.convert_to_outgoing_federation_event(
self.services
.timeline
.get_pdu_json_from_id(pdu_id)
.map_err(|e| (dest.clone(), e))?
.ok_or_else(|| {
error!(?dest, ?server, ?pdu_id, "event not found");
(
dest.clone(),
Error::bad_database("[Normal] Event in servernameevent_data not found in db."),
)
})?,
),
),
SendingEvent::Edu(edu) => {
if let Ok(raw) = serde_json::from_slice(edu) {
edu_jsons.push(raw);
}
},
SendingEvent::Flush => {}, // flush only; no new content
}
}
//debug_assert!(pdu_jsons.len() + edu_jsons.len() > 0, "sending empty
// transaction");
let transaction_id = &*general_purpose::URL_SAFE_NO_PAD.encode(calculate_hash(
&events
.iter()
.map(|e| match e {
SendingEvent::Edu(b) | SendingEvent::Pdu(b) => &**b,
SendingEvent::Flush => &[],
})
.collect::<Vec<_>>(),
));
let request = send_transaction_message::v1::Request {
origin: self.server.config.server_name.clone(),
pdus: pdu_jsons,
edus: edu_jsons,
origin_server_ts: MilliSecondsSinceUnixEpoch::now(),
transaction_id: (&*general_purpose::URL_SAFE_NO_PAD.encode(calculate_hash(
&events
transaction_id: transaction_id.into(),
};
let client = &self.services.client.sender;
self.send(client, server, request)
.await
.inspect(|response| {
response
.pdus
.iter()
.map(|e| match e {
SendingEvent::Edu(b) | SendingEvent::Pdu(b) => &**b,
SendingEvent::Flush => &[],
})
.collect::<Vec<_>>(),
)))
.into(),
},
)
.await
.map(|response| {
for pdu in response.pdus {
if pdu.1.is_err() {
warn!("error for {} from remote: {:?}", pdu.0, pdu.1);
.filter(|(_, res)| res.is_err())
.for_each(|(pdu_id, res)| warn!("error for {pdu_id} from remote: {res:?}"));
})
.map(|_| dest.clone())
.map_err(|e| (dest.clone(), e))
}
/// This does not return a full `Pdu` it is only to satisfy ruma's types.
pub fn convert_to_outgoing_federation_event(&self, mut pdu_json: CanonicalJsonObject) -> Box<RawJsonValue> {
if let Some(unsigned) = pdu_json
.get_mut("unsigned")
.and_then(|val| val.as_object_mut())
{
unsigned.remove("transaction_id");
}
// room v3 and above removed the "event_id" field from remote PDU format
if let Some(room_id) = pdu_json
.get("room_id")
.and_then(|val| RoomId::parse(val.as_str()?).ok())
{
match self.services.state.get_room_version(&room_id) {
Ok(room_version_id) => match room_version_id {
RoomVersionId::V1 | RoomVersionId::V2 => {},
_ => _ = pdu_json.remove("event_id"),
},
Err(_) => _ = pdu_json.remove("event_id"),
}
} else {
pdu_json.remove("event_id");
}
dest.clone()
})
.map_err(|e| (dest.clone(), e))
}
/// This does not return a full `Pdu` it is only to satisfy ruma's types.
#[tracing::instrument]
pub fn convert_to_outgoing_federation_event(mut pdu_json: CanonicalJsonObject) -> Box<RawJsonValue> {
if let Some(unsigned) = pdu_json
.get_mut("unsigned")
.and_then(|val| val.as_object_mut())
{
unsigned.remove("transaction_id");
// TODO: another option would be to convert it to a canonical string to validate
// size and return a Result<Raw<...>>
// serde_json::from_str::<Raw<_>>(
// ruma::serde::to_canonical_json_string(pdu_json).expect("CanonicalJson is
// valid serde_json::Value"), )
// .expect("Raw::from_value always works")
to_raw_value(&pdu_json).expect("CanonicalJson is valid serde_json::Value")
}
// room v3 and above removed the "event_id" field from remote PDU format
if let Some(room_id) = pdu_json
.get("room_id")
.and_then(|val| RoomId::parse(val.as_str()?).ok())
{
match services().rooms.state.get_room_version(&room_id) {
Ok(room_version_id) => match room_version_id {
RoomVersionId::V1 | RoomVersionId::V2 => {},
_ => _ = pdu_json.remove("event_id"),
},
Err(_) => _ = pdu_json.remove("event_id"),
}
} else {
pdu_json.remove("event_id");
}
// TODO: another option would be to convert it to a canonical string to validate
// size and return a Result<Raw<...>>
// serde_json::from_str::<Raw<_>>(
// ruma::serde::to_canonical_json_string(pdu_json).expect("CanonicalJson is
// valid serde_json::Value"), )
// .expect("Raw::from_value always works")
to_raw_value(&pdu_json).expect("CanonicalJson is valid serde_json::Value")
}

View file

@ -3,7 +3,7 @@ use std::{
collections::BTreeMap,
fmt::Write,
ops::Deref,
sync::{Arc, OnceLock},
sync::{Arc, OnceLock, RwLock},
};
use async_trait::async_trait;
@ -50,20 +50,20 @@ pub(crate) struct Args<'a> {
}
/// Dep is a reference to a service used within another service.
/// Circular-dependencies between services require this indirection to allow the
/// referenced service construction after the referencing service.
/// Circular-dependencies between services require this indirection.
pub(crate) struct Dep<T> {
dep: OnceLock<Arc<T>>,
service: Arc<Map>,
name: &'static str,
}
pub(crate) type Map = BTreeMap<String, MapVal>;
pub(crate) type Map = RwLock<BTreeMap<String, MapVal>>;
pub(crate) type MapVal = (Arc<dyn Service>, Arc<dyn Any + Send + Sync>);
impl<T: Any + Send + Sync> Deref for Dep<T> {
impl<T: Send + Sync + 'static> Deref for Dep<T> {
type Target = Arc<T>;
/// Dereference a dependency. The dependency must be ready or panics.
fn deref(&self) -> &Self::Target {
self.dep
.get_or_init(|| require::<T>(&self.service, self.name))
@ -71,39 +71,61 @@ impl<T: Any + Send + Sync> Deref for Dep<T> {
}
impl Args<'_> {
pub(crate) fn depend_service<T: Any + Send + Sync>(&self, name: &'static str) -> Dep<T> {
/// Create a lazy-reference to a service when constructing another Service.
pub(crate) fn depend<T: Send + Sync + 'static>(&self, name: &'static str) -> Dep<T> {
Dep::<T> {
dep: OnceLock::new(),
service: self.service.clone(),
name,
}
}
/// Create a reference immediately to a service when constructing another
/// Service. The other service must be constructed.
pub(crate) fn require<T: Send + Sync + 'static>(&self, name: &str) -> Arc<T> { require::<T>(self.service, name) }
}
pub(crate) fn require<T: Any + Send + Sync>(map: &Map, name: &str) -> Arc<T> {
/// Reference a Service by name. Panics if the Service does not exist or was
/// incorrectly cast.
pub(crate) fn require<T: Send + Sync + 'static>(map: &Map, name: &str) -> Arc<T> {
try_get::<T>(map, name)
.inspect_err(inspect_log)
.expect("Failure to reference service required by another service.")
}
pub(crate) fn try_get<T: Any + Send + Sync>(map: &Map, name: &str) -> Result<Arc<T>> {
map.get(name).map_or_else(
|| Err!("Service {name:?} does not exist or has not been built yet."),
|(_, s)| {
/// Reference a Service by name. Returns Err if the Service does not exist or
/// was incorrectly cast.
pub(crate) fn try_get<T: Send + Sync + 'static>(map: &Map, name: &str) -> Result<Arc<T>> {
map.read()
.expect("locked for reading")
.get(name)
.map_or_else(
|| Err!("Service {name:?} does not exist or has not been built yet."),
|(_, s)| {
s.clone()
.downcast::<T>()
.map_err(|_| err!("Service {name:?} must be correctly downcast."))
},
)
}
/// Reference a Service by name. Returns None if the Service does not exist, but
/// panics if incorrectly cast.
///
/// # Panics
/// Incorrect type is not a silent failure (None) as the type never has a reason
/// to be incorrect.
pub(crate) fn get<T: Send + Sync + 'static>(map: &Map, name: &str) -> Option<Arc<T>> {
map.read()
.expect("locked for reading")
.get(name)
.map(|(_, s)| {
s.clone()
.downcast::<T>()
.map_err(|_| err!("Service {name:?} must be correctly downcast."))
},
)
}
pub(crate) fn get<T: Any + Send + Sync>(map: &Map, name: &str) -> Option<Arc<T>> {
map.get(name).map(|(_, s)| {
s.clone()
.downcast::<T>()
.expect("Service must be correctly downcast.")
})
.expect("Service must be correctly downcast.")
})
}
/// Utility for service implementations; see Service::name() in the trait.
#[inline]
pub(crate) fn make_name(module_path: &str) -> &str { split_once_infallible(module_path, "::").1 }

View file

@ -1,11 +1,16 @@
use std::{any::Any, collections::BTreeMap, fmt::Write, sync::Arc};
use std::{
any::Any,
collections::BTreeMap,
fmt::Write,
sync::{Arc, RwLock},
};
use conduit::{debug, debug_info, info, trace, Result, Server};
use database::Database;
use tokio::sync::Mutex;
use crate::{
account_data, admin, appservice, client, globals, key_backups,
account_data, admin, appservice, client, emergency, globals, key_backups,
manager::Manager,
media, presence, pusher, resolver, rooms, sending, service,
service::{Args, Map, Service},
@ -13,22 +18,23 @@ use crate::{
};
pub struct Services {
pub resolver: Arc<resolver::Service>,
pub client: Arc<client::Service>,
pub globals: Arc<globals::Service>,
pub rooms: rooms::Service,
pub appservice: Arc<appservice::Service>,
pub pusher: Arc<pusher::Service>,
pub transaction_ids: Arc<transaction_ids::Service>,
pub uiaa: Arc<uiaa::Service>,
pub users: Arc<users::Service>,
pub account_data: Arc<account_data::Service>,
pub presence: Arc<presence::Service>,
pub admin: Arc<admin::Service>,
pub appservice: Arc<appservice::Service>,
pub client: Arc<client::Service>,
pub emergency: Arc<emergency::Service>,
pub globals: Arc<globals::Service>,
pub key_backups: Arc<key_backups::Service>,
pub media: Arc<media::Service>,
pub presence: Arc<presence::Service>,
pub pusher: Arc<pusher::Service>,
pub resolver: Arc<resolver::Service>,
pub rooms: rooms::Service,
pub sending: Arc<sending::Service>,
pub transaction_ids: Arc<transaction_ids::Service>,
pub uiaa: Arc<uiaa::Service>,
pub updates: Arc<updates::Service>,
pub users: Arc<users::Service>,
manager: Mutex<Option<Arc<Manager>>>,
pub(crate) service: Arc<Map>,
@ -36,37 +42,34 @@ pub struct Services {
pub db: Arc<Database>,
}
macro_rules! build_service {
($map:ident, $server:ident, $db:ident, $tyname:ty) => {{
let built = <$tyname>::build(Args {
server: &$server,
db: &$db,
service: &$map,
})?;
Arc::get_mut(&mut $map)
.expect("must have mutable reference to services collection")
.insert(built.name().to_owned(), (built.clone(), built.clone()));
trace!("built service #{}: {:?}", $map.len(), built.name());
built
}};
}
impl Services {
#[allow(clippy::cognitive_complexity)]
pub fn build(server: Arc<Server>, db: Arc<Database>) -> Result<Self> {
let mut service: Arc<Map> = Arc::new(BTreeMap::new());
let service: Arc<Map> = Arc::new(RwLock::new(BTreeMap::new()));
macro_rules! build {
($srv:ty) => {
build_service!(service, server, db, $srv)
};
($tyname:ty) => {{
let built = <$tyname>::build(Args {
db: &db,
server: &server,
service: &service,
})?;
add_service(&service, built.clone(), built.clone());
built
}};
}
Ok(Self {
globals: build!(globals::Service),
account_data: build!(account_data::Service),
admin: build!(admin::Service),
appservice: build!(appservice::Service),
resolver: build!(resolver::Service),
client: build!(client::Service),
emergency: build!(emergency::Service),
globals: build!(globals::Service),
key_backups: build!(key_backups::Service),
media: build!(media::Service),
presence: build!(presence::Service),
pusher: build!(pusher::Service),
rooms: rooms::Service {
alias: build!(rooms::alias::Service),
auth_chain: build!(rooms::auth_chain::Service),
@ -79,28 +82,22 @@ impl Services {
read_receipt: build!(rooms::read_receipt::Service),
search: build!(rooms::search::Service),
short: build!(rooms::short::Service),
spaces: build!(rooms::spaces::Service),
state: build!(rooms::state::Service),
state_accessor: build!(rooms::state_accessor::Service),
state_cache: build!(rooms::state_cache::Service),
state_compressor: build!(rooms::state_compressor::Service),
timeline: build!(rooms::timeline::Service),
threads: build!(rooms::threads::Service),
timeline: build!(rooms::timeline::Service),
typing: build!(rooms::typing::Service),
spaces: build!(rooms::spaces::Service),
user: build!(rooms::user::Service),
},
appservice: build!(appservice::Service),
pusher: build!(pusher::Service),
sending: build!(sending::Service),
transaction_ids: build!(transaction_ids::Service),
uiaa: build!(uiaa::Service),
users: build!(users::Service),
account_data: build!(account_data::Service),
presence: build!(presence::Service),
admin: build!(admin::Service),
key_backups: build!(key_backups::Service),
media: build!(media::Service),
sending: build!(sending::Service),
updates: build!(updates::Service),
users: build!(users::Service),
manager: Mutex::new(None),
service,
server,
@ -111,7 +108,7 @@ impl Services {
pub(super) async fn start(&self) -> Result<()> {
debug_info!("Starting services...");
globals::migrations::migrations(&self.db, &self.server.config).await?;
globals::migrations::migrations(self).await?;
self.manager
.lock()
.await
@ -144,7 +141,7 @@ impl Services {
}
pub async fn clear_cache(&self) {
for (service, ..) in self.service.values() {
for (service, ..) in self.service.read().expect("locked for reading").values() {
service.clear_cache();
}
@ -159,7 +156,7 @@ impl Services {
pub async fn memory_usage(&self) -> Result<String> {
let mut out = String::new();
for (service, ..) in self.service.values() {
for (service, ..) in self.service.read().expect("locked for reading").values() {
service.memory_usage(&mut out)?;
}
@ -179,23 +176,26 @@ impl Services {
fn interrupt(&self) {
debug!("Interrupting services...");
for (name, (service, ..)) in self.service.iter() {
for (name, (service, ..)) in self.service.read().expect("locked for reading").iter() {
trace!("Interrupting {name}");
service.interrupt();
}
}
pub fn try_get<T>(&self, name: &str) -> Result<Arc<T>>
where
T: Any + Send + Sync,
{
pub fn try_get<T: Send + Sync + 'static>(&self, name: &str) -> Result<Arc<T>> {
service::try_get::<T>(&self.service, name)
}
pub fn get<T>(&self, name: &str) -> Option<Arc<T>>
where
T: Any + Send + Sync,
{
service::get::<T>(&self.service, name)
}
pub fn get<T: Send + Sync + 'static>(&self, name: &str) -> Option<Arc<T>> { service::get::<T>(&self.service, name) }
}
fn add_service(map: &Arc<Map>, s: Arc<dyn Service>, a: Arc<dyn Any + Send + Sync>) {
let name = s.name();
let len = map.read().expect("locked for reading").len();
trace!("built service #{len}: {name:?}");
map.write()
.expect("locked for writing")
.insert(name.to_owned(), (s, a));
}

View file

@ -2,7 +2,7 @@ mod data;
use std::sync::Arc;
use conduit::{utils, utils::hash, Error, Result};
use conduit::{error, utils, utils::hash, Error, Result, Server};
use data::Data;
use ruma::{
api::client::{
@ -11,19 +11,30 @@ use ruma::{
},
CanonicalJsonValue, DeviceId, UserId,
};
use tracing::error;
use crate::services;
use crate::{globals, users, Dep};
pub const SESSION_ID_LENGTH: usize = 32;
pub struct Service {
server: Arc<Server>,
services: Services,
pub db: Data,
}
struct Services {
globals: Dep<globals::Service>,
users: Dep<users::Service>,
}
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
server: args.server.clone(),
services: Services {
globals: args.depend::<globals::Service>("globals"),
users: args.depend::<users::Service>("users"),
},
db: Data::new(args.db),
}))
}
@ -87,11 +98,11 @@ impl Service {
return Err(Error::BadRequest(ErrorKind::Unrecognized, "Identifier type not recognized."));
};
let user_id = UserId::parse_with_server_name(username.clone(), services().globals.server_name())
let user_id = UserId::parse_with_server_name(username.clone(), self.services.globals.server_name())
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid."))?;
// Check if password is correct
if let Some(hash) = services().users.password_hash(&user_id)? {
if let Some(hash) = self.services.users.password_hash(&user_id)? {
let hash_matches = hash::verify_password(password, &hash).is_ok();
if !hash_matches {
uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody {
@ -106,7 +117,7 @@ impl Service {
uiaainfo.completed.push(AuthType::Password);
},
AuthData::RegistrationToken(t) => {
if Some(t.token.trim()) == services().globals.config.registration_token.as_deref() {
if Some(t.token.trim()) == self.server.config.registration_token.as_deref() {
uiaainfo.completed.push(AuthType::RegistrationToken);
} else {
uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody {

View file

@ -7,14 +7,20 @@ use ruma::events::room::message::RoomMessageEventContent;
use serde::Deserialize;
use tokio::{sync::Notify, time::interval};
use crate::services;
use crate::{admin, client, Dep};
pub struct Service {
services: Services,
db: Arc<Map>,
interrupt: Notify,
interval: Duration,
}
struct Services {
admin: Dep<admin::Service>,
client: Dep<client::Service>,
}
#[derive(Deserialize)]
struct CheckForUpdatesResponse {
updates: Vec<CheckForUpdatesResponseEntry>,
@ -35,6 +41,10 @@ const LAST_CHECK_FOR_UPDATES_COUNT: &[u8] = b"u";
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
services: Services {
admin: args.depend::<admin::Service>("admin"),
client: args.depend::<client::Service>("client"),
},
db: args.db["global"].clone(),
interrupt: Notify::new(),
interval: Duration::from_secs(CHECK_FOR_UPDATES_INTERVAL),
@ -63,7 +73,8 @@ impl crate::Service for Service {
impl Service {
#[tracing::instrument(skip_all)]
async fn handle_updates(&self) -> Result<()> {
let response = services()
let response = self
.services
.client
.default
.get(CHECK_FOR_UPDATES_URL)
@ -78,7 +89,7 @@ impl Service {
last_update_id = last_update_id.max(update.id);
if update.id > self.last_check_for_updates_id()? {
info!("{:#}", update.message);
services()
self.services
.admin
.send_message(RoomMessageEventContent::text_markdown(format!(
"### the following is a message from the conduwuit puppy\n\nit was sent on `{}`:\n\n@room: {}",

View file

@ -1,7 +1,7 @@
use std::{collections::BTreeMap, mem::size_of, sync::Arc};
use conduit::{debug_info, err, utils, warn, Err, Error, Result};
use database::{Database, Map};
use conduit::{debug_info, err, utils, warn, Err, Error, Result, Server};
use database::Map;
use ruma::{
api::client::{device::Device, error::ErrorKind, filter::FilterDefinition},
encryption::{CrossSigningKey, DeviceKeys, OneTimeKey},
@ -11,52 +11,65 @@ use ruma::{
OwnedMxcUri, OwnedUserId, UInt, UserId,
};
use crate::{services, users::clean_signatures};
use crate::{globals, rooms, users::clean_signatures, Dep};
pub struct Data {
userid_password: Arc<Map>,
keychangeid_userid: Arc<Map>,
keyid_key: Arc<Map>,
onetimekeyid_onetimekeys: Arc<Map>,
openidtoken_expiresatuserid: Arc<Map>,
todeviceid_events: Arc<Map>,
token_userdeviceid: Arc<Map>,
userid_displayname: Arc<Map>,
userdeviceid_metadata: Arc<Map>,
userdeviceid_token: Arc<Map>,
userfilterid_filter: Arc<Map>,
userid_avatarurl: Arc<Map>,
userid_blurhash: Arc<Map>,
userid_devicelistversion: Arc<Map>,
userdeviceid_token: Arc<Map>,
userdeviceid_metadata: Arc<Map>,
onetimekeyid_onetimekeys: Arc<Map>,
userid_displayname: Arc<Map>,
userid_lastonetimekeyupdate: Arc<Map>,
keyid_key: Arc<Map>,
userid_masterkeyid: Arc<Map>,
userid_password: Arc<Map>,
userid_selfsigningkeyid: Arc<Map>,
userid_usersigningkeyid: Arc<Map>,
openidtoken_expiresatuserid: Arc<Map>,
keychangeid_userid: Arc<Map>,
todeviceid_events: Arc<Map>,
userfilterid_filter: Arc<Map>,
_db: Arc<Database>,
services: Services,
}
struct Services {
server: Arc<Server>,
globals: Dep<globals::Service>,
state_cache: Dep<rooms::state_cache::Service>,
state_accessor: Dep<rooms::state_accessor::Service>,
}
impl Data {
pub(super) fn new(db: Arc<Database>) -> Self {
pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self {
userid_password: db["userid_password"].clone(),
keychangeid_userid: db["keychangeid_userid"].clone(),
keyid_key: db["keyid_key"].clone(),
onetimekeyid_onetimekeys: db["onetimekeyid_onetimekeys"].clone(),
openidtoken_expiresatuserid: db["openidtoken_expiresatuserid"].clone(),
todeviceid_events: db["todeviceid_events"].clone(),
token_userdeviceid: db["token_userdeviceid"].clone(),
userid_displayname: db["userid_displayname"].clone(),
userdeviceid_metadata: db["userdeviceid_metadata"].clone(),
userdeviceid_token: db["userdeviceid_token"].clone(),
userfilterid_filter: db["userfilterid_filter"].clone(),
userid_avatarurl: db["userid_avatarurl"].clone(),
userid_blurhash: db["userid_blurhash"].clone(),
userid_devicelistversion: db["userid_devicelistversion"].clone(),
userdeviceid_token: db["userdeviceid_token"].clone(),
userdeviceid_metadata: db["userdeviceid_metadata"].clone(),
onetimekeyid_onetimekeys: db["onetimekeyid_onetimekeys"].clone(),
userid_displayname: db["userid_displayname"].clone(),
userid_lastonetimekeyupdate: db["userid_lastonetimekeyupdate"].clone(),
keyid_key: db["keyid_key"].clone(),
userid_masterkeyid: db["userid_masterkeyid"].clone(),
userid_password: db["userid_password"].clone(),
userid_selfsigningkeyid: db["userid_selfsigningkeyid"].clone(),
userid_usersigningkeyid: db["userid_usersigningkeyid"].clone(),
openidtoken_expiresatuserid: db["openidtoken_expiresatuserid"].clone(),
keychangeid_userid: db["keychangeid_userid"].clone(),
todeviceid_events: db["todeviceid_events"].clone(),
userfilterid_filter: db["userfilterid_filter"].clone(),
_db: db,
services: Services {
server: args.server.clone(),
globals: args.depend::<globals::Service>("globals"),
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
},
}
}
@ -377,7 +390,7 @@ impl Data {
)?;
self.userid_lastonetimekeyupdate
.insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?;
.insert(user_id.as_bytes(), &self.services.globals.next_count()?.to_be_bytes())?;
Ok(())
}
@ -403,7 +416,7 @@ impl Data {
prefix.push(b':');
self.userid_lastonetimekeyupdate
.insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?;
.insert(user_id.as_bytes(), &self.services.globals.next_count()?.to_be_bytes())?;
self.onetimekeyid_onetimekeys
.scan_prefix(prefix)
@ -631,16 +644,16 @@ impl Data {
}
pub(super) fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> {
let count = services().globals.next_count()?.to_be_bytes();
for room_id in services()
.rooms
let count = self.services.globals.next_count()?.to_be_bytes();
for room_id in self
.services
.state_cache
.rooms_joined(user_id)
.filter_map(Result::ok)
{
// Don't send key updates to unencrypted rooms
if services()
.rooms
if self
.services
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomEncryption, "")?
.is_none()
@ -750,7 +763,7 @@ impl Data {
key.push(0xFF);
key.extend_from_slice(target_device_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
key.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes());
let mut json = serde_json::Map::new();
json.insert("type".to_owned(), event_type.to_owned().into());
@ -916,7 +929,7 @@ impl Data {
pub(super) fn create_openid_token(&self, user_id: &UserId, token: &str) -> Result<u64> {
use std::num::Saturating as Sat;
let expires_in = services().globals.config.openid_token_ttl;
let expires_in = self.services.server.config.openid_token_ttl;
let expires_at = Sat(utils::millis_since_unix_epoch()) + Sat(expires_in) * Sat(1000);
let mut value = expires_at.0.to_be_bytes().to_vec();

View file

@ -7,7 +7,6 @@ use std::{
};
use conduit::{Error, Result};
use data::Data;
use ruma::{
api::client::{
device::Device,
@ -24,7 +23,8 @@ use ruma::{
UInt, UserId,
};
use crate::services;
use self::data::Data;
use crate::{admin, rooms, Dep};
pub struct SlidingSyncCache {
lists: BTreeMap<String, SyncRequestList>,
@ -36,14 +36,24 @@ pub struct SlidingSyncCache {
type DbConnections = Mutex<BTreeMap<(OwnedUserId, OwnedDeviceId, String), Arc<Mutex<SlidingSyncCache>>>>;
pub struct Service {
services: Services,
pub db: Data,
pub connections: DbConnections,
}
struct Services {
admin: Dep<admin::Service>,
state_cache: Dep<rooms::state_cache::Service>,
}
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db: Data::new(args.db.clone()),
services: Services {
admin: args.depend::<admin::Service>("admin"),
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
},
db: Data::new(&args),
connections: StdMutex::new(BTreeMap::new()),
}))
}
@ -247,11 +257,8 @@ impl Service {
/// Check if a user is an admin
pub fn is_admin(&self, user_id: &UserId) -> Result<bool> {
if let Some(admin_room_id) = services().admin.get_admin_room()? {
services()
.rooms
.state_cache
.is_joined(user_id, &admin_room_id)
if let Some(admin_room_id) = self.services.admin.get_admin_room()? {
self.services.state_cache.is_joined(user_id, &admin_room_id)
} else {
Ok(false)
}